|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import copy
- from abc import ABCMeta, abstractmethod
- from typing import Iterable, Tuple, Union
-
- import megengine._internal as mgb
-
- from .tensor import Tensor
-
-
- class _OverrideGradientCraniotome(mgb.craniotome.CraniotomeBase):
- __nr_inputs__ = None
- __nr_outputs__ = None
- __expand_single_outputs__ = False
- __allow_duplicate__ = False
-
- grad_func = None
-
- def setup(self, nr_inputs, nr_outputs, grad_func):
- self.__nr_inputs__ = nr_inputs + nr_outputs
- self.__nr_outputs__ = nr_outputs
- self.grad_func = grad_func
-
- def infer_shape(self, inp_shapes):
- return inp_shapes[-self.__nr_outputs__ :]
-
- def init_output_dtype(self, input_dtypes):
- return input_dtypes[-self.__nr_outputs__ :]
-
- def execute(self, inputs, outputs):
- for ivar, ovar in zip(inputs[-self.__nr_outputs__ :], outputs):
- ovar.set_value(ivar)
-
- def grad(self, wrt_idx, inputs, outputs, out_grad):
- # TODO: Make sure grad_values really have values in eager mode.
- # Porting to the new imperative engine would solve this, but if it
- # don't happen, EagerEvalManager should be changed.
- grads = self.grad_func(
- *(Tensor(x) if x is not None else None for x in out_grad)
- )
- # pylint: disable=literal-comparison
- if isinstance(grads, Tensor) or grads is None or grads is 0:
- grads = (grads,)
- assert (
- len(grads) == self.__nr_inputs__ - self.__nr_outputs__
- ), "Function.backward should return a tuple with len = {}, got {}".format(
- self.__nr_inputs__ - self.__nr_outputs__, len(grads)
- )
- # pylint: disable=literal-comparison
- return (
- list(x._symvar if x is not None and x is not 0 else 0 for x in grads)
- + [0] * self.__nr_outputs__
- )
-
- def get_serialize_params(self):
- raise NotImplementedError("Serialization of Function is not implemented")
-
-
- class Function(metaclass=ABCMeta):
- """
- Defines a block of operations with customizable differentiation.
-
- The computation should be defined in ``forward`` method, with gradient
- computation defined in ``backward`` method.
-
- Each instance of ``Function`` should be used only once during forwardding.
-
- Examples:
-
- .. testcode::
-
- class Sigmoid(Function):
- def forward(self, x):
- y = 1 / (1 + F.exp(-x))
- self.save_for_backward(y)
- return y
-
- def backward(self. output_grads):
- (y, ) = self.saved_tensors
- return output_grads * y * (1-y)
-
- """
-
- _has_saved_state = False
- saved_tensors = None
-
- def __init__(self):
- self.saved_tensors = ()
-
- @abstractmethod
- def forward(self, *inputs: Iterable[Tensor]) -> Union[Tuple[Tensor], Tensor]:
- """
- Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses.
- Users can call :meth:`~.function.Function.save_for_backward` in this method to save tensors.
-
- :param input: Input tensors.
- :return: A tuple of Tensor or a single Tensor.
-
- .. note::
-
- This method should return a tuple of Tensor or a single Tensor representing the output
- of the function.
- """
- raise NotImplementedError
-
- @abstractmethod
- def backward(
- self, *output_grads: Iterable[Union[Tensor, None]]
- ) -> Union[Tuple[Tensor], Tensor]:
- """
- Compute the gradient of the forward function. It must be overriden by all subclasses.
-
- :param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward`
-
- .. note::
-
- In case when some tensors of outputs are not related to loss function, the corresponding
- values in ``output_grads`` would be ``None``.
-
- .. note::
-
- This method should return a tuple which containing the gradients of all inputs, in the same order
- as the ``inputs`` argument of :meth:`~.function.Function.forward` . A ``Tensor`` could be returned
- instead if there is only one input. If users want to stop the propagation of some gradients,
- the corresponding returned values should be set ``None`` .
-
- """
- raise NotImplementedError
-
- def save_for_backward(self, *tensors: Iterable[Tensor]):
- """
- Saves tensors needed for gradient computation. This method should be called only
- once in :meth:`~.function.Function.forward`, additional calls will replace values saved previously.
-
- The saved tensors can be accessed through the ``saved_tensors`` attribute.
- """
- self.saved_tensors = tensors
-
- def __deepcopy__(self, memo):
- """
- Defines how the operator is deeply copied
- """
- cls = self.__class__
- result = cls.__new__(cls)
- tmp = self.saved_tensors
- self.saved_tensors = None
- memo[id(self)] = result
- for k, v in self.__dict__.items():
- setattr(result, k, copy.deepcopy(v, memo))
- setattr(result, "saved_tensors", tmp)
- self.saved_tensors = tmp
- return result
-
- def __call__(self, *inputs):
- assert (
- not self._has_saved_state
- ), "A Function instance should not be called multiple times"
- outputs = self.forward(*inputs)
- if isinstance(outputs, Tensor):
- outputs = (outputs,)
- self._has_saved_state = True
- sv = (x._symvar for x in inputs + outputs)
- outputs = _OverrideGradientCraniotome.make(
- *sv, nr_inputs=len(inputs), nr_outputs=len(outputs), grad_func=self.backward
- )
- outputs = tuple(map(Tensor, outputs))
- if len(outputs) == 1:
- outputs = outputs[0]
- return outputs
|