diff --git a/imperative/python/megengine/core/_wrap.py b/imperative/python/megengine/core/_wrap.py index 538518a1..c9a422fd 100644 --- a/imperative/python/megengine/core/_wrap.py +++ b/imperative/python/megengine/core/_wrap.py @@ -22,7 +22,14 @@ class Device: else: self._cn = CompNode(device) - self.logical_name = self._cn.logical_name + self._logical_name = None + + @property + def logical_name(self): + if self._logical_name: + return self._logical_name + self._logical_name = self._cn.logical_name + return self._logical_name def to_c(self): return self._cn @@ -39,7 +46,7 @@ class Device: def __eq__(self, rhs): if not isinstance(rhs, Device): rhs = Device(rhs) - return str(self._cn) == str(rhs._cn) + return self._cn == rhs._cn def device(obj): diff --git a/imperative/python/megengine/core/autodiff/builtin_op_utils.py b/imperative/python/megengine/core/autodiff/builtin_op_utils.py index 356e21f7..07177e26 100644 --- a/imperative/python/megengine/core/autodiff/builtin_op_utils.py +++ b/imperative/python/megengine/core/autodiff/builtin_op_utils.py @@ -28,6 +28,7 @@ from ..ops.builtin import ( from ..ops.special import Const from ..tensor.core import apply from ..tensor.function import Function +from ..tensor.tensor import Tensor from ..tensor.tensor_wrapper import TensorWrapper _reduce_sum_param = Reduce(mode="SUM").to_c().param[0] @@ -103,8 +104,8 @@ def default_grad_fn(op, inputs, outputs, input_requires_grad): def get_shape(x): - (s,) = apply(GetVarShape(), x) - return s + (s,) = apply(GetVarShape(), x._data) + return Tensor(s) # override for Elemwise.add diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index 81ea827e..ae998761 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -387,16 +387,19 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): if not manager._enabled: return - opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) - # register backward method # tuple of backward functions corresponding to dy / dx_i # None means y is not a function of x_i - opnode.backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn( + backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn( op, ctx.inputs, ctx.outputs, input_requires_grad ) + assert len(ctx.outputs) == len(output_need_grad) + if not any(output_need_grad): + return + + opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) + opnode.backward = backward - assert len(outputs) == len(output_need_grad) outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] opnode.backward_allow_noinput = check_backward_allow_noinput(op) diff --git a/imperative/python/megengine/core/tensor/tensor.py b/imperative/python/megengine/core/tensor/tensor.py index 7780ea19..b60e30bc 100644 --- a/imperative/python/megengine/core/tensor/tensor.py +++ b/imperative/python/megengine/core/tensor/tensor.py @@ -55,6 +55,8 @@ class Tensor(TensorBase): class ApplyContext: + __slots__ = ("inputs", "outputs", "key") + def __init__(self): self.inputs = None self.outputs = None @@ -81,7 +83,7 @@ def get_context(): @apply.register() def tensor_apply(op: OpBase, *args: Tensor): - data = tuple(i._data if isinstance(i, Tensor) else i for i in args) + data = tuple(i._data for i in args) # type(Tensor._data) is RawTensor # dispached to apply.add@RawTensor.py if passed Tensor args outputs = apply(op, *data) @@ -90,7 +92,7 @@ def tensor_apply(op: OpBase, *args: Tensor): with push_context() as ctx: ctx.inputs = args ctx.outputs = ret - for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): + for k in set().union(*(i._extra_data for i in args)): ctx.key = k data = tuple( i._extra_data.get(k) if isinstance(i, Tensor) else i for i in args diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index fb7f1856..60c3c899 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -229,7 +229,7 @@ def mean( [3.5] """ - return inp.astype("float32").mean(axis=axis, keepdims=keepdims) + return inp.mean(axis=axis, keepdims=keepdims) def var( diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index 5f404a03..864162f1 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -35,15 +35,14 @@ class _BatchNorm(Module): self.track_running_stats = track_running_stats self._track_running_stats_saved = track_running_stats self.freeze = freeze + tshape = (1, self.num_features, 1, 1) if self.affine: - self.weight = Parameter(np.ones(num_features, dtype=np.float32)) - self.bias = Parameter(np.zeros(num_features, dtype=np.float32)) + self.weight = Parameter(np.ones(tshape, dtype=np.float32)) + self.bias = Parameter(np.zeros(tshape, dtype=np.float32)) else: self.weight = None self.bias = None - tshape = (1, self.num_features, 1, 1) - if self.track_running_stats: self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32)) self.running_var = Tensor(np.ones(tshape, dtype=np.float32)) @@ -86,10 +85,8 @@ class _BatchNorm(Module): inp = inp.reshape(new_shape) if self.freeze and self.training and self._track_running_stats_saved: - scale = self.weight.reshape(1, -1, 1, 1) * ( - self.running_var + self.eps - ) ** (-0.5) - bias = self.bias.reshape(1, -1, 1, 1) - self.running_mean * scale + scale = self.weight * (self.running_var + self.eps) ** (-0.5) + bias = self.bias - self.running_mean * scale return inp * scale.detach() + bias.detach() if self.training and self.track_running_stats: @@ -276,7 +273,7 @@ class BatchNorm2d(_BatchNorm): m = M.BatchNorm2d(4) inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32")) oup = m(inp) - print(m.weight.numpy(), m.bias.numpy()) + print(m.weight.numpy().flatten(), m.bias.numpy().flatten()) # Without L`e`arnable Parameters m = M.BatchNorm2d(4, affine=False) oup = m(inp) diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index c2336729..e400a6c4 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -55,6 +55,14 @@ def _is_module(obj): return isinstance(obj, Module) +def _get_XNorm_typeclass(): + from .batchnorm import _BatchNorm + + XNorm_types = [] + XNorm_types.append(_BatchNorm) + return tuple(XNorm_types) + + class Module(metaclass=ABCMeta): """ Base Module class. @@ -393,6 +401,18 @@ class Module(metaclass=ABCMeta): return offset def state_dict(self, rst=None, prefix="", keep_var=False): + _rst = self._state_dict(rst=rst, prefix=prefix, keep_var=keep_var) + rst = OrderedDict() + XNorm_typeclass = _get_XNorm_typeclass() + for (module_type, k), v in _rst.items(): + # for performance reasons, parameters in XNorm (e.g., BatchNorm2d) are 4-dim tensors, + # however they will be reshaped to 1-dim tensors before returned by `statr_dict()` + if issubclass(module_type, XNorm_typeclass): + v = v.reshape(-1) + rst[k] = v + return rst + + def _state_dict(self, rst=None, prefix="", keep_var=False): r""" Returns a dictionary containing whole states of the module. """ @@ -400,15 +420,16 @@ class Module(metaclass=ABCMeta): def is_state(obj): return _is_parameter(obj) or _is_buffer(obj) + module_type = self.__class__ if rst is None: rst = OrderedDict() for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state): assert prefix + k not in rst, "duplicated state: {}".format(k) if keep_var: - rst[prefix + k] = v + rst[(module_type, prefix + k)] = v else: - rst[prefix + k] = v.numpy() + rst[(module_type, prefix + k)] = v.numpy() for k, submodule in self._flatten( recursive=False, @@ -507,13 +528,14 @@ class Module(metaclass=ABCMeta): Advance state_dict load through callable ``closure`` whose signature is ``closure(key: str, var: Tensor) -> Union[np.ndarry, None]`` """ + XNorm_typeclass = _get_XNorm_typeclass() assert callable(closure), "closure must be a function" loaded = [] skipped = [] - local_state_dict = self.state_dict(keep_var=True) - for k, var in local_state_dict.items(): + local_state_dict = self._state_dict(keep_var=True) + for (module_type, k), var in local_state_dict.items(): to_be_load = closure(k, var) if to_be_load is None: skipped.append(k) @@ -523,11 +545,27 @@ class Module(metaclass=ABCMeta): ), "closure should return a `np.ndarray`, now `{}` get {}".format( k, to_be_load ) - assert make_shape_tuple(var.shape) == make_shape_tuple( - to_be_load.shape - ), "param `{}` shape mismatch, should be {}, get {}".format( - k, var.shape, to_be_load.shape - ) + var_shape = make_shape_tuple(var.shape) + to_be_load_shape = make_shape_tuple(to_be_load.shape) + if var_shape != to_be_load_shape: + # weight and bias in BatchNorm1d, BatchNorm2d and SyncBatchNorm are 1-dim tensors in v1.0, and + # since v1.1 they are 4-dim tensors. The following special rule for these modules preserves the + # backward compatibility. + if issubclass(module_type, XNorm_typeclass): + if np.prod(var_shape) == np.prod(to_be_load_shape): + to_be_load = to_be_load.reshape(var_shape) + else: + raise ValueError( + "param `{}` size mismatch, should be {}, get {}".format( + k, np.prod(var_shape), np.prod(to_be_load_shape) + ) + ) + else: + raise ValueError( + "param `{}` shape mismatch, should be {}, get {}".format( + k, var_shape, to_be_load_shape + ) + ) var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device)) loaded.append(k) diff --git a/imperative/python/test/integration/test_correctness.py b/imperative/python/test/integration/test_correctness.py index 092bd047..4ef30789 100644 --- a/imperative/python/test/integration/test_correctness.py +++ b/imperative/python/test/integration/test_correctness.py @@ -193,7 +193,11 @@ def run_train( net.state_dict().items(), checkpoint["net_updated"].items() ): assert param[0] == param_ref[0] - np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) + if "bn" in param[0]: + ref = param_ref[1].reshape(param[1].shape) + np.testing.assert_allclose(param[1], ref, atol=max_err) + else: + np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) def run_eval( diff --git a/imperative/python/test/integration/test_dp_correctness.py b/imperative/python/test/integration/test_dp_correctness.py index 3491cf5f..34d152c6 100644 --- a/imperative/python/test/integration/test_dp_correctness.py +++ b/imperative/python/test/integration/test_dp_correctness.py @@ -188,7 +188,11 @@ def run_test( net.state_dict().items(), checkpoint["net_updated"].items() ): assert param[0] == param_ref[0] - np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) + if "bn" in param[0]: + ref = param_ref[1].reshape(param[1].shape) + np.testing.assert_allclose(param[1], ref, atol=max_err) + else: + np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) procs = [] for rank in range(p_num): diff --git a/imperative/src/impl/interpreter_impl.h b/imperative/src/impl/interpreter_impl.h index 508e7c46..393933db 100644 --- a/imperative/src/impl/interpreter_impl.h +++ b/imperative/src/impl/interpreter_impl.h @@ -107,7 +107,7 @@ private: //! level 2: both device and user side errors are async; //! level 1: user side errors are sync; //! level 0: both sync. - int m_async_level = 1; + int m_async_level = 2; }; } // namespace mgb::imperative::interpreter::intl diff --git a/imperative/src/impl/proxy_graph.h b/imperative/src/impl/proxy_graph.h index d4b94645..8a0e16a6 100644 --- a/imperative/src/impl/proxy_graph.h +++ b/imperative/src/impl/proxy_graph.h @@ -94,7 +94,7 @@ private: cg::OperatorNodeBase* m_cur_opr = nullptr; std::unique_ptr m_graph; - size_t m_max_op_cnt = 1000; + size_t m_max_op_cnt = 100; std::unique_ptr m_env; std::unique_ptr m_static_infer_manager; std::unique_ptr m_seq_comp_node_optimizer; diff --git a/imperative/src/impl/proxy_graph_detail.cpp b/imperative/src/impl/proxy_graph_detail.cpp index 163f38dd..207c09a5 100644 --- a/imperative/src/impl/proxy_graph_detail.cpp +++ b/imperative/src/impl/proxy_graph_detail.cpp @@ -120,12 +120,12 @@ make_backward_graph(const OpDef& def, const SmallVector& inputs, const SmallVector& input_requires_grad, const SmallVector& output_has_grad) { - auto&& graph = ProxyGraph::get_default_graph(); auto hash_key = get_backward_graph_hash_key(def, inputs, input_requires_grad, output_has_grad); auto&& iter = backward_graph_cache.find(hash_key); if (iter != backward_graph_cache.end()) { return iter->second; } + auto&& graph = ProxyGraph::get_default_graph(); auto res = graph->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); backward_graph_cache.emplace(hash_key, res); return res;