|
|
@@ -909,9 +909,9 @@ def _get_prelu_op(dtype=None, device=None): |
|
|
|
min_0 = f("min", inp, c(0)) |
|
|
|
oup = f("fma3", min_0, weight, max_0) |
|
|
|
(oup_grad,) = yield (oup,) |
|
|
|
inp_grad_0 = f("cond_leq_mov", inp, c(0), oup_grad) |
|
|
|
inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad) |
|
|
|
inp_grad_1 = f("*", oup_grad, weight) |
|
|
|
inp_grad_1 = f("cond_leq_mov", c(0), inp, inp_grad_1) |
|
|
|
inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1) |
|
|
|
inp_grad = f("+", inp_grad_0, inp_grad_1) |
|
|
|
weight_grad = f("*", oup_grad, min_0) |
|
|
|
yield (inp_grad, weight_grad) |
|
|
@@ -925,7 +925,7 @@ def prelu(inp: Tensor, weight: Tensor) -> Tensor: |
|
|
|
Refer to :class:`~.PReLU` for more information. |
|
|
|
""" |
|
|
|
prelu = _get_prelu_op(dtype=inp.dtype, device=inp.device) |
|
|
|
(oup,) = prelu(inp, weight) |
|
|
|
(oup,) = prelu(inp, broadcast_to(weight, inp.shape)) |
|
|
|
return oup |
|
|
|
|
|
|
|
|
|
|
@@ -947,7 +947,7 @@ def _get_leagk_relu_op(negative_slope, *, dtype=None, device=None): |
|
|
|
(oup_grad,) = yield (oup,) |
|
|
|
inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad) |
|
|
|
inp_grad_1 = f("*", oup_grad, c(negative_slope)) |
|
|
|
inp_grad_1 = f("cond_leq_mov", inp, c(negative_slope), inp_grad_1) |
|
|
|
inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1) |
|
|
|
inp_grad = f("+", inp_grad_0, inp_grad_1) |
|
|
|
yield (inp_grad,) |
|
|
|
|
|
|
@@ -994,13 +994,14 @@ def _get_softplus_op(dtype=None, device=None): |
|
|
|
(inp,) = inputs[0:1] |
|
|
|
neg_abs = f("-", f("abs", inp)) |
|
|
|
exp = f("exp", neg_abs) |
|
|
|
oup = f("log1p", exp) |
|
|
|
oup = f("+", oup, f("relu", inp)) |
|
|
|
oup0 = f("log1p", exp) |
|
|
|
oup1 = f("relu", inp) |
|
|
|
oup = f("+", oup0, oup1) |
|
|
|
(oup_grad,) = yield (oup,) |
|
|
|
inp_grad_0 = f("switch_gt0", inp, oup_grad) |
|
|
|
inp_grad_0 = f("switch_gt0", oup1, oup_grad) |
|
|
|
inp_grad_1 = oup_grad |
|
|
|
inp_grad_1 = f("/", oup_grad, f("+", exp, c(1))) |
|
|
|
inp_grad_1 = f("*", oup_grad, exp) |
|
|
|
inp_grad_1 = f("*", inp_grad_1, exp) |
|
|
|
inp_grad_1 = f("-", inp_grad_1) |
|
|
|
inp_grad_1 = f("abs_grad", inp, inp_grad_1) |
|
|
|
inp_grad = f("+", inp_grad_0, inp_grad_1) |
|
|
@@ -1098,16 +1099,17 @@ def _get_logsigmoid_op(dtype=None, device=None): |
|
|
|
(inp,) = inputs[0:1] |
|
|
|
neg_abs = f("-", f("abs", inp)) |
|
|
|
exp = f("exp", neg_abs) |
|
|
|
oup = f("log1p", exp) |
|
|
|
oup = f("+", oup, f("relu", f("-", inp))) |
|
|
|
oup0 = f("log1p", exp) |
|
|
|
oup1 = f("relu", f("-", inp)) |
|
|
|
oup = f("+", oup0, oup1) |
|
|
|
oup = f("-", oup) |
|
|
|
(oup_grad,) = yield (oup,) |
|
|
|
oup_grad = f("-", oup_grad) |
|
|
|
inp_grad_0 = f("switch_gt0", inp, oup_grad) |
|
|
|
inp_grad_0 = f("switch_gt0", oup1, oup_grad) |
|
|
|
inp_grad_0 = f("-", inp_grad_0) |
|
|
|
inp_grad_1 = oup_grad |
|
|
|
inp_grad_1 = f("/", oup_grad, f("+", exp, c(1))) |
|
|
|
inp_grad_1 = f("*", oup_grad, exp) |
|
|
|
inp_grad_1 = f("/", inp_grad_1, f("+", exp, c(1))) |
|
|
|
inp_grad_1 = f("*", inp_grad_1, exp) |
|
|
|
inp_grad_1 = f("-", inp_grad_1) |
|
|
|
inp_grad_1 = f("abs_grad", inp, inp_grad_1) |
|
|
|
inp_grad = f("+", inp_grad_0, inp_grad_1) |
|
|
|