You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debug_param.py 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # -*- coding: utf-8 -*-
  2. import os
  3. from ..core import _config
  4. from ..core._imperative_rt.core2 import _clear_algorithm_cache
  5. from ..core.ops import builtin
  6. from ..logger import get_logger
  7. from ..utils.deprecation import deprecated
  8. Strategy = builtin.ops.Convolution.Strategy
  9. if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None:
  10. get_logger().warning(
  11. "Environment variable `MEGENGINE_CONV_EXECUTION_STRATEGY` is deprecated, please use `MEGENGINE_EXECUTION_STRATEGY`"
  12. )
  13. _valid_string_option = {
  14. "REPRODUCIBLE": Strategy.REPRODUCIBLE,
  15. "HEURISTIC": Strategy.HEURISTIC,
  16. "PROFILE": Strategy.PROFILE,
  17. }
  18. def get_execution_strategy() -> Strategy:
  19. r"""Returns the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul`
  20. See :func:`~.set_execution_strategy` for possible return values
  21. """
  22. strategy = Strategy(0)
  23. if _config._benchmark_kernel:
  24. strategy |= Strategy.PROFILE
  25. else:
  26. strategy |= Strategy.HEURISTIC
  27. if _config._deterministic_kernel:
  28. strategy |= Strategy.REPRODUCIBLE
  29. return strategy
  30. def set_execution_strategy(option):
  31. r"""Sets the execution strategy of :class:`~.module.Conv2d` and :func:`~.matmul`
  32. Args:
  33. option: Decides how :class:`~.module.Conv2d` and :func:`~.matmul` algorithms are chosen.
  34. Available strategy values:
  35. * "HEURISTIC": uses heuristic to choose the fastest algorithm.
  36. * "PROFILE": runs possible algorithms on a real device to find the best one.
  37. * "REPRODUCIBLE": uses algorithms that are reproducible.
  38. The default strategy is "HEURISTIC", these options can be combined to
  39. form a combination option, e.g. PROFILE_REPRODUCIBLE is a combination
  40. of "PROFILE" and "REPRODUCIBLE", which means using the fastest profiling
  41. result that is also reproducible.
  42. Available values string:
  43. * "HEURISTIC" uses heuristic to choose the fastest algorithm.
  44. * "PROFILE" runs possible algorithms on a real device to find the best one.
  45. * "PROFILE_REPRODUCIBLE" uses the fastest profiling result that is also reproducible.
  46. * "HEURISTIC_REPRODUCIBLE" uses heuristic to choose the fastest algorithm that is also reproducible.
  47. The default strategy is "HEURISTIC".
  48. It can also be set through the environment variable ``MEGENGINE_EXECUTION_STRATEGY``.
  49. """
  50. _benchmark_kernel = False
  51. _deterministic_kernel = False
  52. if isinstance(option, Strategy):
  53. _benchmark_kernel = (
  54. True if option & _valid_string_option["PROFILE"] != Strategy(0) else False
  55. )
  56. _deterministic_kernel = (
  57. True
  58. if option & _valid_string_option["REPRODUCIBLE"] != Strategy(0)
  59. else False
  60. )
  61. if _benchmark_kernel != _config._benchmark_kernel:
  62. _clear_algorithm_cache()
  63. _config._benchmark_kernel = _benchmark_kernel
  64. _config._deterministic_kernel = _deterministic_kernel
  65. return
  66. assert isinstance(option, str)
  67. for opt in option.split("_"):
  68. if not opt in _valid_string_option:
  69. raise ValueError(
  70. "Valid option can only be one of {}, or combine them with '_'.".format(
  71. _valid_string_option.keys()
  72. )
  73. )
  74. _benchmark_kernel |= _valid_string_option[opt] == Strategy.PROFILE
  75. _deterministic_kernel |= _valid_string_option[opt] == Strategy.REPRODUCIBLE
  76. if _benchmark_kernel != _config._benchmark_kernel:
  77. _clear_algorithm_cache()
  78. _config._benchmark_kernel = _benchmark_kernel
  79. _config._deterministic_kernel = _deterministic_kernel
  80. @deprecated(version="1.3", reason="use get_execution_strategy() instead")
  81. def get_conv_execution_strategy() -> str:
  82. return get_execution_strategy()
  83. @deprecated(version="1.3", reason="use set_execution_strategy() instead")
  84. def set_conv_execution_strategy(option: str):
  85. return set_execution_strategy(option)