You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_format.py 1.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from copy import deepcopy
  9. from .. import functional as F
  10. from ..module import Module
  11. from ..tensor import Tensor
  12. def _is_nchw_format(param: Tensor):
  13. # TODO: use better condition
  14. return (len(param.shape) == 4 or len(param.shape) == 5) and param.format != "nhwc"
  15. def convert_tensor_format(x: Tensor, inplace: bool = True):
  16. """Convert NCHW Tensor to NHWC Tensor."""
  17. if x.ndim == 4:
  18. pattern = (0, 2, 3, 1)
  19. elif x.ndim == 5:
  20. pattern = (0, 1, 3, 4, 2)
  21. else:
  22. raise ValueError("Unsupport tensor ndim {}".format(x.ndim))
  23. # TODO: use initialization from tensor after fixing format setting
  24. if inplace:
  25. x[...] = Tensor(x.numpy().transpose(*pattern), format="nhwc")
  26. else:
  27. x = Tensor(x.numpy().transpose(*pattern), format="nhwc")
  28. return x
  29. def convert_module_format(module: Module, inplace: bool = True):
  30. """Convert NCHW Module to NHWC Module."""
  31. if not inplace:
  32. module = deepcopy(module)
  33. for name, param in module.named_tensors():
  34. if _is_nchw_format(param):
  35. # hostvalue should still be valid, so no d2h cost.
  36. convert_tensor_format(param, inplace=True)
  37. return module