You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_autocast.py 991 B

123456789101112131415161718192021222324252627
  1. from megengine import amp
  2. from megengine.core.tensor import amp as origin_amp
  3. def test_autocast():
  4. def check(enabled, low, high):
  5. assert amp.enabled == enabled
  6. assert origin_amp._enabled == enabled
  7. assert amp.low_prec_dtype == low
  8. assert origin_amp._get_amp_low_prec_dtype() == low
  9. assert amp.high_prec_dtype == high
  10. assert origin_amp._get_amp_high_prec_dtype() == high
  11. origin_enabled = amp.enabled
  12. origin_high = amp.high_prec_dtype
  13. origin_low = amp.low_prec_dtype
  14. with amp.autocast(low_prec_dtype="float16", high_prec_dtype="float32"):
  15. check(True, "float16", "float32")
  16. check(origin_enabled, origin_low, origin_high)
  17. amp.enabled = True
  18. amp.high_prec_dtype = "float32"
  19. amp.low_prec_dtype = "float16"
  20. check(True, "float16", "float32")
  21. amp.enabled = origin_enabled
  22. amp.high_prec_dtype = origin_high
  23. amp.low_prec_dtype = origin_low
  24. check(origin_enabled, origin_low, origin_high)