|
|
@@ -19,7 +19,7 @@ class Sequential(Module): |
|
|
|
To make it easier to understand, here is a small example: |
|
|
|
|
|
|
|
.. testcode:: |
|
|
|
|
|
|
|
from collections import OrderedDict |
|
|
|
import numpy as np |
|
|
|
import megengine.nn as nn |
|
|
|
import megengine.nn.functional as F |
|
|
@@ -29,34 +29,35 @@ class Sequential(Module): |
|
|
|
label = nn.Input("label", shape=(batch_size,), dtype=np.int32, value=np.zeros(batch_size,)) |
|
|
|
|
|
|
|
data = data.reshape(batch_size, -1) |
|
|
|
net = nn.Sequential( |
|
|
|
|
|
|
|
net0 = nn.Sequential( |
|
|
|
nn.Linear(28 * 28, 320), |
|
|
|
nn.Linear(320, 500), |
|
|
|
nn.Linear(500, 320), |
|
|
|
nn.Linear(320, 10) |
|
|
|
) |
|
|
|
pred = net(data) |
|
|
|
|
|
|
|
loss = F.cross_entropy_with_softmax(pred, label) |
|
|
|
pred0 = net0(data) |
|
|
|
|
|
|
|
modules = OrderedDict() |
|
|
|
modules["fc0"] = nn.Linear(28 * 28, 320) |
|
|
|
modules["fc1"] = nn.Linear(320, 10) |
|
|
|
net1 = nn.Sequential(modules) |
|
|
|
|
|
|
|
pred1 = net1(data) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, *args): |
|
|
|
super().__init__() |
|
|
|
self.layer_keys = [] |
|
|
|
self.layer_values = [] |
|
|
|
if len(args) == 1 and isinstance(args[0], OrderedDict): |
|
|
|
for key, module in args[0].items(): |
|
|
|
# self.add_module(key, module) |
|
|
|
setattr(self, key, module) |
|
|
|
self.layer_keys.append(key) |
|
|
|
self.layer_values.append(module) |
|
|
|
else: |
|
|
|
for idx, module in enumerate(args): |
|
|
|
# self.add_module(str(idx), module) |
|
|
|
setattr(self, str(idx), module) |
|
|
|
self.layer_keys.append(str(idx)) |
|
|
|
self.layer_values.append(module) |
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
if isinstance(idx, slice): |
|
|
@@ -64,11 +65,10 @@ class Sequential(Module): |
|
|
|
OrderedDict(zip(self.layer_keys[idx], self.layer_values[idx])) |
|
|
|
) |
|
|
|
else: |
|
|
|
return self.layer_values[idx] |
|
|
|
return getattr(self, self.layer_keys[idx]) |
|
|
|
|
|
|
|
def __setitem__(self, idx, module): |
|
|
|
key = self.layer_keys[idx] |
|
|
|
self.layer_values[idx] = module |
|
|
|
return setattr(self, key, module) |
|
|
|
|
|
|
|
def __delitem__(self, idx): |
|
|
@@ -76,11 +76,9 @@ class Sequential(Module): |
|
|
|
for key in self.layer_keys[idx]: |
|
|
|
delattr(self, key) |
|
|
|
del self.layer_keys[idx] |
|
|
|
del self.layer_values[idx] |
|
|
|
else: |
|
|
|
delattr(self, self.layer_keys[idx]) |
|
|
|
del self.layer_keys[idx] |
|
|
|
del self.layer_values[idx] |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.layer_keys) |
|
|
@@ -88,6 +86,10 @@ class Sequential(Module): |
|
|
|
def __iter__(self): |
|
|
|
return iter(self.layer_values) |
|
|
|
|
|
|
|
@property |
|
|
|
def layer_values(self): |
|
|
|
return [getattr(self, key) for key in self.layer_keys] |
|
|
|
|
|
|
|
def forward(self, inp): |
|
|
|
for layer in self.layer_values: |
|
|
|
inp = layer(inp) |
|
|
|