GitOrigin-RevId: ac13cc4659
release-1.1
@@ -55,6 +55,9 @@ class Softmax(Module): | |||||
def forward(self, inputs): | def forward(self, inputs): | ||||
return softmax(inputs, self.axis) | return softmax(inputs, self.axis) | ||||
def _module_info_string(self) -> str: | |||||
return "axis={axis}".format(axis=self.axis) | |||||
class Sigmoid(Module): | class Sigmoid(Module): | ||||
r""" | r""" | ||||
@@ -113,6 +113,13 @@ class _BatchNorm(Module): | |||||
return output | return output | ||||
def _module_info_string(self) -> str: | |||||
s = ( | |||||
"{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " | |||||
"track_running_stats={track_running_stats}" | |||||
) | |||||
return s.format(**self.__dict__) | |||||
class SyncBatchNorm(_BatchNorm): | class SyncBatchNorm(_BatchNorm): | ||||
r""" | r""" | ||||
@@ -70,6 +70,21 @@ class _ConvNd(Module): | |||||
def _infer_bias_shape(self): | def _infer_bias_shape(self): | ||||
pass | pass | ||||
def _module_info_string(self): | |||||
s = "{in_channels}, {out_channels}, kernel_size={kernel_size}" | |||||
if self.stride != (1,) * len(self.stride): | |||||
s += ", stride={stride}" | |||||
if self.padding != (0,) * len(self.padding): | |||||
s += ", padding={padding}" | |||||
if self.dilation != (1,) * len(self.dilation): | |||||
s += ", dilation={dilation}" | |||||
if self.groups != 1: | |||||
s += ", groups={groups}" | |||||
if self.bias is None: | |||||
s += ", bias=False" | |||||
return s.format(**self.__dict__) | |||||
class Conv2d(_ConvNd): | class Conv2d(_ConvNd): | ||||
r"""Applies a 2D convolution over an input tensor. | r"""Applies a 2D convolution over an input tensor. | ||||
@@ -28,3 +28,6 @@ class Dropout(Module): | |||||
return dropout(inputs, self.drop_prob, training=True) | return dropout(inputs, self.drop_prob, training=True) | ||||
else: | else: | ||||
return inputs | return inputs | ||||
def _module_info_string(self) -> str: | |||||
return "drop_prob={drop_prob}".format(drop_prob=self.drop_prob) |
@@ -78,3 +78,8 @@ class Linear(Module): | |||||
def forward(self, x): | def forward(self, x): | ||||
return self._calc_linear(x, self.weight, self.bias) | return self._calc_linear(x, self.weight, self.bias) | ||||
def _module_info_string(self) -> str: | |||||
return "in_features={}, out_features={}, bias={}".format( | |||||
self.in_features, self.out_features, self.bias is not None | |||||
) |
@@ -69,6 +69,8 @@ class Module(metaclass=ABCMeta): | |||||
self._forward_pre_hooks = OrderedDict() | self._forward_pre_hooks = OrderedDict() | ||||
self._forward_hooks = OrderedDict() | self._forward_hooks = OrderedDict() | ||||
self._modules = [] | |||||
@abstractmethod | @abstractmethod | ||||
def forward(self, inputs): | def forward(self, inputs): | ||||
pass | pass | ||||
@@ -518,3 +520,57 @@ class Module(metaclass=ABCMeta): | |||||
loaded.append(k) | loaded.append(k) | ||||
return set(loaded), set(skipped) | return set(loaded), set(skipped) | ||||
def __setattr__(self, name: str, value): | |||||
if _is_module(value): | |||||
modules = self.__dict__.get("_modules") | |||||
if modules is None: | |||||
raise AttributeError( | |||||
"cannot assign module before Module.__init__() call" | |||||
) | |||||
if name not in self.__dict__: | |||||
modules.append(name) | |||||
super().__setattr__(name, value) | |||||
def __delattr__(self, name: str): | |||||
if name in self.__dict__ and _is_module(self.__dict__[name]): | |||||
modules = self.__dict__.get("_modules") | |||||
modules.remove(name) | |||||
super().__delattr__(name) | |||||
def _module_info_string(self) -> str: | |||||
r"""Set the extra representation of the module. | |||||
""" | |||||
return "" | |||||
def __repr__(self): | |||||
def add_indent(repr_str, num_spaces): | |||||
s = repr_str.split("\n") | |||||
# don't do anything for single-line stuff | |||||
if len(s) == 1: | |||||
return repr_str | |||||
first = s.pop(0) | |||||
s = [(num_spaces * " ") + line for line in s] | |||||
s = "\n".join(s) | |||||
s = first + "\n" + s | |||||
return s | |||||
extra_lines = [] | |||||
extra_repr = self._module_info_string() | |||||
if extra_repr: | |||||
extra_lines = extra_repr.split("\n") | |||||
child_lines = [ | |||||
"(" + name + "): " + add_indent(repr(self.__dict__[name]), 2) | |||||
for name in self._modules | |||||
] | |||||
lines = extra_lines + child_lines | |||||
main_str = self.__class__.__name__ + "(" | |||||
if lines: | |||||
# simple one-liner info, which most builtin Modules will use | |||||
if len(extra_lines) == 1 and not child_lines: | |||||
main_str += extra_lines[0] | |||||
else: | |||||
main_str += "\n " + "\n ".join(lines) + "\n" | |||||
main_str += ")" | |||||
return main_str |
@@ -29,6 +29,11 @@ class _PoolNd(Module): | |||||
def forward(self, inp): | def forward(self, inp): | ||||
pass | pass | ||||
def _module_info_string(self) -> str: | |||||
return "kernel_size={kernel_size}, stride={stride}, padding={padding}".format( | |||||
**self.__dict__ | |||||
) | |||||
class MaxPool2d(_PoolNd): | class MaxPool2d(_PoolNd): | ||||
r"""Applies a 2D max pooling over an input. | r"""Applies a 2D max pooling over an input. | ||||
@@ -21,9 +21,12 @@ from megengine.module import ( | |||||
BatchNorm1d, | BatchNorm1d, | ||||
BatchNorm2d, | BatchNorm2d, | ||||
Conv2d, | Conv2d, | ||||
Dropout, | |||||
Linear, | Linear, | ||||
MaxPool2d, | |||||
Module, | Module, | ||||
Sequential, | Sequential, | ||||
Softmax, | |||||
) | ) | ||||
from megengine.quantization.quantize import quantize, quantize_qat | from megengine.quantization.quantize import quantize, quantize_qat | ||||
from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
@@ -609,3 +612,111 @@ def test_load_quantized(): | |||||
assertTensorClose( | assertTensorClose( | ||||
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6 | pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6 | ||||
) | ) | ||||
def test_repr_basic(): | |||||
# test whether __repr__ can output correct information | |||||
class ConvModel(Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.conv1 = Conv2d(3, 128, 3, stride=2, bias=False) | |||||
self.conv2 = Conv2d(3, 128, 3, padding=1, bias=False) | |||||
self.conv3 = Conv2d(3, 128, 3, dilation=2, bias=False) | |||||
self.bn1 = BatchNorm2d(128) | |||||
self.bn2 = BatchNorm1d(128) | |||||
self.dropout = Dropout(drop_prob=0.1) | |||||
self.softmax = Softmax(axis=100) | |||||
self.pooling = MaxPool2d(kernel_size=2, padding=0) | |||||
self.submodule1 = Sequential(Dropout(drop_prob=0.1), Softmax(axis=100),) | |||||
self.fc1 = Linear(512, 1024) | |||||
def forward(self, inputs): | |||||
pass | |||||
ground_truth = ( | |||||
"ConvModel(\n" | |||||
" (conv1): Conv2d(3, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)\n" | |||||
" (conv2): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n" | |||||
" (conv3): Conv2d(3, 128, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n" | |||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" | |||||
" (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" | |||||
" (dropout): Dropout(drop_prob=0.1)\n (softmax): Softmax(axis=100)\n" | |||||
" (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0)\n" | |||||
" (submodule1): Sequential(\n" | |||||
" (0): Dropout(drop_prob=0.1)\n" | |||||
" (1): Softmax(axis=100)\n )\n" | |||||
" (fc1): Linear(in_features=512, out_features=1024, bias=True)\n" | |||||
")" | |||||
) | |||||
net = ConvModel() | |||||
output = net.__repr__() | |||||
assert output == ground_truth | |||||
def test_repr_module_reassign(): | |||||
# test whether __repr__ can deal with module reassign | |||||
class ConvModel1(Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.conv1 = Conv2d(3, 128, 3, bias=False) | |||||
self.conv2 = Conv2d(3, 128, 3, padding=1, bias=False) | |||||
self.conv1 = Conv2d(3, 256, 3, dilation=2, bias=False) | |||||
def forward(self, inputs): | |||||
pass | |||||
ground_truth = ( | |||||
"ConvModel1(\n" | |||||
" (conv1): Conv2d(3, 256, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n" | |||||
" (conv2): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n" | |||||
")" | |||||
) | |||||
net = ConvModel1() | |||||
output = net.__repr__() | |||||
assert output == ground_truth | |||||
def test_repr_module_rereference(): | |||||
# test whether __repr__ can deal with module re-reference | |||||
class ConvModel2(Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.conv1 = Conv2d(3, 128, 3, bias=False) | |||||
self.conv2 = self.conv1 | |||||
self.conv3 = self.conv1 | |||||
def forward(self, inputs): | |||||
pass | |||||
ground_truth = ( | |||||
"ConvModel2(\n" | |||||
" (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" | |||||
" (conv2): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" | |||||
" (conv3): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" | |||||
")" | |||||
) | |||||
net = ConvModel2() | |||||
output = net.__repr__() | |||||
assert output == ground_truth | |||||
def test_repr_module_delete(): | |||||
# test whether __repr__ can deal with module delete | |||||
class ConvModel3(Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.conv1 = Conv2d(3, 128, 3, bias=False) | |||||
self.softmax = Softmax(100) | |||||
def forward(self, inputs): | |||||
pass | |||||
ground_truth = ( | |||||
"ConvModel3(\n" | |||||
" (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" | |||||
")" | |||||
) | |||||
net = ConvModel3() | |||||
del net.softmax | |||||
output = net.__repr__() | |||||
assert output == ground_truth |