You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_config.py 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # -*- coding: utf-8 -*-
  2. import os
  3. from contextlib import contextmanager
  4. from ._imperative_rt.core2 import (
  5. _clear_algorithm_cache,
  6. get_auto_format_convert,
  7. get_option,
  8. set_auto_format_convert,
  9. set_option,
  10. )
  11. # use "default" to distinguish it from None in _reset_execution_config
  12. __compute_mode = "default"
  13. _benchmark_kernel = False
  14. _deterministic_kernel = False
  15. _benchmark_with_subprocess = False
  16. __all__ = [
  17. "benchmark_kernel",
  18. "benchmark_with_subprocess",
  19. "deterministic_kernel",
  20. "async_level",
  21. "disable_memory_forwarding",
  22. "_compute_mode",
  23. "_auto_format_convert",
  24. "_override",
  25. ]
  26. @property
  27. def benchmark_kernel(mod):
  28. r"""Whether or not run possible algorithms on real device to find the best one. The default option is false,
  29. which means use heuristic to choose the fastest algorithm.
  30. Examples:
  31. .. code-block::
  32. import megengine as mge
  33. mge.config.benchmark_kernel = True
  34. """
  35. return _benchmark_kernel
  36. @benchmark_kernel.setter
  37. def benchmark_kernel(mod, option: bool):
  38. global _benchmark_kernel
  39. # try different strategy, then clear algorithm cache
  40. if option != _benchmark_kernel:
  41. _clear_algorithm_cache()
  42. _benchmark_kernel = option
  43. @property
  44. def deterministic_kernel(mod):
  45. r"""Whether or not the fastest algorithm choosed is reproducible. The default option is false,
  46. which means the algorithm is not reproducible.
  47. Examples:
  48. .. code-block::
  49. import megengine as mge
  50. mge.config.deterministic_kernel = True
  51. """
  52. return _deterministic_kernel
  53. @deterministic_kernel.setter
  54. def deterministic_kernel(mod, option: bool):
  55. global _deterministic_kernel
  56. _deterministic_kernel = option
  57. @property
  58. def benchmark_with_subprocess(mod):
  59. r"""Whether or not run possible algorithms on real device to find the best one. The default option is false,
  60. which means use heuristic to choose the fastest algorithm.
  61. Examples:
  62. .. code-block::
  63. import megengine as mge
  64. mge.config.benchmark_with_subprocess = True
  65. """
  66. return _benchmark_with_subprocess
  67. @benchmark_with_subprocess.setter
  68. def benchmark_with_subprocess(mod, option: bool):
  69. if option:
  70. import sys
  71. from ._imperative_rt.utils import _set_fork_exec_path_for_timed_func
  72. _set_fork_exec_path_for_timed_func(
  73. sys.executable,
  74. os.path.join(
  75. os.path.dirname(__file__), "../utils", "_timed_func_fork_exec_entry.py"
  76. ),
  77. )
  78. @property
  79. def async_level(mod) -> int:
  80. r"""Get or set config whether raise error exactly when invoking op. The default level is 2,
  81. which means both device and user side errors are async.
  82. Examples:
  83. .. code-block::
  84. import megengine as mge
  85. mge.config.async_level = 2
  86. """
  87. return get_option("async_level")
  88. @async_level.setter
  89. def async_level(mod, level: int):
  90. assert level >= 0 and level <= 2, "async_level should be 0, 1 or 2"
  91. set_option("async_level", level)
  92. @property
  93. def disable_memory_forwarding(mod) -> bool:
  94. r"""Get or set config whether to disable memory forwarding. The default option is false,
  95. which means storage may be shared among tensors.
  96. Examples:
  97. .. code-block::
  98. import megengine as mge
  99. mge.config.disable_memory_forwarding = False
  100. """
  101. return bool(get_option("disable_memory_forwarding"))
  102. @disable_memory_forwarding.setter
  103. def disable_memory_forwarding(mod, disable: bool):
  104. set_option("disable_memory_forwarding", disable)
  105. @property
  106. def _compute_mode(mod):
  107. r"""Get or set the precision of intermediate results for conv, matmul. The default
  108. option is None and will fallback to "default". When set to "float32", it will
  109. trigger mixed precision computation on TensorCore, but only effective when input and
  110. output are of float16 dtype.
  111. Examples:
  112. .. code-block::
  113. import megengine as mge
  114. mge.config._compute_mode = "float32"
  115. """
  116. return __compute_mode
  117. @_compute_mode.setter
  118. def _compute_mode(mod, _compute_mode: str):
  119. global __compute_mode
  120. __compute_mode = _compute_mode
  121. @property
  122. def _bn_format(mod):
  123. r"""Get or set batchnorm param layout format. The default option is None and will
  124. fallback to "dim_1c11" which corresponds to NCHW format. When set to "dim_111c",
  125. param format of batchnorm will be changed to NHWC.
  126. Examples:
  127. .. code-block::
  128. import megengine as mge
  129. mge.config._bn_format = "dim_111c"
  130. """
  131. return __bn_format
  132. @_bn_format.setter
  133. def _bn_format(mod, format: str):
  134. global __bn_format
  135. __bn_format = format
  136. @property
  137. def _auto_format_convert(mod):
  138. r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order.
  139. The default value is False, which means no convert.
  140. Examples:
  141. .. code-block::
  142. import megengine as mge
  143. mge.config._auto_format_convert = True
  144. """
  145. return get_auto_format_convert()
  146. @_auto_format_convert.setter
  147. def _auto_format_convert(mod, option: bool):
  148. set_auto_format_convert(option)
  149. def _reset_execution_config(
  150. benchmark_kernel=None,
  151. deterministic_kernel=None,
  152. async_level=None,
  153. compute_mode=None,
  154. ):
  155. global _benchmark_kernel, _deterministic_kernel, __compute_mode
  156. orig_flags = (
  157. _benchmark_kernel,
  158. _deterministic_kernel,
  159. get_option("async_level"),
  160. __compute_mode,
  161. )
  162. if benchmark_kernel is not None:
  163. _benchmark_kernel = benchmark_kernel
  164. if deterministic_kernel is not None:
  165. _deterministic_kernel = deterministic_kernel
  166. if async_level is not None:
  167. set_option("async_level", async_level)
  168. if compute_mode is not None:
  169. __compute_mode = compute_mode
  170. return orig_flags
  171. @contextmanager
  172. def _override(
  173. benchmark_kernel=None,
  174. deterministic_kernel=None,
  175. async_level=None,
  176. compute_mode=None,
  177. ):
  178. r"""A context manager that users can opt in by attaching the decorator to set
  179. the config of the global variable.
  180. Examples:
  181. .. code-block::
  182. import megengine as mge
  183. @mge.config._override(
  184. benchmark_kernel = True,
  185. deterministic_kernel = Fasle,
  186. async_level=2,
  187. compute_mode="float32",
  188. )
  189. def train():
  190. """
  191. orig_flags = _reset_execution_config(
  192. benchmark_kernel=benchmark_kernel,
  193. deterministic_kernel=deterministic_kernel,
  194. async_level=async_level,
  195. compute_mode=compute_mode,
  196. )
  197. try:
  198. yield
  199. finally:
  200. # recover the previous values
  201. _reset_execution_config(*orig_flags)
  202. def _get_actual_op_param(function_param, config_param):
  203. return function_param if config_param == "default" else config_param