You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dump_naming.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import io
  10. import numpy as np
  11. import pytest
  12. import megengine.functional as F
  13. import megengine.module as M
  14. import megengine.utils.comp_graph_tools as cgtools
  15. from megengine import Parameter, Tensor
  16. from megengine.core.tensor import megbrain_graph as G
  17. from megengine.jit.tracing import trace
  18. from megengine.utils.naming import auto_naming
  19. def _dump_and_load(func, symbolic, keep_opr_name=True):
  20. auto_naming.clear()
  21. func = trace(func, symbolic=symbolic, capture_as_const=True)
  22. x = Tensor(np.ones(shape=(2, 3)))
  23. func(x).numpy()
  24. file = io.BytesIO()
  25. func.dump(
  26. file,
  27. optimize_for_inference=False,
  28. arg_names="x",
  29. keep_opr_name=keep_opr_name,
  30. keep_var_name=2,
  31. )
  32. file.seek(0)
  33. *_, outputs = G.load_graph(file)
  34. op = cgtools.get_oprs_seq(outputs)[-1]
  35. return op
  36. @pytest.mark.parametrize("symbolic", [False, True])
  37. def test_auto_naming(symbolic):
  38. class Simple(M.Module):
  39. def __init__(self, name):
  40. super().__init__()
  41. self.name = name
  42. def forward(self, x):
  43. return x + x
  44. m = Simple("simple")
  45. op = _dump_and_load(m, symbolic)
  46. assert op.name == "simple.ADD"
  47. assert op.outputs[0].name == "simple.ADD"
  48. @pytest.mark.parametrize("symbolic", [False, True])
  49. def test_user_named_tensor(symbolic):
  50. class Simple(M.Module):
  51. def __init__(self, name):
  52. super().__init__()
  53. self.name = name
  54. self.k = Parameter(1.0, name="k")
  55. def forward(self, x):
  56. x = x + x
  57. x.name = "o_x"
  58. return x
  59. m = Simple("simple")
  60. op = _dump_and_load(m, symbolic)
  61. assert op.name == "simple.ADD"
  62. assert op.outputs[0].name == "o_x"
  63. @pytest.mark.parametrize("symbolic", [False, True])
  64. def test_user_named_param(symbolic):
  65. class Simple(M.Module):
  66. def __init__(self, name):
  67. super().__init__()
  68. self.name = name
  69. self.k = Parameter(2.0, name="k")
  70. def forward(self, x):
  71. return self.k * x
  72. m = Simple("simple")
  73. op = _dump_and_load(m, symbolic)
  74. assert op.inputs[0].name == "x"
  75. assert op.inputs[1].name == "simple.k"
  76. @pytest.mark.parametrize("symbolic", [False, True])
  77. def test_without_module(symbolic):
  78. def f(x):
  79. return 2 * x
  80. op = _dump_and_load(f, symbolic)
  81. assert op.name == "MUL"
  82. @pytest.mark.parametrize("symbolic", [False, True])
  83. def test_with_submodule(symbolic):
  84. class Simple(M.Module):
  85. def __init__(self, name):
  86. super().__init__()
  87. self.name = name
  88. self.linear = M.Linear(3, 3)
  89. def forward(self, x):
  90. x = self.linear(x)
  91. return x
  92. m = Simple("simple")
  93. op = _dump_and_load(m, symbolic)
  94. assert op.name == "simple.linear.ADD"
  95. assert op.inputs[0].owner.name == "simple.linear.MatrixMul"
  96. assert op.outputs[0].name == "simple.linear.ADD"
  97. @pytest.mark.parametrize("symbolic", [False, True])
  98. def test_named_submodule(symbolic):
  99. class Simple(M.Module):
  100. def __init__(self, name):
  101. super().__init__()
  102. self.name = name
  103. self.linear = M.Linear(3, 3, name="x")
  104. def forward(self, x):
  105. x = self.linear(x)
  106. return x
  107. m = Simple("simple")
  108. op = _dump_and_load(m, symbolic)
  109. assert op.name == "simple.x.ADD"
  110. assert op.inputs[0].owner.name == "simple.x.MatrixMul"
  111. assert op.outputs[0].name == "simple.x.ADD"
  112. @pytest.mark.parametrize("symbolic", [False, True])
  113. def test_with_same_operators(symbolic):
  114. class Simple(M.Module):
  115. def __init__(self, name):
  116. super().__init__()
  117. self.name = name
  118. def forward(self, x):
  119. x = F.relu(x)
  120. x = F.relu(x)
  121. return x
  122. m = Simple("simple")
  123. op = _dump_and_load(m, symbolic)
  124. assert op.name == "simple.RELU[1]"
  125. assert op.inputs[0].owner.name == "simple.RELU[0]"
  126. def test_not_keep_opr_name():
  127. def f(x):
  128. return 2 * x
  129. op = _dump_and_load(f, True, False)
  130. assert op.name == "MUL(x,2[2])[4]"

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台