You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_jit.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import contextlib
  10. import os
  11. import tempfile
  12. import numpy as np
  13. import pytest
  14. import megengine as mge
  15. import megengine._internal as mgb
  16. import megengine.functional as F
  17. import megengine.module as M
  18. from megengine import functional as F
  19. from megengine import jit, tensor
  20. from megengine.core.tensor import Tensor
  21. from megengine.jit import SublinearMemoryConfig
  22. from megengine.test import assertTensorClose
  23. @contextlib.contextmanager
  24. def mkstemp():
  25. fd, path = tempfile.mkstemp()
  26. try:
  27. os.close(fd)
  28. yield path
  29. finally:
  30. os.remove(path)
  31. def load_and_compile(fpath):
  32. cg, _, outputs = mgb.load_comp_graph_from_file(fpath)
  33. inputs = mgb.cgtools.get_dep_vars(outputs, "Host2DeviceCopy")
  34. inputs = sorted(inputs, key=lambda i: i.name)
  35. outputs = list(map(mgb.copy_output, outputs))
  36. if len(outputs) == 1:
  37. (outputs,) = outputs
  38. return cg.compile(inputs, outputs)
  39. def test_symbolic():
  40. @jit.trace(symbolic=False)
  41. def f(x):
  42. return Tensor(mgb.opr.assert_equal(x._symvar, x._symvar + 1))
  43. with pytest.raises(mgb.exc.MegBrainError):
  44. f.trace(0)
  45. @jit.trace(symbolic=True)
  46. def f(x):
  47. return Tensor(mgb.opr.assert_equal(x._symvar, x._symvar + 1))
  48. f.trace(0)
  49. def test_add_update_semantic():
  50. for symbolic in [False, True]:
  51. x = tensor(0)
  52. @jit.trace(symbolic=symbolic)
  53. def f():
  54. F.add_update(x, 1)
  55. return x + 1
  56. np.testing.assert_equal(f().numpy(), [2])
  57. np.testing.assert_equal(f().numpy(), [3])
  58. def test_dump():
  59. @jit.trace(symbolic=True)
  60. def f(x, y):
  61. return x * y
  62. f.trace(0, 0)
  63. with mkstemp() as out:
  64. f.dump(out)
  65. g = load_and_compile(out)
  66. np.testing.assert_allclose(g([1, 2, 3], [1, 2, 3]), [1, 4, 9])
  67. def test_goptions():
  68. @jit.trace(symbolic=True, opt_level=0)
  69. def f(x):
  70. return x / x
  71. @jit.trace(symbolic=True, opt_level=1)
  72. def g(x):
  73. return x / x
  74. out = f([0.0]).numpy()
  75. # out is nan
  76. if out == out:
  77. raise
  78. # with gopt, x / x returns 1
  79. out = g([0.0]).numpy()
  80. assert out == 1
  81. def test_json_prof():
  82. @jit.trace(symbolic=True, profiling=True)
  83. def f(x):
  84. return x * x
  85. f([0.0])
  86. out = f.get_profile()
  87. assert out.get("profiler")
  88. def test_capture_dump():
  89. p = tensor(7)
  90. @jit.trace(symbolic=True)
  91. def f(x):
  92. return x * p
  93. f.trace(0)
  94. with mkstemp() as out:
  95. f.dump(out)
  96. g = load_and_compile(out)
  97. np.testing.assert_allclose(g([1, 2, 3]), [7, 14, 21])
  98. def test_dump_volatile():
  99. p = tensor(7)
  100. @jit.trace(symbolic=True)
  101. def f(x):
  102. return x * p
  103. f.trace(0)
  104. with mkstemp() as out:
  105. f.dump(out)
  106. cg, _, outputs = mgb.load_comp_graph_from_file(out)
  107. (out,) = outputs
  108. assert mgb.cgtools.get_type(mgb.cgtools.get_inputs(out)[1]) == "SharedDeviceTensor"
  109. def test_graph_traversal():
  110. net = M.Conv2d(3, 4, 3, 1, 1, groups=1, bias=False)
  111. net.eval()
  112. @jit.trace(symbolic=True)
  113. def fun(data):
  114. return net(data)
  115. data = np.random.random([1, 3, 224, 224]).astype(np.float32)
  116. fun.trace(data)
  117. with mkstemp() as out:
  118. fun.dump(out)
  119. *_, outputs = mgb.load_comp_graph_from_file(out)
  120. _, map_vars, var2oprs, *_ = mgb.cgtools.graph_traversal(outputs)
  121. input_var = map_vars[1]
  122. _, var_idx = var2oprs[input_var.id][0]
  123. assert var_idx == 0
  124. def test_network_visitor():
  125. @jit.trace(symbolic=True)
  126. def f(x):
  127. # this line will produce shape_of, subtensor and concat op
  128. # after pruning, they will be deleted
  129. target_shape = (x.shape[0], -1)
  130. return x.reshape(*target_shape)
  131. f.trace(tensor(np.random.random([2, 3, 4, 5]).astype(np.float32)))
  132. with mkstemp() as out:
  133. f.dump(out)
  134. *_, outputs = mgb.load_comp_graph_from_file(out)
  135. all_oprs = mgb.cgtools.get_oprs_seq(outputs)
  136. pruned_oprs = mgb.cgtools.get_oprs_seq(outputs, prune_reshape=True)
  137. assert len(all_oprs) == len(pruned_oprs) + 3
  138. def test_shape_tracing():
  139. for symbolic in [False, True]:
  140. @jit.trace(symbolic=symbolic)
  141. def f(x):
  142. a, b = x.shape
  143. return a * b
  144. assert f(np.zeros([4, 3], dtype="float32")).item() == 12
  145. assert f(np.zeros([6, 4], dtype="float32")).item() == 24
  146. def test_shape_infer():
  147. @jit.trace(symbolic=True)
  148. def f(x):
  149. a, b = x.shape
  150. return sum(x[i] for i in range(a))
  151. x = np.random.randn(3, 10).astype("float32")
  152. assertTensorClose(f(x), x.sum(0))
  153. x = np.random.randn(4, 10).astype("float32")
  154. assertTensorClose(f(x), x[:3].sum(0))
  155. def test_dump_bn_fused():
  156. class ConvBNReLU(M.Sequential):
  157. def __init__(self):
  158. super(ConvBNReLU, self).__init__(
  159. M.Conv2d(3, 4, 3, 1, 1, groups=1, bias=False),
  160. M.BatchNorm2d(4),
  161. M.ReLU(),
  162. )
  163. net = ConvBNReLU()
  164. net.eval()
  165. @jit.trace(symbolic=True)
  166. def fun(data):
  167. return net(data)
  168. data = np.random.random([1, 3, 224, 224]).astype(np.float32)
  169. fun.trace(data)
  170. with mkstemp() as out:
  171. fun.dump(out, optimize_for_inference=True)
  172. cg, _, outputs = mgb.load_comp_graph_from_file(out)
  173. (out,) = outputs
  174. inputs = mgb.cgtools.get_inputs(out)
  175. assert len(inputs) == 2 and (
  176. mgb.cgtools.get_type(inputs[0]) == "MultipleDeviceTensorHolder"
  177. and mgb.cgtools.get_type(inputs[1]) == "ConvolutionForward"
  178. )
  179. # Simply verify the options passed down
  180. def test_sublinear():
  181. config = SublinearMemoryConfig(genetic_nr_iter=10)
  182. @jit.trace(symbolic=True, sublinear_memory_config=config)
  183. def f(x):
  184. return x + x
  185. f([0.0])

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台