You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_config.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. from contextlib import contextmanager
  11. from ._imperative_rt.core2 import _clear_algorithm_cache, get_option, set_option
  12. __compute_mode = "default"
  13. __conv_format = "default"
  14. _benchmark_kernel = False
  15. _deterministic_kernel = False
  16. __all__ = [
  17. "benchmark_kernel",
  18. "deterministic_kernel",
  19. "async_level",
  20. "disable_memory_forwarding",
  21. "_compute_mode",
  22. "_conv_format",
  23. "_override",
  24. ]
  25. @property
  26. def benchmark_kernel(mod):
  27. r"""Whether or not run possible algorithms on real device to find the best one. The default option is false,
  28. which means use heuristic to choose the fastest algorithm.
  29. Examples:
  30. .. code-block::
  31. import megengine as mge
  32. mge.config.benchmark_kernel = True
  33. """
  34. return _benchmark_kernel
  35. @benchmark_kernel.setter
  36. def benchmark_kernel(mod, option: bool):
  37. global _benchmark_kernel
  38. # try different strategy, then clear algorithm cache
  39. if option != _benchmark_kernel:
  40. _clear_algorithm_cache()
  41. _benchmark_kernel = option
  42. @property
  43. def deterministic_kernel(mod):
  44. r"""Whether or not the fastest algorithm choosed is reproducible. The default option is false,
  45. which means the algorithm is not reproducible.
  46. Examples:
  47. .. code-block::
  48. import megengine as mge
  49. mge.config.deterministic_kernel = True
  50. """
  51. return _deterministic_kernel
  52. @deterministic_kernel.setter
  53. def deterministic_kernel(mod, option: bool):
  54. global _deterministic_kernel
  55. _deterministic_kernel = option
  56. @property
  57. def async_level(mod) -> int:
  58. r"""Get or set config whether raise error exactly when invoking op. The default level is 2,
  59. which means both device and user side errors are async.
  60. Examples:
  61. .. code-block::
  62. import megengine as mge
  63. mge.config.async_level = 2
  64. """
  65. return get_option("async_level")
  66. @async_level.setter
  67. def async_level(mod, level: int):
  68. assert level >= 0 and level <= 2, "async_level should be 0, 1 or 2"
  69. set_option("async_level", level)
  70. @property
  71. def disable_memory_forwarding(mod) -> bool:
  72. r"""Get or set config whether to disable memory forwarding. The default option is false,
  73. which means storage may be shared among tensors.
  74. Examples:
  75. .. code-block::
  76. import megengine as mge
  77. mge.config.disable_memory_forwarding = False
  78. """
  79. return bool(get_option("disable_memory_forwarding"))
  80. @disable_memory_forwarding.setter
  81. def disable_memory_forwarding(mod, disable: bool):
  82. set_option("disable_memory_forwarding", disable)
  83. @property
  84. def _compute_mode(mod):
  85. r"""Get or set the precision of intermediate results. The default option is "default",
  86. which means that no special requirements will be placed on. When set to 'float32', it
  87. would be used for accumulator and intermediate result, but only effective when input and
  88. output are of float16 dtype.
  89. Examples:
  90. .. code-block::
  91. import megengine as mge
  92. mge.config._compute_mode = "default"
  93. """
  94. return __compute_mode
  95. @_compute_mode.setter
  96. def _compute_mode(mod, _compute_mode: str):
  97. global __compute_mode
  98. __compute_mode = _compute_mode
  99. @property
  100. def _conv_format(mod):
  101. r"""Get or set convolution data/filter/output layout format. The default option is "default",
  102. which means that no special format will be placed on. There are all layout definitions
  103. ``NCHW`` layout: ``{N, C, H, W}``
  104. ``NHWC`` layout: ``{N, H, W, C}``
  105. ``NHWCD4`` layout: ``{N, H, (C + 3) / 4, W, 4}``
  106. ``NHWCD4I`` layout: with ``align_axis = 2``
  107. ``NCHW4`` layout: ``{N, C/4, H, W, 4}``
  108. ``NCHW88`` layout: ``{N, C/8, H, W, 8}``
  109. ``CHWN4`` layout: ``{C/4, H, W, N, 4}``
  110. ``NCHW64`` layout: ``{N, C/64, H, W, 64}``
  111. Examples:
  112. .. code-block::
  113. import megengine as mge
  114. mge.config._conv_format = "NHWC"
  115. """
  116. return __conv_format
  117. @_conv_format.setter
  118. def _conv_format(mod, format: str):
  119. global __conv_format
  120. __conv_format = format
  121. def _reset_execution_config(
  122. benchmark_kernel=None,
  123. deterministic_kernel=None,
  124. async_level=None,
  125. compute_mode=None,
  126. conv_format=None,
  127. ):
  128. global _benchmark_kernel, _deterministic_kernel, _async_level, __compute_mode, __conv_format
  129. orig_flags = (
  130. _benchmark_kernel,
  131. _deterministic_kernel,
  132. get_option("async_level"),
  133. __compute_mode,
  134. __conv_format,
  135. )
  136. if benchmark_kernel is not None:
  137. _benchmark_kernel = benchmark_kernel
  138. if deterministic_kernel is not None:
  139. _deterministic_kernel = deterministic_kernel
  140. if async_level is not None:
  141. set_option("async_level", async_level)
  142. if compute_mode is not None:
  143. __compute_mode = compute_mode
  144. if conv_format is not None:
  145. __conv_format = conv_format
  146. return orig_flags
  147. @contextmanager
  148. def _override(
  149. benchmark_kernel=None,
  150. deterministic_kernel=None,
  151. async_level=None,
  152. compute_mode=None,
  153. conv_format=None,
  154. ):
  155. r"""A context manager that users can opt in by attaching the decorator to set
  156. the config of the global variable.
  157. Examples:
  158. .. code-block::
  159. import megengine as mge
  160. @mge.config._override(
  161. benchmark_kernel = True,
  162. deterministic_kernel = Fasle,
  163. async_level=2,
  164. compute_mode="float32",
  165. conv_format="NHWC",
  166. )
  167. def train():
  168. """
  169. orig_flags = _reset_execution_config(
  170. benchmark_kernel, deterministic_kernel, async_level, compute_mode, conv_format,
  171. )
  172. try:
  173. yield
  174. finally:
  175. # recover the previous values
  176. _reset_execution_config(*orig_flags)
  177. def _get_actual_op_param(function_param, config_param):
  178. return function_param if config_param == "default" else config_param