You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_config.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. from contextlib import contextmanager
  11. __compute_mode = "default"
  12. __conv_format = "default"
  13. _benchmark_kernel = False
  14. _deterministic_kernel = False
  15. _async_level = os.getenv("MEGENGINE_INTERP_ASYNC_LEVEL", 2)
  16. __all__ = [
  17. "benchmark_kernel",
  18. "deterministic_kernel",
  19. "async_level",
  20. "_compute_mode",
  21. "_conv_format",
  22. "_override",
  23. ]
  24. @property
  25. def benchmark_kernel(mod):
  26. r"""Whether or not run possible algorithms on real device to find the best one. The default option is false,
  27. which means use heuristic to choose the fastest algorithm.
  28. Examples:
  29. .. code-block::
  30. import megengine as mge
  31. mge.config.benchmark_kernel = True
  32. """
  33. return _benchmark_kernel
  34. @benchmark_kernel.setter
  35. def benchmark_kernel(mod, option: bool):
  36. global _benchmark_kernel
  37. _benchmark_kernel = option
  38. @property
  39. def deterministic_kernel(mod):
  40. r"""Whether or not the fastest algorithm choosed is reproducible. The default option is false,
  41. which means the algorithm is not reproducible.
  42. Examples:
  43. .. code-block::
  44. import megengine as mge
  45. mge.config.deterministic_kernel = True
  46. """
  47. return _deterministic_kernel
  48. @deterministic_kernel.setter
  49. def deterministic_kernel(mod, option: bool):
  50. global _deterministic_kernel
  51. _deterministic_kernel = option
  52. @property
  53. def async_level(mod) -> int:
  54. r"""Get or set config whether raise error exactly when invoking op. The default level is 2,
  55. which means both device and user side errors are async.
  56. Examples:
  57. .. code-block::
  58. import megengine as mge
  59. mge.config.async_level = 2
  60. """
  61. return _async_level
  62. @async_level.setter
  63. def async_level(mod, level: int):
  64. global _async_level
  65. _async_level = level
  66. @property
  67. def _compute_mode(mod):
  68. r"""Get or set the precision of intermediate results. The default option is "default",
  69. which means that no special requirements will be placed on. When set to 'float32', it
  70. would be used for accumulator and intermediate result, but only effective when input and
  71. output are of float16 dtype.
  72. Examples:
  73. .. code-block::
  74. import megengine as mge
  75. mge.config._compute_mode = "default"
  76. """
  77. return __compute_mode
  78. @_compute_mode.setter
  79. def _compute_mode(mod, _compute_mode: str):
  80. global __compute_mode
  81. __compute_mode = _compute_mode
  82. @property
  83. def _conv_format(mod):
  84. r"""Get or set convolution data/filter/output layout format. The default option is "default",
  85. which means that no special format will be placed on. There are all layout definitions
  86. ``NCHW`` layout: ``{N, C, H, W}``
  87. ``NHWC`` layout: ``{N, H, W, C}``
  88. ``NHWCD4`` layout: ``{N, H, (C + 3) / 4, W, 4}``
  89. ``NHWCD4I`` layout: with ``align_axis = 2``
  90. ``NCHW4`` layout: ``{N, C/4, H, W, 4}``
  91. ``NCHW88`` layout: ``{N, C/8, H, W, 8}``
  92. ``CHWN4`` layout: ``{C/4, H, W, N, 4}``
  93. ``NCHW64`` layout: ``{N, C/64, H, W, 64}``
  94. Examples:
  95. .. code-block::
  96. import megengine as mge
  97. mge.config._conv_format = "NHWC"
  98. """
  99. return __conv_format
  100. @_conv_format.setter
  101. def _conv_format(mod, format: str):
  102. global __conv_format
  103. __conv_format = format
  104. def _reset_execution_config(
  105. benchmark_kernel=None,
  106. deterministic_kernel=None,
  107. async_level=None,
  108. compute_mode=None,
  109. conv_format=None,
  110. ):
  111. global _benchmark_kernel, _deterministic_kernel, _async_level, __compute_mode, __conv_format
  112. orig_flags = (
  113. _benchmark_kernel,
  114. _deterministic_kernel,
  115. _async_level,
  116. __compute_mode,
  117. __conv_format,
  118. )
  119. if benchmark_kernel is not None:
  120. _benchmark_kernel = benchmark_kernel
  121. if deterministic_kernel is not None:
  122. _deterministic_kernel = deterministic_kernel
  123. if async_level is not None:
  124. _async_level = async_level
  125. if compute_mode is not None:
  126. __compute_mode = compute_mode
  127. if conv_format is not None:
  128. __conv_format = conv_format
  129. return orig_flags
  130. @contextmanager
  131. def _override(
  132. benchmark_kernel=None,
  133. deterministic_kernel=None,
  134. async_level=None,
  135. compute_mode=None,
  136. conv_format=None,
  137. ):
  138. r"""A context manager that users can opt in by attaching the decorator to set
  139. the config of the global variable.
  140. Examples:
  141. .. code-block::
  142. import megengine as mge
  143. @mge.config._override(
  144. benchmark_kernel = True,
  145. deterministic_kernel = Fasle,
  146. async_level=2,
  147. compute_mode="float32",
  148. conv_format="NHWC",
  149. )
  150. def train():
  151. """
  152. orig_flags = _reset_execution_config(
  153. benchmark_kernel, deterministic_kernel, async_level, compute_mode, conv_format,
  154. )
  155. try:
  156. yield
  157. finally:
  158. # recover the previous values
  159. _reset_execution_config(*orig_flags)
  160. def _get_actual_op_param(function_param, config_param):
  161. return function_param if config_param == "default" else config_param

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台