GitOrigin-RevId: 5ced9e1a31
release-1.10
@@ -23,23 +23,17 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): | |||||
if not _is_nchw_format(x): | if not _is_nchw_format(x): | ||||
return x | return x | ||||
if x.ndim == 4: | |||||
pattern = (0, 2, 3, 1) | |||||
elif x.ndim == 5: | |||||
pattern = (0, 1, 3, 4, 2) | |||||
else: | |||||
if x.ndim != 4 and x.ndim != 5: | |||||
raise ValueError("Unsupport tensor ndim {}".format(x.ndim)) | raise ValueError("Unsupport tensor ndim {}".format(x.ndim)) | ||||
# TODO: use initialization from tensor after fixing format setting | |||||
if x.format != "nhwc": | if x.format != "nhwc": | ||||
# hostvalue should still be valid, so no d2h cost. | |||||
data = x.numpy() | |||||
if inplace: | if inplace: | ||||
# hostvalue should still be valid, so no d2h cost. | |||||
data = x.numpy() | |||||
# reset will destroy existed backward grad | # reset will destroy existed backward grad | ||||
x[...] = Tensor(data, format="nhwc") | x[...] = Tensor(data, format="nhwc") | ||||
else: | else: | ||||
# use mge interface to maintain grad | # use mge interface to maintain grad | ||||
x = F.transpose(x, pattern) | |||||
x.format = "nhwc" | |||||
x = Tensor(data, format="nhwc") | |||||
return x | return x | ||||
@@ -181,7 +181,6 @@ def _reset_execution_config( | |||||
deterministic_kernel=None, | deterministic_kernel=None, | ||||
async_level=None, | async_level=None, | ||||
compute_mode=None, | compute_mode=None, | ||||
auto_format_convert=None, | |||||
): | ): | ||||
global _benchmark_kernel, _deterministic_kernel, __compute_mode | global _benchmark_kernel, _deterministic_kernel, __compute_mode | ||||
orig_flags = ( | orig_flags = ( | ||||
@@ -189,7 +188,6 @@ def _reset_execution_config( | |||||
_deterministic_kernel, | _deterministic_kernel, | ||||
get_option("async_level"), | get_option("async_level"), | ||||
__compute_mode, | __compute_mode, | ||||
get_auto_format_convert(), | |||||
) | ) | ||||
if benchmark_kernel is not None: | if benchmark_kernel is not None: | ||||
_benchmark_kernel = benchmark_kernel | _benchmark_kernel = benchmark_kernel | ||||
@@ -199,8 +197,6 @@ def _reset_execution_config( | |||||
set_option("async_level", async_level) | set_option("async_level", async_level) | ||||
if compute_mode is not None: | if compute_mode is not None: | ||||
__compute_mode = compute_mode | __compute_mode = compute_mode | ||||
if auto_format_convert is not None: | |||||
set_auto_format_convert(auto_format_convert) | |||||
return orig_flags | return orig_flags | ||||
@@ -211,7 +207,6 @@ def _override( | |||||
deterministic_kernel=None, | deterministic_kernel=None, | ||||
async_level=None, | async_level=None, | ||||
compute_mode=None, | compute_mode=None, | ||||
auto_format_convert=None, | |||||
): | ): | ||||
r"""A context manager that users can opt in by attaching the decorator to set | r"""A context manager that users can opt in by attaching the decorator to set | ||||
the config of the global variable. | the config of the global variable. | ||||
@@ -227,7 +222,6 @@ def _override( | |||||
deterministic_kernel = Fasle, | deterministic_kernel = Fasle, | ||||
async_level=2, | async_level=2, | ||||
compute_mode="float32", | compute_mode="float32", | ||||
auto_format_convert=True, | |||||
) | ) | ||||
def train(): | def train(): | ||||
""" | """ | ||||
@@ -236,7 +230,6 @@ def _override( | |||||
deterministic_kernel=deterministic_kernel, | deterministic_kernel=deterministic_kernel, | ||||
async_level=async_level, | async_level=async_level, | ||||
compute_mode=compute_mode, | compute_mode=compute_mode, | ||||
auto_format_convert=auto_format_convert, | |||||
) | ) | ||||
try: | try: | ||||
yield | yield | ||||
@@ -1206,9 +1206,9 @@ def batch_norm( | |||||
if x is None: | if x is None: | ||||
x = Const(value, inp.dtype, inp.device) | x = Const(value, inp.dtype, inp.device) | ||||
x.format = inp.format | |||||
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | ||||
(result,) = apply(builtin.Broadcast(), x, shape) | (result,) = apply(builtin.Broadcast(), x, shape) | ||||
result.format = inp.format | |||||
return result | return result | ||||
else: | else: | ||||
assert x_ndim == 1 | assert x_ndim == 1 | ||||
@@ -274,7 +274,6 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||||
return x | return x | ||||
# set x's format to use FormatTransformation rule for Broadcast. | # set x's format to use FormatTransformation rule for Broadcast. | ||||
x.format = inp.format | |||||
return broadcast_to(x, inp.shape) | return broadcast_to(x, inp.shape) | ||||
@@ -91,14 +91,13 @@ class Optimizer(metaclass=ABCMeta): | |||||
else: | else: | ||||
param_group["params"] = list(param_group["params"]) | param_group["params"] = list(param_group["params"]) | ||||
with _config._override(auto_format_convert=False): | |||||
for param in param_group["params"]: | |||||
if not isinstance(param, Parameter): | |||||
raise TypeError( | |||||
"optimizer can only optimize Parameters, but one of the params is " | |||||
+ str(type(param)) | |||||
) | |||||
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||||
for param in param_group["params"]: | |||||
if not isinstance(param, Parameter): | |||||
raise TypeError( | |||||
"optimizer can only optimize Parameters, but one of the params is " | |||||
+ str(type(param)) | |||||
) | |||||
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||||
for name, default in self._defaults.items(): | for name, default in self._defaults.items(): | ||||
if default is required and name not in param_group: | if default is required and name not in param_group: | ||||
@@ -121,8 +120,7 @@ class Optimizer(metaclass=ABCMeta): | |||||
def _add_state(self, param, state_name, initializer=None): | def _add_state(self, param, state_name, initializer=None): | ||||
if initializer is None: | if initializer is None: | ||||
with _config._override(auto_format_convert=False): | |||||
initializer = np.zeros(param.shape, dtype=np.float32) | |||||
initializer = np.zeros(param.shape, dtype=np.float32) | |||||
state_dict = self._state.setdefault(param, {}) | state_dict = self._state.setdefault(param, {}) | ||||
assert state_name not in state_dict | assert state_name not in state_dict | ||||
state = Tensor(initializer, no_cache=True, format=param.format) | state = Tensor(initializer, no_cache=True, format=param.format) | ||||
@@ -10,7 +10,8 @@ import pytest | |||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.module as M | import megengine.module as M | ||||
from megengine import Parameter, Tensor, amp, config | |||||
from megengine import Parameter, Tensor, amp | |||||
from megengine.core._config import set_auto_format_convert | |||||
class MyModule(M.Module): | class MyModule(M.Module): | ||||
@@ -56,5 +57,6 @@ def test_convert_module(is_inplace): | |||||
m = amp.convert_module_format(m, is_inplace) | m = amp.convert_module_format(m, is_inplace) | ||||
for name, param in m.named_tensors(): | for name, param in m.named_tensors(): | ||||
assert param.format == "nhwc" | assert param.format == "nhwc" | ||||
with config._override(auto_format_convert=False): | |||||
assert param.shape == expected_shape[name], name | |||||
set_auto_format_convert(False) | |||||
assert param.shape == expected_shape[name], name | |||||
set_auto_format_convert(True) |
@@ -19,6 +19,9 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to( | |||||
const std::string& scope) const { | const std::string& scope) const { | ||||
std::vector<int32_t> pattern; | std::vector<int32_t> pattern; | ||||
Format format = tensor.format(); | Format format = tensor.format(); | ||||
if (format == target) | |||||
return as(tensor, target); | |||||
if (format == FT::NHWC && (target == FT::NCHW || target == FT::DEFAULT)) { | if (format == FT::NHWC && (target == FT::NCHW || target == FT::DEFAULT)) { | ||||
// FIXME(czh): temporary fast path for group conv 5D weight. | // FIXME(czh): temporary fast path for group conv 5D weight. | ||||
if (tensor.value().shape().cast<ShapeValue>().ndim == 5) { | if (tensor.value().shape().cast<ShapeValue>().ndim == 5) { | ||||
@@ -618,7 +621,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||||
} else if (auto* _op = op.as<SetFormat>()) { | } else if (auto* _op = op.as<SetFormat>()) { | ||||
auto&& inp_ref = inputs[0].as_ref(m_value_type); | auto&& inp_ref = inputs[0].as_ref(m_value_type); | ||||
mgb_assert(inp_ref, "Cannot set format for non-format Tensor."); | mgb_assert(inp_ref, "Cannot set format for non-format Tensor."); | ||||
return {m_value_type.make(inp_ref->value(), _op->format())}; | |||||
return {to(*inp_ref, _op->format().type(), "")}; | |||||
} else if (op.is<Operator::IdentityLike>()) { | } else if (op.is<Operator::IdentityLike>()) { | ||||
auto&& inp_ref = inputs[0].as_ref(m_value_type); | auto&& inp_ref = inputs[0].as_ref(m_value_type); | ||||
if (inp_ref) { | if (inp_ref) { | ||||