|
|
@@ -14,14 +14,15 @@ import megengine |
|
|
|
import megengine.functional as F |
|
|
|
import megengine.module as M |
|
|
|
import megengine.utils.comp_graph_tools as cgtools |
|
|
|
from megengine.core.ops.builtin import Elemwise |
|
|
|
from megengine.core.tensor import megbrain_graph as mgb_graph |
|
|
|
from megengine.core.tensor.raw_tensor import as_raw_tensor |
|
|
|
from megengine.core.tensor.megbrain_graph import apply_normal_op |
|
|
|
from megengine.core.tensor.utils import astensor1d |
|
|
|
from megengine.jit import trace |
|
|
|
|
|
|
|
|
|
|
|
def make_dev_tensor(value, dtype=None, device=None): |
|
|
|
return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor() |
|
|
|
return megengine.tensor(value, dtype=dtype, device=device)._dev_tensor() |
|
|
|
|
|
|
|
|
|
|
|
def test_replace_vars(): |
|
|
@@ -30,10 +31,12 @@ def test_replace_vars(): |
|
|
|
device = "xpux" |
|
|
|
dtype = np.float32 |
|
|
|
a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g) |
|
|
|
const = g.make_const(1.234) |
|
|
|
a_plus_a = F.add(a.outputs[0], a.outputs[0]) |
|
|
|
a_plus_a_mul_const = F.mul(a_plus_a, const) |
|
|
|
rst = F.add(a_plus_a_mul_const, a.outputs[0]) |
|
|
|
const = g.make_const(1.234, device=device) |
|
|
|
add_op = Elemwise(Elemwise.Mode.ADD) |
|
|
|
mul_op = Elemwise(Elemwise.Mode.MUL) |
|
|
|
a_plus_a = apply_normal_op(add_op, a.outputs[0], a.outputs[0])[0] |
|
|
|
a_plus_a_mul_const = apply_normal_op(mul_op, a_plus_a, const)[0] |
|
|
|
rst = apply_normal_op(add_op, a_plus_a_mul_const, a.outputs[0])[0] |
|
|
|
(new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node}) |
|
|
|
out = mgb_graph.OutputNode(mgb_graph.VarNode(new)) |
|
|
|
func = g.compile(out.outputs[0]) |
|
|
@@ -50,11 +53,13 @@ def test_replace_oprs(): |
|
|
|
device = "xpux" |
|
|
|
dtype = np.float32 |
|
|
|
a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g) |
|
|
|
const = g.make_const(1.25) |
|
|
|
a_plus_a = F.add(a.outputs[0], a.outputs[0]) |
|
|
|
const = g.make_const(1.25, device=device) |
|
|
|
add_op = Elemwise(Elemwise.Mode.ADD) |
|
|
|
mul_op = Elemwise(Elemwise.Mode.MUL) |
|
|
|
a_plus_a = apply_normal_op(add_op, a.outputs[0], a.outputs[0])[0] |
|
|
|
old_opr = a_plus_a.op |
|
|
|
a_plus_a_mul_const = F.mul(a_plus_a, const) |
|
|
|
a_mul_a = F.mul(a.outputs[0], a.outputs[0]) |
|
|
|
a_plus_a_mul_const = apply_normal_op(mul_op, a_plus_a, const)[0] |
|
|
|
a_mul_a = apply_normal_op(mul_op, a.outputs[0], a.outputs[0])[0] |
|
|
|
new_opr = a_mul_a.op |
|
|
|
(new,) = cgtools.replace_oprs( |
|
|
|
[a_plus_a_mul_const._node], {old_opr._node: new_opr._node} |
|
|
|