|
|
@@ -178,8 +178,10 @@ class Function: |
|
|
|
return self._default_rule(*args), self.backward |
|
|
|
|
|
|
|
def __call__(self, *args): |
|
|
|
from ...tensor import Tensor |
|
|
|
|
|
|
|
for arg in args: |
|
|
|
if not isinstance(arg, core2.Tensor): |
|
|
|
if not isinstance(arg, Tensor): |
|
|
|
raise TypeError( |
|
|
|
"op Function expect type Tensor as inputs, got {}".format(type(arg)) |
|
|
|
) |
|
|
@@ -191,6 +193,8 @@ class Function: |
|
|
|
grad = Grad.key2grad[grad_key] |
|
|
|
group = [ref() for ref in grad._group] |
|
|
|
|
|
|
|
origin_args = [Tensor(arg) for arg in args] |
|
|
|
|
|
|
|
for grad in group: |
|
|
|
grad.suppress() |
|
|
|
outputs, backward = self._grad_rule(*args) |
|
|
@@ -199,13 +203,13 @@ class Function: |
|
|
|
|
|
|
|
def normalized_backward(*output_grads): |
|
|
|
input_grads = backward(*output_grads) |
|
|
|
if isinstance(input_grads, core2.Tensor) or input_grads is None: |
|
|
|
if isinstance(input_grads, Tensor) or input_grads is None: |
|
|
|
input_grads = (input_grads,) |
|
|
|
return input_grads |
|
|
|
|
|
|
|
if self.__single_output: |
|
|
|
outputs = (outputs,) |
|
|
|
outputs = core2.set_grad(normalized_backward, args, outputs) |
|
|
|
outputs = core2.set_grad(normalized_backward, origin_args, outputs) |
|
|
|
if self.__single_output: |
|
|
|
(outputs,) = outputs |
|
|
|
return outputs |
|
|
|