Browse Source

fix(tensor): check args when construct tensor with existing tensor

GitOrigin-RevId: 0345454070
release-1.10
Megvii Engine Team 3 years ago
parent
commit
2484cd2741
6 changed files with 56 additions and 4 deletions
  1. +1
    -1
      imperative/python/megengine/optimizer/optimizer.py
  2. +29
    -2
      imperative/python/src/tensor.cpp
  3. +1
    -1
      imperative/python/src/tensor.h
  4. +15
    -0
      imperative/python/test/unit/core/test_tensor_wrapper.py
  5. +5
    -0
      imperative/src/impl/transformations/eval.cpp
  6. +5
    -0
      imperative/src/include/megbrain/imperative/basic_operators.h

+ 1
- 1
imperative/python/megengine/optimizer/optimizer.py View File

@@ -97,7 +97,7 @@ class Optimizer(metaclass=ABCMeta):
"optimizer can only optimize Parameters, but one of the params is "
+ str(type(param))
)
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format))
param[...] = Tensor(param.numpy(), no_cache=True)

for name, default in self._defaults.items():
if default is required and name not in param_group:


+ 29
- 2
imperative/python/src/tensor.cpp View File

@@ -525,7 +525,34 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
}
mgb_assert(tup.size() == 7);
if (auto* t = try_cast(tup[0].ptr())) {
m_tensor = t->m_tensor->copy();
m_tensor = t->m_tensor;
// TODO: merge two path in arg parse
if (!tup[1].is_none()) {
auto dtype = tup[1].cast<DType>();
mgb_assert(
dtype == m_tensor->dtype(), "dtype mismatch: %s vs %s",
dtype.name(), m_tensor->dtype().name());
}
if (!tup[2].is_none()) {
auto device = as_comp_node(tup[2]);
mgb_assert(
device == m_tensor->comp_node(), "device mismatch: %s vs %s",
device.to_string().c_str(),
m_tensor->comp_node().to_string().c_str());
}
mgb_assert(!tup[3].cast<bool>(), "expect is_const == False, got True");
bool no_cache = tup[4].cast<bool>();
if (no_cache) {
// always copy because it's hard to tell whether this tensor is cached
m_tensor = m_tensor->copy();
}
// ignore name
if (!tup[6].is_none()) {
Format format = tup[6].cast<std::string>();
mgb_assert(
format == m_tensor->format(), "format mismatch: %s vs %s",
format.to_string().c_str(), m_tensor->format().to_string().c_str());
}
} else {
auto data = tup[0];
DType dtype = tup[1].cast<DType>();
@@ -1030,7 +1057,7 @@ void init_tensor(py::module m) {
try {
self.compiled->compile();
} catch (const std::exception& e) {
mgb_log_error(e.what());
mgb_log_error("error in trace: %s", e.what());
}
}
// register transformations


+ 1
- 1
imperative/python/src/tensor.h View File

@@ -47,7 +47,7 @@ public:

~Tensor() = default;

inline Tensor copy() { return *this; }
inline Tensor copy() { return Tensor(imperative::apply(DupTensor(), data())[0]); }

inline DType dtype() { return *data().dtype(); }
inline CompNode comp_node() { return *data().device(); }


+ 15
- 0
imperative/python/test/unit/core/test_tensor_wrapper.py View File

@@ -5,7 +5,9 @@ import numpy as np
import pytest
from utils import get_var_value, make_tensor

from megengine import _full_sync
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.device import get_default_device
from megengine.tensor import Parameter, Tensor
from megengine.utils.network import Network

@@ -220,3 +222,16 @@ def test_tensor_from_bool():
assert x.dtype == np.bool_
x = Tensor([True, False])
assert x.dtype == np.bool_


def test_tensor_construct_tensor():
x = Tensor(0, dtype=np.float32, device="xpu0:1", name="MyName")
assert Tensor(x.astype(np.int32)).dtype == np.int32
with pytest.raises(RuntimeError):
Tensor(x.astype(np.int32), dtype=np.float32)
assert Tensor(x).name == ""
assert Tensor(x, name="MyName2").name == "MyName2"
with pytest.raises(RuntimeError):
assert Tensor(x.to("xpu0:2"), device="xpu0:1").device == "xpu0:1"
assert Tensor(x.to("xpu0:2")).device == x.to("xpu0:2").device
_full_sync()

+ 5
- 0
imperative/src/impl/transformations/eval.cpp View File

@@ -126,6 +126,11 @@ ValueRefList InterpreterTransformation::apply_transformation(
} else {
return {ValueRef()};
}
} else if (op.is<DupTensor>()) {
auto& input = inputs[0].cast(m_value_type);
DeviceTensorND dev_tensor;
dev_tensor.copy_from(m_channel->get_dev_tensor(input.handle()->handle()));
return m_value_type.make(share_handle(m_channel->put(dev_tensor, {})));
} else {
return op.fallback(inputs);
}


+ 5
- 0
imperative/src/include/megbrain/imperative/basic_operators.h View File

@@ -196,5 +196,10 @@ public:
std::string to_string() const override;
};

class DupTensor final : public OperatorImpl<DupTensor, Operator::IdentityLike> {
public:
std::string to_string() const override { return "DupTensor"; }
};

} // namespace imperative
} // namespace mgb

Loading…
Cancel
Save