Browse Source

chore(mge/imperative): fix Graph.make_const

GitOrigin-RevId: 0f4c62aebf
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
e6e29748b9
2 changed files with 11 additions and 2 deletions
  1. +9
    -0
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +2
    -2
      imperative/python/src/graph_rt.cpp

+ 9
- 0
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -11,6 +11,8 @@ import threading
import weakref
from concurrent.futures import Future, ThreadPoolExecutor

import numpy as np

from .. import _imperative_rt
from .._imperative_rt.ops import BackwardGraph
from .._wrap import device as as_device
@@ -32,6 +34,8 @@ class Graph(_imperative_rt.ComputingGraph):
wrapper, cache = VarNode, self._var_cache
elif type(obj) is _imperative_rt.OperatorNode:
wrapper, cache = OpNode, self._op_cache
else:
raise TypeError(type(obj))
if obj not in cache:
cache[obj] = wrapper(obj)
return cache[obj]
@@ -62,6 +66,11 @@ class Graph(_imperative_rt.ComputingGraph):
assert dtype is None and device is None
return self._wrap(_imperative_rt.make_shared(self, data))
else:
data = np.asarray(data, dtype=dtype)
if data.dtype == np.float64:
data = data.astype(np.float32)
elif data.dtype == np.int64:
data = data.astype(np.int32)
device = as_device(device).to_c()
return self._wrap(_imperative_rt.make_const(self, data, device, dtype))



+ 2
- 2
imperative/python/src/graph_rt.cpp View File

@@ -181,10 +181,10 @@ void init_graph_rt(py::module m) {

m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) {
if (!cn.valid()) {
throw py::type_error("device must not be None");
cn = CompNode::load("xpux");
}
auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype);
opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node();
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node();
});

m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback,


Loading…
Cancel
Save