GitOrigin-RevId: 57a3b9d418
release-1.10
@@ -51,14 +51,7 @@ class _Hashable: | |||||
return self.value == o.value | return self.value == o.value | ||||
def _matmul( | |||||
inp1, | |||||
inp2, | |||||
transpose_a=False, | |||||
transpose_b=False, | |||||
compute_mode="default", | |||||
format="default", | |||||
): | |||||
def _matmul(inp1, inp2, transpose_a=False, transpose_b=False, compute_mode="default"): | |||||
dim1, dim2 = inp1.ndim, inp2.ndim | dim1, dim2 = inp1.ndim, inp2.ndim | ||||
assert dim1 > 0 and dim2 > 0 | assert dim1 > 0 and dim2 > 0 | ||||
maxdim = dim1 if dim1 > dim2 else dim2 | maxdim = dim1 if dim1 > dim2 else dim2 | ||||
@@ -1206,6 +1206,7 @@ def batch_norm( | |||||
if x is None: | if x is None: | ||||
x = Const(value, inp.dtype, inp.device) | x = Const(value, inp.dtype, inp.device) | ||||
x.format = inp.format | |||||
shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | ||||
(result,) = apply(builtin.Broadcast(), x, shape) | (result,) = apply(builtin.Broadcast(), x, shape) | ||||
return result | return result | ||||
@@ -1227,14 +1228,14 @@ def batch_norm( | |||||
if not training: | if not training: | ||||
op = builtin.BatchNorm( | op = builtin.BatchNorm( | ||||
fwd_mode=BatchNorm.FwdMode.INFERENCE, param_dim="dim_1c11", epsilon=eps | |||||
fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="dim_1c11" | |||||
) | ) | ||||
ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | ||||
return ret | return ret | ||||
else: | else: | ||||
op = builtin.BatchNorm( | op = builtin.BatchNorm( | ||||
avg_factor=1 - momentum, param_dim="dim_1c11", epsilon=eps | |||||
avg_factor=1 - momentum, epsilon=eps, param_dim="dim_1c11" | |||||
) | ) | ||||
if has_mean or has_var: | if has_mean or has_var: | ||||
running_mean = make_full_if_none(running_mean, 0) | running_mean = make_full_if_none(running_mean, 0) | ||||
@@ -272,6 +272,9 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||||
x = Const(value, inp.dtype, inp.device) | x = Const(value, inp.dtype, inp.device) | ||||
if inp.ndim == 0: | if inp.ndim == 0: | ||||
return x | return x | ||||
# set x's format to use FormatTransformation rule for Broadcast. | |||||
x.format = inp.format | |||||
return broadcast_to(x, inp.shape) | return broadcast_to(x, inp.shape) | ||||
@@ -91,13 +91,14 @@ class Optimizer(metaclass=ABCMeta): | |||||
else: | else: | ||||
param_group["params"] = list(param_group["params"]) | param_group["params"] = list(param_group["params"]) | ||||
for param in param_group["params"]: | |||||
if not isinstance(param, Parameter): | |||||
raise TypeError( | |||||
"optimizer can only optimize Parameters, but one of the params is " | |||||
+ str(type(param)) | |||||
) | |||||
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||||
with _config._override(auto_format_convert=False): | |||||
for param in param_group["params"]: | |||||
if not isinstance(param, Parameter): | |||||
raise TypeError( | |||||
"optimizer can only optimize Parameters, but one of the params is " | |||||
+ str(type(param)) | |||||
) | |||||
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) | |||||
for name, default in self._defaults.items(): | for name, default in self._defaults.items(): | ||||
if default is required and name not in param_group: | if default is required and name not in param_group: | ||||
@@ -58,7 +58,6 @@ def run_around_tests(): | |||||
"benchmark_kernel": config.benchmark_kernel, | "benchmark_kernel": config.benchmark_kernel, | ||||
"deterministic_kernel": config.deterministic_kernel, | "deterministic_kernel": config.deterministic_kernel, | ||||
"compute_mode": config._compute_mode, | "compute_mode": config._compute_mode, | ||||
"conv_format": config._conv_format, | |||||
"amp_enabled": amp.enabled, | "amp_enabled": amp.enabled, | ||||
"convert_inputs": _get_convert_inputs(), | "convert_inputs": _get_convert_inputs(), | ||||
"amp_dtype_autocast": _get_amp_dtype_autocast(), | "amp_dtype_autocast": _get_amp_dtype_autocast(), | ||||
@@ -82,7 +81,6 @@ def run_around_tests(): | |||||
"benchmark_kernel": config.benchmark_kernel, | "benchmark_kernel": config.benchmark_kernel, | ||||
"deterministic_kernel": config.deterministic_kernel, | "deterministic_kernel": config.deterministic_kernel, | ||||
"compute_mode": config._compute_mode, | "compute_mode": config._compute_mode, | ||||
"conv_format": config._conv_format, | |||||
"amp_enabled": amp.enabled, | "amp_enabled": amp.enabled, | ||||
"convert_inputs": _get_convert_inputs(), | "convert_inputs": _get_convert_inputs(), | ||||
"amp_dtype_autocast": _get_amp_dtype_autocast(), | "amp_dtype_autocast": _get_amp_dtype_autocast(), | ||||
@@ -386,13 +386,6 @@ def test_backward_conv2d_dimshuffle(is_symbolic): | |||||
return F.transpose(self.conv(inp), (0, 2, 3, 1)).reshape(1, 18, 2) | return F.transpose(self.conv(inp), (0, 2, 3, 1)).reshape(1, 18, 2) | ||||
inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) | inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) | ||||
# x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||||
# w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc") | |||||
# b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc") | |||||
# grads = [ | |||||
# np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), | |||||
# np.array([12, 12, 12]).reshape((1, 1, 1, 3)), | |||||
# ] | |||||
_compare_backward([inp], Net(), is_symbolic) | _compare_backward([inp], Net(), is_symbolic) | ||||
@@ -403,37 +396,10 @@ def test_backward_groupconv2d_bn(is_symbolic): | |||||
super().__init__() | super().__init__() | ||||
self.conv0 = M.Conv2d(32, 256, 3, groups=32, stride=2) | self.conv0 = M.Conv2d(32, 256, 3, groups=32, stride=2) | ||||
self.conv1 = M.Conv2d(256, 2048, 3, groups=32, stride=2) | self.conv1 = M.Conv2d(256, 2048, 3, groups=32, stride=2) | ||||
# self.bn = M.BatchNorm2d(2048) | |||||
self.bn = M.BatchNorm2d(2048) | |||||
def forward(self, inp): | def forward(self, inp): | ||||
# test manually convert to NHWC, usually used in detection head | |||||
return self.conv1(self.conv0(inp)) | |||||
return self.bn(self.conv1(self.conv0(inp))) | |||||
inp = mge.tensor(np.ones(shape=(32, 32, 56, 56)).astype("float32")) | inp = mge.tensor(np.ones(shape=(32, 32, 56, 56)).astype("float32")) | ||||
_compare_backward([inp], Net(), is_symbolic) | _compare_backward([inp], Net(), is_symbolic) | ||||
# def func(x, w, b, bn_w, bn_b): | |||||
# x = F.conv2d(x, w, b, groups=2) | |||||
# x = F.batch_norm( | |||||
# x, | |||||
# running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||||
# running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), | |||||
# weight=bn_w, | |||||
# bias=bn_b, | |||||
# training=True, | |||||
# inplace=True, | |||||
# ) | |||||
# return x | |||||
# data = np.arange(0, 24).reshape((1, 2, 3, 4)) | |||||
# x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||||
# w = tensor(np.ones((2, 1, 1, 1, 1)), format="nhwc") | |||||
# b = tensor(np.ones((1, 1, 1, 2)), format="nhwc") | |||||
# bn_w = tensor(np.ones((1, 1, 1, 2)), format="nhwc") | |||||
# bn_b = tensor(np.ones((1, 1, 1, 2)), format="nhwc") | |||||
# grads = [ | |||||
# np.array([66, 210]).reshape((2, 1, 1, 1, 1)), | |||||
# np.array([12, 12]).reshape((1, 1, 1, 2)), | |||||
# np.array([12, 12]).reshape((1, 1, 1, 2)), | |||||
# np.array([12, 12]).reshape((1, 1, 1, 2)), | |||||
# ] | |||||
# _compare_backward(x, func, [w, b, bn_w, bn_b], grads, is_symbolic) |