You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conftest.py 4.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import os
  2. import platform
  3. import sys
  4. import pytest
  5. from megengine.core import _config as config
  6. from megengine.core import _trace_option as trace_option
  7. from megengine.core import get_option
  8. from megengine.core._imperative_rt.core2 import (
  9. _get_amp_dtype_autocast,
  10. _get_amp_high_prec_dtype,
  11. _get_amp_low_prec_dtype,
  12. _get_convert_inputs,
  13. )
  14. from megengine.core.tensor import amp
  15. from megengine.device import get_device_count
  16. sys.path.append(os.path.join(os.path.dirname(__file__), "helpers"))
  17. _ngpu = get_device_count("gpu")
  18. @pytest.fixture(autouse=True)
  19. def skip_by_ngpu(request):
  20. if request.node.get_closest_marker("require_ngpu"):
  21. require_ngpu = int(request.node.get_closest_marker("require_ngpu").args[0])
  22. if require_ngpu > _ngpu:
  23. pytest.skip("skipped for ngpu unsatisfied: {}".format(require_ngpu))
  24. @pytest.fixture(autouse=True)
  25. def skip_distributed(request):
  26. if request.node.get_closest_marker("distributed_isolated"):
  27. if platform.system() in ("Windows", "Darwin"):
  28. pytest.skip(
  29. "skipped for distributed unsupported at platform: {}".format(
  30. platform.system()
  31. )
  32. )
  33. @pytest.fixture(autouse=True)
  34. def run_around_tests():
  35. env_vars1 = {
  36. "symbolic_shape": trace_option.use_symbolic_shape(),
  37. "async_level": get_option("async_level"),
  38. "enable_drop": get_option("enable_drop"),
  39. "max_recompute_time": get_option("max_recompute_time"),
  40. "catch_worker_execption": get_option("catch_worker_execption"),
  41. "enable_host_compute": get_option("enable_host_compute"),
  42. # "record_computing_path": get_option("record_computing_path"),
  43. "disable_memory_forwarding": get_option("disable_memory_forwarding"),
  44. "enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"),
  45. "enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"),
  46. "dtr_eviction_threshold": get_option("dtr_eviction_threshold"),
  47. "dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"),
  48. "benchmark_kernel": config.benchmark_kernel,
  49. "deterministic_kernel": config.deterministic_kernel,
  50. "compute_mode": config._compute_mode,
  51. "conv_format": config._conv_format,
  52. "amp_enabled": amp.enabled,
  53. "convert_inputs": _get_convert_inputs(),
  54. "amp_dtype_autocast": _get_amp_dtype_autocast(),
  55. "amp_high_prec_dtype": _get_amp_high_prec_dtype(),
  56. "amp_low_prec_dtype": _get_amp_low_prec_dtype(),
  57. }
  58. yield
  59. env_vars2 = {
  60. "symbolic_shape": trace_option.use_symbolic_shape(),
  61. "async_level": get_option("async_level"),
  62. "enable_drop": get_option("enable_drop"),
  63. "max_recompute_time": get_option("max_recompute_time"),
  64. "catch_worker_execption": get_option("catch_worker_execption"),
  65. "enable_host_compute": get_option("enable_host_compute"),
  66. # "record_computing_path": get_option("record_computing_path"),
  67. "disable_memory_forwarding": get_option("disable_memory_forwarding"),
  68. "enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"),
  69. "enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"),
  70. "dtr_eviction_threshold": get_option("dtr_eviction_threshold"),
  71. "dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"),
  72. "benchmark_kernel": config.benchmark_kernel,
  73. "deterministic_kernel": config.deterministic_kernel,
  74. "compute_mode": config._compute_mode,
  75. "conv_format": config._conv_format,
  76. "amp_enabled": amp.enabled,
  77. "convert_inputs": _get_convert_inputs(),
  78. "amp_dtype_autocast": _get_amp_dtype_autocast(),
  79. "amp_high_prec_dtype": _get_amp_high_prec_dtype(),
  80. "amp_low_prec_dtype": _get_amp_low_prec_dtype(),
  81. }
  82. for key in env_vars1:
  83. assert (
  84. env_vars1[key] == env_vars2[key]
  85. ), "{} have been changed after test".format(key)