From 7bd848ce049c2f917ccb0a84e0f71e1b1965576e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 29 Dec 2021 15:00:29 +0800 Subject: [PATCH] fix(subgraph): fix hand-written backward for serval jit-elemwise ops GitOrigin-RevId: ea3a40d96efb6dd083fa4278c041837aff3833d0 --- imperative/python/megengine/functional/nn.py | 28 +++++++++++++++------------- imperative/src/impl/ops/utility.cpp | 22 ++++++++++++++++++++++ 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 6729bcb8..2c6eb32d 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -909,9 +909,9 @@ def _get_prelu_op(dtype=None, device=None): min_0 = f("min", inp, c(0)) oup = f("fma3", min_0, weight, max_0) (oup_grad,) = yield (oup,) - inp_grad_0 = f("cond_leq_mov", inp, c(0), oup_grad) + inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad) inp_grad_1 = f("*", oup_grad, weight) - inp_grad_1 = f("cond_leq_mov", c(0), inp, inp_grad_1) + inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1) inp_grad = f("+", inp_grad_0, inp_grad_1) weight_grad = f("*", oup_grad, min_0) yield (inp_grad, weight_grad) @@ -925,7 +925,7 @@ def prelu(inp: Tensor, weight: Tensor) -> Tensor: Refer to :class:`~.PReLU` for more information. """ prelu = _get_prelu_op(dtype=inp.dtype, device=inp.device) - (oup,) = prelu(inp, weight) + (oup,) = prelu(inp, broadcast_to(weight, inp.shape)) return oup @@ -947,7 +947,7 @@ def _get_leagk_relu_op(negative_slope, *, dtype=None, device=None): (oup_grad,) = yield (oup,) inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad) inp_grad_1 = f("*", oup_grad, c(negative_slope)) - inp_grad_1 = f("cond_leq_mov", inp, c(negative_slope), inp_grad_1) + inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1) inp_grad = f("+", inp_grad_0, inp_grad_1) yield (inp_grad,) @@ -994,13 +994,14 @@ def _get_softplus_op(dtype=None, device=None): (inp,) = inputs[0:1] neg_abs = f("-", f("abs", inp)) exp = f("exp", neg_abs) - oup = f("log1p", exp) - oup = f("+", oup, f("relu", inp)) + oup0 = f("log1p", exp) + oup1 = f("relu", inp) + oup = f("+", oup0, oup1) (oup_grad,) = yield (oup,) - inp_grad_0 = f("switch_gt0", inp, oup_grad) + inp_grad_0 = f("switch_gt0", oup1, oup_grad) inp_grad_1 = oup_grad inp_grad_1 = f("/", oup_grad, f("+", exp, c(1))) - inp_grad_1 = f("*", oup_grad, exp) + inp_grad_1 = f("*", inp_grad_1, exp) inp_grad_1 = f("-", inp_grad_1) inp_grad_1 = f("abs_grad", inp, inp_grad_1) inp_grad = f("+", inp_grad_0, inp_grad_1) @@ -1098,16 +1099,17 @@ def _get_logsigmoid_op(dtype=None, device=None): (inp,) = inputs[0:1] neg_abs = f("-", f("abs", inp)) exp = f("exp", neg_abs) - oup = f("log1p", exp) - oup = f("+", oup, f("relu", f("-", inp))) + oup0 = f("log1p", exp) + oup1 = f("relu", f("-", inp)) + oup = f("+", oup0, oup1) oup = f("-", oup) (oup_grad,) = yield (oup,) oup_grad = f("-", oup_grad) - inp_grad_0 = f("switch_gt0", inp, oup_grad) + inp_grad_0 = f("switch_gt0", oup1, oup_grad) inp_grad_0 = f("-", inp_grad_0) inp_grad_1 = oup_grad - inp_grad_1 = f("/", oup_grad, f("+", exp, c(1))) - inp_grad_1 = f("*", oup_grad, exp) + inp_grad_1 = f("/", inp_grad_1, f("+", exp, c(1))) + inp_grad_1 = f("*", inp_grad_1, exp) inp_grad_1 = f("-", inp_grad_1) inp_grad_1 = f("abs_grad", inp, inp_grad_1) inp_grad = f("+", inp_grad_0, inp_grad_1) diff --git a/imperative/src/impl/ops/utility.cpp b/imperative/src/impl/ops/utility.cpp index 39940ac7..b3c42438 100644 --- a/imperative/src/impl/ops/utility.cpp +++ b/imperative/src/impl/ops/utility.cpp @@ -726,6 +726,28 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto infer_output_attrs_fallible( const OpDef& def, const SmallVector& input_descs) { + TensorShape shape; + DType dtype = input_descs[0].layout.dtype; + CompNode comp_node = input_descs[0].comp_node; + for (auto&& desc : input_descs) { + if (desc.layout.ndim) { + shape = desc.layout; + break; + } + } + for (size_t i = 0; i < input_descs.size(); ++i) { + if (input_descs[i].layout.ndim) { + mgb_assert( + input_descs[i].layout.eq_shape(shape), + "inputs of JITFusionOp should have same shapes"); + } + mgb_assert( + input_descs[i].layout.dtype == dtype, + "inputs of JITFusionOp should have same dtypes"); + mgb_assert( + input_descs[i].comp_node == comp_node, + "inputs of JITFusionOp should have same devices"); + } return OpDef::infer_output_attrs_fallible( *def.cast_final_safe().op, input_descs); }