|
|
@@ -1,3 +1,4 @@ |
|
|
|
import collections |
|
|
|
import math |
|
|
|
from copy import deepcopy |
|
|
|
|
|
|
@@ -27,6 +28,31 @@ def test_module_stats(): |
|
|
|
assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif( |
|
|
|
use_symbolic_shape(), reason="This test do not support symbolic shape.", |
|
|
|
) |
|
|
|
def test_other_input_module_state(): |
|
|
|
a = [1, 2] |
|
|
|
b = {"1": 1, "2": 2} |
|
|
|
nt = collections.namedtuple("nt", ["n", "t"]) |
|
|
|
_nt = nt(n=1, t=2) |
|
|
|
net = FakeNet() |
|
|
|
net(a) |
|
|
|
net(b) |
|
|
|
net(_nt) |
|
|
|
|
|
|
|
|
|
|
|
class FakeNet(M.Module): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
assert isinstance( |
|
|
|
x, |
|
|
|
(np.ndarray, collections.abc.Mapping, collections.abc.Sequence, mge.Tensor), |
|
|
|
) or (isinstance(x, tuple) and hasattr(x, "_fields")) |
|
|
|
|
|
|
|
|
|
|
|
class BasicBlock(M.Module): |
|
|
|
expansion = 1 |
|
|
|
|
|
|
|