# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import functools import heapq import itertools import typing import weakref import numpy as np from .._imperative_rt import core2, ops from ..ops.builtin import Elemwise, OpDef, RemoteSend from ..ops.special import Const _grad_count = 0 _grad_manager_dict = weakref.WeakValueDictionary() def get_grad_managers(): return [_grad_manager_dict[key] for key in _grad_manager_dict] class GradKey(core2.GradKey): def __init__(self, name=None): if name: self.name = name def backward(self, ys, dys): return core2.backward(self, ys, dys) class Grad: def __init__(self, name=None): global _grad_count if name is None: name = "grad_%d" % _grad_count _grad_count += 1 self._refkeeper = [] self._impl = GradKey(name) _grad_manager_dict[self._name] = self @property def _priority(self): return self._impl.priority @_priority.setter def _priority(self, priority): self._impl.priority = priority @property def _name(self): return self._impl.name def _is_attached_to(self, tensor): return self._impl.is_attached_to(tensor) def wrt(self, *tensors, callback=None): for x in tensors: self._impl.attach(x, callback) return self def __call__(self, ys, dys): from collections.abc import Sequence if not isinstance(ys, Sequence): ys = [ys] if not isinstance(dys, Sequence): dys = [dys] self._impl.backward(ys, dys) self._refkeeper = None def __enter__(self): return self def __exit__(self, _1, _2, _3): self._refkeeper = None del self._impl class Function(ops.PyOpBase): r"""Defines a block of operations with customizable differentiation. The computation should be defined in ``forward`` method, with gradient computation defined in ``backward`` method. Each instance of ``Function`` should be used only once during forwardding. Examples: .. code-block:: class Sigmoid(Function): def forward(self, x): y = 1 / (1 + F.exp(-x)) self.y = y return y def backward(self, dy): y = self.y """ def forward(self, *args, **kwargs): r"""Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses. Args: input: input tensors. Returns: a tuple of Tensor or a single Tensor. Note: * This method should return a tuple of Tensor or a single Tensor representing the output of the function. * positional arguments should all be Tensor """ raise NotImplementedError def backward(self, *output_grads): r"""Compute the gradient of the forward function. It must be overriden by all subclasses. Args: output_grads: gradients of outputs that are returned by :meth:`forward`. Note: * In case when some tensors of outputs are not related to loss function, the corresponding values in ``output_grads`` would be ``None``. * This method should return a tuple which containing the gradients of all inputs, in the same order as the ``inputs`` argument of :meth:`forward` . A ``Tensor`` could be returned instead if there is only one input. If users want to stop the propagation of some gradients, the corresponding returned values should be set ``None`` . """ raise NotImplementedError def _default_rule(self, *args): ret = self.forward(*args) self.__single_output = isinstance(ret, core2.Tensor) return ret def _grad_rule(self, *args): return self._default_rule(*args), self.backward def __call__(self, *args): ret = core2.apply(self, *args) if self.__single_output: (ret,) = ret return ret def __getstate__(self): return self.__dict__ def __setstate__(self, state): self.__dict__.update(state)