diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 818d119e..cc4a4f4e 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -11,6 +11,8 @@ import threading import weakref from concurrent.futures import Future, ThreadPoolExecutor +import numpy as np + from .. import _imperative_rt from .._imperative_rt.ops import BackwardGraph from .._wrap import device as as_device @@ -32,6 +34,8 @@ class Graph(_imperative_rt.ComputingGraph): wrapper, cache = VarNode, self._var_cache elif type(obj) is _imperative_rt.OperatorNode: wrapper, cache = OpNode, self._op_cache + else: + raise TypeError(type(obj)) if obj not in cache: cache[obj] = wrapper(obj) return cache[obj] @@ -62,6 +66,11 @@ class Graph(_imperative_rt.ComputingGraph): assert dtype is None and device is None return self._wrap(_imperative_rt.make_shared(self, data)) else: + data = np.asarray(data, dtype=dtype) + if data.dtype == np.float64: + data = data.astype(np.float32) + elif data.dtype == np.int64: + data = data.astype(np.int32) device = as_device(device).to_c() return self._wrap(_imperative_rt.make_const(self, data, device, dtype)) diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 38b2b3b5..67de3508 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -181,10 +181,10 @@ void init_graph_rt(py::module m) { m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) { if (!cn.valid()) { - throw py::type_error("device must not be None"); + cn = CompNode::load("xpux"); } auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); - opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); + return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); }); m.def("input_callback", [input_callback](std::function callback,