Browse Source

feat(mgb/opr): add property params for python operator class

GitOrigin-RevId: af6da0e0ac
tags/v0.5.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
c3d5b61f1c
4 changed files with 25 additions and 1 deletions
  1. +5
    -0
      python_module/src/cpp/megbrain_wrap.cpp
  2. +11
    -1
      python_module/src/cpp/megbrain_wrap.h
  3. +4
    -0
      python_module/src/swig/operator.i
  4. +5
    -0
      python_module/src/swig/operator.py

+ 5
- 0
python_module/src/cpp/megbrain_wrap.cpp View File

@@ -880,6 +880,11 @@ SymbolVar SharedScalar::_as_sym_var(CompGraph &cg, mgb::CompNode &cn) {
ssprintf("SharedScalar@%p", m_val.get())); ssprintf("SharedScalar@%p", m_val.get()));
} }


/* =============== Operator =============== */

const std::unique_ptr<mgb::OprFootprint> Operator::sm_opr_footprint_ptr{
std::make_unique<mgb::OprFootprint>()};

/* ================= misc ================= */ /* ================= misc ================= */


SymbolVar fill_retain_dtype(SymbolVar var, PyObject *value) { SymbolVar fill_retain_dtype(SymbolVar var, PyObject *value) {


+ 11
- 1
python_module/src/cpp/megbrain_wrap.h View File

@@ -17,6 +17,8 @@
#include "megbrain/graph.h" #include "megbrain/graph.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"


#include "megbrain/plugin/opr_footprint.h"

#include <map> #include <map>
#include <string> #include <string>


@@ -441,16 +443,24 @@ class SharedScalar {
*/ */
class Operator { class Operator {
mgb::cg::OperatorNodeBase* m_operator_node; mgb::cg::OperatorNodeBase* m_operator_node;
std::string m_params;

static const std::unique_ptr<mgb::OprFootprint> sm_opr_footprint_ptr;


public: public:
Operator() : m_operator_node(nullptr){}; Operator() : m_operator_node(nullptr){};
Operator(mgb::cg::OperatorNodeBase* operator_node) Operator(mgb::cg::OperatorNodeBase* operator_node)
: m_operator_node(operator_node) {}
: m_operator_node(operator_node),
m_params(std::move(
(sm_opr_footprint_ptr->calc_footprint(m_operator_node)).param->to_string()))
{}


size_t id() const { return m_operator_node->id(); } size_t id() const { return m_operator_node->id(); }


const std::string& name() const { return m_operator_node->name(); } const std::string& name() const { return m_operator_node->name(); }


const std::string& params() const { return m_params; }

const std::shared_ptr<mgb::ComputingGraph> get_owner_graph() const { const std::shared_ptr<mgb::ComputingGraph> get_owner_graph() const {
return m_operator_node->owner_graph()->shared_from_this(); return m_operator_node->owner_graph()->shared_from_this();
} }


+ 4
- 0
python_module/src/swig/operator.i View File

@@ -58,6 +58,10 @@ public:
return $self->name(); return $self->name();
} }


const std::string& _get_params() const {
return $self->params();
}

SymbolVarArray _get_inputs() { SymbolVarArray _get_inputs() {
return $self->inputs(); return $self->inputs();
} }


+ 5
- 0
python_module/src/swig/operator.py View File

@@ -20,6 +20,11 @@ def name(self):
return self._get_name() return self._get_name()


@property @property
def params(self):
import json
return json.loads(self._get_params())

@property
def inputs(self): def inputs(self):
return tuple(self._get_inputs()) return tuple(self._get_inputs())




Loading…
Cancel
Save