You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conftest.py 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import os
  9. import platform
  10. import sys
  11. import pytest
  12. from megengine.core import _config as config
  13. from megengine.core import _trace_option as trace_option
  14. from megengine.core import get_option
  15. from megengine.core._imperative_rt.core2 import (
  16. _get_amp_dtype_autocast,
  17. _get_amp_high_prec_dtype,
  18. _get_amp_low_prec_dtype,
  19. _get_convert_inputs,
  20. )
  21. from megengine.core.tensor import amp
  22. from megengine.device import get_device_count
  23. sys.path.append(os.path.join(os.path.dirname(__file__), "helpers"))
  24. _ngpu = get_device_count("gpu")
  25. @pytest.fixture(autouse=True)
  26. def skip_by_ngpu(request):
  27. if request.node.get_closest_marker("require_ngpu"):
  28. require_ngpu = int(request.node.get_closest_marker("require_ngpu").args[0])
  29. if require_ngpu > _ngpu:
  30. pytest.skip("skipped for ngpu unsatisfied: {}".format(require_ngpu))
  31. @pytest.fixture(autouse=True)
  32. def skip_distributed(request):
  33. if request.node.get_closest_marker("distributed_isolated"):
  34. if platform.system() in ("Windows", "Darwin"):
  35. pytest.skip(
  36. "skipped for distributed unsupported at platform: {}".format(
  37. platform.system()
  38. )
  39. )
  40. @pytest.fixture(autouse=True)
  41. def run_around_tests():
  42. env_vars1 = {
  43. "symbolic_shape": trace_option.use_symbolic_shape(),
  44. "async_level": get_option("async_level"),
  45. "enable_drop": get_option("enable_drop"),
  46. "max_recompute_time": get_option("max_recompute_time"),
  47. "catch_worker_execption": get_option("catch_worker_execption"),
  48. "enable_host_compute": get_option("enable_host_compute"),
  49. # "record_computing_path": get_option("record_computing_path"),
  50. "disable_memory_forwarding": get_option("disable_memory_forwarding"),
  51. "enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"),
  52. "enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"),
  53. "dtr_eviction_threshold": get_option("dtr_eviction_threshold"),
  54. "dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"),
  55. "benchmark_kernel": config.benchmark_kernel,
  56. "deterministic_kernel": config.deterministic_kernel,
  57. "compute_mode": config._compute_mode,
  58. "conv_format": config._conv_format,
  59. "amp_enabled": amp.enabled,
  60. "convert_inputs": _get_convert_inputs(),
  61. "amp_dtype_autocast": _get_amp_dtype_autocast(),
  62. "amp_high_prec_dtype": _get_amp_high_prec_dtype(),
  63. "amp_low_prec_dtype": _get_amp_low_prec_dtype(),
  64. }
  65. yield
  66. env_vars2 = {
  67. "symbolic_shape": trace_option.use_symbolic_shape(),
  68. "async_level": get_option("async_level"),
  69. "enable_drop": get_option("enable_drop"),
  70. "max_recompute_time": get_option("max_recompute_time"),
  71. "catch_worker_execption": get_option("catch_worker_execption"),
  72. "enable_host_compute": get_option("enable_host_compute"),
  73. # "record_computing_path": get_option("record_computing_path"),
  74. "disable_memory_forwarding": get_option("disable_memory_forwarding"),
  75. "enable_dtr_auto_drop": get_option("enable_dtr_auto_drop"),
  76. "enable_dtr_sqrt_sampling": get_option("enable_dtr_sqrt_sampling"),
  77. "dtr_eviction_threshold": get_option("dtr_eviction_threshold"),
  78. "dtr_evictee_minimum_size": get_option("dtr_evictee_minimum_size"),
  79. "benchmark_kernel": config.benchmark_kernel,
  80. "deterministic_kernel": config.deterministic_kernel,
  81. "compute_mode": config._compute_mode,
  82. "conv_format": config._conv_format,
  83. "amp_enabled": amp.enabled,
  84. "convert_inputs": _get_convert_inputs(),
  85. "amp_dtype_autocast": _get_amp_dtype_autocast(),
  86. "amp_high_prec_dtype": _get_amp_high_prec_dtype(),
  87. "amp_low_prec_dtype": _get_amp_low_prec_dtype(),
  88. }
  89. for key in env_vars1:
  90. assert (
  91. env_vars1[key] == env_vars2[key]
  92. ), "{} have been changed after test".format(key)