You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tm_config.py 1.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import contextlib
  2. from ..core._imperative_rt.core2 import (
  3. is_tracing_module,
  4. set_module_tracing,
  5. unset_module_tracing,
  6. )
  7. _enable_expr_checker = False
  8. _enable_default_checker = True
  9. def _get_expr_checker():
  10. return _enable_expr_checker
  11. def _get_default_checker():
  12. return _enable_default_checker
  13. def enable_expr_checker():
  14. r"""Call this function to check the result of each expr during tracing."""
  15. global _enable_expr_checker
  16. _enable_expr_checker = True
  17. _enable_default_checker = False
  18. def disable_default_checker():
  19. r"""Call this function to disable checking the final output of the model after tracing."""
  20. global _enable_default_checker
  21. _enable_default_checker = False
  22. _enable_graph_surgery_mode = False
  23. def _graph_surgery_mode():
  24. return _enable_graph_surgery_mode
  25. def _set_graph_surgery_mode(mode: bool):
  26. global _enable_graph_surgery_mode
  27. pre_mode = _enable_graph_surgery_mode
  28. _enable_graph_surgery_mode = mode
  29. return pre_mode
  30. @contextlib.contextmanager
  31. def _exclude_from_trace():
  32. is_tracing = is_tracing_module()
  33. if is_tracing:
  34. unset_module_tracing()
  35. yield
  36. if is_tracing:
  37. set_module_tracing()