@@ -22,7 +22,14 @@ class Device: | |||||
else: | else: | ||||
self._cn = CompNode(device) | self._cn = CompNode(device) | ||||
self.logical_name = self._cn.logical_name | |||||
self._logical_name = None | |||||
@property | |||||
def logical_name(self): | |||||
if self._logical_name: | |||||
return self._logical_name | |||||
self._logical_name = self._cn.logical_name | |||||
return self._logical_name | |||||
def to_c(self): | def to_c(self): | ||||
return self._cn | return self._cn | ||||
@@ -39,7 +46,7 @@ class Device: | |||||
def __eq__(self, rhs): | def __eq__(self, rhs): | ||||
if not isinstance(rhs, Device): | if not isinstance(rhs, Device): | ||||
rhs = Device(rhs) | rhs = Device(rhs) | ||||
return str(self._cn) == str(rhs._cn) | |||||
return self._cn == rhs._cn | |||||
def device(obj): | def device(obj): | ||||
@@ -28,6 +28,7 @@ from ..ops.builtin import ( | |||||
from ..ops.special import Const | from ..ops.special import Const | ||||
from ..tensor.core import apply | from ..tensor.core import apply | ||||
from ..tensor.function import Function | from ..tensor.function import Function | ||||
from ..tensor.tensor import Tensor | |||||
from ..tensor.tensor_wrapper import TensorWrapper | from ..tensor.tensor_wrapper import TensorWrapper | ||||
_reduce_sum_param = Reduce(mode="SUM").to_c().param[0] | _reduce_sum_param = Reduce(mode="SUM").to_c().param[0] | ||||
@@ -103,8 +104,8 @@ def default_grad_fn(op, inputs, outputs, input_requires_grad): | |||||
def get_shape(x): | def get_shape(x): | ||||
(s,) = apply(GetVarShape(), x) | |||||
return s | |||||
(s,) = apply(GetVarShape(), x._data) | |||||
return Tensor(s) | |||||
# override for Elemwise.add | # override for Elemwise.add | ||||
@@ -387,16 +387,19 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): | |||||
if not manager._enabled: | if not manager._enabled: | ||||
return | return | ||||
opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) | |||||
# register backward method | # register backward method | ||||
# tuple of backward functions corresponding to dy / dx_i | # tuple of backward functions corresponding to dy / dx_i | ||||
# None means y is not a function of x_i | # None means y is not a function of x_i | ||||
opnode.backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn( | |||||
backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn( | |||||
op, ctx.inputs, ctx.outputs, input_requires_grad | op, ctx.inputs, ctx.outputs, input_requires_grad | ||||
) | ) | ||||
assert len(ctx.outputs) == len(output_need_grad) | |||||
if not any(output_need_grad): | |||||
return | |||||
opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) | |||||
opnode.backward = backward | |||||
assert len(outputs) == len(output_need_grad) | |||||
outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] | outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] | ||||
opnode.backward_allow_noinput = check_backward_allow_noinput(op) | opnode.backward_allow_noinput = check_backward_allow_noinput(op) | ||||
@@ -55,6 +55,8 @@ class Tensor(TensorBase): | |||||
class ApplyContext: | class ApplyContext: | ||||
__slots__ = ("inputs", "outputs", "key") | |||||
def __init__(self): | def __init__(self): | ||||
self.inputs = None | self.inputs = None | ||||
self.outputs = None | self.outputs = None | ||||
@@ -81,7 +83,7 @@ def get_context(): | |||||
@apply.register() | @apply.register() | ||||
def tensor_apply(op: OpBase, *args: Tensor): | def tensor_apply(op: OpBase, *args: Tensor): | ||||
data = tuple(i._data if isinstance(i, Tensor) else i for i in args) | |||||
data = tuple(i._data for i in args) | |||||
# type(Tensor._data) is RawTensor | # type(Tensor._data) is RawTensor | ||||
# dispached to apply.add@RawTensor.py if passed Tensor args | # dispached to apply.add@RawTensor.py if passed Tensor args | ||||
outputs = apply(op, *data) | outputs = apply(op, *data) | ||||
@@ -90,7 +92,7 @@ def tensor_apply(op: OpBase, *args: Tensor): | |||||
with push_context() as ctx: | with push_context() as ctx: | ||||
ctx.inputs = args | ctx.inputs = args | ||||
ctx.outputs = ret | ctx.outputs = ret | ||||
for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): | |||||
for k in set().union(*(i._extra_data for i in args)): | |||||
ctx.key = k | ctx.key = k | ||||
data = tuple( | data = tuple( | ||||
i._extra_data.get(k) if isinstance(i, Tensor) else i for i in args | i._extra_data.get(k) if isinstance(i, Tensor) else i for i in args | ||||
@@ -229,7 +229,7 @@ def mean( | |||||
[3.5] | [3.5] | ||||
""" | """ | ||||
return inp.astype("float32").mean(axis=axis, keepdims=keepdims) | |||||
return inp.mean(axis=axis, keepdims=keepdims) | |||||
def var( | def var( | ||||
@@ -35,15 +35,14 @@ class _BatchNorm(Module): | |||||
self.track_running_stats = track_running_stats | self.track_running_stats = track_running_stats | ||||
self._track_running_stats_saved = track_running_stats | self._track_running_stats_saved = track_running_stats | ||||
self.freeze = freeze | self.freeze = freeze | ||||
tshape = (1, self.num_features, 1, 1) | |||||
if self.affine: | if self.affine: | ||||
self.weight = Parameter(np.ones(num_features, dtype=np.float32)) | |||||
self.bias = Parameter(np.zeros(num_features, dtype=np.float32)) | |||||
self.weight = Parameter(np.ones(tshape, dtype=np.float32)) | |||||
self.bias = Parameter(np.zeros(tshape, dtype=np.float32)) | |||||
else: | else: | ||||
self.weight = None | self.weight = None | ||||
self.bias = None | self.bias = None | ||||
tshape = (1, self.num_features, 1, 1) | |||||
if self.track_running_stats: | if self.track_running_stats: | ||||
self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32)) | self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32)) | ||||
self.running_var = Tensor(np.ones(tshape, dtype=np.float32)) | self.running_var = Tensor(np.ones(tshape, dtype=np.float32)) | ||||
@@ -86,10 +85,8 @@ class _BatchNorm(Module): | |||||
inp = inp.reshape(new_shape) | inp = inp.reshape(new_shape) | ||||
if self.freeze and self.training and self._track_running_stats_saved: | if self.freeze and self.training and self._track_running_stats_saved: | ||||
scale = self.weight.reshape(1, -1, 1, 1) * ( | |||||
self.running_var + self.eps | |||||
) ** (-0.5) | |||||
bias = self.bias.reshape(1, -1, 1, 1) - self.running_mean * scale | |||||
scale = self.weight * (self.running_var + self.eps) ** (-0.5) | |||||
bias = self.bias - self.running_mean * scale | |||||
return inp * scale.detach() + bias.detach() | return inp * scale.detach() + bias.detach() | ||||
if self.training and self.track_running_stats: | if self.training and self.track_running_stats: | ||||
@@ -276,7 +273,7 @@ class BatchNorm2d(_BatchNorm): | |||||
m = M.BatchNorm2d(4) | m = M.BatchNorm2d(4) | ||||
inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32")) | inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32")) | ||||
oup = m(inp) | oup = m(inp) | ||||
print(m.weight.numpy(), m.bias.numpy()) | |||||
print(m.weight.numpy().flatten(), m.bias.numpy().flatten()) | |||||
# Without L`e`arnable Parameters | # Without L`e`arnable Parameters | ||||
m = M.BatchNorm2d(4, affine=False) | m = M.BatchNorm2d(4, affine=False) | ||||
oup = m(inp) | oup = m(inp) | ||||
@@ -55,6 +55,14 @@ def _is_module(obj): | |||||
return isinstance(obj, Module) | return isinstance(obj, Module) | ||||
def _get_XNorm_typeclass(): | |||||
from .batchnorm import _BatchNorm | |||||
XNorm_types = [] | |||||
XNorm_types.append(_BatchNorm) | |||||
return tuple(XNorm_types) | |||||
class Module(metaclass=ABCMeta): | class Module(metaclass=ABCMeta): | ||||
""" | """ | ||||
Base Module class. | Base Module class. | ||||
@@ -393,6 +401,18 @@ class Module(metaclass=ABCMeta): | |||||
return offset | return offset | ||||
def state_dict(self, rst=None, prefix="", keep_var=False): | def state_dict(self, rst=None, prefix="", keep_var=False): | ||||
_rst = self._state_dict(rst=rst, prefix=prefix, keep_var=keep_var) | |||||
rst = OrderedDict() | |||||
XNorm_typeclass = _get_XNorm_typeclass() | |||||
for (module_type, k), v in _rst.items(): | |||||
# for performance reasons, parameters in XNorm (e.g., BatchNorm2d) are 4-dim tensors, | |||||
# however they will be reshaped to 1-dim tensors before returned by `statr_dict()` | |||||
if issubclass(module_type, XNorm_typeclass): | |||||
v = v.reshape(-1) | |||||
rst[k] = v | |||||
return rst | |||||
def _state_dict(self, rst=None, prefix="", keep_var=False): | |||||
r""" | r""" | ||||
Returns a dictionary containing whole states of the module. | Returns a dictionary containing whole states of the module. | ||||
""" | """ | ||||
@@ -400,15 +420,16 @@ class Module(metaclass=ABCMeta): | |||||
def is_state(obj): | def is_state(obj): | ||||
return _is_parameter(obj) or _is_buffer(obj) | return _is_parameter(obj) or _is_buffer(obj) | ||||
module_type = self.__class__ | |||||
if rst is None: | if rst is None: | ||||
rst = OrderedDict() | rst = OrderedDict() | ||||
for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state): | for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state): | ||||
assert prefix + k not in rst, "duplicated state: {}".format(k) | assert prefix + k not in rst, "duplicated state: {}".format(k) | ||||
if keep_var: | if keep_var: | ||||
rst[prefix + k] = v | |||||
rst[(module_type, prefix + k)] = v | |||||
else: | else: | ||||
rst[prefix + k] = v.numpy() | |||||
rst[(module_type, prefix + k)] = v.numpy() | |||||
for k, submodule in self._flatten( | for k, submodule in self._flatten( | ||||
recursive=False, | recursive=False, | ||||
@@ -507,13 +528,14 @@ class Module(metaclass=ABCMeta): | |||||
Advance state_dict load through callable ``closure`` whose signature is | Advance state_dict load through callable ``closure`` whose signature is | ||||
``closure(key: str, var: Tensor) -> Union[np.ndarry, None]`` | ``closure(key: str, var: Tensor) -> Union[np.ndarry, None]`` | ||||
""" | """ | ||||
XNorm_typeclass = _get_XNorm_typeclass() | |||||
assert callable(closure), "closure must be a function" | assert callable(closure), "closure must be a function" | ||||
loaded = [] | loaded = [] | ||||
skipped = [] | skipped = [] | ||||
local_state_dict = self.state_dict(keep_var=True) | |||||
for k, var in local_state_dict.items(): | |||||
local_state_dict = self._state_dict(keep_var=True) | |||||
for (module_type, k), var in local_state_dict.items(): | |||||
to_be_load = closure(k, var) | to_be_load = closure(k, var) | ||||
if to_be_load is None: | if to_be_load is None: | ||||
skipped.append(k) | skipped.append(k) | ||||
@@ -523,11 +545,27 @@ class Module(metaclass=ABCMeta): | |||||
), "closure should return a `np.ndarray`, now `{}` get {}".format( | ), "closure should return a `np.ndarray`, now `{}` get {}".format( | ||||
k, to_be_load | k, to_be_load | ||||
) | ) | ||||
assert make_shape_tuple(var.shape) == make_shape_tuple( | |||||
to_be_load.shape | |||||
), "param `{}` shape mismatch, should be {}, get {}".format( | |||||
k, var.shape, to_be_load.shape | |||||
) | |||||
var_shape = make_shape_tuple(var.shape) | |||||
to_be_load_shape = make_shape_tuple(to_be_load.shape) | |||||
if var_shape != to_be_load_shape: | |||||
# weight and bias in BatchNorm1d, BatchNorm2d and SyncBatchNorm are 1-dim tensors in v1.0, and | |||||
# since v1.1 they are 4-dim tensors. The following special rule for these modules preserves the | |||||
# backward compatibility. | |||||
if issubclass(module_type, XNorm_typeclass): | |||||
if np.prod(var_shape) == np.prod(to_be_load_shape): | |||||
to_be_load = to_be_load.reshape(var_shape) | |||||
else: | |||||
raise ValueError( | |||||
"param `{}` size mismatch, should be {}, get {}".format( | |||||
k, np.prod(var_shape), np.prod(to_be_load_shape) | |||||
) | |||||
) | |||||
else: | |||||
raise ValueError( | |||||
"param `{}` shape mismatch, should be {}, get {}".format( | |||||
k, var_shape, to_be_load_shape | |||||
) | |||||
) | |||||
var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device)) | var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device)) | ||||
loaded.append(k) | loaded.append(k) | ||||
@@ -193,7 +193,11 @@ def run_train( | |||||
net.state_dict().items(), checkpoint["net_updated"].items() | net.state_dict().items(), checkpoint["net_updated"].items() | ||||
): | ): | ||||
assert param[0] == param_ref[0] | assert param[0] == param_ref[0] | ||||
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) | |||||
if "bn" in param[0]: | |||||
ref = param_ref[1].reshape(param[1].shape) | |||||
np.testing.assert_allclose(param[1], ref, atol=max_err) | |||||
else: | |||||
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) | |||||
def run_eval( | def run_eval( | ||||
@@ -188,7 +188,11 @@ def run_test( | |||||
net.state_dict().items(), checkpoint["net_updated"].items() | net.state_dict().items(), checkpoint["net_updated"].items() | ||||
): | ): | ||||
assert param[0] == param_ref[0] | assert param[0] == param_ref[0] | ||||
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) | |||||
if "bn" in param[0]: | |||||
ref = param_ref[1].reshape(param[1].shape) | |||||
np.testing.assert_allclose(param[1], ref, atol=max_err) | |||||
else: | |||||
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) | |||||
procs = [] | procs = [] | ||||
for rank in range(p_num): | for rank in range(p_num): | ||||
@@ -107,7 +107,7 @@ private: | |||||
//! level 2: both device and user side errors are async; | //! level 2: both device and user side errors are async; | ||||
//! level 1: user side errors are sync; | //! level 1: user side errors are sync; | ||||
//! level 0: both sync. | //! level 0: both sync. | ||||
int m_async_level = 1; | |||||
int m_async_level = 2; | |||||
}; | }; | ||||
} // namespace mgb::imperative::interpreter::intl | } // namespace mgb::imperative::interpreter::intl |
@@ -94,7 +94,7 @@ private: | |||||
cg::OperatorNodeBase* m_cur_opr = nullptr; | cg::OperatorNodeBase* m_cur_opr = nullptr; | ||||
std::unique_ptr<ProxyGraphImpl> m_graph; | std::unique_ptr<ProxyGraphImpl> m_graph; | ||||
size_t m_max_op_cnt = 1000; | |||||
size_t m_max_op_cnt = 100; | |||||
std::unique_ptr<ExecEnv> m_env; | std::unique_ptr<ExecEnv> m_env; | ||||
std::unique_ptr<StaticInferManager> m_static_infer_manager; | std::unique_ptr<StaticInferManager> m_static_infer_manager; | ||||
std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer; | std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer; | ||||
@@ -120,12 +120,12 @@ make_backward_graph(const OpDef& def, | |||||
const SmallVector<LogicalTensorDesc>& inputs, | const SmallVector<LogicalTensorDesc>& inputs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
const SmallVector<bool>& output_has_grad) { | const SmallVector<bool>& output_has_grad) { | ||||
auto&& graph = ProxyGraph::get_default_graph(); | |||||
auto hash_key = get_backward_graph_hash_key(def, inputs, input_requires_grad, output_has_grad); | auto hash_key = get_backward_graph_hash_key(def, inputs, input_requires_grad, output_has_grad); | ||||
auto&& iter = backward_graph_cache.find(hash_key); | auto&& iter = backward_graph_cache.find(hash_key); | ||||
if (iter != backward_graph_cache.end()) { | if (iter != backward_graph_cache.end()) { | ||||
return iter->second; | return iter->second; | ||||
} | } | ||||
auto&& graph = ProxyGraph::get_default_graph(); | |||||
auto res = graph->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | auto res = graph->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | ||||
backward_graph_cache.emplace(hash_key, res); | backward_graph_cache.emplace(hash_key, res); | ||||
return res; | return res; | ||||