You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_config.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. # -*- coding: utf-8 -*-
  2. import os
  3. from contextlib import contextmanager
  4. from ._imperative_rt.core2 import (
  5. _clear_algorithm_cache,
  6. get_auto_format_convert,
  7. get_option,
  8. set_auto_format_convert,
  9. set_option,
  10. )
  11. # use "default" to distinguish it from None in _reset_execution_config
  12. __compute_mode = "default"
  13. __conv_format = "default"
  14. __bn_format = "default"
  15. _benchmark_kernel = False
  16. _deterministic_kernel = False
  17. __all__ = [
  18. "benchmark_kernel",
  19. "deterministic_kernel",
  20. "async_level",
  21. "disable_memory_forwarding",
  22. "_compute_mode",
  23. "_conv_format",
  24. "_bn_format",
  25. "_auto_format_convert",
  26. "_override",
  27. ]
  28. @property
  29. def benchmark_kernel(mod):
  30. r"""Whether or not run possible algorithms on real device to find the best one. The default option is false,
  31. which means use heuristic to choose the fastest algorithm.
  32. Examples:
  33. .. code-block::
  34. import megengine as mge
  35. mge.config.benchmark_kernel = True
  36. """
  37. return _benchmark_kernel
  38. @benchmark_kernel.setter
  39. def benchmark_kernel(mod, option: bool):
  40. global _benchmark_kernel
  41. # try different strategy, then clear algorithm cache
  42. if option != _benchmark_kernel:
  43. _clear_algorithm_cache()
  44. _benchmark_kernel = option
  45. @property
  46. def deterministic_kernel(mod):
  47. r"""Whether or not the fastest algorithm choosed is reproducible. The default option is false,
  48. which means the algorithm is not reproducible.
  49. Examples:
  50. .. code-block::
  51. import megengine as mge
  52. mge.config.deterministic_kernel = True
  53. """
  54. return _deterministic_kernel
  55. @deterministic_kernel.setter
  56. def deterministic_kernel(mod, option: bool):
  57. global _deterministic_kernel
  58. _deterministic_kernel = option
  59. @property
  60. def async_level(mod) -> int:
  61. r"""Get or set config whether raise error exactly when invoking op. The default level is 2,
  62. which means both device and user side errors are async.
  63. Examples:
  64. .. code-block::
  65. import megengine as mge
  66. mge.config.async_level = 2
  67. """
  68. return get_option("async_level")
  69. @async_level.setter
  70. def async_level(mod, level: int):
  71. assert level >= 0 and level <= 2, "async_level should be 0, 1 or 2"
  72. set_option("async_level", level)
  73. @property
  74. def disable_memory_forwarding(mod) -> bool:
  75. r"""Get or set config whether to disable memory forwarding. The default option is false,
  76. which means storage may be shared among tensors.
  77. Examples:
  78. .. code-block::
  79. import megengine as mge
  80. mge.config.disable_memory_forwarding = False
  81. """
  82. return bool(get_option("disable_memory_forwarding"))
  83. @disable_memory_forwarding.setter
  84. def disable_memory_forwarding(mod, disable: bool):
  85. set_option("disable_memory_forwarding", disable)
  86. @property
  87. def _compute_mode(mod):
  88. r"""Get or set the precision of intermediate results for conv, matmul. The default
  89. option is None and will fallback to "default". When set to "float32", it will
  90. trigger mixed precision computation on TensorCore, but only effective when input and
  91. output are of float16 dtype.
  92. Examples:
  93. .. code-block::
  94. import megengine as mge
  95. mge.config._compute_mode = "float32"
  96. """
  97. return __compute_mode
  98. @_compute_mode.setter
  99. def _compute_mode(mod, _compute_mode: str):
  100. global __compute_mode
  101. __compute_mode = _compute_mode
  102. @property
  103. def _conv_format(mod):
  104. r"""Get or set convolution data/filter/output layout format. The default option is None,
  105. which means that no special format will be placed on. There are all layout definitions
  106. ``NCHW`` layout: ``{N, C, H, W}``
  107. ``NHWC`` layout: ``{N, H, W, C}``
  108. ``NHWCD4`` layout: ``{N, H, (C + 3) / 4, W, 4}``
  109. ``NHWCD4I`` layout: with ``align_axis = 2``
  110. ``NCHW4`` layout: ``{N, C/4, H, W, 4}``
  111. ``NCHW88`` layout: ``{N, C/8, H, W, 8}``
  112. ``CHWN4`` layout: ``{C/4, H, W, N, 4}``
  113. ``NCHW64`` layout: ``{N, C/64, H, W, 64}``
  114. Examples:
  115. .. code-block::
  116. import megengine as mge
  117. mge.config._conv_format = "NHWC"
  118. """
  119. return __conv_format
  120. @_conv_format.setter
  121. def _conv_format(mod, format: str):
  122. global __conv_format
  123. __conv_format = format
  124. @property
  125. def _bn_format(mod):
  126. r"""Get or set batchnorm param layout format. The default option is None and will
  127. fallback to "dim_1c11" which corresponds to NCHW format. When set to "dim_111c",
  128. param format of batchnorm will be changed to NHWC.
  129. Examples:
  130. .. code-block::
  131. import megengine as mge
  132. mge.config._bn_format = "dim_111c"
  133. """
  134. return __bn_format
  135. @_bn_format.setter
  136. def _bn_format(mod, format: str):
  137. global __bn_format
  138. __bn_format = format
  139. @property
  140. def _auto_format_convert(mod):
  141. r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order.
  142. The default value is False, which means no convert.
  143. Examples:
  144. .. code-block::
  145. import megengine as mge
  146. mge.config._auto_format_convert = True
  147. """
  148. return get_auto_format_convert()
  149. @_auto_format_convert.setter
  150. def _auto_format_convert(mod, option: bool):
  151. set_auto_format_convert(option)
  152. def _reset_execution_config(
  153. benchmark_kernel=None,
  154. deterministic_kernel=None,
  155. async_level=None,
  156. compute_mode=None,
  157. conv_format=None,
  158. bn_format=None,
  159. auto_format_convert=None,
  160. ):
  161. global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format, __bn_format
  162. orig_flags = (
  163. _benchmark_kernel,
  164. _deterministic_kernel,
  165. get_option("async_level"),
  166. __compute_mode,
  167. __conv_format,
  168. __bn_format,
  169. get_auto_format_convert(),
  170. )
  171. if benchmark_kernel is not None:
  172. _benchmark_kernel = benchmark_kernel
  173. if deterministic_kernel is not None:
  174. _deterministic_kernel = deterministic_kernel
  175. if async_level is not None:
  176. set_option("async_level", async_level)
  177. if compute_mode is not None:
  178. __compute_mode = compute_mode
  179. if conv_format is not None:
  180. __conv_format = conv_format
  181. if bn_format is not None:
  182. __bn_format = bn_format
  183. if auto_format_convert is not None:
  184. set_auto_format_convert(auto_format_convert)
  185. return orig_flags
  186. @contextmanager
  187. def _override(
  188. benchmark_kernel=None,
  189. deterministic_kernel=None,
  190. async_level=None,
  191. compute_mode=None,
  192. conv_format=None,
  193. bn_format=None,
  194. auto_format_convert=None,
  195. ):
  196. r"""A context manager that users can opt in by attaching the decorator to set
  197. the config of the global variable.
  198. Examples:
  199. .. code-block::
  200. import megengine as mge
  201. @mge.config._override(
  202. benchmark_kernel = True,
  203. deterministic_kernel = Fasle,
  204. async_level=2,
  205. compute_mode="float32",
  206. conv_format="NHWC",
  207. bn_format="dim_111c",
  208. auto_format_convert=True,
  209. )
  210. def train():
  211. """
  212. orig_flags = _reset_execution_config(
  213. benchmark_kernel,
  214. deterministic_kernel,
  215. async_level,
  216. compute_mode,
  217. conv_format,
  218. bn_format,
  219. auto_format_convert,
  220. )
  221. try:
  222. yield
  223. finally:
  224. # recover the previous values
  225. _reset_execution_config(*orig_flags)
  226. def _get_actual_op_param(function_param, config_param):
  227. return function_param if config_param is "default" else config_param