diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index 412579da..0a09496b 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -97,7 +97,7 @@ class Optimizer(metaclass=ABCMeta): "optimizer can only optimize Parameters, but one of the params is " + str(type(param)) ) - param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) + param[...] = Tensor(param.numpy(), no_cache=True) for name, default in self._defaults.items(): if default is required and name not in param_group: diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 8f8bf036..e76ca5a3 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -525,7 +525,34 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { } mgb_assert(tup.size() == 7); if (auto* t = try_cast(tup[0].ptr())) { - m_tensor = t->m_tensor->copy(); + m_tensor = t->m_tensor; + // TODO: merge two path in arg parse + if (!tup[1].is_none()) { + auto dtype = tup[1].cast(); + mgb_assert( + dtype == m_tensor->dtype(), "dtype mismatch: %s vs %s", + dtype.name(), m_tensor->dtype().name()); + } + if (!tup[2].is_none()) { + auto device = as_comp_node(tup[2]); + mgb_assert( + device == m_tensor->comp_node(), "device mismatch: %s vs %s", + device.to_string().c_str(), + m_tensor->comp_node().to_string().c_str()); + } + mgb_assert(!tup[3].cast(), "expect is_const == False, got True"); + bool no_cache = tup[4].cast(); + if (no_cache) { + // always copy because it's hard to tell whether this tensor is cached + m_tensor = m_tensor->copy(); + } + // ignore name + if (!tup[6].is_none()) { + Format format = tup[6].cast(); + mgb_assert( + format == m_tensor->format(), "format mismatch: %s vs %s", + format.to_string().c_str(), m_tensor->format().to_string().c_str()); + } } else { auto data = tup[0]; DType dtype = tup[1].cast(); @@ -1030,7 +1057,7 @@ void init_tensor(py::module m) { try { self.compiled->compile(); } catch (const std::exception& e) { - mgb_log_error(e.what()); + mgb_log_error("error in trace: %s", e.what()); } } // register transformations diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 1f849f2b..47783234 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -47,7 +47,7 @@ public: ~Tensor() = default; - inline Tensor copy() { return *this; } + inline Tensor copy() { return Tensor(imperative::apply(DupTensor(), data())[0]); } inline DType dtype() { return *data().dtype(); } inline CompNode comp_node() { return *data().device(); } diff --git a/imperative/python/test/unit/core/test_tensor_wrapper.py b/imperative/python/test/unit/core/test_tensor_wrapper.py index 154f8945..fce911cf 100644 --- a/imperative/python/test/unit/core/test_tensor_wrapper.py +++ b/imperative/python/test/unit/core/test_tensor_wrapper.py @@ -5,7 +5,9 @@ import numpy as np import pytest from utils import get_var_value, make_tensor +from megengine import _full_sync from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 +from megengine.device import get_default_device from megengine.tensor import Parameter, Tensor from megengine.utils.network import Network @@ -220,3 +222,16 @@ def test_tensor_from_bool(): assert x.dtype == np.bool_ x = Tensor([True, False]) assert x.dtype == np.bool_ + + +def test_tensor_construct_tensor(): + x = Tensor(0, dtype=np.float32, device="xpu0:1", name="MyName") + assert Tensor(x.astype(np.int32)).dtype == np.int32 + with pytest.raises(RuntimeError): + Tensor(x.astype(np.int32), dtype=np.float32) + assert Tensor(x).name == "" + assert Tensor(x, name="MyName2").name == "MyName2" + with pytest.raises(RuntimeError): + assert Tensor(x.to("xpu0:2"), device="xpu0:1").device == "xpu0:1" + assert Tensor(x.to("xpu0:2")).device == x.to("xpu0:2").device + _full_sync() diff --git a/imperative/src/impl/transformations/eval.cpp b/imperative/src/impl/transformations/eval.cpp index 314c78e9..ae9604be 100644 --- a/imperative/src/impl/transformations/eval.cpp +++ b/imperative/src/impl/transformations/eval.cpp @@ -126,6 +126,11 @@ ValueRefList InterpreterTransformation::apply_transformation( } else { return {ValueRef()}; } + } else if (op.is()) { + auto& input = inputs[0].cast(m_value_type); + DeviceTensorND dev_tensor; + dev_tensor.copy_from(m_channel->get_dev_tensor(input.handle()->handle())); + return m_value_type.make(share_handle(m_channel->put(dev_tensor, {}))); } else { return op.fallback(inputs); } diff --git a/imperative/src/include/megbrain/imperative/basic_operators.h b/imperative/src/include/megbrain/imperative/basic_operators.h index 5d785857..f12527e0 100644 --- a/imperative/src/include/megbrain/imperative/basic_operators.h +++ b/imperative/src/include/megbrain/imperative/basic_operators.h @@ -196,5 +196,10 @@ public: std::string to_string() const override; }; +class DupTensor final : public OperatorImpl { +public: + std::string to_string() const override { return "DupTensor"; } +}; + } // namespace imperative } // namespace mgb