Browse Source

feat(imperative/amp): add dimshuffle in set_format for nhwc

GitOrigin-RevId: 5ced9e1a31
release-1.10
Megvii Engine Team 3 years ago
parent
commit
261a5bce23
7 changed files with 22 additions and 33 deletions
  1. +4
    -10
      imperative/python/megengine/amp/convert_format.py
  2. +0
    -7
      imperative/python/megengine/core/_config.py
  3. +1
    -1
      imperative/python/megengine/functional/nn.py
  4. +0
    -1
      imperative/python/megengine/functional/tensor.py
  5. +8
    -10
      imperative/python/megengine/optimizer/optimizer.py
  6. +5
    -3
      imperative/python/test/unit/amp/test_convert_format.py
  7. +4
    -1
      imperative/src/impl/transformations/format.cpp

+ 4
- 10
imperative/python/megengine/amp/convert_format.py View File

@@ -23,23 +23,17 @@ def convert_tensor_format(x: Tensor, inplace: bool = True):
if not _is_nchw_format(x): if not _is_nchw_format(x):
return x return x


if x.ndim == 4:
pattern = (0, 2, 3, 1)
elif x.ndim == 5:
pattern = (0, 1, 3, 4, 2)
else:
if x.ndim != 4 and x.ndim != 5:
raise ValueError("Unsupport tensor ndim {}".format(x.ndim)) raise ValueError("Unsupport tensor ndim {}".format(x.ndim))
# TODO: use initialization from tensor after fixing format setting
if x.format != "nhwc": if x.format != "nhwc":
# hostvalue should still be valid, so no d2h cost.
data = x.numpy()
if inplace: if inplace:
# hostvalue should still be valid, so no d2h cost.
data = x.numpy()
# reset will destroy existed backward grad # reset will destroy existed backward grad
x[...] = Tensor(data, format="nhwc") x[...] = Tensor(data, format="nhwc")
else: else:
# use mge interface to maintain grad # use mge interface to maintain grad
x = F.transpose(x, pattern)
x.format = "nhwc"
x = Tensor(data, format="nhwc")
return x return x






+ 0
- 7
imperative/python/megengine/core/_config.py View File

@@ -181,7 +181,6 @@ def _reset_execution_config(
deterministic_kernel=None, deterministic_kernel=None,
async_level=None, async_level=None,
compute_mode=None, compute_mode=None,
auto_format_convert=None,
): ):
global _benchmark_kernel, _deterministic_kernel, __compute_mode global _benchmark_kernel, _deterministic_kernel, __compute_mode
orig_flags = ( orig_flags = (
@@ -189,7 +188,6 @@ def _reset_execution_config(
_deterministic_kernel, _deterministic_kernel,
get_option("async_level"), get_option("async_level"),
__compute_mode, __compute_mode,
get_auto_format_convert(),
) )
if benchmark_kernel is not None: if benchmark_kernel is not None:
_benchmark_kernel = benchmark_kernel _benchmark_kernel = benchmark_kernel
@@ -199,8 +197,6 @@ def _reset_execution_config(
set_option("async_level", async_level) set_option("async_level", async_level)
if compute_mode is not None: if compute_mode is not None:
__compute_mode = compute_mode __compute_mode = compute_mode
if auto_format_convert is not None:
set_auto_format_convert(auto_format_convert)


return orig_flags return orig_flags


@@ -211,7 +207,6 @@ def _override(
deterministic_kernel=None, deterministic_kernel=None,
async_level=None, async_level=None,
compute_mode=None, compute_mode=None,
auto_format_convert=None,
): ):
r"""A context manager that users can opt in by attaching the decorator to set r"""A context manager that users can opt in by attaching the decorator to set
the config of the global variable. the config of the global variable.
@@ -227,7 +222,6 @@ def _override(
deterministic_kernel = Fasle, deterministic_kernel = Fasle,
async_level=2, async_level=2,
compute_mode="float32", compute_mode="float32",
auto_format_convert=True,
) )
def train(): def train():
""" """
@@ -236,7 +230,6 @@ def _override(
deterministic_kernel=deterministic_kernel, deterministic_kernel=deterministic_kernel,
async_level=async_level, async_level=async_level,
compute_mode=compute_mode, compute_mode=compute_mode,
auto_format_convert=auto_format_convert,
) )
try: try:
yield yield


+ 1
- 1
imperative/python/megengine/functional/nn.py View File

@@ -1206,9 +1206,9 @@ def batch_norm(


if x is None: if x is None:
x = Const(value, inp.dtype, inp.device) x = Const(value, inp.dtype, inp.device)
x.format = inp.format
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), x, shape) (result,) = apply(builtin.Broadcast(), x, shape)
result.format = inp.format
return result return result
else: else:
assert x_ndim == 1 assert x_ndim == 1


+ 0
- 1
imperative/python/megengine/functional/tensor.py View File

@@ -274,7 +274,6 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
return x return x


# set x's format to use FormatTransformation rule for Broadcast. # set x's format to use FormatTransformation rule for Broadcast.
x.format = inp.format
return broadcast_to(x, inp.shape) return broadcast_to(x, inp.shape)






+ 8
- 10
imperative/python/megengine/optimizer/optimizer.py View File

@@ -91,14 +91,13 @@ class Optimizer(metaclass=ABCMeta):
else: else:
param_group["params"] = list(param_group["params"]) param_group["params"] = list(param_group["params"])


with _config._override(auto_format_convert=False):
for param in param_group["params"]:
if not isinstance(param, Parameter):
raise TypeError(
"optimizer can only optimize Parameters, but one of the params is "
+ str(type(param))
)
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format))
for param in param_group["params"]:
if not isinstance(param, Parameter):
raise TypeError(
"optimizer can only optimize Parameters, but one of the params is "
+ str(type(param))
)
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format))


for name, default in self._defaults.items(): for name, default in self._defaults.items():
if default is required and name not in param_group: if default is required and name not in param_group:
@@ -121,8 +120,7 @@ class Optimizer(metaclass=ABCMeta):


def _add_state(self, param, state_name, initializer=None): def _add_state(self, param, state_name, initializer=None):
if initializer is None: if initializer is None:
with _config._override(auto_format_convert=False):
initializer = np.zeros(param.shape, dtype=np.float32)
initializer = np.zeros(param.shape, dtype=np.float32)
state_dict = self._state.setdefault(param, {}) state_dict = self._state.setdefault(param, {})
assert state_name not in state_dict assert state_name not in state_dict
state = Tensor(initializer, no_cache=True, format=param.format) state = Tensor(initializer, no_cache=True, format=param.format)


+ 5
- 3
imperative/python/test/unit/amp/test_convert_format.py View File

@@ -10,7 +10,8 @@ import pytest


import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine import Parameter, Tensor, amp, config
from megengine import Parameter, Tensor, amp
from megengine.core._config import set_auto_format_convert




class MyModule(M.Module): class MyModule(M.Module):
@@ -56,5 +57,6 @@ def test_convert_module(is_inplace):
m = amp.convert_module_format(m, is_inplace) m = amp.convert_module_format(m, is_inplace)
for name, param in m.named_tensors(): for name, param in m.named_tensors():
assert param.format == "nhwc" assert param.format == "nhwc"
with config._override(auto_format_convert=False):
assert param.shape == expected_shape[name], name
set_auto_format_convert(False)
assert param.shape == expected_shape[name], name
set_auto_format_convert(True)

+ 4
- 1
imperative/src/impl/transformations/format.cpp View File

@@ -19,6 +19,9 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to(
const std::string& scope) const { const std::string& scope) const {
std::vector<int32_t> pattern; std::vector<int32_t> pattern;
Format format = tensor.format(); Format format = tensor.format();
if (format == target)
return as(tensor, target);

if (format == FT::NHWC && (target == FT::NCHW || target == FT::DEFAULT)) { if (format == FT::NHWC && (target == FT::NCHW || target == FT::DEFAULT)) {
// FIXME(czh): temporary fast path for group conv 5D weight. // FIXME(czh): temporary fast path for group conv 5D weight.
if (tensor.value().shape().cast<ShapeValue>().ndim == 5) { if (tensor.value().shape().cast<ShapeValue>().ndim == 5) {
@@ -618,7 +621,7 @@ ValueRefList FormatTransformation::apply_transformation(
} else if (auto* _op = op.as<SetFormat>()) { } else if (auto* _op = op.as<SetFormat>()) {
auto&& inp_ref = inputs[0].as_ref(m_value_type); auto&& inp_ref = inputs[0].as_ref(m_value_type);
mgb_assert(inp_ref, "Cannot set format for non-format Tensor."); mgb_assert(inp_ref, "Cannot set format for non-format Tensor.");
return {m_value_type.make(inp_ref->value(), _op->format())};
return {to(*inp_ref, _op->format().type(), "")};
} else if (op.is<Operator::IdentityLike>()) { } else if (op.is<Operator::IdentityLike>()) {
auto&& inp_ref = inputs[0].as_ref(m_value_type); auto&& inp_ref = inputs[0].as_ref(m_value_type);
if (inp_ref) { if (inp_ref) {


Loading…
Cancel
Save