You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_cache.py 993 B

12345678910111213141516171819202122232425262728293031323334
  1. from ..core._imperative_rt.core2 import Const
  2. from ..jit.tracing import is_tracing
  3. small_tensor_cache = {}
  4. def _get_scalar_tensor_with_value(value, dtype=None, device=None):
  5. global small_tensor_cache
  6. if is_tracing():
  7. ret = Const(value, dtype, device, None)
  8. else:
  9. cache_key = (value, dtype, device)
  10. if cache_key not in small_tensor_cache:
  11. ret = Const(value, dtype, device, None)
  12. small_tensor_cache[cache_key] = ret
  13. else:
  14. ret = small_tensor_cache[cache_key]
  15. return ret
  16. def get_scalar_zero(dtype=None, device=None):
  17. return _get_scalar_tensor_with_value(0, dtype, device)
  18. def get_scalar_zero_point_five(dtype=None, device=None):
  19. return _get_scalar_tensor_with_value(0.5, dtype, device)
  20. def get_scalar_one(dtype=None, device=None):
  21. return _get_scalar_tensor_with_value(1, dtype, device)
  22. def get_scalar_two(dtype=None, device=None):
  23. return _get_scalar_tensor_with_value(2, dtype, device)