Browse Source

refactor(mge): reopen passed assertions

GitOrigin-RevId: e0276e73e3
release-1.2
Megvii Engine Team 4 years ago
parent
commit
de0742be25
5 changed files with 38 additions and 31 deletions
  1. +3
    -0
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +2
    -0
      imperative/python/megengine/distributed/functional.py
  3. +25
    -29
      imperative/python/megengine/jit/tracing.py
  4. +8
    -1
      imperative/python/src/tensor.cpp
  5. +0
    -1
      imperative/python/test/unit/autodiff/test_grad_manger.py

+ 3
- 0
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -450,6 +450,9 @@ def _unwrap(x):




def apply_normal_varnode(op: OpDef, *args: VarNode): def apply_normal_varnode(op: OpDef, *args: VarNode):
# for PyOp like RemoteSend/Recv
if getattr(op, "op", None):
op = op.op
outputs = _imperative_rt.invoke_op(op, _unwrap(args)) outputs = _imperative_rt.invoke_op(op, _unwrap(args))
return _wrap(outputs) return _wrap(outputs)




+ 2
- 0
imperative/python/megengine/distributed/functional.py View File

@@ -292,6 +292,8 @@ def remote_recv(
op = RemoteRecv() op = RemoteRecv()
op.key = key op.key = key
op.cn = device op.cn = device
if isinstance(shape, Tensor):
shape = shape.numpy()
op.shape = shape op.shape = shape
op.dtype = dtype op.dtype = dtype
op.addr, op.port = get_mm_server_addr() op.addr, op.port = get_mm_server_addr()


+ 25
- 29
imperative/python/megengine/jit/tracing.py View File

@@ -234,20 +234,21 @@ class trace:
) )
info.data_setter.set_value(x._dev_tensor()) info.data_setter.set_value(x._dev_tensor())
else: else:
pass
# if x.__class__ is not CompiledTensorProxy:
# if x not in self._tensor_remaps:
# raise TraceMismatchError(
# "unexpected capture: trying to use an external tensor as "
# "input, but that input was an internal tensor last time"
# )
# else:
# x = self._tensor_remaps[x]
# if x._CompiledTensorProxy__handle != h:
# raise TraceMismatchError(
# "mis-wiring: input edge to an data flow "
# "graph node is different from last time"
# )
if x.mixin_handle == -1:
if x._handle not in self._tensor_remaps:
raise TraceMismatchError(
"unexpected capture: trying to use an external tensor as "
"input, but that input was an internal tensor last time"
)
else:
x.mixin_handle = self._tensor_remaps[
x._handle
]._CompiledTensorProxy__handle
if x.mixin_handle != h:
raise TraceMismatchError(
"mis-wiring: input edge to an data flow "
"graph node is different from last time"
)


self._pc += 1 self._pc += 1
outputs = [] outputs = []
@@ -268,14 +269,11 @@ class trace:
op_, ihandles, ohandles = record op_, ihandles, ohandles = record
assert isinstance(op_, str) and op_ == "Const" assert isinstance(op_, str) and op_ == "Const"


# TODO : assert on const value
# eq = value == self._tinfo[ohandles[0]].bound_data.numpy()
# if not isinstance(eq, bool):
# eq = all(eq)
# if not eq:
# raise TraceMismatchError(
# "const tensor violated: got a different tensor this time"
# )
eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy())
if not eq:
raise TraceMismatchError(
"const tensor violated: got a different tensor this time"
)


self._pc += 1 self._pc += 1
(h,) = ohandles (h,) = ohandles
@@ -750,7 +748,6 @@ class trace:
dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k
) )


set_tracing()
for op, ihandles, ohandles in self._seq: for op, ihandles, ohandles in self._seq:
if isinstance(op, str) and op == "Const": if isinstance(op, str) and op == "Const":
assert len(ihandles) == 0 assert len(ihandles) == 0
@@ -776,7 +773,6 @@ class trace:
ovars = G.apply_normal_varnode(op, *ivars) ovars = G.apply_normal_varnode(op, *ivars)
assert len(ovars) == len(ohandles) assert len(ovars) == len(ohandles)
h2v.update(zip(ohandles, ovars)) h2v.update(zip(ohandles, ovars))
unset_tracing()


dest_vars = [] dest_vars = []
for i, h in enumerate(self._output_bindings): for i, h in enumerate(self._output_bindings):
@@ -843,7 +839,7 @@ class trace:
if x.device != info.device: if x.device != info.device:
raise TypeError("args[%d].device different from last time" % i) raise TypeError("args[%d].device different from last time" % i)
info.data_setter.set_value(x._dev_tensor()) info.data_setter.set_value(x._dev_tensor())
self._tensor_remaps[x] = CompiledTensorProxy(h)
self._tensor_remaps[x._handle] = CompiledTensorProxy(h)


kwargs_tensors = {} kwargs_tensors = {}
for k, x in kwargs.items(): for k, x in kwargs.items():
@@ -870,7 +866,7 @@ class trace:
if x.device != info.device: if x.device != info.device:
raise TypeError("kwargs[%s].device different from last time" % k) raise TypeError("kwargs[%s].device different from last time" % k)
info.data_setter.set_value(x._dev_tensor()) info.data_setter.set_value(x._dev_tensor())
self._tensor_remaps[x] = CompiledTensorProxy(h)
self._tensor_remaps[x._handle] = CompiledTensorProxy(h)


def _process_outputs(self, outputs): def _process_outputs(self, outputs):
output_names = None output_names = None
@@ -1000,8 +996,8 @@ class CompiledTensorProxy:
def __del__(self): def __del__(self):
if self.__tensor.shape_read and self.__shape is not None: if self.__tensor.shape_read and self.__shape is not None:
self.__info.shape_reader.drop_value() self.__info.shape_reader.drop_value()
# if self.__tensor.value_read and self.__value is not None:
# self.__info.value_reader.drop_value()
if self.__tensor.value_read and self.__value is not None:
self.__info.value_reader.drop_value()
if self.__tensor.data_read and self.__data is not None: if self.__tensor.data_read and self.__data is not None:
self.__info.data_reader.drop_value() self.__info.data_reader.drop_value()


@@ -1047,7 +1043,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
outputs = [RawTensor(o) for o in ovars] outputs = [RawTensor(o) for o in ovars]


if require_links: if require_links:
active_trace._lazy_eval_links = (outputs[0]._varnode,)
active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),)


active_trace._lazy_eval_tensors.update([TensorWeakRef(o) for o in outputs]) active_trace._lazy_eval_tensors.update([TensorWeakRef(o) for o in outputs])
return outputs return outputs


+ 8
- 1
imperative/python/src/tensor.cpp View File

@@ -760,7 +760,14 @@ void init_tensor(py::module m) {
m.attr("skip_tracing") = &skip_tracing; m.attr("skip_tracing") = &skip_tracing;


py::class_<SharedHandle>(m, "SharedHandle") py::class_<SharedHandle>(m, "SharedHandle")
.def(py::init<const SharedHandle&>());
.def(py::init<const SharedHandle&>())
.def("__eq__", [](SharedHandle &thish, SharedHandle &thath) {
return (thish.get() == thath.get());
})
.def("__hash__", [](SharedHandle &sh) {
return reinterpret_cast<int64_t>(sh.get());
})
;


m.def("set_tracing", &set_tracing); m.def("set_tracing", &set_tracing);
m.def("unset_tracing", &unset_tracing); m.def("unset_tracing", &unset_tracing);


+ 0
- 1
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -141,7 +141,6 @@ def test_regression_1762():
) )
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
@pytest.mark.skip(reason="FIXME: remote_send/recv")
def test_remote_grad(): def test_remote_grad():
@dist.launcher @dist.launcher
def worker(): def worker():


Loading…
Cancel
Save