You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dump_naming.py 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # -*- coding: utf-8 -*-
  2. import io
  3. import numpy as np
  4. import pytest
  5. import megengine.functional as F
  6. import megengine.module as M
  7. import megengine.utils.comp_graph_tools as cgtools
  8. from megengine import Parameter, Tensor
  9. from megengine.core.tensor import megbrain_graph as G
  10. from megengine.jit.tracing import trace
  11. from megengine.quantization.quantize import quantize, quantize_qat
  12. from megengine.utils.naming import AutoNaming
  13. def _dump_and_load(func, symbolic, keep_opr_name=True):
  14. AutoNaming.clear()
  15. func = trace(func, symbolic=symbolic, capture_as_const=True)
  16. x = Tensor(np.ones(shape=(2, 3)))
  17. func(x).numpy()
  18. file = io.BytesIO()
  19. func.dump(
  20. file,
  21. optimize_for_inference=False,
  22. arg_names=("x",),
  23. keep_opr_name=keep_opr_name,
  24. keep_var_name=2,
  25. )
  26. file.seek(0)
  27. outputs = G.load_graph(file).output_vars_list
  28. ops = cgtools.get_oprs_seq(outputs)
  29. return ops
  30. @pytest.mark.parametrize("symbolic", [False, True])
  31. def test_auto_naming(symbolic):
  32. class Simple(M.Module):
  33. def __init__(self, name):
  34. super().__init__()
  35. self.name = name
  36. def forward(self, x):
  37. return x + x
  38. m = Simple("simple")
  39. op = _dump_and_load(m, symbolic)[-1]
  40. assert op.name == "simple.ADD"
  41. assert op.outputs[0].name == "simple.ADD"
  42. @pytest.mark.parametrize("symbolic", [False, True])
  43. def test_user_named_tensor(symbolic):
  44. class Simple(M.Module):
  45. def __init__(self, name):
  46. super().__init__()
  47. self.name = name
  48. self.k = Parameter(1.0, name="k")
  49. def forward(self, x):
  50. x = x + x
  51. x.name = "o_x"
  52. return x
  53. m = Simple("simple")
  54. op = _dump_and_load(m, symbolic)[-1]
  55. assert op.name == "simple.ADD"
  56. assert op.outputs[0].name == "o_x"
  57. @pytest.mark.parametrize("symbolic", [False, True])
  58. def test_user_named_param(symbolic):
  59. class Simple(M.Module):
  60. def __init__(self, name):
  61. super().__init__()
  62. self.name = name
  63. self.k = Parameter(2.0, name="k")
  64. def forward(self, x):
  65. return self.k * x
  66. m = Simple("simple")
  67. op = _dump_and_load(m, symbolic)[-1]
  68. assert op.inputs[0].name == "x"
  69. assert op.inputs[1].name == "simple.k"
  70. @pytest.mark.parametrize("symbolic", [False, True])
  71. def test_without_module(symbolic):
  72. def f(x):
  73. return 2 * x
  74. op = _dump_and_load(f, symbolic)[-1]
  75. assert op.name == "MUL"
  76. @pytest.mark.parametrize("symbolic", [False, True])
  77. def test_ignore_top_module(symbolic):
  78. class Simple(M.Module):
  79. def forward(self, x):
  80. return x + x
  81. m = Simple()
  82. op = _dump_and_load(m, symbolic)[-1]
  83. assert op.name == "ADD"
  84. assert op.outputs[0].name == "ADD"
  85. @pytest.mark.parametrize("symbolic", [False, True])
  86. def test_with_submodule(symbolic):
  87. class Simple(M.Module):
  88. def __init__(self, name):
  89. super().__init__()
  90. self.name = name
  91. self.linear = M.Linear(3, 3)
  92. def forward(self, x):
  93. x = self.linear(x)
  94. return x
  95. m = Simple("simple")
  96. ops = _dump_and_load(m, symbolic)
  97. assert ops[-1].name == "simple.linear.ADD"
  98. assert ops[-2].name == "simple.linear.MatrixMul"
  99. assert ops[-1].outputs[0].name == "simple.linear.ADD"
  100. @pytest.mark.parametrize("symbolic", [False, True])
  101. def test_with_submodule_in_container(symbolic):
  102. class Simple(M.Module):
  103. def __init__(self, name):
  104. super().__init__()
  105. self.name = name
  106. self.l0 = [M.Linear(3, 3) for _ in range(2)]
  107. self.l1 = tuple(self.l0)
  108. self.l2 = dict(zip(["l2-0", "l2-1"], self.l0))
  109. def forward(self, x):
  110. for i in range(2):
  111. x = self.l0[i](x)
  112. x = self.l1[i](x)
  113. x = self.l2["l2-%d" % i](x)
  114. return x
  115. m = Simple("simple")
  116. ops = _dump_and_load(m, symbolic)
  117. assert ops[-1].outputs[0].name == "simple.l0.1.ADD[2]"
  118. assert ops[-1].name == "simple.l0.1.ADD[2]"
  119. assert ops[-2].name == "simple.l0.1.MatrixMul[2]"
  120. assert ops[-3].name == "simple.l0.1.ADD[1]"
  121. assert ops[-4].name == "simple.l0.1.MatrixMul[1]"
  122. assert ops[-5].name == "simple.l0.1.ADD[0]"
  123. assert ops[-6].name == "simple.l0.1.MatrixMul[0]"
  124. @pytest.mark.parametrize("symbolic", [False, True])
  125. def test_named_submodule(symbolic):
  126. class Simple(M.Module):
  127. def __init__(self, name):
  128. super().__init__()
  129. self.name = name
  130. self.linear = M.Linear(3, 3, name="x")
  131. def forward(self, x):
  132. x = self.linear(x)
  133. return x
  134. m = Simple("simple")
  135. ops = _dump_and_load(m, symbolic)
  136. assert ops[-1].name == "simple.x.ADD"
  137. assert ops[-2].name == "simple.x.MatrixMul"
  138. assert ops[-1].outputs[0].name == "simple.x.ADD"
  139. @pytest.mark.parametrize("symbolic", [False, True])
  140. def test_with_same_operators(symbolic):
  141. class Simple(M.Module):
  142. def __init__(self, name):
  143. super().__init__()
  144. self.name = name
  145. def forward(self, x):
  146. x = F.relu(x)
  147. x = F.relu(x)
  148. return x
  149. m = Simple("simple")
  150. ops = _dump_and_load(m, symbolic)
  151. assert ops[-1].name == "simple.RELU[1]"
  152. assert ops[-2].name == "simple.RELU[0]"
  153. @pytest.mark.parametrize("symbolic", [False, True])
  154. def test_not_keep_opr_name(symbolic):
  155. def f(x):
  156. return 2 * x
  157. op = _dump_and_load(f, symbolic, False)[-1]
  158. assert op.name == "MUL(x,const<2>[2])[4]"
  159. @pytest.mark.parametrize("tensor_name, var_name", [("data", "data"), (None, "arg_0")])
  160. def test_catch_input_name(tensor_name, var_name):
  161. def f(x):
  162. return 2 * x
  163. func = trace(f, symbolic=True, capture_as_const=True)
  164. x = Tensor(np.ones(shape=(2, 3)), name=tensor_name)
  165. func(x).numpy()
  166. file = io.BytesIO()
  167. func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2)
  168. file.seek(0)
  169. outputs = G.load_graph(file).output_vars_list
  170. op = cgtools.get_oprs_seq(outputs)[-1]
  171. assert op.inputs[0].name == var_name
  172. @pytest.mark.parametrize("symbolic", [False, True])
  173. def test_quantized_module_auto_naming(symbolic):
  174. class Simple(M.Module):
  175. def __init__(self, name):
  176. super().__init__(name=name)
  177. self.quant = M.QuantStub()
  178. self.linear = M.Linear(3, 3, bias=True)
  179. self.dequant = M.DequantStub()
  180. def forward(self, x):
  181. out = self.quant(x)
  182. out = self.linear(out)
  183. out = self.dequant(out)
  184. return out
  185. m = Simple("simple")
  186. quantize_qat(m)
  187. quantize(m)
  188. m.eval()
  189. ops = _dump_and_load(m, symbolic)
  190. ops_name = (
  191. "x",
  192. "simple.quant.TypeCvt",
  193. "simple.linear.MatrixMul",
  194. "simple.linear.ADD",
  195. "simple.linear.TypeCvt",
  196. "simple.dequant.TypeCvt",
  197. )
  198. for op, name in zip(ops, ops_name):
  199. assert op.name == name
  200. @pytest.mark.parametrize("symbolic", [False, True])
  201. def test_quantized_module_user_naming(symbolic):
  202. class Simple(M.Module):
  203. def __init__(self, name):
  204. super().__init__(name=name)
  205. self.quant = M.QuantStub()
  206. self.linear = M.Linear(3, 3, bias=True, name="user-linear")
  207. self.dequant = M.DequantStub()
  208. def forward(self, x):
  209. out = self.quant(x)
  210. out = self.linear(out)
  211. out = self.dequant(out)
  212. return out
  213. m = Simple("simple")
  214. quantize_qat(m)
  215. quantize(m)
  216. m.eval()
  217. ops = _dump_and_load(m, symbolic)
  218. ops_name = (
  219. "x",
  220. "simple.quant.TypeCvt",
  221. "simple.user-linear.MatrixMul",
  222. "simple.user-linear.ADD",
  223. "simple.user-linear.TypeCvt",
  224. "simple.dequant.TypeCvt",
  225. )
  226. for op, name in zip(ops, ops_name):
  227. assert op.name == name
  228. @pytest.mark.parametrize("symbolic", [False, True])
  229. def test_quantized_module_user_naming_param(symbolic):
  230. class Simple(M.Module):
  231. def __init__(self, name):
  232. super().__init__(name=name)
  233. self.quant = M.QuantStub()
  234. self.linear = M.Linear(3, 3, bias=True)
  235. self.dequant = M.DequantStub()
  236. self.linear.weight.name = "user-weight"
  237. self.linear.bias.name = "user-bias"
  238. def forward(self, x):
  239. out = self.quant(x)
  240. out = self.linear(out)
  241. out = self.dequant(out)
  242. return out
  243. m = Simple("simple")
  244. quantize_qat(m)
  245. quantize(m)
  246. m.eval()
  247. ops = _dump_and_load(m, symbolic)
  248. (matrix_mul_op,) = [op for op in ops if op.name == "simple.linear.MatrixMul"]
  249. for var in matrix_mul_op.inputs:
  250. assert var.name in ("simple.quant.TypeCvt", "simple.linear.user-weight")
  251. # WONTFIX: bias' name does not meet expectations because of astype operator after quantization