Browse Source

fix(mge/module): fix named_children of Sequential

GitOrigin-RevId: d3220fb361
tags/v1.0.0-rc1
Megvii Engine Team Xu Xinran 4 years ago
parent
commit
8269215959
2 changed files with 36 additions and 14 deletions
  1. +15
    -13
      python_module/megengine/module/sequential.py
  2. +21
    -1
      python_module/test/unit/module/test_module.py

+ 15
- 13
python_module/megengine/module/sequential.py View File

@@ -19,7 +19,7 @@ class Sequential(Module):
To make it easier to understand, here is a small example: To make it easier to understand, here is a small example:


.. testcode:: .. testcode::
from collections import OrderedDict
import numpy as np import numpy as np
import megengine.nn as nn import megengine.nn as nn
import megengine.nn.functional as F import megengine.nn.functional as F
@@ -29,34 +29,35 @@ class Sequential(Module):
label = nn.Input("label", shape=(batch_size,), dtype=np.int32, value=np.zeros(batch_size,)) label = nn.Input("label", shape=(batch_size,), dtype=np.int32, value=np.zeros(batch_size,))


data = data.reshape(batch_size, -1) data = data.reshape(batch_size, -1)
net = nn.Sequential(

net0 = nn.Sequential(
nn.Linear(28 * 28, 320), nn.Linear(28 * 28, 320),
nn.Linear(320, 500),
nn.Linear(500, 320),
nn.Linear(320, 10) nn.Linear(320, 10)
) )
pred = net(data)


loss = F.cross_entropy_with_softmax(pred, label)
pred0 = net0(data)


modules = OrderedDict()
modules["fc0"] = nn.Linear(28 * 28, 320)
modules["fc1"] = nn.Linear(320, 10)
net1 = nn.Sequential(modules)

pred1 = net1(data)
""" """


def __init__(self, *args): def __init__(self, *args):
super().__init__() super().__init__()
self.layer_keys = [] self.layer_keys = []
self.layer_values = []
if len(args) == 1 and isinstance(args[0], OrderedDict): if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items(): for key, module in args[0].items():
# self.add_module(key, module) # self.add_module(key, module)
setattr(self, key, module) setattr(self, key, module)
self.layer_keys.append(key) self.layer_keys.append(key)
self.layer_values.append(module)
else: else:
for idx, module in enumerate(args): for idx, module in enumerate(args):
# self.add_module(str(idx), module) # self.add_module(str(idx), module)
setattr(self, str(idx), module) setattr(self, str(idx), module)
self.layer_keys.append(str(idx)) self.layer_keys.append(str(idx))
self.layer_values.append(module)


def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, slice): if isinstance(idx, slice):
@@ -64,11 +65,10 @@ class Sequential(Module):
OrderedDict(zip(self.layer_keys[idx], self.layer_values[idx])) OrderedDict(zip(self.layer_keys[idx], self.layer_values[idx]))
) )
else: else:
return self.layer_values[idx]
return getattr(self, self.layer_keys[idx])


def __setitem__(self, idx, module): def __setitem__(self, idx, module):
key = self.layer_keys[idx] key = self.layer_keys[idx]
self.layer_values[idx] = module
return setattr(self, key, module) return setattr(self, key, module)


def __delitem__(self, idx): def __delitem__(self, idx):
@@ -76,11 +76,9 @@ class Sequential(Module):
for key in self.layer_keys[idx]: for key in self.layer_keys[idx]:
delattr(self, key) delattr(self, key)
del self.layer_keys[idx] del self.layer_keys[idx]
del self.layer_values[idx]
else: else:
delattr(self, self.layer_keys[idx]) delattr(self, self.layer_keys[idx])
del self.layer_keys[idx] del self.layer_keys[idx]
del self.layer_values[idx]


def __len__(self): def __len__(self):
return len(self.layer_keys) return len(self.layer_keys)
@@ -88,6 +86,10 @@ class Sequential(Module):
def __iter__(self): def __iter__(self):
return iter(self.layer_values) return iter(self.layer_values)


@property
def layer_values(self):
return [getattr(self, key) for key in self.layer_keys]

def forward(self, inp): def forward(self, inp):
for layer in self.layer_values: for layer in self.layer_values:
inp = layer(inp) inp = layer(inp)


+ 21
- 1
python_module/test/unit/module/test_module.py View File

@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import tempfile import tempfile
from collections import OrderedDict
from io import BytesIO from io import BytesIO


import numpy as np import numpy as np
@@ -16,7 +17,14 @@ from helpers import MLP
import megengine as mge import megengine as mge
import megengine._internal as mgb import megengine._internal as mgb
from megengine.core import Buffer, Parameter, Tensor, tensor from megengine.core import Buffer, Parameter, Tensor, tensor
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential
from megengine.module import (
BatchNorm1d,
BatchNorm2d,
Conv2d,
Linear,
Module,
Sequential,
)
from megengine.quantization.quantize import quantize, quantize_qat from megengine.quantization.quantize import quantize, quantize_qat
from megengine.test import assertTensorClose from megengine.test import assertTensorClose


@@ -238,6 +246,18 @@ def test_module_api_with_sequential():
] ]




def test_sequential_named_children():
modules = OrderedDict()
modules["name0"] = Linear(20, 10)
modules["name1"] = Linear(10, 5)
modules["name2"] = Linear(5, 1)
m = Sequential(modules)
l = list(m.named_children())
assert l[0][0] == "name0"
assert l[1][0] == "name1"
assert l[2][0] == "name2"


def test_state_dict(): def test_state_dict():
data_shape = (2, 28) data_shape = (2, 28)
data = tensor() data = tensor()


Loading…
Cancel
Save