You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_format.py 1.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from copy import deepcopy
  2. from .. import functional as F
  3. from ..core import _config
  4. from ..module import Module
  5. from ..tensor import Tensor
  6. def _is_nchw_format(param: Tensor):
  7. # TODO: use better condition
  8. return (param.ndim == 4 or param.ndim == 5) and param.format != "nhwc"
  9. def convert_tensor_format(x: Tensor, inplace: bool = True):
  10. """Convert NCHW Tensor to NHWC Tensor."""
  11. if not _is_nchw_format(x):
  12. return x
  13. if x.ndim != 4 and x.ndim != 5:
  14. raise ValueError("Unsupport tensor ndim {}".format(x.ndim))
  15. if x.format != "nhwc":
  16. # hostvalue should still be valid, so no d2h cost.
  17. data = x.numpy()
  18. if inplace:
  19. # reset will destroy existed backward grad
  20. x[...] = Tensor(data, format="nhwc")
  21. else:
  22. # use mge interface to maintain grad
  23. x = Tensor(data, format="nhwc")
  24. return x
  25. def convert_module_format(module: Module, inplace: bool = True):
  26. """Convert NCHW Module to NHWC Module."""
  27. if not inplace:
  28. module = deepcopy(module)
  29. for name, param in module.named_tensors():
  30. convert_tensor_format(param, inplace=True)
  31. return module