Browse Source

chore(mge): add VarNode.value

GitOrigin-RevId: 1dc0d0c711
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
ac3408bfaa
3 changed files with 18 additions and 1 deletions
  1. +4
    -0
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +1
    -1
      imperative/python/megengine/jit/tracing.py
  3. +13
    -0
      imperative/python/src/graph_rt.cpp

+ 4
- 0
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -95,6 +95,10 @@ class VarNode(TensorBase):
def shape(self):
return self._node.shape

@property
def value(self):
return self._node.value


class OpNode:
def __init__(self, node: _imperative_rt.OperatorNode):


+ 1
- 1
imperative/python/megengine/jit/tracing.py View File

@@ -399,7 +399,7 @@ class LazyEvalTensor(RawTensor):
return self.__varnode.shape

def numpy(self):
raise RuntimeError("cannot read value during symbolic tracing")
return self.__varnode.value

def _dev_tensor(self):
raise RuntimeError("cannot access data during symbolic tracing")


+ 13
- 0
imperative/python/src/graph_rt.cpp View File

@@ -58,6 +58,19 @@ void init_graph_rt(py::module m) {
return nullptr;
}
return mgr.infer_shape_fallible(v);
})
.def_property_readonly("value", [](cg::VarNode* v) -> py::object {
auto&& mgr = v->owner_graph()->static_infer_manager();
auto&& type = mgr.get_infer_type(v);
using InferType = cg::static_infer::InferType;
if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
return py::none();
}
auto* val = mgr.infer_value_fallible(v);
if (!val) {
return py::none();
}
return py::cast(*val).attr("numpy")();
});

py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode")


Loading…
Cancel
Save