You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_format.py 1.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from copy import deepcopy
  9. from .. import functional as F
  10. from ..core import _config
  11. from ..module import Module
  12. from ..tensor import Tensor
  13. def _is_nchw_format(param: Tensor):
  14. # TODO: use better condition
  15. return (param.ndim == 4 or param.ndim == 5) and param.format != "nhwc"
  16. def convert_tensor_format(x: Tensor, inplace: bool = True):
  17. """Convert NCHW Tensor to NHWC Tensor."""
  18. if not _is_nchw_format(x):
  19. return x
  20. if x.ndim == 4:
  21. pattern = (0, 2, 3, 1)
  22. elif x.ndim == 5:
  23. pattern = (0, 1, 3, 4, 2)
  24. else:
  25. raise ValueError("Unsupport tensor ndim {}".format(x.ndim))
  26. # TODO: use initialization from tensor after fixing format setting
  27. if x.format != "nhwc":
  28. if inplace:
  29. # hostvalue should still be valid, so no d2h cost.
  30. data = x.numpy()
  31. # reset will destroy existed backward grad
  32. x[...] = Tensor(data, format="nhwc")
  33. else:
  34. # use mge interface to maintain grad
  35. x = F.transpose(x, pattern)
  36. x.format = "nhwc"
  37. return x
  38. def convert_module_format(module: Module, inplace: bool = True):
  39. """Convert NCHW Module to NHWC Module."""
  40. if not inplace:
  41. module = deepcopy(module)
  42. for name, param in module.named_tensors():
  43. convert_tensor_format(param, inplace=True)
  44. return module