You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_autocast.py 1.3 kB

12345678910111213141516171819202122232425262728293031323334
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from megengine import amp
  9. from megengine.core.tensor import amp as origin_amp
  10. def test_grad_scaler():
  11. def check(enabled, low, high):
  12. assert amp.enabled == enabled
  13. assert origin_amp._enabled == enabled
  14. assert amp.low_prec_dtype == low
  15. assert origin_amp._get_amp_low_prec_dtype() == low
  16. assert amp.high_prec_dtype == high
  17. assert origin_amp._get_amp_high_prec_dtype() == high
  18. origin_enabled = amp.enabled
  19. origin_high = amp.high_prec_dtype
  20. origin_low = amp.low_prec_dtype
  21. with amp.autocast(low_prec_dtype="float16", high_prec_dtype="float32"):
  22. check(True, "float16", "float32")
  23. check(origin_enabled, origin_low, origin_high)
  24. amp.enabled = True
  25. amp.high_prec_dtype = "float32"
  26. amp.low_prec_dtype = "float16"
  27. check(True, "float16", "float32")
  28. amp.enabled = origin_enabled
  29. amp.high_prec_dtype = origin_high
  30. amp.low_prec_dtype = origin_low
  31. check(origin_enabled, origin_low, origin_high)