You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

internal.py 1.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Union
  10. import megengine._internal as mgb
  11. from ..core.tensor import Tensor, tensor
  12. def add_update_fastpath(
  13. dest: Tensor,
  14. delta: Tensor,
  15. *,
  16. alpha: Union[Tensor, float, int] = 1.0,
  17. beta: Union[Tensor, float, int] = 1.0,
  18. bias: Union[Tensor, float, int] = 0.0
  19. ):
  20. """a fast-path ONLY used to update parameters in optimzier, since it
  21. would bypass computing graph and launch dnn/add_update kernel directly,
  22. it is more efficient than functional/add_update.
  23. """
  24. if isinstance(beta, Tensor) or isinstance(alpha, Tensor):
  25. delta *= beta
  26. beta = 1.0
  27. if isinstance(alpha, Tensor):
  28. delta += (alpha - 1.0) * dest
  29. alpha = 1.0
  30. if isinstance(bias, Tensor):
  31. delta += bias
  32. bias = 0.0
  33. if not isinstance(delta, Tensor):
  34. delta = tensor(delta, device=dest.device, dtype=dest.dtype)
  35. def get_v(x):
  36. if x._Tensor__val is None:
  37. assert isinstance(x._Tensor__sym, mgb.SymbolVar)
  38. return x._Tensor__sym.eager_val
  39. else:
  40. assert isinstance(x._Tensor__val, mgb.SharedND)
  41. return x._Tensor__val
  42. mgb.mgb._add_update_fastpath(get_v(dest), get_v(delta), alpha, beta, bias)
  43. return dest

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台