Browse Source

perf(mge/imperative): misc optimizations

GitOrigin-RevId: bbe7a10b00
release-1.1
Megvii Engine Team 4 years ago
parent
commit
94dba16ff4
12 changed files with 90 additions and 34 deletions
  1. +9
    -2
      imperative/python/megengine/core/_wrap.py
  2. +3
    -2
      imperative/python/megengine/core/autodiff/builtin_op_utils.py
  3. +7
    -4
      imperative/python/megengine/core/autodiff/grad.py
  4. +4
    -2
      imperative/python/megengine/core/tensor/tensor.py
  5. +1
    -1
      imperative/python/megengine/functional/math.py
  6. +6
    -9
      imperative/python/megengine/module/batchnorm.py
  7. +47
    -9
      imperative/python/megengine/module/module.py
  8. +5
    -1
      imperative/python/test/integration/test_correctness.py
  9. +5
    -1
      imperative/python/test/integration/test_dp_correctness.py
  10. +1
    -1
      imperative/src/impl/interpreter_impl.h
  11. +1
    -1
      imperative/src/impl/proxy_graph.h
  12. +1
    -1
      imperative/src/impl/proxy_graph_detail.cpp

+ 9
- 2
imperative/python/megengine/core/_wrap.py View File

@@ -22,7 +22,14 @@ class Device:
else: else:
self._cn = CompNode(device) self._cn = CompNode(device)


self.logical_name = self._cn.logical_name
self._logical_name = None

@property
def logical_name(self):
if self._logical_name:
return self._logical_name
self._logical_name = self._cn.logical_name
return self._logical_name


def to_c(self): def to_c(self):
return self._cn return self._cn
@@ -39,7 +46,7 @@ class Device:
def __eq__(self, rhs): def __eq__(self, rhs):
if not isinstance(rhs, Device): if not isinstance(rhs, Device):
rhs = Device(rhs) rhs = Device(rhs)
return str(self._cn) == str(rhs._cn)
return self._cn == rhs._cn




def device(obj): def device(obj):


+ 3
- 2
imperative/python/megengine/core/autodiff/builtin_op_utils.py View File

@@ -28,6 +28,7 @@ from ..ops.builtin import (
from ..ops.special import Const from ..ops.special import Const
from ..tensor.core import apply from ..tensor.core import apply
from ..tensor.function import Function from ..tensor.function import Function
from ..tensor.tensor import Tensor
from ..tensor.tensor_wrapper import TensorWrapper from ..tensor.tensor_wrapper import TensorWrapper


_reduce_sum_param = Reduce(mode="SUM").to_c().param[0] _reduce_sum_param = Reduce(mode="SUM").to_c().param[0]
@@ -103,8 +104,8 @@ def default_grad_fn(op, inputs, outputs, input_requires_grad):




def get_shape(x): def get_shape(x):
(s,) = apply(GetVarShape(), x)
return s
(s,) = apply(GetVarShape(), x._data)
return Tensor(s)




# override for Elemwise.add # override for Elemwise.add


+ 7
- 4
imperative/python/megengine/core/autodiff/grad.py View File

@@ -387,16 +387,19 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
if not manager._enabled: if not manager._enabled:
return return


opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs)

# register backward method # register backward method
# tuple of backward functions corresponding to dy / dx_i # tuple of backward functions corresponding to dy / dx_i
# None means y is not a function of x_i # None means y is not a function of x_i
opnode.backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn(
backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn(
op, ctx.inputs, ctx.outputs, input_requires_grad op, ctx.inputs, ctx.outputs, input_requires_grad
) )
assert len(ctx.outputs) == len(output_need_grad)
if not any(output_need_grad):
return

opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs)
opnode.backward = backward


assert len(outputs) == len(output_need_grad)
outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)]


opnode.backward_allow_noinput = check_backward_allow_noinput(op) opnode.backward_allow_noinput = check_backward_allow_noinput(op)


+ 4
- 2
imperative/python/megengine/core/tensor/tensor.py View File

@@ -55,6 +55,8 @@ class Tensor(TensorBase):




class ApplyContext: class ApplyContext:
__slots__ = ("inputs", "outputs", "key")

def __init__(self): def __init__(self):
self.inputs = None self.inputs = None
self.outputs = None self.outputs = None
@@ -81,7 +83,7 @@ def get_context():


@apply.register() @apply.register()
def tensor_apply(op: OpBase, *args: Tensor): def tensor_apply(op: OpBase, *args: Tensor):
data = tuple(i._data if isinstance(i, Tensor) else i for i in args)
data = tuple(i._data for i in args)
# type(Tensor._data) is RawTensor # type(Tensor._data) is RawTensor
# dispached to apply.add@RawTensor.py if passed Tensor args # dispached to apply.add@RawTensor.py if passed Tensor args
outputs = apply(op, *data) outputs = apply(op, *data)
@@ -90,7 +92,7 @@ def tensor_apply(op: OpBase, *args: Tensor):
with push_context() as ctx: with push_context() as ctx:
ctx.inputs = args ctx.inputs = args
ctx.outputs = ret ctx.outputs = ret
for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))):
for k in set().union(*(i._extra_data for i in args)):
ctx.key = k ctx.key = k
data = tuple( data = tuple(
i._extra_data.get(k) if isinstance(i, Tensor) else i for i in args i._extra_data.get(k) if isinstance(i, Tensor) else i for i in args


+ 1
- 1
imperative/python/megengine/functional/math.py View File

@@ -229,7 +229,7 @@ def mean(
[3.5] [3.5]


""" """
return inp.astype("float32").mean(axis=axis, keepdims=keepdims)
return inp.mean(axis=axis, keepdims=keepdims)




def var( def var(


+ 6
- 9
imperative/python/megengine/module/batchnorm.py View File

@@ -35,15 +35,14 @@ class _BatchNorm(Module):
self.track_running_stats = track_running_stats self.track_running_stats = track_running_stats
self._track_running_stats_saved = track_running_stats self._track_running_stats_saved = track_running_stats
self.freeze = freeze self.freeze = freeze
tshape = (1, self.num_features, 1, 1)
if self.affine: if self.affine:
self.weight = Parameter(np.ones(num_features, dtype=np.float32))
self.bias = Parameter(np.zeros(num_features, dtype=np.float32))
self.weight = Parameter(np.ones(tshape, dtype=np.float32))
self.bias = Parameter(np.zeros(tshape, dtype=np.float32))
else: else:
self.weight = None self.weight = None
self.bias = None self.bias = None


tshape = (1, self.num_features, 1, 1)

if self.track_running_stats: if self.track_running_stats:
self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32)) self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32))
self.running_var = Tensor(np.ones(tshape, dtype=np.float32)) self.running_var = Tensor(np.ones(tshape, dtype=np.float32))
@@ -86,10 +85,8 @@ class _BatchNorm(Module):
inp = inp.reshape(new_shape) inp = inp.reshape(new_shape)


if self.freeze and self.training and self._track_running_stats_saved: if self.freeze and self.training and self._track_running_stats_saved:
scale = self.weight.reshape(1, -1, 1, 1) * (
self.running_var + self.eps
) ** (-0.5)
bias = self.bias.reshape(1, -1, 1, 1) - self.running_mean * scale
scale = self.weight * (self.running_var + self.eps) ** (-0.5)
bias = self.bias - self.running_mean * scale
return inp * scale.detach() + bias.detach() return inp * scale.detach() + bias.detach()


if self.training and self.track_running_stats: if self.training and self.track_running_stats:
@@ -276,7 +273,7 @@ class BatchNorm2d(_BatchNorm):
m = M.BatchNorm2d(4) m = M.BatchNorm2d(4)
inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32")) inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32"))
oup = m(inp) oup = m(inp)
print(m.weight.numpy(), m.bias.numpy())
print(m.weight.numpy().flatten(), m.bias.numpy().flatten())
# Without L`e`arnable Parameters # Without L`e`arnable Parameters
m = M.BatchNorm2d(4, affine=False) m = M.BatchNorm2d(4, affine=False)
oup = m(inp) oup = m(inp)


+ 47
- 9
imperative/python/megengine/module/module.py View File

@@ -55,6 +55,14 @@ def _is_module(obj):
return isinstance(obj, Module) return isinstance(obj, Module)




def _get_XNorm_typeclass():
from .batchnorm import _BatchNorm

XNorm_types = []
XNorm_types.append(_BatchNorm)
return tuple(XNorm_types)


class Module(metaclass=ABCMeta): class Module(metaclass=ABCMeta):
""" """
Base Module class. Base Module class.
@@ -393,6 +401,18 @@ class Module(metaclass=ABCMeta):
return offset return offset


def state_dict(self, rst=None, prefix="", keep_var=False): def state_dict(self, rst=None, prefix="", keep_var=False):
_rst = self._state_dict(rst=rst, prefix=prefix, keep_var=keep_var)
rst = OrderedDict()
XNorm_typeclass = _get_XNorm_typeclass()
for (module_type, k), v in _rst.items():
# for performance reasons, parameters in XNorm (e.g., BatchNorm2d) are 4-dim tensors,
# however they will be reshaped to 1-dim tensors before returned by `statr_dict()`
if issubclass(module_type, XNorm_typeclass):
v = v.reshape(-1)
rst[k] = v
return rst

def _state_dict(self, rst=None, prefix="", keep_var=False):
r""" r"""
Returns a dictionary containing whole states of the module. Returns a dictionary containing whole states of the module.
""" """
@@ -400,15 +420,16 @@ class Module(metaclass=ABCMeta):
def is_state(obj): def is_state(obj):
return _is_parameter(obj) or _is_buffer(obj) return _is_parameter(obj) or _is_buffer(obj)


module_type = self.__class__
if rst is None: if rst is None:
rst = OrderedDict() rst = OrderedDict()


for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state): for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state):
assert prefix + k not in rst, "duplicated state: {}".format(k) assert prefix + k not in rst, "duplicated state: {}".format(k)
if keep_var: if keep_var:
rst[prefix + k] = v
rst[(module_type, prefix + k)] = v
else: else:
rst[prefix + k] = v.numpy()
rst[(module_type, prefix + k)] = v.numpy()


for k, submodule in self._flatten( for k, submodule in self._flatten(
recursive=False, recursive=False,
@@ -507,13 +528,14 @@ class Module(metaclass=ABCMeta):
Advance state_dict load through callable ``closure`` whose signature is Advance state_dict load through callable ``closure`` whose signature is
``closure(key: str, var: Tensor) -> Union[np.ndarry, None]`` ``closure(key: str, var: Tensor) -> Union[np.ndarry, None]``
""" """
XNorm_typeclass = _get_XNorm_typeclass()
assert callable(closure), "closure must be a function" assert callable(closure), "closure must be a function"


loaded = [] loaded = []
skipped = [] skipped = []


local_state_dict = self.state_dict(keep_var=True)
for k, var in local_state_dict.items():
local_state_dict = self._state_dict(keep_var=True)
for (module_type, k), var in local_state_dict.items():
to_be_load = closure(k, var) to_be_load = closure(k, var)
if to_be_load is None: if to_be_load is None:
skipped.append(k) skipped.append(k)
@@ -523,11 +545,27 @@ class Module(metaclass=ABCMeta):
), "closure should return a `np.ndarray`, now `{}` get {}".format( ), "closure should return a `np.ndarray`, now `{}` get {}".format(
k, to_be_load k, to_be_load
) )
assert make_shape_tuple(var.shape) == make_shape_tuple(
to_be_load.shape
), "param `{}` shape mismatch, should be {}, get {}".format(
k, var.shape, to_be_load.shape
)
var_shape = make_shape_tuple(var.shape)
to_be_load_shape = make_shape_tuple(to_be_load.shape)
if var_shape != to_be_load_shape:
# weight and bias in BatchNorm1d, BatchNorm2d and SyncBatchNorm are 1-dim tensors in v1.0, and
# since v1.1 they are 4-dim tensors. The following special rule for these modules preserves the
# backward compatibility.
if issubclass(module_type, XNorm_typeclass):
if np.prod(var_shape) == np.prod(to_be_load_shape):
to_be_load = to_be_load.reshape(var_shape)
else:
raise ValueError(
"param `{}` size mismatch, should be {}, get {}".format(
k, np.prod(var_shape), np.prod(to_be_load_shape)
)
)
else:
raise ValueError(
"param `{}` shape mismatch, should be {}, get {}".format(
k, var_shape, to_be_load_shape
)
)
var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device)) var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device))
loaded.append(k) loaded.append(k)




+ 5
- 1
imperative/python/test/integration/test_correctness.py View File

@@ -193,7 +193,11 @@ def run_train(
net.state_dict().items(), checkpoint["net_updated"].items() net.state_dict().items(), checkpoint["net_updated"].items()
): ):
assert param[0] == param_ref[0] assert param[0] == param_ref[0]
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err)
if "bn" in param[0]:
ref = param_ref[1].reshape(param[1].shape)
np.testing.assert_allclose(param[1], ref, atol=max_err)
else:
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err)




def run_eval( def run_eval(


+ 5
- 1
imperative/python/test/integration/test_dp_correctness.py View File

@@ -188,7 +188,11 @@ def run_test(
net.state_dict().items(), checkpoint["net_updated"].items() net.state_dict().items(), checkpoint["net_updated"].items()
): ):
assert param[0] == param_ref[0] assert param[0] == param_ref[0]
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err)
if "bn" in param[0]:
ref = param_ref[1].reshape(param[1].shape)
np.testing.assert_allclose(param[1], ref, atol=max_err)
else:
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err)


procs = [] procs = []
for rank in range(p_num): for rank in range(p_num):


+ 1
- 1
imperative/src/impl/interpreter_impl.h View File

@@ -107,7 +107,7 @@ private:
//! level 2: both device and user side errors are async; //! level 2: both device and user side errors are async;
//! level 1: user side errors are sync; //! level 1: user side errors are sync;
//! level 0: both sync. //! level 0: both sync.
int m_async_level = 1;
int m_async_level = 2;
}; };


} // namespace mgb::imperative::interpreter::intl } // namespace mgb::imperative::interpreter::intl

+ 1
- 1
imperative/src/impl/proxy_graph.h View File

@@ -94,7 +94,7 @@ private:


cg::OperatorNodeBase* m_cur_opr = nullptr; cg::OperatorNodeBase* m_cur_opr = nullptr;
std::unique_ptr<ProxyGraphImpl> m_graph; std::unique_ptr<ProxyGraphImpl> m_graph;
size_t m_max_op_cnt = 1000;
size_t m_max_op_cnt = 100;
std::unique_ptr<ExecEnv> m_env; std::unique_ptr<ExecEnv> m_env;
std::unique_ptr<StaticInferManager> m_static_infer_manager; std::unique_ptr<StaticInferManager> m_static_infer_manager;
std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer; std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer;


+ 1
- 1
imperative/src/impl/proxy_graph_detail.cpp View File

@@ -120,12 +120,12 @@ make_backward_graph(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad, const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad) { const SmallVector<bool>& output_has_grad) {
auto&& graph = ProxyGraph::get_default_graph();
auto hash_key = get_backward_graph_hash_key(def, inputs, input_requires_grad, output_has_grad); auto hash_key = get_backward_graph_hash_key(def, inputs, input_requires_grad, output_has_grad);
auto&& iter = backward_graph_cache.find(hash_key); auto&& iter = backward_graph_cache.find(hash_key);
if (iter != backward_graph_cache.end()) { if (iter != backward_graph_cache.end()) {
return iter->second; return iter->second;
} }
auto&& graph = ProxyGraph::get_default_graph();
auto res = graph->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); auto res = graph->make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
backward_graph_cache.emplace(hash_key, res); backward_graph_cache.emplace(hash_key, res);
return res; return res;


Loading…
Cancel
Save