Browse Source

fix(mge/module): fix named_children of Sequential

GitOrigin-RevId: d3220fb361
tags/v1.0.0-rc1
Megvii Engine Team Xu Xinran 4 years ago
parent
commit
8269215959
2 changed files with 36 additions and 14 deletions
  1. +15
    -13
      python_module/megengine/module/sequential.py
  2. +21
    -1
      python_module/test/unit/module/test_module.py

+ 15
- 13
python_module/megengine/module/sequential.py View File

@@ -19,7 +19,7 @@ class Sequential(Module):
To make it easier to understand, here is a small example:

.. testcode::
from collections import OrderedDict
import numpy as np
import megengine.nn as nn
import megengine.nn.functional as F
@@ -29,34 +29,35 @@ class Sequential(Module):
label = nn.Input("label", shape=(batch_size,), dtype=np.int32, value=np.zeros(batch_size,))

data = data.reshape(batch_size, -1)
net = nn.Sequential(

net0 = nn.Sequential(
nn.Linear(28 * 28, 320),
nn.Linear(320, 500),
nn.Linear(500, 320),
nn.Linear(320, 10)
)
pred = net(data)

loss = F.cross_entropy_with_softmax(pred, label)
pred0 = net0(data)

modules = OrderedDict()
modules["fc0"] = nn.Linear(28 * 28, 320)
modules["fc1"] = nn.Linear(320, 10)
net1 = nn.Sequential(modules)

pred1 = net1(data)
"""

def __init__(self, *args):
super().__init__()
self.layer_keys = []
self.layer_values = []
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
# self.add_module(key, module)
setattr(self, key, module)
self.layer_keys.append(key)
self.layer_values.append(module)
else:
for idx, module in enumerate(args):
# self.add_module(str(idx), module)
setattr(self, str(idx), module)
self.layer_keys.append(str(idx))
self.layer_values.append(module)

def __getitem__(self, idx):
if isinstance(idx, slice):
@@ -64,11 +65,10 @@ class Sequential(Module):
OrderedDict(zip(self.layer_keys[idx], self.layer_values[idx]))
)
else:
return self.layer_values[idx]
return getattr(self, self.layer_keys[idx])

def __setitem__(self, idx, module):
key = self.layer_keys[idx]
self.layer_values[idx] = module
return setattr(self, key, module)

def __delitem__(self, idx):
@@ -76,11 +76,9 @@ class Sequential(Module):
for key in self.layer_keys[idx]:
delattr(self, key)
del self.layer_keys[idx]
del self.layer_values[idx]
else:
delattr(self, self.layer_keys[idx])
del self.layer_keys[idx]
del self.layer_values[idx]

def __len__(self):
return len(self.layer_keys)
@@ -88,6 +86,10 @@ class Sequential(Module):
def __iter__(self):
return iter(self.layer_values)

@property
def layer_values(self):
return [getattr(self, key) for key in self.layer_keys]

def forward(self, inp):
for layer in self.layer_values:
inp = layer(inp)


+ 21
- 1
python_module/test/unit/module/test_module.py View File

@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import tempfile
from collections import OrderedDict
from io import BytesIO

import numpy as np
@@ -16,7 +17,14 @@ from helpers import MLP
import megengine as mge
import megengine._internal as mgb
from megengine.core import Buffer, Parameter, Tensor, tensor
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential
from megengine.module import (
BatchNorm1d,
BatchNorm2d,
Conv2d,
Linear,
Module,
Sequential,
)
from megengine.quantization.quantize import quantize, quantize_qat
from megengine.test import assertTensorClose

@@ -238,6 +246,18 @@ def test_module_api_with_sequential():
]


def test_sequential_named_children():
modules = OrderedDict()
modules["name0"] = Linear(20, 10)
modules["name1"] = Linear(10, 5)
modules["name2"] = Linear(5, 1)
m = Sequential(modules)
l = list(m.named_children())
assert l[0][0] == "name0"
assert l[1][0] == "name1"
assert l[2][0] == "name2"


def test_state_dict():
data_shape = (2, 28)
data = tensor()


Loading…
Cancel
Save