|
|
@@ -12,9 +12,26 @@ import itertools |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from .._imperative_rt import TensorAttr, imperative |
|
|
|
from ..ops.builtin import Elemwise, GetVarShape, OpDef, OprAttr, Reduce, Reshape |
|
|
|
from ..ops.builtin import ( |
|
|
|
Broadcast, |
|
|
|
Elemwise, |
|
|
|
GetVarShape, |
|
|
|
IndexingMultiAxisVec, |
|
|
|
IndexingSetMultiAxisVec, |
|
|
|
OpDef, |
|
|
|
OprAttr, |
|
|
|
Reduce, |
|
|
|
Reshape, |
|
|
|
SetSubtensor, |
|
|
|
Subtensor, |
|
|
|
) |
|
|
|
from ..ops.special import Const |
|
|
|
from ..tensor.core import apply |
|
|
|
from ..tensor.function import Function |
|
|
|
from ..tensor.tensor_wrapper import TensorWrapper |
|
|
|
|
|
|
|
_elemwise_add_param = Elemwise(mode="add").to_c().param |
|
|
|
_reduce_sum_param = Reduce(mode="SUM").to_c().param[0] |
|
|
|
|
|
|
|
|
|
|
|
@functools.singledispatch |
|
|
@@ -22,19 +39,17 @@ def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): |
|
|
|
assert 0 |
|
|
|
|
|
|
|
|
|
|
|
_elemwise_add_param = Elemwise(mode="add").to_c().param |
|
|
|
|
|
|
|
|
|
|
|
@builtin_op_get_backward_fn.register(OpDef) |
|
|
|
def _(op: OpDef, inputs, outputs, input_requires_grad): |
|
|
|
if ( |
|
|
|
isinstance(op, OprAttr) |
|
|
|
and op.type == "Elemwise" |
|
|
|
and op.param == _elemwise_add_param |
|
|
|
): |
|
|
|
grad_fn = elemwise_grad_fn |
|
|
|
elif isinstance(op, OprAttr) and op.type == Reshape.name: |
|
|
|
grad_fn = reshape_grad_fn |
|
|
|
if isinstance(op, OprAttr): |
|
|
|
grad_fn = _oprAttr_grad_fn.get(op.type, None) |
|
|
|
if grad_fn is None: |
|
|
|
if op.type == Elemwise.name and op.param == _elemwise_add_param: |
|
|
|
grad_fn = elemwise_add_grad_fn |
|
|
|
elif op.type == Reduce.name and op.param[0] == _reduce_sum_param: |
|
|
|
grad_fn = reduce_sum_grad_fn |
|
|
|
else: |
|
|
|
grad_fn = default_grad_fn |
|
|
|
else: |
|
|
|
grad_fn = default_grad_fn |
|
|
|
return grad_fn(op, inputs, outputs, input_requires_grad) |
|
|
@@ -73,6 +88,7 @@ def default_grad_fn(op, inputs, outputs, input_requires_grad): |
|
|
|
save_for_backward = tuple( |
|
|
|
val for val, mask in zip(inputs + outputs, intput_output_mask) if mask |
|
|
|
) |
|
|
|
|
|
|
|
del inputs |
|
|
|
del outputs |
|
|
|
|
|
|
@@ -85,13 +101,14 @@ def default_grad_fn(op, inputs, outputs, input_requires_grad): |
|
|
|
return backward, output_grad_mask |
|
|
|
|
|
|
|
|
|
|
|
# override for elemwise |
|
|
|
def elemwise_grad_fn(op, inputs, outputs, input_requires_grad): |
|
|
|
assert len(inputs) == len(input_requires_grad) == 2 |
|
|
|
def get_shape(x): |
|
|
|
(s,) = apply(GetVarShape(), x) |
|
|
|
return s |
|
|
|
|
|
|
|
|
|
|
|
def get_shape(x): |
|
|
|
(s,) = apply(GetVarShape(), x) |
|
|
|
return s |
|
|
|
# override for Elemwise.add |
|
|
|
def elemwise_add_grad_fn(op, inputs, outputs, input_requires_grad): |
|
|
|
assert len(inputs) == len(input_requires_grad) == 2 |
|
|
|
|
|
|
|
input_shapes = [ |
|
|
|
get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs) |
|
|
@@ -110,13 +127,10 @@ def elemwise_grad_fn(op, inputs, outputs, input_requires_grad): |
|
|
|
return backward, [True] |
|
|
|
|
|
|
|
|
|
|
|
# override for Reshape |
|
|
|
def reshape_grad_fn(op, inputs, outputs, input_requires_grad): |
|
|
|
assert len(inputs) == len(input_requires_grad) == 2 |
|
|
|
|
|
|
|
def get_shape(x): |
|
|
|
(s,) = apply(GetVarShape(), x) |
|
|
|
return s |
|
|
|
|
|
|
|
input_shapes = [ |
|
|
|
get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs) |
|
|
|
] |
|
|
@@ -132,3 +146,78 @@ def reshape_grad_fn(op, inputs, outputs, input_requires_grad): |
|
|
|
) |
|
|
|
|
|
|
|
return backward, [True] |
|
|
|
|
|
|
|
|
|
|
|
# override for Subtensor |
|
|
|
def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): |
|
|
|
grad_op = OprAttr() |
|
|
|
grad_op.type = SetSubtensor.name |
|
|
|
grad_op.param = op.param |
|
|
|
|
|
|
|
input_shape = get_shape(inputs[0]) |
|
|
|
params = inputs[1:] |
|
|
|
|
|
|
|
def make_grad(grad_op, dy): |
|
|
|
grad = ( |
|
|
|
TensorWrapper(0, dtype=dy.dtype, device=dy.device) |
|
|
|
.broadcast(TensorWrapper(input_shape)) |
|
|
|
.__wrapped__ |
|
|
|
) |
|
|
|
(dx,) = apply(grad_op, grad, dy, *params) |
|
|
|
return dx |
|
|
|
|
|
|
|
def backward(dy): |
|
|
|
return tuple( |
|
|
|
make_grad(grad_op, dy) if mask else None for mask in input_requires_grad |
|
|
|
) |
|
|
|
|
|
|
|
return backward, [True] |
|
|
|
|
|
|
|
|
|
|
|
# override for IndexingMultiAxisVec |
|
|
|
def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad): |
|
|
|
grad_op = OprAttr() |
|
|
|
grad_op.type = IndexingSetMultiAxisVec.name |
|
|
|
grad_op.param = op.param |
|
|
|
|
|
|
|
input_shape = get_shape(inputs[0]) |
|
|
|
params = inputs[1:] |
|
|
|
|
|
|
|
def make_grad(grad_op, dy): |
|
|
|
grad = ( |
|
|
|
TensorWrapper(0, dtype=dy.dtype, device=dy.device) |
|
|
|
.broadcast(TensorWrapper(input_shape)) |
|
|
|
.__wrapped__ |
|
|
|
) |
|
|
|
(dx,) = apply(grad_op, grad, dy, *params) |
|
|
|
return dx |
|
|
|
|
|
|
|
def backward(dy): |
|
|
|
return tuple( |
|
|
|
make_grad(grad_op, dy) if mask else None for mask in input_requires_grad |
|
|
|
) |
|
|
|
|
|
|
|
return backward, [True] |
|
|
|
|
|
|
|
|
|
|
|
# override for Reduce.sum |
|
|
|
def reduce_sum_grad_fn(op, inputs, outputs, input_requires_grad): |
|
|
|
assert len(inputs) == len(input_requires_grad) == 1 |
|
|
|
input_shape = get_shape(inputs[0]) |
|
|
|
|
|
|
|
def broadcast_to(dy, s): |
|
|
|
(dx,) = apply(Broadcast(), dy, s) |
|
|
|
return dx |
|
|
|
|
|
|
|
def backward(dy): |
|
|
|
return (broadcast_to(dy, input_shape) if input_requires_grad[0] else None,) |
|
|
|
|
|
|
|
return backward, [True] |
|
|
|
|
|
|
|
|
|
|
|
_oprAttr_grad_fn = { |
|
|
|
Reshape.name: reshape_grad_fn, |
|
|
|
Subtensor.name: subtensor_grad_fn, |
|
|
|
IndexingMultiAxisVec.name: indexingMultiAxisVec_grad_fn, |
|
|
|
Broadcast.name: elemwise_add_grad_fn, |
|
|
|
} |