GitOrigin-RevId: 0345454070
release-1.10
@@ -97,7 +97,7 @@ class Optimizer(metaclass=ABCMeta): | |||||
"optimizer can only optimize Parameters, but one of the params is " | "optimizer can only optimize Parameters, but one of the params is " | ||||
+ str(type(param)) | + str(type(param)) | ||||
) | ) | ||||
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||||
param[...] = Tensor(param.numpy(), no_cache=True) | |||||
for name, default in self._defaults.items(): | for name, default in self._defaults.items(): | ||||
if default is required and name not in param_group: | if default is required and name not in param_group: | ||||
@@ -525,7 +525,34 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
} | } | ||||
mgb_assert(tup.size() == 7); | mgb_assert(tup.size() == 7); | ||||
if (auto* t = try_cast(tup[0].ptr())) { | if (auto* t = try_cast(tup[0].ptr())) { | ||||
m_tensor = t->m_tensor->copy(); | |||||
m_tensor = t->m_tensor; | |||||
// TODO: merge two path in arg parse | |||||
if (!tup[1].is_none()) { | |||||
auto dtype = tup[1].cast<DType>(); | |||||
mgb_assert( | |||||
dtype == m_tensor->dtype(), "dtype mismatch: %s vs %s", | |||||
dtype.name(), m_tensor->dtype().name()); | |||||
} | |||||
if (!tup[2].is_none()) { | |||||
auto device = as_comp_node(tup[2]); | |||||
mgb_assert( | |||||
device == m_tensor->comp_node(), "device mismatch: %s vs %s", | |||||
device.to_string().c_str(), | |||||
m_tensor->comp_node().to_string().c_str()); | |||||
} | |||||
mgb_assert(!tup[3].cast<bool>(), "expect is_const == False, got True"); | |||||
bool no_cache = tup[4].cast<bool>(); | |||||
if (no_cache) { | |||||
// always copy because it's hard to tell whether this tensor is cached | |||||
m_tensor = m_tensor->copy(); | |||||
} | |||||
// ignore name | |||||
if (!tup[6].is_none()) { | |||||
Format format = tup[6].cast<std::string>(); | |||||
mgb_assert( | |||||
format == m_tensor->format(), "format mismatch: %s vs %s", | |||||
format.to_string().c_str(), m_tensor->format().to_string().c_str()); | |||||
} | |||||
} else { | } else { | ||||
auto data = tup[0]; | auto data = tup[0]; | ||||
DType dtype = tup[1].cast<DType>(); | DType dtype = tup[1].cast<DType>(); | ||||
@@ -1030,7 +1057,7 @@ void init_tensor(py::module m) { | |||||
try { | try { | ||||
self.compiled->compile(); | self.compiled->compile(); | ||||
} catch (const std::exception& e) { | } catch (const std::exception& e) { | ||||
mgb_log_error(e.what()); | |||||
mgb_log_error("error in trace: %s", e.what()); | |||||
} | } | ||||
} | } | ||||
// register transformations | // register transformations | ||||
@@ -47,7 +47,7 @@ public: | |||||
~Tensor() = default; | ~Tensor() = default; | ||||
inline Tensor copy() { return *this; } | |||||
inline Tensor copy() { return Tensor(imperative::apply(DupTensor(), data())[0]); } | |||||
inline DType dtype() { return *data().dtype(); } | inline DType dtype() { return *data().dtype(); } | ||||
inline CompNode comp_node() { return *data().device(); } | inline CompNode comp_node() { return *data().device(); } | ||||
@@ -5,7 +5,9 @@ import numpy as np | |||||
import pytest | import pytest | ||||
from utils import get_var_value, make_tensor | from utils import get_var_value, make_tensor | ||||
from megengine import _full_sync | |||||
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | ||||
from megengine.device import get_default_device | |||||
from megengine.tensor import Parameter, Tensor | from megengine.tensor import Parameter, Tensor | ||||
from megengine.utils.network import Network | from megengine.utils.network import Network | ||||
@@ -220,3 +222,16 @@ def test_tensor_from_bool(): | |||||
assert x.dtype == np.bool_ | assert x.dtype == np.bool_ | ||||
x = Tensor([True, False]) | x = Tensor([True, False]) | ||||
assert x.dtype == np.bool_ | assert x.dtype == np.bool_ | ||||
def test_tensor_construct_tensor(): | |||||
x = Tensor(0, dtype=np.float32, device="xpu0:1", name="MyName") | |||||
assert Tensor(x.astype(np.int32)).dtype == np.int32 | |||||
with pytest.raises(RuntimeError): | |||||
Tensor(x.astype(np.int32), dtype=np.float32) | |||||
assert Tensor(x).name == "" | |||||
assert Tensor(x, name="MyName2").name == "MyName2" | |||||
with pytest.raises(RuntimeError): | |||||
assert Tensor(x.to("xpu0:2"), device="xpu0:1").device == "xpu0:1" | |||||
assert Tensor(x.to("xpu0:2")).device == x.to("xpu0:2").device | |||||
_full_sync() |
@@ -126,6 +126,11 @@ ValueRefList InterpreterTransformation::apply_transformation( | |||||
} else { | } else { | ||||
return {ValueRef()}; | return {ValueRef()}; | ||||
} | } | ||||
} else if (op.is<DupTensor>()) { | |||||
auto& input = inputs[0].cast(m_value_type); | |||||
DeviceTensorND dev_tensor; | |||||
dev_tensor.copy_from(m_channel->get_dev_tensor(input.handle()->handle())); | |||||
return m_value_type.make(share_handle(m_channel->put(dev_tensor, {}))); | |||||
} else { | } else { | ||||
return op.fallback(inputs); | return op.fallback(inputs); | ||||
} | } | ||||
@@ -196,5 +196,10 @@ public: | |||||
std::string to_string() const override; | std::string to_string() const override; | ||||
}; | }; | ||||
class DupTensor final : public OperatorImpl<DupTensor, Operator::IdentityLike> { | |||||
public: | |||||
std::string to_string() const override { return "DupTensor"; } | |||||
}; | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb |