You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cgtools.py 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import io
  2. import numpy as np
  3. import pytest
  4. import megengine
  5. import megengine.functional as F
  6. import megengine.module as M
  7. import megengine.utils.comp_graph_tools as cgtools
  8. from megengine.core.ops.builtin import Elemwise
  9. from megengine.core.tensor import megbrain_graph as mgb_graph
  10. from megengine.core.tensor.megbrain_graph import apply_normal_varnode
  11. from megengine.core.tensor.utils import astensor1d
  12. from megengine.jit import trace
  13. from megengine.utils.network import Network
  14. def make_dev_tensor(value, dtype=None, device=None):
  15. return megengine.tensor(value, dtype=dtype, device=device)._dev_tensor()
  16. def test_replace_vars():
  17. g = mgb_graph.Graph()
  18. g.options.async_exec_level = 0b100
  19. device = "xpux"
  20. dtype = np.float32
  21. a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
  22. const = g.make_const(1.234, device=device)
  23. add_op = Elemwise(Elemwise.Mode.ADD)
  24. mul_op = Elemwise(Elemwise.Mode.MUL)
  25. a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0]
  26. a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0]
  27. rst = apply_normal_varnode(add_op, a_plus_a_mul_const, a.outputs[0])[0]
  28. (new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node})
  29. out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
  30. func = g.compile(out.outputs[0])
  31. func.execute()
  32. x = make_dev_tensor(5.0, device=device)
  33. a.set_value(x)
  34. res = out.get_value().numpy()
  35. np.testing.assert_equal(res, np.array([105.0]))
  36. def test_replace_oprs():
  37. g = mgb_graph.Graph()
  38. g.options.async_exec_level = 0b100
  39. device = "xpux"
  40. dtype = np.float32
  41. a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
  42. const = g.make_const(1.25, device=device)
  43. add_op = Elemwise(Elemwise.Mode.ADD)
  44. mul_op = Elemwise(Elemwise.Mode.MUL)
  45. a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0]
  46. old_opr = a_plus_a.op
  47. a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0]
  48. a_mul_a = apply_normal_varnode(mul_op, a.outputs[0], a.outputs[0])[0]
  49. new_opr = a_mul_a.op
  50. (new,) = cgtools.replace_oprs(
  51. [a_plus_a_mul_const._node], {old_opr._node: new_opr._node}
  52. )
  53. out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
  54. func = g.compile(out.outputs[0])
  55. func.execute()
  56. x = make_dev_tensor(5.0, device=device)
  57. a.set_value(x)
  58. res = out.get_value().numpy()
  59. np.testing.assert_equal(res, np.array([5.0 * 5.0 * 1.25]))
  60. def test_graph_traversal():
  61. net = M.Conv2d(3, 32, 3)
  62. @trace(symbolic=True, capture_as_const=True)
  63. def fun(data):
  64. x = net(data)
  65. return x
  66. data = np.random.random([1, 3, 224, 224]).astype(np.float32)
  67. for _ in range(3):
  68. fun(megengine.tensor(data))
  69. file = io.BytesIO()
  70. fun.dump(file, optimize_for_inference=False)
  71. file.seek(0)
  72. outputs = mgb_graph.load_graph(file).output_vars_list
  73. _, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs)
  74. input_var = map_vars[1]
  75. _, var_idx = var2oprs[input_var.id][0]
  76. assert var_idx == 0
  77. def test_load_refcnt():
  78. graph = mgb_graph.Graph()
  79. varnode = graph.make_const(0)
  80. buf, _ = mgb_graph.dump_graph([varnode])
  81. ret = mgb_graph.load_graph(io.BytesIO(buf))
  82. graph, (varnode,) = ret.graph, ret.output_vars_list
  83. del ret
  84. del graph
  85. varnode.owner
  86. def test_get_opr_seq():
  87. class Net(M.Module):
  88. def __init__(self):
  89. super().__init__()
  90. self.data = megengine.tensor(
  91. np.random.random((1, 1, 4, 4)), dtype=np.float32
  92. )
  93. def forward(self, input):
  94. A = input.shape[0]
  95. shape = astensor1d((A, A), self.data, dtype="int32", device=input.device)
  96. x = F.reshape(self.data, shape)
  97. o = input + x
  98. return o
  99. net = Net()
  100. input = megengine.tensor(np.random.random((4, 4)), dtype=np.float32)
  101. @trace(symbolic=True, capture_as_const=True)
  102. def func(inp, *, net=None):
  103. return net(inp)
  104. func(input, net=net)
  105. file = io.BytesIO()
  106. func.dump(file, optimize_for_inference=False)
  107. file.seek(0)
  108. outputs = mgb_graph.load_graph(file).output_vars_list
  109. seq_1 = cgtools.get_oprs_seq(outputs, True)
  110. assert len(seq_1) == 5
  111. seq_2 = cgtools.get_oprs_seq(outputs, False)
  112. assert len(seq_2) == 6
  113. def test_topological_sort():
  114. @trace(symbolic=True, capture_as_const=True)
  115. def func(x, y):
  116. a = x + y
  117. a1 = F.relu(a)
  118. a2 = F.abs(a)
  119. a3 = F.ceil(a) * 2
  120. a4 = F.floor(a)
  121. r = a1 - a2
  122. r1 = a3 / a4
  123. return r, r1
  124. file = io.BytesIO()
  125. func(megengine.tensor(1.0), megengine.tensor(2.0))
  126. func.dump(
  127. file, optimize_for_inference=False, keep_opr_name=True, keep_opr_priority=True
  128. )
  129. file.seek(0)
  130. g = Network.load(file)
  131. oprseq1 = g.all_oprs
  132. gt = [
  133. "Host2DeviceCopy",
  134. "Host2DeviceCopy",
  135. "ADD",
  136. "RELU",
  137. "ABS",
  138. "CEIL",
  139. "ImmutableTensor",
  140. "MUL",
  141. "FLOOR",
  142. "SUB",
  143. "TRUE_DIV",
  144. ]
  145. for op, mode in zip(oprseq1, gt):
  146. if op.type == "Elemwise":
  147. assert op.params["mode"] == mode
  148. else:
  149. assert op.type == mode
  150. def test_graph_function():
  151. class Net(M.Module):
  152. def forward(self, a, b):
  153. return a - b, a * b
  154. net = Net()
  155. @trace(symbolic=True, capture_as_const=True)
  156. def function(a, b, *, net=None):
  157. return net(a, b)
  158. a = np.array([1, 2, 3])
  159. b = np.array([3])
  160. x, y = function(megengine.tensor(a), megengine.tensor(b), net=net)
  161. file = io.BytesIO()
  162. function.dump(
  163. file,
  164. arg_names=["a", "b"],
  165. output_names=["x", "y"],
  166. optimize_for_inference=False,
  167. )
  168. file.seek(0)
  169. graph = cgtools.GraphInference(file)
  170. results = graph.run(inp_dict={"a": a, "b": b})
  171. np.testing.assert_equal(x.numpy(), results["x"])
  172. np.testing.assert_equal(y.numpy(), results["y"])
  173. results = graph.run(a, inp_dict={"b": b})
  174. np.testing.assert_equal(x.numpy(), results["x"])
  175. np.testing.assert_equal(y.numpy(), results["y"])
  176. results = graph.run(a, b)
  177. np.testing.assert_equal(x.numpy(), results["x"])
  178. np.testing.assert_equal(y.numpy(), results["y"])
  179. file.seek(0)
  180. graph1 = cgtools.GraphInference(file, outputs=["x"])
  181. results = graph1.run(inp_dict={"a": a, "b": b})
  182. np.testing.assert_equal(x.numpy(), results["x"])
  183. assert "y" not in results