You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conftest.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import os
  2. import platform
  3. import sys
  4. import pytest
  5. from megengine.core import _config as config
  6. from megengine.core import _trace_option as trace_option
  7. from megengine.core import get_option
  8. from megengine.core._imperative_rt.core2 import (
  9. _get_amp_dtype_autocast,
  10. _get_amp_high_prec_dtype,
  11. _get_amp_low_prec_dtype,
  12. _get_convert_inputs,
  13. )
  14. from megengine.core.tensor import amp
  15. from megengine.device import get_device_count
  16. sys.path.append(os.path.join(os.path.dirname(__file__), "helpers"))
  17. _ngpu = get_device_count("gpu")
  18. @pytest.fixture(autouse=True)
  19. def skip_by_ngpu(request):
  20. if request.node.get_closest_marker("require_ngpu"):
  21. require_ngpu = int(request.node.get_closest_marker("require_ngpu").args[0])
  22. if require_ngpu > _ngpu:
  23. pytest.skip("skipped for ngpu unsatisfied: {}".format(require_ngpu))
  24. @pytest.fixture(autouse=True)
  25. def skip_distributed(request):
  26. if request.node.get_closest_marker("distributed_isolated"):
  27. if platform.system() in ("Windows", "Darwin"):
  28. pytest.skip(
  29. "skipped for distributed unsupported at platform: {}".format(
  30. platform.system()
  31. )
  32. )
  33. @pytest.fixture(autouse=True)
  34. def run_around_tests():
  35. env_vars1 = {
  36. "symbolic_shape": trace_option.use_symbolic_shape(),
  37. "async_level": get_option("async_level"),
  38. "enable_drop": get_option("enable_drop"),
  39. "max_recompute_time": get_option("max_recompute_time"),
  40. "catch_worker_execption": get_option("catch_worker_execption"),
  41. "enable_host_compute": get_option("enable_host_compute"),
  42. # "record_computing_path": get_option("record_computing_path"),
  43. "disable_memory_forwarding": get_option("disable_memory_forwarding"),
  44. "enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"),
  45. "enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"),
  46. "dtr_eviction_threshold": get_option("dtr_eviction_threshold"),
  47. "dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"),
  48. "benchmark_kernel": config.benchmark_kernel,
  49. "deterministic_kernel": config.deterministic_kernel,
  50. "compute_mode": config._compute_mode,
  51. "amp_enabled": amp.enabled,
  52. "convert_inputs": _get_convert_inputs(),
  53. "amp_dtype_autocast": _get_amp_dtype_autocast(),
  54. "amp_high_prec_dtype": _get_amp_high_prec_dtype(),
  55. "amp_low_prec_dtype": _get_amp_low_prec_dtype(),
  56. }
  57. yield
  58. env_vars2 = {
  59. "symbolic_shape": trace_option.use_symbolic_shape(),
  60. "async_level": get_option("async_level"),
  61. "enable_drop": get_option("enable_drop"),
  62. "max_recompute_time": get_option("max_recompute_time"),
  63. "catch_worker_execption": get_option("catch_worker_execption"),
  64. "enable_host_compute": get_option("enable_host_compute"),
  65. # "record_computing_path": get_option("record_computing_path"),
  66. "disable_memory_forwarding": get_option("disable_memory_forwarding"),
  67. "enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"),
  68. "enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"),
  69. "dtr_eviction_threshold": get_option("dtr_eviction_threshold"),
  70. "dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"),
  71. "benchmark_kernel": config.benchmark_kernel,
  72. "deterministic_kernel": config.deterministic_kernel,
  73. "compute_mode": config._compute_mode,
  74. "amp_enabled": amp.enabled,
  75. "convert_inputs": _get_convert_inputs(),
  76. "amp_dtype_autocast": _get_amp_dtype_autocast(),
  77. "amp_high_prec_dtype": _get_amp_high_prec_dtype(),
  78. "amp_low_prec_dtype": _get_amp_low_prec_dtype(),
  79. }
  80. for key in env_vars1:
  81. assert (
  82. env_vars1[key] == env_vars2[key]
  83. ), "{} have been changed after test".format(key)