GitOrigin-RevId: ac13cc4659
release-1.1
@@ -55,6 +55,9 @@ class Softmax(Module): | |||
def forward(self, inputs): | |||
return softmax(inputs, self.axis) | |||
def _module_info_string(self) -> str: | |||
return "axis={axis}".format(axis=self.axis) | |||
class Sigmoid(Module): | |||
r""" | |||
@@ -113,6 +113,13 @@ class _BatchNorm(Module): | |||
return output | |||
def _module_info_string(self) -> str: | |||
s = ( | |||
"{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " | |||
"track_running_stats={track_running_stats}" | |||
) | |||
return s.format(**self.__dict__) | |||
class SyncBatchNorm(_BatchNorm): | |||
r""" | |||
@@ -70,6 +70,21 @@ class _ConvNd(Module): | |||
def _infer_bias_shape(self): | |||
pass | |||
def _module_info_string(self): | |||
s = "{in_channels}, {out_channels}, kernel_size={kernel_size}" | |||
if self.stride != (1,) * len(self.stride): | |||
s += ", stride={stride}" | |||
if self.padding != (0,) * len(self.padding): | |||
s += ", padding={padding}" | |||
if self.dilation != (1,) * len(self.dilation): | |||
s += ", dilation={dilation}" | |||
if self.groups != 1: | |||
s += ", groups={groups}" | |||
if self.bias is None: | |||
s += ", bias=False" | |||
return s.format(**self.__dict__) | |||
class Conv2d(_ConvNd): | |||
r"""Applies a 2D convolution over an input tensor. | |||
@@ -28,3 +28,6 @@ class Dropout(Module): | |||
return dropout(inputs, self.drop_prob, training=True) | |||
else: | |||
return inputs | |||
def _module_info_string(self) -> str: | |||
return "drop_prob={drop_prob}".format(drop_prob=self.drop_prob) |
@@ -78,3 +78,8 @@ class Linear(Module): | |||
def forward(self, x): | |||
return self._calc_linear(x, self.weight, self.bias) | |||
def _module_info_string(self) -> str: | |||
return "in_features={}, out_features={}, bias={}".format( | |||
self.in_features, self.out_features, self.bias is not None | |||
) |
@@ -69,6 +69,8 @@ class Module(metaclass=ABCMeta): | |||
self._forward_pre_hooks = OrderedDict() | |||
self._forward_hooks = OrderedDict() | |||
self._modules = [] | |||
@abstractmethod | |||
def forward(self, inputs): | |||
pass | |||
@@ -518,3 +520,57 @@ class Module(metaclass=ABCMeta): | |||
loaded.append(k) | |||
return set(loaded), set(skipped) | |||
def __setattr__(self, name: str, value): | |||
if _is_module(value): | |||
modules = self.__dict__.get("_modules") | |||
if modules is None: | |||
raise AttributeError( | |||
"cannot assign module before Module.__init__() call" | |||
) | |||
if name not in self.__dict__: | |||
modules.append(name) | |||
super().__setattr__(name, value) | |||
def __delattr__(self, name: str): | |||
if name in self.__dict__ and _is_module(self.__dict__[name]): | |||
modules = self.__dict__.get("_modules") | |||
modules.remove(name) | |||
super().__delattr__(name) | |||
def _module_info_string(self) -> str: | |||
r"""Set the extra representation of the module. | |||
""" | |||
return "" | |||
def __repr__(self): | |||
def add_indent(repr_str, num_spaces): | |||
s = repr_str.split("\n") | |||
# don't do anything for single-line stuff | |||
if len(s) == 1: | |||
return repr_str | |||
first = s.pop(0) | |||
s = [(num_spaces * " ") + line for line in s] | |||
s = "\n".join(s) | |||
s = first + "\n" + s | |||
return s | |||
extra_lines = [] | |||
extra_repr = self._module_info_string() | |||
if extra_repr: | |||
extra_lines = extra_repr.split("\n") | |||
child_lines = [ | |||
"(" + name + "): " + add_indent(repr(self.__dict__[name]), 2) | |||
for name in self._modules | |||
] | |||
lines = extra_lines + child_lines | |||
main_str = self.__class__.__name__ + "(" | |||
if lines: | |||
# simple one-liner info, which most builtin Modules will use | |||
if len(extra_lines) == 1 and not child_lines: | |||
main_str += extra_lines[0] | |||
else: | |||
main_str += "\n " + "\n ".join(lines) + "\n" | |||
main_str += ")" | |||
return main_str |
@@ -29,6 +29,11 @@ class _PoolNd(Module): | |||
def forward(self, inp): | |||
pass | |||
def _module_info_string(self) -> str: | |||
return "kernel_size={kernel_size}, stride={stride}, padding={padding}".format( | |||
**self.__dict__ | |||
) | |||
class MaxPool2d(_PoolNd): | |||
r"""Applies a 2D max pooling over an input. | |||
@@ -21,9 +21,12 @@ from megengine.module import ( | |||
BatchNorm1d, | |||
BatchNorm2d, | |||
Conv2d, | |||
Dropout, | |||
Linear, | |||
MaxPool2d, | |||
Module, | |||
Sequential, | |||
Softmax, | |||
) | |||
from megengine.quantization.quantize import quantize, quantize_qat | |||
from megengine.test import assertTensorClose | |||
@@ -609,3 +612,111 @@ def test_load_quantized(): | |||
assertTensorClose( | |||
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6 | |||
) | |||
def test_repr_basic(): | |||
# test whether __repr__ can output correct information | |||
class ConvModel(Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv1 = Conv2d(3, 128, 3, stride=2, bias=False) | |||
self.conv2 = Conv2d(3, 128, 3, padding=1, bias=False) | |||
self.conv3 = Conv2d(3, 128, 3, dilation=2, bias=False) | |||
self.bn1 = BatchNorm2d(128) | |||
self.bn2 = BatchNorm1d(128) | |||
self.dropout = Dropout(drop_prob=0.1) | |||
self.softmax = Softmax(axis=100) | |||
self.pooling = MaxPool2d(kernel_size=2, padding=0) | |||
self.submodule1 = Sequential(Dropout(drop_prob=0.1), Softmax(axis=100),) | |||
self.fc1 = Linear(512, 1024) | |||
def forward(self, inputs): | |||
pass | |||
ground_truth = ( | |||
"ConvModel(\n" | |||
" (conv1): Conv2d(3, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)\n" | |||
" (conv2): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n" | |||
" (conv3): Conv2d(3, 128, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n" | |||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" | |||
" (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" | |||
" (dropout): Dropout(drop_prob=0.1)\n (softmax): Softmax(axis=100)\n" | |||
" (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0)\n" | |||
" (submodule1): Sequential(\n" | |||
" (0): Dropout(drop_prob=0.1)\n" | |||
" (1): Softmax(axis=100)\n )\n" | |||
" (fc1): Linear(in_features=512, out_features=1024, bias=True)\n" | |||
")" | |||
) | |||
net = ConvModel() | |||
output = net.__repr__() | |||
assert output == ground_truth | |||
def test_repr_module_reassign(): | |||
# test whether __repr__ can deal with module reassign | |||
class ConvModel1(Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv1 = Conv2d(3, 128, 3, bias=False) | |||
self.conv2 = Conv2d(3, 128, 3, padding=1, bias=False) | |||
self.conv1 = Conv2d(3, 256, 3, dilation=2, bias=False) | |||
def forward(self, inputs): | |||
pass | |||
ground_truth = ( | |||
"ConvModel1(\n" | |||
" (conv1): Conv2d(3, 256, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n" | |||
" (conv2): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n" | |||
")" | |||
) | |||
net = ConvModel1() | |||
output = net.__repr__() | |||
assert output == ground_truth | |||
def test_repr_module_rereference(): | |||
# test whether __repr__ can deal with module re-reference | |||
class ConvModel2(Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv1 = Conv2d(3, 128, 3, bias=False) | |||
self.conv2 = self.conv1 | |||
self.conv3 = self.conv1 | |||
def forward(self, inputs): | |||
pass | |||
ground_truth = ( | |||
"ConvModel2(\n" | |||
" (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" | |||
" (conv2): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" | |||
" (conv3): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" | |||
")" | |||
) | |||
net = ConvModel2() | |||
output = net.__repr__() | |||
assert output == ground_truth | |||
def test_repr_module_delete(): | |||
# test whether __repr__ can deal with module delete | |||
class ConvModel3(Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv1 = Conv2d(3, 128, 3, bias=False) | |||
self.softmax = Softmax(100) | |||
def forward(self, inputs): | |||
pass | |||
ground_truth = ( | |||
"ConvModel3(\n" | |||
" (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" | |||
")" | |||
) | |||
net = ConvModel3() | |||
del net.softmax | |||
output = net.__repr__() | |||
assert output == ground_truth |