You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dump_naming.py 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import io
  10. import numpy as np
  11. import pytest
  12. import megengine.functional as F
  13. import megengine.module as M
  14. import megengine.utils.comp_graph_tools as cgtools
  15. from megengine import Parameter, Tensor
  16. from megengine.core.tensor import megbrain_graph as G
  17. from megengine.jit.tracing import trace
  18. from megengine.quantization.quantize import quantize, quantize_qat
  19. from megengine.utils.naming import auto_naming
  20. def _dump_and_load(func, symbolic, keep_opr_name=True):
  21. auto_naming.clear()
  22. func = trace(func, symbolic=symbolic, capture_as_const=True)
  23. x = Tensor(np.ones(shape=(2, 3)))
  24. func(x).numpy()
  25. file = io.BytesIO()
  26. func.dump(
  27. file,
  28. optimize_for_inference=False,
  29. arg_names=("x",),
  30. keep_opr_name=keep_opr_name,
  31. keep_var_name=2,
  32. )
  33. file.seek(0)
  34. *_, outputs = G.load_graph(file)
  35. ops = cgtools.get_oprs_seq(outputs)
  36. return ops
  37. @pytest.mark.parametrize("symbolic", [False, True])
  38. def test_auto_naming(symbolic):
  39. class Simple(M.Module):
  40. def __init__(self, name):
  41. super().__init__()
  42. self.name = name
  43. def forward(self, x):
  44. return x + x
  45. m = Simple("simple")
  46. op = _dump_and_load(m, symbolic)[-1]
  47. assert op.name == "simple.ADD"
  48. assert op.outputs[0].name == "simple.ADD"
  49. @pytest.mark.parametrize("symbolic", [False, True])
  50. def test_user_named_tensor(symbolic):
  51. class Simple(M.Module):
  52. def __init__(self, name):
  53. super().__init__()
  54. self.name = name
  55. self.k = Parameter(1.0, name="k")
  56. def forward(self, x):
  57. x = x + x
  58. x.name = "o_x"
  59. return x
  60. m = Simple("simple")
  61. op = _dump_and_load(m, symbolic)[-1]
  62. assert op.name == "simple.ADD"
  63. assert op.outputs[0].name == "o_x"
  64. @pytest.mark.parametrize("symbolic", [False, True])
  65. def test_user_named_param(symbolic):
  66. class Simple(M.Module):
  67. def __init__(self, name):
  68. super().__init__()
  69. self.name = name
  70. self.k = Parameter(2.0, name="k")
  71. def forward(self, x):
  72. return self.k * x
  73. m = Simple("simple")
  74. op = _dump_and_load(m, symbolic)[-1]
  75. assert op.inputs[0].name == "x"
  76. assert op.inputs[1].name == "simple.k"
  77. @pytest.mark.parametrize("symbolic", [False, True])
  78. def test_without_module(symbolic):
  79. def f(x):
  80. return 2 * x
  81. op = _dump_and_load(f, symbolic)[-1]
  82. assert op.name == "MUL"
  83. @pytest.mark.parametrize("symbolic", [False, True])
  84. def test_with_submodule(symbolic):
  85. class Simple(M.Module):
  86. def __init__(self, name):
  87. super().__init__()
  88. self.name = name
  89. self.linear = M.Linear(3, 3)
  90. def forward(self, x):
  91. x = self.linear(x)
  92. return x
  93. m = Simple("simple")
  94. ops = _dump_and_load(m, symbolic)
  95. assert ops[-1].name == "simple.linear.ADD"
  96. assert ops[-2].name == "simple.linear.MatrixMul"
  97. assert ops[-1].outputs[0].name == "simple.linear.ADD"
  98. @pytest.mark.parametrize("symbolic", [False, True])
  99. def test_with_submodule_in_container(symbolic):
  100. class Simple(M.Module):
  101. def __init__(self, name):
  102. super().__init__()
  103. self.name = name
  104. self.l0 = [M.Linear(3, 3) for _ in range(2)]
  105. self.l1 = tuple(self.l0)
  106. self.l2 = dict(zip(["l2-0", "l2-1"], self.l0))
  107. def forward(self, x):
  108. for i in range(2):
  109. x = self.l0[i](x)
  110. x = self.l1[i](x)
  111. x = self.l2["l2-%d" % i](x)
  112. return x
  113. m = Simple("simple")
  114. ops = _dump_and_load(m, symbolic)
  115. assert ops[-1].outputs[0].name == "simple.l2.l2-1.ADD"
  116. assert ops[-1].name == "simple.l2.l2-1.ADD"
  117. assert ops[-2].name == "simple.l2.l2-1.MatrixMul"
  118. assert ops[-3].name == "simple.l1.1.ADD"
  119. assert ops[-4].name == "simple.l1.1.MatrixMul"
  120. assert ops[-5].name == "simple.l0.1.ADD"
  121. assert ops[-6].name == "simple.l0.1.MatrixMul"
  122. @pytest.mark.parametrize("symbolic", [False, True])
  123. def test_named_submodule(symbolic):
  124. class Simple(M.Module):
  125. def __init__(self, name):
  126. super().__init__()
  127. self.name = name
  128. self.linear = M.Linear(3, 3, name="x")
  129. def forward(self, x):
  130. x = self.linear(x)
  131. return x
  132. m = Simple("simple")
  133. ops = _dump_and_load(m, symbolic)
  134. assert ops[-1].name == "simple.x.ADD"
  135. assert ops[-2].name == "simple.x.MatrixMul"
  136. assert ops[-1].outputs[0].name == "simple.x.ADD"
  137. @pytest.mark.parametrize("symbolic", [False, True])
  138. def test_with_same_operators(symbolic):
  139. class Simple(M.Module):
  140. def __init__(self, name):
  141. super().__init__()
  142. self.name = name
  143. def forward(self, x):
  144. x = F.relu(x)
  145. x = F.relu(x)
  146. return x
  147. m = Simple("simple")
  148. ops = _dump_and_load(m, symbolic)
  149. assert ops[-1].name == "simple.RELU[1]"
  150. assert ops[-2].name == "simple.RELU[0]"
  151. def test_not_keep_opr_name():
  152. def f(x):
  153. return 2 * x
  154. op = _dump_and_load(f, True, False)[-1]
  155. assert op.name == "MUL(x,2[2])[4]"
  156. @pytest.mark.parametrize("symbolic", [False, True])
  157. def test_quantized_module_auto_naming(symbolic):
  158. class Simple(M.Module):
  159. def __init__(self, name):
  160. super().__init__(name=name)
  161. self.quant = M.QuantStub()
  162. self.linear = M.Linear(3, 3, bias=True)
  163. self.dequant = M.DequantStub()
  164. def forward(self, x):
  165. out = self.quant(x)
  166. out = self.linear(out)
  167. out = self.dequant(out)
  168. return out
  169. m = Simple("simple")
  170. quantize_qat(m)
  171. quantize(m)
  172. m.eval()
  173. ops = _dump_and_load(m, symbolic)
  174. ops_name = (
  175. "x",
  176. "simple.quant.TypeCvt",
  177. "simple.linear.MatrixMul",
  178. "simple.linear.ADD",
  179. "simple.linear.TypeCvt",
  180. "simple.dequant.TypeCvt",
  181. )
  182. for op, name in zip(ops, ops_name):
  183. assert op.name == name
  184. @pytest.mark.parametrize("symbolic", [False, True])
  185. def test_quantized_module_user_naming(symbolic):
  186. class Simple(M.Module):
  187. def __init__(self, name):
  188. super().__init__(name=name)
  189. self.quant = M.QuantStub()
  190. self.linear = M.Linear(3, 3, bias=True, name="user-linear")
  191. self.dequant = M.DequantStub()
  192. def forward(self, x):
  193. out = self.quant(x)
  194. out = self.linear(out)
  195. out = self.dequant(out)
  196. return out
  197. m = Simple("simple")
  198. quantize_qat(m)
  199. quantize(m)
  200. m.eval()
  201. ops = _dump_and_load(m, symbolic)
  202. ops_name = (
  203. "x",
  204. "simple.quant.TypeCvt",
  205. "simple.user-linear.MatrixMul",
  206. "simple.user-linear.ADD",
  207. "simple.user-linear.TypeCvt",
  208. "simple.dequant.TypeCvt",
  209. )
  210. for op, name in zip(ops, ops_name):
  211. assert op.name == name
  212. @pytest.mark.parametrize("symbolic", [False, True])
  213. def test_quantized_module_user_naming_param(symbolic):
  214. class Simple(M.Module):
  215. def __init__(self, name):
  216. super().__init__(name=name)
  217. self.quant = M.QuantStub()
  218. self.linear = M.Linear(3, 3, bias=True)
  219. self.dequant = M.DequantStub()
  220. self.linear.weight.name = "user-weight"
  221. self.linear.bias.name = "user-bias"
  222. def forward(self, x):
  223. out = self.quant(x)
  224. out = self.linear(out)
  225. out = self.dequant(out)
  226. return out
  227. m = Simple("simple")
  228. quantize_qat(m)
  229. quantize(m)
  230. m.eval()
  231. ops = _dump_and_load(m, symbolic)
  232. (matrix_mul_op,) = [op for op in ops if op.name == "simple.linear.MatrixMul"]
  233. for var in matrix_mul_op.inputs:
  234. assert var.name in ("simple.quant.TypeCvt", "simple.linear.user-weight")
  235. # WONTFIX: bias' name does not meet expectations because of astype operator after quantization

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台