From c3d5b61f1c7a37ec14a366a099335cc169625800 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 30 Apr 2020 09:48:27 +0000 Subject: [PATCH] feat(mgb/opr): add property params for python operator class GitOrigin-RevId: af6da0e0ac18fa8d81c2415dfe982ae6ce71451a --- python_module/src/cpp/megbrain_wrap.cpp | 5 +++++ python_module/src/cpp/megbrain_wrap.h | 12 +++++++++++- python_module/src/swig/operator.i | 4 ++++ python_module/src/swig/operator.py | 5 +++++ 4 files changed, 25 insertions(+), 1 deletion(-) diff --git a/python_module/src/cpp/megbrain_wrap.cpp b/python_module/src/cpp/megbrain_wrap.cpp index 4d4ab7a1..efcea280 100644 --- a/python_module/src/cpp/megbrain_wrap.cpp +++ b/python_module/src/cpp/megbrain_wrap.cpp @@ -880,6 +880,11 @@ SymbolVar SharedScalar::_as_sym_var(CompGraph &cg, mgb::CompNode &cn) { ssprintf("SharedScalar@%p", m_val.get())); } +/* =============== Operator =============== */ + +const std::unique_ptr Operator::sm_opr_footprint_ptr{ + std::make_unique()}; + /* ================= misc ================= */ SymbolVar fill_retain_dtype(SymbolVar var, PyObject *value) { diff --git a/python_module/src/cpp/megbrain_wrap.h b/python_module/src/cpp/megbrain_wrap.h index 37c773e3..e1484e4e 100644 --- a/python_module/src/cpp/megbrain_wrap.h +++ b/python_module/src/cpp/megbrain_wrap.h @@ -17,6 +17,8 @@ #include "megbrain/graph.h" #include "megbrain/opr/io.h" +#include "megbrain/plugin/opr_footprint.h" + #include #include @@ -441,16 +443,24 @@ class SharedScalar { */ class Operator { mgb::cg::OperatorNodeBase* m_operator_node; + std::string m_params; + + static const std::unique_ptr sm_opr_footprint_ptr; public: Operator() : m_operator_node(nullptr){}; Operator(mgb::cg::OperatorNodeBase* operator_node) - : m_operator_node(operator_node) {} + : m_operator_node(operator_node), + m_params(std::move( + (sm_opr_footprint_ptr->calc_footprint(m_operator_node)).param->to_string())) + {} size_t id() const { return m_operator_node->id(); } const std::string& name() const { return m_operator_node->name(); } + const std::string& params() const { return m_params; } + const std::shared_ptr get_owner_graph() const { return m_operator_node->owner_graph()->shared_from_this(); } diff --git a/python_module/src/swig/operator.i b/python_module/src/swig/operator.i index 3248caa3..e7bd0911 100644 --- a/python_module/src/swig/operator.i +++ b/python_module/src/swig/operator.i @@ -58,6 +58,10 @@ public: return $self->name(); } + const std::string& _get_params() const { + return $self->params(); + } + SymbolVarArray _get_inputs() { return $self->inputs(); } diff --git a/python_module/src/swig/operator.py b/python_module/src/swig/operator.py index 1de6aff8..a1eb095e 100644 --- a/python_module/src/swig/operator.py +++ b/python_module/src/swig/operator.py @@ -20,6 +20,11 @@ def name(self): return self._get_name() @property +def params(self): + import json + return json.loads(self._get_params()) + +@property def inputs(self): return tuple(self._get_inputs())