You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. from typing import Iterable, Optional, Union
  11. import megengine._internal as mgb
  12. from ..core.graph import get_default_graph
  13. from ..core.tensor import Tensor, wrap_io_tensor
  14. from ..jit import barrier, mark_impure, trace
  15. @wrap_io_tensor
  16. def grad(
  17. target: Tensor,
  18. wrt: Union[Tensor, Iterable[Tensor]],
  19. warn_mid_wrt: bool = True,
  20. use_virtual_grad: bool = None,
  21. return_zero_for_nodep: bool = True,
  22. ) -> Union[Tensor, Iterable[Optional[Tensor]], None]:
  23. r"""Compute the symbolic gradient of ``target`` with repect to ``wrt``.
  24. ``wrt`` can either be a single tensor or a sequence of tensors.
  25. :param target: ``grad`` target tensor
  26. :param wrt: with respect to which to compute the gradient
  27. :param warn_mid_wrt: whether to give warning if ``wrt`` is not endpoint
  28. :param use_virtual_grad: whether to use virtual ``grad`` opr, so fwd graph can
  29. be optimized before applying ``grad``; if ``None`` is given, then virtual
  30. ``grad`` would be used if ``graph_opt_level >= 2``
  31. :param return_zero_for_nodep: if ``target`` does not depend on ``wrt``, set to True to return
  32. a zero-valued :class:`~.Tensor` rather than ``None``; can't be set to False when using
  33. virtual ``grad`` opr.
  34. :return: :math:`\partial\text{target} / \partial\text{wrt}`
  35. """
  36. if not isinstance(wrt, mgb.SymbolVar):
  37. assert isinstance(wrt, collections.Iterable)
  38. wrt = [w._symvar for w in wrt]
  39. return mgb.grad(target, wrt, warn_mid_wrt, use_virtual_grad, return_zero_for_nodep)
  40. _add_update_cache = {} # type: dict
  41. _dummy = mgb.SharedScalar(0)
  42. def add_update(
  43. dest: Tensor,
  44. delta: Tensor,
  45. *,
  46. alpha: Union[Tensor, float, int] = 1.0,
  47. beta: Union[Tensor, float, int] = 1.0,
  48. bias: Union[Tensor, float, int] = 0.0
  49. ):
  50. r"""Inplace modify ``dest`` as follows:
  51. .. math::
  52. dest = alpha * dest + beta * delta + bias
  53. :param dest: input data that will be inplace modified.
  54. :param delta: update value that will be added to ``dest``.
  55. :param alpha: weight ratio of ``dest``. Default: 1.0
  56. :param beta: weight ratio of ``delta``. Default: 1.0
  57. :param bias: bias value appended to the result. Default: 0.0
  58. """
  59. if isinstance(beta, Tensor) or isinstance(alpha, Tensor):
  60. delta *= beta
  61. beta = 1.0
  62. if isinstance(alpha, Tensor):
  63. delta += (alpha - 1.0) * dest
  64. alpha = 1.0
  65. if isinstance(bias, Tensor):
  66. delta += bias
  67. bias = 0.0
  68. comp_graph = dest._comp_graph or get_default_graph()
  69. comp_node = dest._comp_node
  70. if not isinstance(delta, Tensor):
  71. _delta = mgb.make_immutable(
  72. value=delta, comp_node=comp_node, comp_graph=comp_graph
  73. )
  74. else:
  75. _delta = delta._attach(comp_graph)
  76. _dest = dest._attach(comp_graph)
  77. # use (dest, delta) as the key, so we could not add the same delta to dest in static graph
  78. key = (comp_graph._id(), _dest.id, _delta.id)
  79. if key in _add_update_cache:
  80. _alpha, _beta, _bias, config = _add_update_cache[key]
  81. mgb.mgb._mgb.SharedScalar__set(_alpha, alpha)
  82. mgb.mgb._mgb.SharedScalar__set(_beta, beta)
  83. mgb.mgb._mgb.SharedScalar__set(_bias, bias)
  84. else:
  85. _alpha = mgb.SharedScalar(alpha)
  86. _beta = mgb.SharedScalar(beta)
  87. _bias = mgb.SharedScalar(bias)
  88. config = mgb.helper.gen_config(None, comp_node, None)
  89. _add_update_cache[key] = (_alpha, _beta, _bias, config)
  90. u = mgb.mgb._Opr.add_update(
  91. _dest, barrier(_delta), _alpha, _beta, _bias, _dummy, config
  92. )
  93. mark_impure(u)
  94. if trace._active_instance:
  95. dest._override_symvar_during_trace(trace._active_instance, u)
  96. return Tensor(u)
  97. @wrap_io_tensor
  98. def add_extra_vardep(oup: Tensor, dep: Tensor):
  99. r"""Explicitly set the dependency that tensor ``oup`` depends on tensor ``dep``.
  100. """
  101. return mgb.config.add_extra_vardep(oup, dep)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台