|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import weakref
-
- from .._imperative_rt import core2
-
- _grad_count = 0
- _grad_manager_dict = weakref.WeakValueDictionary()
-
-
- def get_grad_managers():
- return [_grad_manager_dict[key] for key in _grad_manager_dict]
-
-
- class GradKey(core2.GradKey):
- def __init__(self, name=None):
- if name:
- self.name = name
-
- def backward(self, ys, dys):
- return core2.backward(self, ys, dys)
-
-
- class Grad:
- stack = []
- grouping = False
- key2grad = weakref.WeakValueDictionary()
-
- def __init__(self, name=None):
- global _grad_count
- if name is None:
- name = "grad_%d" % _grad_count
- _grad_count += 1
- self._refkeeper = []
- self._impl = GradKey(name)
- Grad.key2grad[self._impl] = self
- _grad_manager_dict[self._name] = self
- self._group = [weakref.ref(self)]
-
- @property
- def _name(self):
- return self._impl.name
-
- def _is_attached_to(self, tensor):
- return self._impl.is_attached_to(tensor)
-
- def wrt(self, *tensors, callback=None):
- for x in tensors:
- self._impl.attach(x, callback)
- return self
-
- def __call__(self, ys, dys):
- from collections.abc import Sequence
-
- if not isinstance(ys, Sequence):
- ys = [ys]
-
- if not isinstance(dys, Sequence):
- dys = [dys]
-
- group = [ref() for ref in self._group]
-
- for grad in group:
- if grad is self:
- continue
- grad.suppress()
-
- self._impl.backward(ys, dys)
-
- for grad in group:
- if grad is self:
- continue
- grad.resume()
-
- self._refkeeper = None
- return None
-
- def __enter__(self):
- ref = weakref.ref(self)
- self._impl.enter()
- if Grad.grouping:
- group = Grad.stack[-1]
- self._group = group
- group.append(ref)
- else:
- Grad.stack.append(self._group)
- return self
-
- def __exit__(self, _1, _2, _3):
- self._impl.exit()
- self._refkeeper = None
- del Grad.key2grad[self._impl]
- self._impl = None
- self._group.remove(weakref.ref(self))
- if len(self._group) == 0:
- Grad.stack.remove(self._group)
-
- @staticmethod
- def begin_group():
- assert not Grad.grouping
- Grad.grouping = True
-
- @staticmethod
- def end_group():
- group = Grad.stack[-1]
- assert len(group) > 0
- assert Grad.grouping
- Grad.grouping = False
-
- def suppress(self):
- if self._impl is not None:
- self._impl.suppress()
-
- def resume(self):
- if self._impl is not None:
- self._impl.resume()
-
-
- class Function:
- r"""Defines a block of operations with customizable differentiation.
-
- The computation should be defined in ``forward`` method, with gradient
- computation defined in ``backward`` method.
-
- Each instance of ``Function`` should be used only once during forwardding.
-
- Examples:
-
- .. code-block::
-
- class Sigmoid(Function):
- def forward(self, x):
- y = 1 / (1 + F.exp(-x))
- self.y = y
- return y
-
- def backward(self, dy):
- y = self.y
- """
-
- def forward(self, *args, **kwargs):
- r"""Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses.
-
- Args:
- input: input tensors.
-
- Returns:
- a tuple of Tensor or a single Tensor.
-
- Note:
- * This method should return a tuple of Tensor or a single Tensor representing the output
- of the function.
- * positional arguments should all be Tensor
- """
- raise NotImplementedError
-
- def backward(self, *output_grads):
- r"""Compute the gradient of the forward function. It must be overriden by all subclasses.
-
- Args:
- output_grads: gradients of outputs that are returned by :meth:`forward`.
-
- Note:
- * In case when some tensors of outputs are not related to loss function, the corresponding
- values in ``output_grads`` would be ``None``.
- * This method should return a tuple which containing the gradients of all inputs, in the same order
- as the ``inputs`` argument of :meth:`forward` . A ``Tensor`` could be returned
- instead if there is only one input. If users want to stop the propagation of some gradients,
- the corresponding returned values should be set ``None`` .
- """
- raise NotImplementedError
-
- def _default_rule(self, *args):
- ret = self.forward(*args)
- self.__single_output = isinstance(ret, core2.Tensor)
- return ret
-
- def _grad_rule(self, *args):
- return self._default_rule(*args), self.backward
-
- def __call__(self, *args):
- for arg in args:
- if not isinstance(arg, core2.Tensor):
- raise TypeError(
- "op Function expect type Tensor as inputs, got {}".format(type(arg))
- )
-
- grad_key = core2.get_grad_key(args)
- if grad_key is None:
- return self._default_rule(*args)
-
- grad = Grad.key2grad[grad_key]
- group = [ref() for ref in grad._group]
-
- for grad in group:
- grad.suppress()
- outputs, backward = self._grad_rule(*args)
- for grad in reversed(group):
- grad.resume()
-
- def normalized_backward(*output_grads):
- input_grads = backward(*output_grads)
- if isinstance(input_grads, core2.Tensor) or input_grads is None:
- input_grads = (input_grads,)
- return input_grads
-
- if self.__single_output:
- outputs = (outputs,)
- for grad in reversed(group):
- if grad._impl is None:
- continue
- outputs = core2.set_grad(grad._impl, normalized_backward, args, outputs)
- if self.__single_output:
- (outputs,) = outputs
- return outputs
-
- def __getstate__(self):
- return self.__dict__
-
- def __setstate__(self, state):
- self.__dict__.update(state)
|