You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. from typing import Iterable, Optional, Union
  11. import megengine._internal as mgb
  12. from ..core.graph import get_default_graph
  13. from ..core.tensor import Tensor, wrap_io_tensor
  14. from ..jit import barrier, mark_impure
  15. @wrap_io_tensor
  16. def grad(
  17. target: Tensor,
  18. wrt: Union[Tensor, Iterable[Tensor]],
  19. warn_mid_wrt: bool = True,
  20. use_virtual_grad: bool = None,
  21. return_zero_for_nodep: bool = True,
  22. ) -> Union[Tensor, Iterable[Optional[Tensor]], None]:
  23. r"""compute symbolic grad
  24. :param target: grad target var
  25. :param wrt: with respect to which to compute the grad
  26. :param warn_mid_wrt: whether to give warning if ``wrt`` is not endpoint
  27. :param use_virtual_grad: whether to use virtual grad opr, so fwd graph can
  28. be optimized before applying grad; if ``None`` is given, then virtual
  29. grad would be used if ``graph_opt_level >= 2``
  30. :param return_zero_for_nodep: if ``target`` does not depend on ``wrt``, set to True to return
  31. a zero-valued :class:`~.Tensor` rather than ``None``; can't be set to False when using
  32. virtual grad opr.
  33. :return: :math:`\partial\text{target} / \partial\text{wrt}`
  34. """
  35. if not isinstance(wrt, mgb.SymbolVar):
  36. assert isinstance(wrt, collections.Iterable)
  37. wrt = [w._symvar for w in wrt]
  38. return mgb.grad(target, wrt, warn_mid_wrt, use_virtual_grad, return_zero_for_nodep)
  39. _add_update_cache = {} # type: dict
  40. _dummy = mgb.SharedScalar(0)
  41. def add_update(
  42. dest: Tensor,
  43. delta: Tensor,
  44. *,
  45. alpha: Union[Tensor, float, int] = 1.0,
  46. beta: Union[Tensor, float, int] = 1.0,
  47. bias: Union[Tensor, float, int] = 0.0
  48. ):
  49. r"""Inplace modify ``dest`` as follows:
  50. .. math::
  51. dest = alpha * dest + beta * delta + bias
  52. :param dest: input data that will be inplace modified.
  53. :param delta: update value that will be added to ``dest``.
  54. :param alpha: weight ratio of ``dest``. Default: 1.0
  55. :param beta: weight ratio of ``delta``. Default: 1.0
  56. :param bias: bias value appended to the result. Default: 0.0
  57. """
  58. if isinstance(beta, Tensor) or isinstance(alpha, Tensor):
  59. delta *= beta
  60. beta = 1.0
  61. if isinstance(alpha, Tensor):
  62. delta += (alpha - 1.0) * dest
  63. alpha = 1.0
  64. if isinstance(bias, Tensor):
  65. delta += bias
  66. bias = 0.0
  67. comp_graph = dest._comp_graph or get_default_graph()
  68. comp_node = dest._comp_node
  69. if not isinstance(delta, Tensor):
  70. _delta = mgb.make_immutable(
  71. value=delta, comp_node=comp_node, comp_graph=comp_graph
  72. )
  73. else:
  74. _delta = delta._attach(comp_graph)
  75. _dest = dest._attach(comp_graph)
  76. # use (dest, delta) as the key, so we could not add the same delta to dest in static graph
  77. key = (comp_graph._id(), _dest.id, _delta.id)
  78. if key in _add_update_cache:
  79. _alpha, _beta, _bias, config = _add_update_cache[key]
  80. mgb.mgb._mgb.SharedScalar__set(_alpha, alpha)
  81. mgb.mgb._mgb.SharedScalar__set(_beta, beta)
  82. mgb.mgb._mgb.SharedScalar__set(_bias, bias)
  83. else:
  84. _alpha = mgb.SharedScalar(alpha)
  85. _beta = mgb.SharedScalar(beta)
  86. _bias = mgb.SharedScalar(bias)
  87. config = mgb.helper.gen_config(None, comp_node, None)
  88. _add_update_cache[key] = (_alpha, _beta, _bias, config)
  89. u = mgb.mgb._Opr.add_update(
  90. _dest, barrier(_delta), _alpha, _beta, _bias, _dummy, config
  91. )
  92. mark_impure(u)
  93. return Tensor(u)
  94. @wrap_io_tensor
  95. def add_extra_vardep(oup: Tensor, dep: Tensor):
  96. r"""Explicitly set the dependency that tensor ``oup`` depends on tensor ``dep``.
  97. """
  98. return mgb.config.add_extra_vardep(oup, dep)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台