Browse Source

fix(subgraph): fix hand-written backward for serval jit-elemwise ops

GitOrigin-RevId: ea3a40d96e
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
7bd848ce04
2 changed files with 37 additions and 13 deletions
  1. +15
    -13
      imperative/python/megengine/functional/nn.py
  2. +22
    -0
      imperative/src/impl/ops/utility.cpp

+ 15
- 13
imperative/python/megengine/functional/nn.py View File

@@ -909,9 +909,9 @@ def _get_prelu_op(dtype=None, device=None):
min_0 = f("min", inp, c(0))
oup = f("fma3", min_0, weight, max_0)
(oup_grad,) = yield (oup,)
inp_grad_0 = f("cond_leq_mov", inp, c(0), oup_grad)
inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad)
inp_grad_1 = f("*", oup_grad, weight)
inp_grad_1 = f("cond_leq_mov", c(0), inp, inp_grad_1)
inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
weight_grad = f("*", oup_grad, min_0)
yield (inp_grad, weight_grad)
@@ -925,7 +925,7 @@ def prelu(inp: Tensor, weight: Tensor) -> Tensor:
Refer to :class:`~.PReLU` for more information.
"""
prelu = _get_prelu_op(dtype=inp.dtype, device=inp.device)
(oup,) = prelu(inp, weight)
(oup,) = prelu(inp, broadcast_to(weight, inp.shape))
return oup


@@ -947,7 +947,7 @@ def _get_leagk_relu_op(negative_slope, *, dtype=None, device=None):
(oup_grad,) = yield (oup,)
inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad)
inp_grad_1 = f("*", oup_grad, c(negative_slope))
inp_grad_1 = f("cond_leq_mov", inp, c(negative_slope), inp_grad_1)
inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)

@@ -994,13 +994,14 @@ def _get_softplus_op(dtype=None, device=None):
(inp,) = inputs[0:1]
neg_abs = f("-", f("abs", inp))
exp = f("exp", neg_abs)
oup = f("log1p", exp)
oup = f("+", oup, f("relu", inp))
oup0 = f("log1p", exp)
oup1 = f("relu", inp)
oup = f("+", oup0, oup1)
(oup_grad,) = yield (oup,)
inp_grad_0 = f("switch_gt0", inp, oup_grad)
inp_grad_0 = f("switch_gt0", oup1, oup_grad)
inp_grad_1 = oup_grad
inp_grad_1 = f("/", oup_grad, f("+", exp, c(1)))
inp_grad_1 = f("*", oup_grad, exp)
inp_grad_1 = f("*", inp_grad_1, exp)
inp_grad_1 = f("-", inp_grad_1)
inp_grad_1 = f("abs_grad", inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
@@ -1098,16 +1099,17 @@ def _get_logsigmoid_op(dtype=None, device=None):
(inp,) = inputs[0:1]
neg_abs = f("-", f("abs", inp))
exp = f("exp", neg_abs)
oup = f("log1p", exp)
oup = f("+", oup, f("relu", f("-", inp)))
oup0 = f("log1p", exp)
oup1 = f("relu", f("-", inp))
oup = f("+", oup0, oup1)
oup = f("-", oup)
(oup_grad,) = yield (oup,)
oup_grad = f("-", oup_grad)
inp_grad_0 = f("switch_gt0", inp, oup_grad)
inp_grad_0 = f("switch_gt0", oup1, oup_grad)
inp_grad_0 = f("-", inp_grad_0)
inp_grad_1 = oup_grad
inp_grad_1 = f("/", oup_grad, f("+", exp, c(1)))
inp_grad_1 = f("*", oup_grad, exp)
inp_grad_1 = f("/", inp_grad_1, f("+", exp, c(1)))
inp_grad_1 = f("*", inp_grad_1, exp)
inp_grad_1 = f("-", inp_grad_1)
inp_grad_1 = f("abs_grad", inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)


+ 22
- 0
imperative/src/impl/ops/utility.cpp View File

@@ -726,6 +726,28 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {

auto infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
TensorShape shape;
DType dtype = input_descs[0].layout.dtype;
CompNode comp_node = input_descs[0].comp_node;
for (auto&& desc : input_descs) {
if (desc.layout.ndim) {
shape = desc.layout;
break;
}
}
for (size_t i = 0; i < input_descs.size(); ++i) {
if (input_descs[i].layout.ndim) {
mgb_assert(
input_descs[i].layout.eq_shape(shape),
"inputs of JITFusionOp should have same shapes");
}
mgb_assert(
input_descs[i].layout.dtype == dtype,
"inputs of JITFusionOp should have same dtypes");
mgb_assert(
input_descs[i].comp_node == comp_node,
"inputs of JITFusionOp should have same devices");
}
return OpDef::infer_output_attrs_fallible(
*def.cast_final_safe<JITFusionOp>().op, input_descs);
}


Loading…
Cancel
Save