diff --git a/imperative/python/megengine/core/autodiff/builtin_op_utils.py b/imperative/python/megengine/core/autodiff/builtin_op_utils.py index 6ed12afb..db9605ca 100644 --- a/imperative/python/megengine/core/autodiff/builtin_op_utils.py +++ b/imperative/python/megengine/core/autodiff/builtin_op_utils.py @@ -30,7 +30,6 @@ from ..tensor.core import apply from ..tensor.function import Function from ..tensor.tensor_wrapper import TensorWrapper -_elemwise_add_param = Elemwise(mode="add").to_c().param _reduce_sum_param = Reduce(mode="SUM").to_c().param[0] @@ -44,12 +43,12 @@ def _(op: OpDef, inputs, outputs, input_requires_grad): if isinstance(op, OprAttr): grad_fn = _oprAttr_grad_fn.get(op.type, None) if grad_fn is None: - if op.type == Elemwise.name and op.param == _elemwise_add_param: - grad_fn = elemwise_add_grad_fn - elif op.type == Reduce.name and op.param[0] == _reduce_sum_param: + if op.type == Reduce.name and op.param[0] == _reduce_sum_param: grad_fn = reduce_sum_grad_fn else: grad_fn = default_grad_fn + elif isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD: + grad_fn = elemwise_add_grad_fn else: grad_fn = default_grad_fn return grad_fn(op, inputs, outputs, input_requires_grad) @@ -158,11 +157,8 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): params = inputs[1:] def make_grad(grad_op, dy): - grad = ( - TensorWrapper(0, dtype=dy.dtype, device=dy.device) - ._broadcast(TensorWrapper(input_shape)) - .__wrapped__ - ) + (_z,) = Const(0, dtype=dy.dtype, device=dy.device)(dy) + (grad,) = apply(Broadcast(), _z, input_shape) (dx,) = apply(grad_op, grad, dy, *params) return dx @@ -184,11 +180,8 @@ def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad): params = inputs[1:] def make_grad(grad_op, dy): - grad = ( - TensorWrapper(0, dtype=dy.dtype, device=dy.device) - ._broadcast(TensorWrapper(input_shape)) - .__wrapped__ - ) + (_z,) = Const(0, dtype=dy.dtype, device=dy.device)(dy) + (grad,) = apply(Broadcast(), _z, input_shape) (dx,) = apply(grad_op, grad, dy, *params) return dx diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index 7562fbb4..12a46b97 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -47,7 +47,7 @@ def get_grad_managers(): def add(a, b): - (c,) = apply(Elemwise(mode="add"), a, b) + (c,) = apply(Elemwise(Elemwise.Mode.ADD), a, b) return c diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 374e0975..6b41c4c8 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -13,7 +13,7 @@ import numpy as np from .._trace_option import use_symbolic_shape from ..ops import builtin -from ..ops.builtin import GetVarShape +from ..ops.builtin import Elemwise, GetVarShape from ..ops.special import Const from . import utils from .core import OpBase, TensorBase, TensorWrapperBase, apply @@ -23,10 +23,12 @@ from .raw_tensor import RawTensor, as_raw_tensor from .tensor import Tensor from .utils import make_shape_tuple as _make_shape_tuple +_ElwMod = Elemwise.Mode + def _elwise(*args, mode): - op = builtin.Elemwise(mode=mode) - if mode in ("TRUE_DIV", "POW"): + op = builtin.Elemwise(mode) + if mode in (_ElwMod.TRUE_DIV, _ElwMod.POW): args = tuple( map( lambda x: x.astype("float32") @@ -272,53 +274,53 @@ class ArrayMethodMixin(abc.ABC): __hash__ = None # due to __eq__ diviates from python convention - __lt__ = lambda self, value: _elwise(self, value, mode="LT").astype("bool") - __le__ = lambda self, value: _elwise(self, value, mode="LEQ").astype("bool") - __gt__ = lambda self, value: _elwise(value, self, mode="LT").astype("bool") - __ge__ = lambda self, value: _elwise(value, self, mode="LEQ").astype("bool") - __eq__ = lambda self, value: _elwise(self, value, mode="EQ").astype("bool") + __lt__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LT).astype("bool") + __le__ = lambda self, value: _elwise(self, value, mode=_ElwMod.LEQ).astype("bool") + __gt__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LT).astype("bool") + __ge__ = lambda self, value: _elwise(value, self, mode=_ElwMod.LEQ).astype("bool") + __eq__ = lambda self, value: _elwise(self, value, mode=_ElwMod.EQ).astype("bool") __ne__ = lambda self, value: _elwise( - _elwise(self, value, mode="EQ").astype("bool"), mode="NOT" + _elwise(self, value, mode=_ElwMod.EQ).astype("bool"), mode=_ElwMod.NOT, ) - __neg__ = _unary_elwise("NEGATE") + __neg__ = _unary_elwise(_ElwMod.NEGATE) __pos__ = lambda self: self - __abs__ = _unary_elwise("ABS") - __invert__ = _logical_unary_elwise("NOT") - __round__ = _unary_elwise("ROUND") + __abs__ = _unary_elwise(_ElwMod.ABS) + __invert__ = _logical_unary_elwise(_ElwMod.NOT) + __round__ = _unary_elwise(_ElwMod.ROUND) __trunc__ = _todo - __floor__ = _unary_elwise("FLOOR") - __ceil__ = _unary_elwise("CEIL") + __floor__ = _unary_elwise(_ElwMod.FLOOR) + __ceil__ = _unary_elwise(_ElwMod.CEIL) - __add__ = _binary_elwise("ADD") - __sub__ = _binary_elwise("SUB") - __mul__ = _binary_elwise("MUL") + __add__ = _binary_elwise(_ElwMod.ADD) + __sub__ = _binary_elwise(_ElwMod.SUB) + __mul__ = _binary_elwise(_ElwMod.MUL) __matmul__ = lambda self, other: _matmul(self, other) - __truediv__ = _binary_elwise("TRUE_DIV") - __floordiv__ = _binary_elwise("FLOOR_DIV") - __mod__ = _binary_elwise("MOD") + __truediv__ = _binary_elwise(_ElwMod.TRUE_DIV) + __floordiv__ = _binary_elwise(_ElwMod.FLOOR_DIV) + __mod__ = _binary_elwise(_ElwMod.MOD) # __divmode__ - __pow__ = _binary_elwise("POW") - __lshift__ = _binary_elwise("SHL") - __rshift__ = _binary_elwise("SHR") - __and__ = _logical_binary_elwise("AND") - __or__ = _logical_binary_elwise("OR") - __xor__ = _logical_binary_elwise("XOR") - - __radd__ = _binary_elwise("ADD", rev=1) - __rsub__ = _binary_elwise("SUB", rev=1) - __rmul__ = _binary_elwise("MUL", rev=1) + __pow__ = _binary_elwise(_ElwMod.POW) + __lshift__ = _binary_elwise(_ElwMod.SHL) + __rshift__ = _binary_elwise(_ElwMod.SHR) + __and__ = _logical_binary_elwise(_ElwMod.AND) + __or__ = _logical_binary_elwise(_ElwMod.OR) + __xor__ = _logical_binary_elwise(_ElwMod.XOR) + + __radd__ = _binary_elwise(_ElwMod.ADD, rev=1) + __rsub__ = _binary_elwise(_ElwMod.SUB, rev=1) + __rmul__ = _binary_elwise(_ElwMod.MUL, rev=1) __rmatmul__ = lambda self, other: _matmul(other, self) - __rtruediv__ = _binary_elwise("TRUE_DIV", rev=1) - __rfloordiv__ = _binary_elwise("FLOOR_DIV", rev=1) - __rmod__ = _binary_elwise("MOD", rev=1) + __rtruediv__ = _binary_elwise(_ElwMod.TRUE_DIV, rev=1) + __rfloordiv__ = _binary_elwise(_ElwMod.FLOOR_DIV, rev=1) + __rmod__ = _binary_elwise(_ElwMod.MOD, rev=1) # __rdivmode__ - __rpow__ = _binary_elwise("POW", rev=1) - __rlshift__ = _binary_elwise("SHL", rev=1) - __rrshift__ = _binary_elwise("SHR", rev=1) - __rand__ = _logical_binary_elwise("AND", rev=1) - __ror__ = _logical_binary_elwise("OR", rev=1) - __rxor__ = _logical_binary_elwise("XOR", rev=1) + __rpow__ = _binary_elwise(_ElwMod.POW, rev=1) + __rlshift__ = _binary_elwise(_ElwMod.SHL, rev=1) + __rrshift__ = _binary_elwise(_ElwMod.SHR, rev=1) + __rand__ = _logical_binary_elwise(_ElwMod.AND, rev=1) + __ror__ = _logical_binary_elwise(_ElwMod.OR, rev=1) + __rxor__ = _logical_binary_elwise(_ElwMod.XOR, rev=1) __iadd__ = _inplace(__add__) __isub__ = _inplace(__sub__) diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 686ddf4c..ddf288db 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -10,6 +10,7 @@ import functools from ..core.ops import builtin +from ..core.ops.builtin import Elemwise from ..core.tensor import megbrain_graph, utils from ..core.tensor.core import apply from ..device import get_default_device @@ -72,7 +73,7 @@ __all__ = [ def _elwise(*args, mode): - op = builtin.Elemwise(mode=mode) + op = builtin.Elemwise(mode) tensor_args = list( filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) ) @@ -128,67 +129,67 @@ def add(x, y): [ 6. 8. 10.]] """ - return _elwise(x, y, mode="add") + return _elwise(x, y, mode=Elemwise.Mode.ADD) def sub(x, y): """Element-wise `subtraction`.""" - return _elwise(x, y, mode="sub") + return _elwise(x, y, mode=Elemwise.Mode.SUB) def mul(x, y): """Element-wise `multiplication`.""" - return _elwise(x, y, mode="mul") + return _elwise(x, y, mode=Elemwise.Mode.MUL) def div(x, y): """Element-wise `(x / y)`.""" - return _elwise(x, y, mode="true_div") + return _elwise(x, y, mode=Elemwise.Mode.TRUE_DIV) def floor_div(x, y): """Element-wise `floor(x / y)`.""" - return _elwise(x, y, mode="floor_divide") + return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIVIDE) def neg(x): """Element-wise `negation`.""" - return _elwise(x, mode="negate") + return _elwise(x, mode=Elemwise.Mode.NEGATE) def pow(x, y): """Element-wise `power`.""" - return _elwise(x, y, mode="pow") + return _elwise(x, y, mode=Elemwise.Mode.POW) def mod(x, y): """Element-wise `remainder of division`.""" - return _elwise(x, y, mode="mod") + return _elwise(x, y, mode=Elemwise.Mode.MOD) def abs(x): """Element-wise `absolute value`.""" - return _elwise(x, mode="abs") + return _elwise(x, mode=Elemwise.Mode.ABS) def exp(x): """Element-wise `exponential`.""" - return _elwise(x, mode="exp") + return _elwise(x, mode=Elemwise.Mode.EXP) def expm1(x): """Element-wise `exp(x)-1`.""" - return _elwise(x, mode="expm1") + return _elwise(x, mode=Elemwise.Mode.EXPM1) def log(x): """Element-wise `logarithm (base e)`.""" - return _elwise(x, mode="log") + return _elwise(x, mode=Elemwise.Mode.LOG) def log1p(x): """Element-wise `log(x+1) (base e)`.""" - return _elwise(x, mode="log1p") + return _elwise(x, mode=Elemwise.Mode.LOG1P) def sqrt(x: Tensor) -> Tensor: @@ -253,27 +254,27 @@ def square(x: Tensor) -> Tensor: def round(x): """Element-wise `rounding to int`.""" - return _elwise(x, mode="round") + return _elwise(x, mode=Elemwise.Mode.ROUND) def ceil(x): """Element-wise `ceiling`.""" - return _elwise(x, mode="ceil") + return _elwise(x, mode=Elemwise.Mode.CEIL) def floor(x): """Element-wise `floor`.""" - return _elwise(x, mode="floor") + return _elwise(x, mode=Elemwise.Mode.FLOOR) def maximum(x, y): """Element-wise `maximum of array elements`.""" - return _elwise(x, y, mode="max") + return _elwise(x, y, mode=Elemwise.Mode.MAX) def minimum(x, y): """Element-wise `minimum of array elements`.""" - return _elwise(x, y, mode="min") + return _elwise(x, y, mode=Elemwise.Mode.MIN) # trigonometric functions @@ -305,12 +306,12 @@ def cos(x): [-0.99 -0.6536 0.2837]] """ - return _elwise(x, mode="cos") + return _elwise(x, mode=Elemwise.Mode.COS) def sin(x): """Element-wise `sine`.""" - return _elwise(x, mode="sin") + return _elwise(x, mode=Elemwise.Mode.SIN) def tan(x): @@ -320,22 +321,22 @@ def tan(x): def acos(x): """Element-wise `inverse cosine`.""" - return _elwise(x, mode="acos") + return _elwise(x, mode=Elemwise.Mode.ACOS) def asin(x): """Element-wise `inverse sine`.""" - return _elwise(x, mode="asin") + return _elwise(x, mode=Elemwise.Mode.ASIN) def atan(x): """Element-wise `inverse tangent`.""" - return _elwise(x, 1, mode="atan2") + return _elwise(x, 1, mode=Elemwise.Mode.ATAN2) def atan2(y, x): """Element-wise `2-argument arctangent`.""" - return _elwise(y, x, mode="atan2") + return _elwise(y, x, mode=Elemwise.Mode.ATAN2) def cosh(x): @@ -351,7 +352,7 @@ def sinh(x): def tanh(x): r"""Element-wise `hyperbolic tangent`.""" - return _elwise(x, mode="tanh") + return _elwise(x, mode=Elemwise.Mode.TANH) def asinh(x): @@ -399,12 +400,12 @@ def left_shift(x, y): [12 16 20]] """ - return _elwise(x, y, mode="shl") + return _elwise(x, y, mode=Elemwise.Mode.SHL) def right_shift(x, y): """Element-wise `bitwise binary: x >> y`.""" - return _elwise(x, y, mode="shr") + return _elwise(x, y, mode=Elemwise.Mode.SHR) # logical functions @@ -412,22 +413,22 @@ def right_shift(x, y): def logical_and(x, y): """Element-wise `logical and: x && y`.""" - return _elwise(x, y, mode="AND") + return _elwise(x, y, mode=Elemwise.Mode.AND) def logical_not(x): """Element-wise `logical not: ~x`.""" - return _elwise(x, mode="NOT") + return _elwise(x, mode=Elemwise.Mode.NOT) def logical_or(x, y): """Element-wise `logical or: x || y`.""" - return _elwise(x, y, mode="OR") + return _elwise(x, y, mode=Elemwise.Mode.OR) def logical_xor(x, y): """Element-wise `logical xor: x ^ y`.""" - return _elwise(x, y, mode="XOR") + return _elwise(x, y, mode=Elemwise.Mode.XOR) # comparison functions @@ -461,7 +462,7 @@ def equal(x, y): [1. 1. 1.]] """ - return _elwise(x, y, mode="eq") + return _elwise(x, y, mode=Elemwise.Mode.EQ) def not_equal(x, y): @@ -471,22 +472,22 @@ def not_equal(x, y): def less(x, y): """Element-wise `(x < y)`.""" - return _elwise(x, y, mode="lt") + return _elwise(x, y, mode=Elemwise.Mode.LT) def less_equal(x, y): """Element-wise `(x <= y)`.""" - return _elwise(x, y, mode="leq") + return _elwise(x, y, mode=Elemwise.Mode.LEQ) def greater(x, y): """Element-wise `(x > y)`.""" - return _elwise(y, x, mode="lt") + return _elwise(y, x, mode=Elemwise.Mode.LT) def greater_equal(x, y): """Element-wise `(x >= y)`.""" - return _elwise(y, x, mode="leq") + return _elwise(y, x, mode=Elemwise.Mode.LEQ) # other functions @@ -515,7 +516,7 @@ def hswish(x): [0. 0.6667 1.6667 3. 4. ] """ - return _elwise(x, mode="h_swish") + return _elwise(x, mode=Elemwise.Mode.H_SWISH) def hsigmoid(x): @@ -525,7 +526,7 @@ def hsigmoid(x): def relu(x): """Element-wise `max(x, 0)`.""" - return _elwise(x, mode="relu") + return _elwise(x, mode=Elemwise.Mode.RELU) def relu6(x): @@ -535,7 +536,7 @@ def relu6(x): def sigmoid(x): """Element-wise `1 / ( 1 + exp( -x ) )`.""" - return _elwise(x, mode="sigmoid") + return _elwise(x, mode=Elemwise.Mode.SIGMOID) def clip(x: Tensor, lower=None, upper=None) -> Tensor: diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index b4a8547d..7fd28771 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -12,6 +12,7 @@ from typing import Optional, Sequence, Tuple, Union from ..core._imperative_rt import CompNode from ..core.ops import builtin from ..core.ops._internal import param_defs as P +from ..core.ops.builtin import BatchNorm from ..core.ops.special import Const from ..core.tensor import megbrain_graph, utils from ..core.tensor.core import TensorBase, TensorWrapperBase, apply @@ -643,19 +644,22 @@ def batch_norm( if inp.ndim != 4: raise NotImplementedError("batch_norm for ndim != 4") - def full_value(value): - C = inp.shape[1] - (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) - return broadcast_to(x, [1, C, 1, 1]) - - def expand_or_full(x, value): - if x is None: - return full_value(value) - return expand_dims(x, [0, 2, 3]) + C = inp.shape[1] def make_full_if_none(x, value): if x is None: - return full(shape=(1, inp.shape[1], 1, 1), value=value) + (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) + shape = utils.astensor1d( + (1, C, 1, 1), inp, dtype="int32", device=inp.device + ) + (result,) = apply(builtin.Broadcast(), x, shape) + return result + elif x.ndim == 1: + shape = utils.astensor1d( + (1, C, 1, 1), inp, dtype="int32", device=inp.device + ) + (result,) = apply(builtin.Reshape(), x, shape) + return result return x has_mean = running_mean is not None @@ -674,19 +678,25 @@ def batch_norm( inp, weight, bias, running_mean, running_var ) - weight = expand_or_full(weight, 1) - bias = expand_or_full(bias, 0) + weight = make_full_if_none(weight, 1) + bias = make_full_if_none(bias, 0) if not training: - op = builtin.BatchNorm(fwd_mode="INFERENCE", epsilon=eps, param_dim="DIM_1C11") + op = builtin.BatchNorm( + BatchNorm.ParamDim.DIM_1C11, BatchNorm.FwdMode.INFERENCE, eps, 1.0, 1.0, 0.0 + ) ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] return ret else: op = builtin.BatchNorm( - avg_factor=1 - momentum, epsilon=eps, param_dim="DIM_1C11" + BatchNorm.ParamDim.DIM_1C11, + BatchNorm.FwdMode.TRAINING, + eps, + 1.0 - momentum, + 1.0, + 0.0, ) - if has_mean or has_var: running_mean = make_full_if_none(running_mean, 0) running_var = make_full_if_none(running_var, 1) @@ -708,7 +718,7 @@ def batch_norm( else: return inp, new_mean, new_var else: - _, _, inp, = apply(op, inp, weight, bias) + (_, _, inp,) = apply(op, inp, weight, bias) return inp diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index 9f2d7bd1..5f404a03 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -72,14 +72,15 @@ class _BatchNorm(Module): self.track_running_stats == False ), "track_running_stats can not be initilized to False and changed to True later" - _ndims = len(inp.shape) + inp_shape = inp.shape + _ndims = len(inp_shape) if _ndims != 4: - origin_shape = inp.shape + origin_shape = inp_shape if _ndims == 2: - n, c = inp.shape[0], inp.shape[1] + n, c = inp_shape[0], inp_shape[1] new_shape = (n, c, 1, 1) elif _ndims == 3: - n, c, h = inp.shape[0], inp.shape[1], inp.shape[2] + n, c, h = inp_shape[0], inp_shape[1], inp_shape[2] new_shape = (n, c, h, 1) inp = inp.reshape(new_shape) @@ -150,17 +151,18 @@ class SyncBatchNorm(_BatchNorm): def forward(self, inp): self._check_input_ndim(inp) - _ndims = len(inp.shape) + inp_shape = inp.shape + _ndims = len(inp_shape) if _ndims != 4: new_shape = Tensor([1, 1, 1, 1], device=inp.device) - origin_shape = inp.shape + origin_shape = inp_shape if _ndims == 2: new_shape[:2] = origin_shape[:2] elif _ndims == 3: new_shape[:3] = origin_shape[:3] else: raise ValueError( - "expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape)) + "expected 2D, 3D or 4D input (got {}D input)".format(len(inp_shape)) ) inp = inp.reshape(new_shape) diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index c6453795..e6afd85c 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -19,6 +19,8 @@ #include "megbrain/imperative/ops/io_remote.h" #include "megbrain/imperative/ops/cond_take.h" #include "megbrain/imperative/ops/nms.h" +#include "megbrain/imperative/ops/elemwise.h" +#include "megbrain/imperative/ops/batch_norm.h" namespace py = pybind11; @@ -117,4 +119,91 @@ void init_ops(py::module m) { .def(py::init()) .def_readwrite("iou_thresh", &NMSKeep::iou_thresh) .def_readwrite("max_output", &NMSKeep::max_output); + + py::class_, OpDef> elemwise(m, "Elemwise"); + elemwise.def(py::init()) + .def_readwrite("mode", &Elemwise::mode); + +#define V(m) .value(#m, Elemwise::Mode::m) + py::enum_(elemwise, "Mode") + V(RELU) + V(ABS) + V(ACOS) + V(ASIN) + V(CEIL) + V(COS) + V(EXP) + V(EXPM1) + V(FLOOR) + V(LOG) + V(LOG1P) + V(NEGATE) + V(SIGMOID) + V(SIN) + V(TANH) + V(ABS_GRAD) + V(ADD) + V(FLOOR_DIV) + V(MAX) + V(MIN) + V(MOD) + V(MUL) + V(POW) + V(SIGMOID_GRAD) + V(SUB) + V(SWITCH_GT0) + V(TANH_GRAD) + V(TRUE_DIV) + V(LOG_SUM_EXP) + V(LT) + V(LEQ) + V(EQ) + V(SHL) + V(SHR) + V(COND_LEQ_MOV) + V(FUSE_MUL_ADD3) + V(FUSE_MUL_ADD4) + V(FUSE_ADD_RELU) + V(FUSE_ADD_SIGMOID) + V(FUSE_ADD_TANH) + V(FAST_TANH) + V(FAST_TANH_GRAD) + V(ROUND) + V(RMULH) + V(ATAN2) + V(ERF) + V(ERFINV) + V(ERFC) + V(ERFCINV) + V(H_SWISH) + V(H_SWISH_GRAD) + V(FUSE_ADD_H_SWISH) + V(NOT) + V(AND) + V(OR) + V(XOR); +#undef V + + py::class_, OpDef> batchnorm(m, "BatchNorm"); + batchnorm.def(py::init()) + .def_readwrite("param_dim", &BatchNorm::param_dim) + .def_readwrite("fwd_mode", &BatchNorm::fwd_mode) + .def_readwrite("epsilon", &BatchNorm::epsilon) + .def_readwrite("avg_factor", &BatchNorm::avg_factor) + .def_readwrite("scale", &BatchNorm::scale) + .def_readwrite("bias", &BatchNorm::bias); + +#define V(m) .value(#m, BatchNorm::Param::ParamDim::m) + py::enum_(batchnorm, "ParamDim") + V(DIM_11HW) + V(DIM_1CHW) + V(DIM_1C11); +#undef V + +#define V(m) .value(#m, BatchNorm::Param::FwdMode::m) + py::enum_(batchnorm, "FwdMode") + V(TRAINING) + V(INFERENCE); +#undef V + } diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index a910da26..30a0cc23 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -27,7 +27,7 @@ from megengine.functional.distributed import remote_recv, remote_send def _elwise(mode): - op = Elemwise(mode=mode) + op = Elemwise(mode) def f(*args): (result,) = apply(op, *args) @@ -36,10 +36,10 @@ def _elwise(mode): return f -add = _elwise("add") -mul = _elwise("mul") -cos = _elwise("cos") -relu = _elwise("relu") +add = _elwise(Elemwise.Mode.ADD) +mul = _elwise(Elemwise.Mode.MUL) +cos = _elwise(Elemwise.Mode.COS) +relu = _elwise(Elemwise.Mode.RELU) def as_tensor(x): @@ -255,7 +255,7 @@ def test_elemwise_relu(): def test_elemwise_relu_backward_fn(): - op = Elemwise(mode="relu").to_c() + op = Elemwise(Elemwise.Mode.RELU) attr = TensorAttr() attr.dtype = "float32" attr.comp_node = "xpux" diff --git a/imperative/python/test/unit/core/test_imperative_rt.py b/imperative/python/test/unit/core/test_imperative_rt.py index 959a08c4..bc622faf 100644 --- a/imperative/python/test/unit/core/test_imperative_rt.py +++ b/imperative/python/test/unit/core/test_imperative_rt.py @@ -17,7 +17,7 @@ def elemwise(*args, mode): from megengine.core.ops.builtin import Elemwise from megengine.core._imperative_rt.imperative import apply_op - return apply_op(Elemwise(mode=mode).to_c(), args) + return apply_op(Elemwise(mode), args) def test_basic_interface(): @@ -37,13 +37,15 @@ def test_basic_interface(): def test_opr_attr(): from megengine.core.ops.builtin import Elemwise - assert Elemwise(mode="add") == Elemwise(mode="add") + assert Elemwise(Elemwise.Mode.ADD) == Elemwise(Elemwise.Mode.ADD) def test_simple_arith(): + from megengine.core.ops.builtin import Elemwise + x = np.random.rand(10).astype("float32") xx = megengine.core._imperative_rt.put(x) - (yy,) = elemwise(xx, xx, mode="mul") + (yy,) = elemwise(xx, xx, mode=Elemwise.Mode.MUL) np.testing.assert_allclose(x * x, megengine.core._imperative_rt.get_value(yy)) megengine.core._imperative_rt.delete(xx) megengine.core._imperative_rt.delete(yy) @@ -64,7 +66,7 @@ def test_raw_tensor(): x = np.random.rand(10).astype("float32") xx = as_raw_tensor(x) - (yy,) = apply(Elemwise(mode="mul"), xx, xx) + (yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx) np.testing.assert_allclose(x * x, yy.numpy()) - (yy,) = apply(Elemwise(mode="mul"), xx, xx) + (yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx) np.testing.assert_allclose(x * x, yy.numpy()) diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 25e5b978..4fbdd2bc 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -17,6 +17,7 @@ import megengine.functional as F from megengine import cgtools, tensor from megengine.core._trace_option import set_symbolic_shape from megengine.core.ops import builtin as ops +from megengine.core.ops.builtin import Elemwise from megengine.core.tensor.core import apply from megengine.core.tensor.raw_tensor import as_raw_tensor from megengine.functional import exp, log @@ -28,7 +29,7 @@ def test_trace(): @trace(symbolic=symbolic) def f(x): - op = ops.Elemwise(mode="negate") + op = ops.Elemwise(Elemwise.Mode.NEGATE) (y,) = apply(op, x) return y @@ -44,7 +45,7 @@ def test_exclude_from_trace(): @trace(symbolic=symbolic) def f(x): - neg = ops.Elemwise(mode="negate") + neg = ops.Elemwise(Elemwise.Mode.NEGATE) (x,) = apply(neg, x) with exclude_from_trace(): if i % 2: @@ -65,7 +66,7 @@ def test_print_in_trace(): @trace(symbolic=symbolic) def f(x): nonlocal buf - neg = ops.Elemwise(mode="negate") + neg = ops.Elemwise(Elemwise.Mode.NEGATE) (x,) = apply(neg, x) buf = x.numpy() (x,) = apply(neg, x) @@ -85,7 +86,7 @@ def test_print_in_trace(): def test_dump(): @trace(symbolic=True, capture_as_const=True) def f(a, b): - op = ops.Elemwise(mode="add") + op = ops.Elemwise(Elemwise.Mode.ADD) (y,) = apply(op, a, b) return y @@ -111,7 +112,7 @@ def test_capture_dump(): @trace(symbolic=True, capture_as_const=True) def f(x): - op = ops.Elemwise(mode="mul") + op = ops.Elemwise(Elemwise.Mode.MUL) (y,) = apply(op, x, a) return y @@ -133,7 +134,7 @@ def test_dump_volatile(): @trace(symbolic=True, capture_as_const=True) def f(x): - op = ops.Elemwise(mode="mul") + op = ops.Elemwise(Elemwise.Mode.MUL) (y,) = apply(op, x, p) return y @@ -159,7 +160,7 @@ def test_trace_profiler(): @trace(symbolic=symbolic, profiling=True) def f(x): - op = ops.Elemwise(mode="negate") + op = ops.Elemwise(Elemwise.Mode.NEGATE) (y,) = apply(op, x) return y diff --git a/imperative/src/impl/ops/batch_norm.cpp b/imperative/src/impl/ops/batch_norm.cpp new file mode 100644 index 00000000..07e41899 --- /dev/null +++ b/imperative/src/impl/ops/batch_norm.cpp @@ -0,0 +1,84 @@ +/** + * \file imperative/src/impl/ops/batch_norm.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/ops/batch_norm.h" +#include "../op_trait.h" + +namespace mgb { +namespace imperative { + +namespace { + +std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { + auto* node = &node_->cast_final_safe(); + auto&& param = node->param(); + return BatchNorm::make(param.param_dim, param.fwd_mode, param.epsilon, + param.avg_factor, param.scale, param.bias); +} + +cg::OperatorNodeBase* apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& bn_opr = def.cast_final_safe(); + size_t nr_inp = inputs.size(); + mgb_assert(nr_inp == 3 ||nr_inp == 5, + "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); + if (nr_inp == 3) { + return opr::BatchNorm::make( + inputs[0], inputs[1], inputs[2], + {bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0] + .node()->owner_opr(); + } else { + return opr::BatchNorm::make( + inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], + {bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0] + .node()->owner_opr(); + } +} + +SmallVector infer_output_attrs_fallible( + const OpDef& def, + const SmallVector& inputs) { + auto&& op_def = def.cast_final_safe(); + size_t nr_inp = inputs.size(); + mgb_assert(nr_inp == 3 ||nr_inp == 5, + "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); + // need running mean/variance + bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::Param::FwdMode::TRAINING; + size_t nr_out = need_stat? 5 : 3; + SmallVector out_shapes(nr_out); + auto&& i0 = inputs[0]; + auto&& i1 = inputs[1]; + size_t i = 0; + if (!need_stat) { + out_shapes[0] = out_shapes[1] = {TensorLayout({0}, i0.layout.dtype, i0.layout.format), i0.comp_node}; + i = 2; + } + for (; i < nr_out-1; ++ i) { + out_shapes[i] = {i1.layout, i1.comp_node}; + } + out_shapes[nr_out-1] = {i0.layout, i0.comp_node}; + return out_shapes; +} + +OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) + .make_from_op_node(make_from_op_node) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .fallback(); +} // anonymous namespace + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNorm); + +} // namespace imperative +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp new file mode 100644 index 00000000..edf55acb --- /dev/null +++ b/imperative/src/impl/ops/elemwise.cpp @@ -0,0 +1,78 @@ +/** + * \file imperative/src/impl/ops/elemwise.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/ops/elemwise.h" +#include "../op_trait.h" + +namespace mgb { +namespace imperative { + +namespace { + +std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { + auto* node = &node_->cast_final_safe(); + return Elemwise::make(node->param().mode); +} + +cg::OperatorNodeBase* apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& elemwise_opr = def.cast_final_safe(); + return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr(); +} + +SmallVector infer_output_attrs_fallible( + const OpDef& def, + const SmallVector& inputs) { + auto&& op_def = def.cast_final_safe(); + auto trait = Elemwise::ModeTrait::from_mode(op_def.mode); + mgb_assert(inputs.size() == trait.arity, + "%s expects %u inputs; got %zu actually", trait.name, + trait.arity, inputs.size()); + TensorShapeArray inp_shapes; + DType out_dt; + CompNode out_cn; + for (size_t i = 0; i < inputs.size(); ++ i) { + auto &&t = inputs[i]; + if (!i) { + out_cn = t.comp_node; + out_dt = t.layout.dtype; + } else { + mgb_assert(t.comp_node == out_cn); + mgb_assert(t.layout.dtype == out_dt); + } + if (t.layout.ndim > 0) { + inp_shapes.push_back(t.layout); + } else { + TensorLayout out_layout; + out_layout.ndim = 0; + out_layout.dtype = out_dt; + return {{out_layout, out_cn}}; + } + } + + auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); + return {{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}; +} + +OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) + .make_from_op_node(make_from_op_node) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .fallback(); +} // anonymous namespace + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(Elemwise); + +} // namespace imperative +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/include/megbrain/imperative/ops/batch_norm.h b/imperative/src/include/megbrain/imperative/ops/batch_norm.h new file mode 100644 index 00000000..0fc2fb3e --- /dev/null +++ b/imperative/src/include/megbrain/imperative/ops/batch_norm.h @@ -0,0 +1,70 @@ +/** + * \file imperative/src/include/megbrain/imperative/ops/batch_norm.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/opr/dnn/batch_norm.h" +#include "megbrain/imperative/op_def.h" +#include "megbrain/utils/hash.h" + +namespace mgb::imperative { + +class BatchNorm : public OpDefImplBase { + MGB_DYN_TYPE_OBJ_FINAL_DECL; +public: + using Param = opr::BatchNorm::Param; + + Param::ParamDim param_dim; + Param::FwdMode fwd_mode; + double epsilon; + double avg_factor; + float scale; + float bias; + + BatchNorm() = default; + + BatchNorm(const Param::ParamDim& param_dim_, const Param::FwdMode& fwd_mode_, + double epsilon_, double avg_factor_, float scale_, float bias_) + : param_dim(param_dim_), + fwd_mode(fwd_mode_), + epsilon(epsilon_), + avg_factor(avg_factor_), + scale(scale_), + bias(bias_) {} + + size_t hash() const override { + XXHash xxhash{}; + auto append = [&xxhash](auto field){ + auto hash_val = HashTrait::eval(field); + xxhash.update(reinterpret_cast(&hash_val), sizeof(hash_val)); + }; + append(param_dim); + append(fwd_mode); + append(epsilon); + append(avg_factor); + append(scale); + append(bias); + return xxhash.digest(); + } + + bool is_same_st(const Hashable& rhs_) const override { + auto&& rhs = static_cast(rhs_); + return rhs.param_dim == param_dim + && rhs.fwd_mode == fwd_mode + && rhs.epsilon == epsilon + && rhs.avg_factor == avg_factor + && rhs.scale == scale + && rhs.bias == bias; + } + +}; + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/ops/elemwise.h b/imperative/src/include/megbrain/imperative/ops/elemwise.h new file mode 100644 index 00000000..5878f08f --- /dev/null +++ b/imperative/src/include/megbrain/imperative/ops/elemwise.h @@ -0,0 +1,42 @@ +/** + * \file imperative/src/include/megbrain/imperative/ops/elemwise.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/opr/basic_arith.h" +#include "megbrain/imperative/op_def.h" + +namespace mgb::imperative { + +class Elemwise : public OpDefImplBase { + MGB_DYN_TYPE_OBJ_FINAL_DECL; +public: + using Mode = opr::Elemwise::Mode; + using ModeTrait = megdnn::Elemwise::ModeTrait; + + Mode mode; + + Elemwise() = default; + + Elemwise(const Mode& mode_): mode(mode_) {} + + size_t hash() const override { + return hash_pair_combine(mgb::hash(mode), reinterpret_cast(dyn_typeinfo())); + } + + bool is_same_st(const Hashable& rhs_) const override { + auto&& rhs = static_cast(rhs_); + return rhs.mode == mode; + } + +}; + +} // namespace mgb::imperative