|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import collections
- from typing import Iterable, Optional, Union
-
- import megengine._internal as mgb
-
- from ..core.graph import get_default_graph
- from ..core.tensor import Tensor, wrap_io_tensor
- from ..jit import barrier, mark_impure, trace
-
-
- @wrap_io_tensor
- def grad(
- target: Tensor,
- wrt: Union[Tensor, Iterable[Tensor]],
- warn_mid_wrt: bool = True,
- use_virtual_grad: bool = None,
- return_zero_for_nodep: bool = True,
- ) -> Union[Tensor, Iterable[Optional[Tensor]], None]:
- r"""Compute the symbolic gradient of ``target`` with repect to ``wrt``.
-
- ``wrt`` can either be a single tensor or a sequence of tensors.
-
- :param target: ``grad`` target tensor
- :param wrt: with respect to which to compute the gradient
- :param warn_mid_wrt: whether to give warning if ``wrt`` is not endpoint
- :param use_virtual_grad: whether to use virtual ``grad`` opr, so fwd graph can
- be optimized before applying ``grad``; if ``None`` is given, then virtual
- ``grad`` would be used if ``graph_opt_level >= 2``
- :param return_zero_for_nodep: if ``target`` does not depend on ``wrt``, set to True to return
- a zero-valued :class:`~.Tensor` rather than ``None``; can't be set to False when using
- virtual ``grad`` opr.
- :return: :math:`\partial\text{target} / \partial\text{wrt}`
- """
- if not isinstance(wrt, mgb.SymbolVar):
- assert isinstance(wrt, collections.Iterable)
- wrt = [w._symvar for w in wrt]
-
- return mgb.grad(target, wrt, warn_mid_wrt, use_virtual_grad, return_zero_for_nodep)
-
-
- _add_update_cache = {} # type: dict
-
- _dummy = mgb.SharedScalar(0)
-
-
- def add_update(
- dest: Tensor,
- delta: Tensor,
- *,
- alpha: Union[Tensor, float, int] = 1.0,
- beta: Union[Tensor, float, int] = 1.0,
- bias: Union[Tensor, float, int] = 0.0
- ):
- r"""Inplace modify ``dest`` as follows:
-
- .. math::
- dest = alpha * dest + beta * delta + bias
-
- :param dest: input data that will be inplace modified.
- :param delta: update value that will be added to ``dest``.
- :param alpha: weight ratio of ``dest``. Default: 1.0
- :param beta: weight ratio of ``delta``. Default: 1.0
- :param bias: bias value appended to the result. Default: 0.0
- """
-
- if isinstance(beta, Tensor) or isinstance(alpha, Tensor):
- delta *= beta
- beta = 1.0
- if isinstance(alpha, Tensor):
- delta += (alpha - 1.0) * dest
- alpha = 1.0
- if isinstance(bias, Tensor):
- delta += bias
- bias = 0.0
-
- comp_graph = dest._comp_graph or get_default_graph()
- comp_node = dest._comp_node
-
- if not isinstance(delta, Tensor):
- _delta = mgb.make_immutable(
- value=delta, comp_node=comp_node, comp_graph=comp_graph
- )
- else:
- _delta = delta._attach(comp_graph)
-
- _dest = dest._attach(comp_graph)
-
- # use (dest, delta) as the key, so we could not add the same delta to dest in static graph
- key = (comp_graph._id(), _dest.id, _delta.id)
- if key in _add_update_cache:
- _alpha, _beta, _bias, config = _add_update_cache[key]
- mgb.mgb._mgb.SharedScalar__set(_alpha, alpha)
- mgb.mgb._mgb.SharedScalar__set(_beta, beta)
- mgb.mgb._mgb.SharedScalar__set(_bias, bias)
- else:
- _alpha = mgb.SharedScalar(alpha)
- _beta = mgb.SharedScalar(beta)
- _bias = mgb.SharedScalar(bias)
- config = mgb.helper.gen_config(None, comp_node, None)
- _add_update_cache[key] = (_alpha, _beta, _bias, config)
-
- u = mgb.mgb._Opr.add_update(
- _dest, barrier(_delta), _alpha, _beta, _bias, _dummy, config
- )
- mark_impure(u)
-
- if trace._active_instance:
- dest._override_symvar_during_trace(trace._active_instance, u)
-
- return Tensor(u)
-
-
- @wrap_io_tensor
- def add_extra_vardep(oup: Tensor, dep: Tensor):
- r"""Explicitly set the dependency that tensor ``oup`` depends on tensor ``dep``.
- """
- return mgb.config.add_extra_vardep(oup, dep)
|