You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dump_naming.py 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import io
  10. import numpy as np
  11. import pytest
  12. import megengine.functional as F
  13. import megengine.module as M
  14. import megengine.utils.comp_graph_tools as cgtools
  15. from megengine import Parameter, Tensor
  16. from megengine.core.tensor import megbrain_graph as G
  17. from megengine.jit.tracing import trace
  18. from megengine.quantization.quantize import quantize, quantize_qat
  19. from megengine.utils.naming import auto_naming
  20. def _dump_and_load(func, symbolic, keep_opr_name=True):
  21. auto_naming.clear()
  22. func = trace(func, symbolic=symbolic, capture_as_const=True)
  23. x = Tensor(np.ones(shape=(2, 3)))
  24. func(x).numpy()
  25. file = io.BytesIO()
  26. func.dump(
  27. file,
  28. optimize_for_inference=False,
  29. arg_names=("x",),
  30. keep_opr_name=keep_opr_name,
  31. keep_var_name=2,
  32. )
  33. file.seek(0)
  34. *_, outputs = G.load_graph(file)
  35. ops = cgtools.get_oprs_seq(outputs)
  36. return ops
  37. @pytest.mark.parametrize("symbolic", [False, True])
  38. def test_auto_naming(symbolic):
  39. class Simple(M.Module):
  40. def __init__(self, name):
  41. super().__init__()
  42. self.name = name
  43. def forward(self, x):
  44. return x + x
  45. m = Simple("simple")
  46. op = _dump_and_load(m, symbolic)[-1]
  47. assert op.name == "simple.ADD"
  48. assert op.outputs[0].name == "simple.ADD"
  49. @pytest.mark.parametrize("symbolic", [False, True])
  50. def test_user_named_tensor(symbolic):
  51. class Simple(M.Module):
  52. def __init__(self, name):
  53. super().__init__()
  54. self.name = name
  55. self.k = Parameter(1.0, name="k")
  56. def forward(self, x):
  57. x = x + x
  58. x.name = "o_x"
  59. return x
  60. m = Simple("simple")
  61. op = _dump_and_load(m, symbolic)[-1]
  62. assert op.name == "simple.ADD"
  63. assert op.outputs[0].name == "o_x"
  64. @pytest.mark.parametrize("symbolic", [False, True])
  65. def test_user_named_param(symbolic):
  66. class Simple(M.Module):
  67. def __init__(self, name):
  68. super().__init__()
  69. self.name = name
  70. self.k = Parameter(2.0, name="k")
  71. def forward(self, x):
  72. return self.k * x
  73. m = Simple("simple")
  74. op = _dump_and_load(m, symbolic)[-1]
  75. assert op.inputs[0].name == "x"
  76. assert op.inputs[1].name == "simple.k"
  77. @pytest.mark.parametrize("symbolic", [False, True])
  78. def test_without_module(symbolic):
  79. def f(x):
  80. return 2 * x
  81. op = _dump_and_load(f, symbolic)[-1]
  82. assert op.name == "MUL"
  83. @pytest.mark.parametrize("symbolic", [False, True])
  84. def test_with_submodule(symbolic):
  85. class Simple(M.Module):
  86. def __init__(self, name):
  87. super().__init__()
  88. self.name = name
  89. self.linear = M.Linear(3, 3)
  90. def forward(self, x):
  91. x = self.linear(x)
  92. return x
  93. m = Simple("simple")
  94. ops = _dump_and_load(m, symbolic)
  95. assert ops[-1].name == "simple.linear.ADD"
  96. assert ops[-2].name == "simple.linear.MatrixMul"
  97. assert ops[-1].outputs[0].name == "simple.linear.ADD"
  98. @pytest.mark.parametrize("symbolic", [False, True])
  99. def test_named_submodule(symbolic):
  100. class Simple(M.Module):
  101. def __init__(self, name):
  102. super().__init__()
  103. self.name = name
  104. self.linear = M.Linear(3, 3, name="x")
  105. def forward(self, x):
  106. x = self.linear(x)
  107. return x
  108. m = Simple("simple")
  109. ops = _dump_and_load(m, symbolic)
  110. assert ops[-1].name == "simple.x.ADD"
  111. assert ops[-2].name == "simple.x.MatrixMul"
  112. assert ops[-1].outputs[0].name == "simple.x.ADD"
  113. @pytest.mark.parametrize("symbolic", [False, True])
  114. def test_with_same_operators(symbolic):
  115. class Simple(M.Module):
  116. def __init__(self, name):
  117. super().__init__()
  118. self.name = name
  119. def forward(self, x):
  120. x = F.relu(x)
  121. x = F.relu(x)
  122. return x
  123. m = Simple("simple")
  124. ops = _dump_and_load(m, symbolic)
  125. assert ops[-1].name == "simple.RELU[1]"
  126. assert ops[-2].name == "simple.RELU[0]"
  127. def test_not_keep_opr_name():
  128. def f(x):
  129. return 2 * x
  130. op = _dump_and_load(f, True, False)[-1]
  131. assert op.name == "MUL(x,2[2])[4]"
  132. @pytest.mark.parametrize("symbolic", [False, True])
  133. def test_quantized_module_auto_naming(symbolic):
  134. class Simple(M.Module):
  135. def __init__(self, name):
  136. super().__init__(name=name)
  137. self.quant = M.QuantStub()
  138. self.linear = M.Linear(3, 3, bias=True)
  139. self.dequant = M.DequantStub()
  140. def forward(self, x):
  141. out = self.quant(x)
  142. out = self.linear(out)
  143. out = self.dequant(out)
  144. return out
  145. m = Simple("simple")
  146. quantize_qat(m)
  147. quantize(m)
  148. m.eval()
  149. ops = _dump_and_load(m, symbolic)
  150. ops_name = (
  151. "x",
  152. "simple.quant.TypeCvt",
  153. "simple.linear.MatrixMul",
  154. "simple.linear.ADD",
  155. "simple.linear.TypeCvt",
  156. "simple.dequant.TypeCvt",
  157. )
  158. for op, name in zip(ops, ops_name):
  159. assert op.name == name
  160. @pytest.mark.parametrize("symbolic", [False, True])
  161. def test_quantized_module_user_naming(symbolic):
  162. class Simple(M.Module):
  163. def __init__(self, name):
  164. super().__init__(name=name)
  165. self.quant = M.QuantStub()
  166. self.linear = M.Linear(3, 3, bias=True, name="user-linear")
  167. self.dequant = M.DequantStub()
  168. def forward(self, x):
  169. out = self.quant(x)
  170. out = self.linear(out)
  171. out = self.dequant(out)
  172. return out
  173. m = Simple("simple")
  174. quantize_qat(m)
  175. quantize(m)
  176. m.eval()
  177. ops = _dump_and_load(m, symbolic)
  178. ops_name = (
  179. "x",
  180. "simple.quant.TypeCvt",
  181. "simple.user-linear.MatrixMul",
  182. "simple.user-linear.ADD",
  183. "simple.user-linear.TypeCvt",
  184. "simple.dequant.TypeCvt",
  185. )
  186. for op, name in zip(ops, ops_name):
  187. assert op.name == name
  188. @pytest.mark.parametrize("symbolic", [False, True])
  189. def test_quantized_module_user_naming_param(symbolic):
  190. class Simple(M.Module):
  191. def __init__(self, name):
  192. super().__init__(name=name)
  193. self.quant = M.QuantStub()
  194. self.linear = M.Linear(3, 3, bias=True)
  195. self.dequant = M.DequantStub()
  196. self.linear.weight.name = "user-weight"
  197. self.linear.bias.name = "user-bias"
  198. def forward(self, x):
  199. out = self.quant(x)
  200. out = self.linear(out)
  201. out = self.dequant(out)
  202. return out
  203. m = Simple("simple")
  204. quantize_qat(m)
  205. quantize(m)
  206. m.eval()
  207. ops = _dump_and_load(m, symbolic)
  208. (matrix_mul_op,) = [op for op in ops if op.name == "simple.linear.MatrixMul"]
  209. for var in matrix_mul_op.inputs:
  210. assert var.name in ("simple.quant.TypeCvt", "simple.linear.user-weight")
  211. # BUG bias' name does not meet expectations because of astype operator after quantization

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台