|
|
@@ -13,6 +13,7 @@ import numpy as np |
|
|
|
|
|
|
|
import megengine.functional as F |
|
|
|
import megengine.module as M |
|
|
|
import megengine.module.qat as qat |
|
|
|
from megengine.module.identity import Identity |
|
|
|
from megengine.traced_module import trace_module |
|
|
|
from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input |
|
|
@@ -199,6 +200,31 @@ def test_insert_module(): |
|
|
|
assert n.value is None |
|
|
|
|
|
|
|
|
|
|
|
def test_insert_qat_module(): |
|
|
|
class concat(qat.Concat): |
|
|
|
pass |
|
|
|
|
|
|
|
traced_module, x, expect = _init_block() |
|
|
|
graph = traced_module.graph |
|
|
|
self = graph.inputs[0] |
|
|
|
out = graph.outputs[0] |
|
|
|
setattr(traced_module, "cat_0", qat.Concat()) |
|
|
|
setattr(traced_module, "cat_1", concat()) |
|
|
|
|
|
|
|
with graph.insert_exprs(): |
|
|
|
x_0 = self.cat_0([out, out]) |
|
|
|
x_1 = self.cat_1([out, x_0]) |
|
|
|
graph.replace_node({out: x_1}) |
|
|
|
graph.compile() |
|
|
|
|
|
|
|
x = F.copy(x) |
|
|
|
np.testing.assert_allclose( |
|
|
|
F.concat([expect, expect, expect]), traced_module(x), atol=1e-6 |
|
|
|
) |
|
|
|
assert not hasattr(traced_module.cat_0, "graph") |
|
|
|
assert traced_module.cat_1.graph is not None |
|
|
|
|
|
|
|
|
|
|
|
def test_add_input_and_output(): |
|
|
|
traced_module, x, y = _init_module() |
|
|
|
|
|
|
|