You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_jit.py 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import contextlib
  10. import os
  11. import tempfile
  12. import numpy as np
  13. import pytest
  14. import megengine as mge
  15. import megengine._internal as mgb
  16. import megengine.module as M
  17. from megengine import functional as F
  18. from megengine import jit, tensor
  19. from megengine.core.tensor import Tensor
  20. from megengine.jit import SublinearMemoryConfig
  21. from megengine.test import assertTensorClose
  22. @contextlib.contextmanager
  23. def mkstemp():
  24. fd, path = tempfile.mkstemp()
  25. try:
  26. os.close(fd)
  27. yield path
  28. finally:
  29. os.remove(path)
  30. def load_and_compile(fpath):
  31. cg, _, outputs = mgb.load_comp_graph_from_file(fpath)
  32. inputs = mgb.cgtools.get_dep_vars(outputs, "Host2DeviceCopy")
  33. inputs = sorted(inputs, key=lambda i: i.name)
  34. outputs = list(map(mgb.copy_output, outputs))
  35. if len(outputs) == 1:
  36. (outputs,) = outputs
  37. return cg.compile(inputs, outputs)
  38. def test_symbolic():
  39. @jit.trace(symbolic=False)
  40. def f(x):
  41. return Tensor(mgb.opr.assert_equal(x._symvar, x._symvar + 1))
  42. with pytest.raises(mgb.exc.MegBrainError):
  43. f.trace(0)
  44. @jit.trace(symbolic=True)
  45. def f(x):
  46. return Tensor(mgb.opr.assert_equal(x._symvar, x._symvar + 1))
  47. f.trace(0)
  48. def test_add_update_semantic():
  49. for symbolic in [False, True]:
  50. x = tensor(0)
  51. @jit.trace(symbolic=symbolic)
  52. def f():
  53. F.add_update(x, 1)
  54. return x + 1
  55. np.testing.assert_equal(f().numpy(), [2])
  56. np.testing.assert_equal(f().numpy(), [3])
  57. def test_dump():
  58. @jit.trace(symbolic=True)
  59. def f(x, y):
  60. return x * y
  61. f.trace(0, 0)
  62. with mkstemp() as out:
  63. f.dump(out)
  64. g = load_and_compile(out)
  65. np.testing.assert_allclose(g([1, 2, 3], [1, 2, 3]), [1, 4, 9])
  66. def test_goptions():
  67. @jit.trace(symbolic=True, opt_level=0)
  68. def f(x):
  69. return x / x
  70. @jit.trace(symbolic=True, opt_level=1)
  71. def g(x):
  72. return x / x
  73. out = f([0.0]).numpy()
  74. # out is nan
  75. if out == out:
  76. raise
  77. # with gopt, x / x returns 1
  78. out = g([0.0]).numpy()
  79. assert out == 1
  80. def test_json_prof():
  81. @jit.trace(symbolic=True, profiling=True)
  82. def f(x):
  83. return x * x
  84. f([0.0])
  85. out = f.get_profile()
  86. assert out.get("profiler")
  87. def test_capture_dump():
  88. p = tensor(7)
  89. @jit.trace(symbolic=True)
  90. def f(x):
  91. return x * p
  92. f.trace(0)
  93. with mkstemp() as out:
  94. f.dump(out)
  95. g = load_and_compile(out)
  96. np.testing.assert_allclose(g([1, 2, 3]), [7, 14, 21])
  97. def test_dump_volatile():
  98. p = tensor(7)
  99. @jit.trace(symbolic=True)
  100. def f(x):
  101. return x * p
  102. f.trace(0)
  103. with mkstemp() as out:
  104. f.dump(out)
  105. cg, _, outputs = mgb.load_comp_graph_from_file(out)
  106. (out,) = outputs
  107. assert mgb.cgtools.get_type(mgb.cgtools.get_inputs(out)[1]) == "SharedDeviceTensor"
  108. def test_shape_tracing():
  109. for symbolic in [False, True]:
  110. @jit.trace(symbolic=symbolic)
  111. def f(x):
  112. a, b = x.shape
  113. return a * b
  114. assert f(np.zeros([4, 3], dtype="float32")).item() == 12
  115. assert f(np.zeros([6, 4], dtype="float32")).item() == 24
  116. def test_shape_infer():
  117. @jit.trace(symbolic=True)
  118. def f(x):
  119. a, b = x.shape
  120. return sum(x[i] for i in range(a))
  121. x = np.random.randn(3, 10).astype("float32")
  122. assertTensorClose(f(x), x.sum(0))
  123. x = np.random.randn(4, 10).astype("float32")
  124. assertTensorClose(f(x), x[:3].sum(0))
  125. def test_dump_bn_fused():
  126. class ConvBNReLU(M.Sequential):
  127. def __init__(self):
  128. super(ConvBNReLU, self).__init__(
  129. M.Conv2d(3, 4, 3, 1, 1, groups=1, bias=False),
  130. M.BatchNorm2d(4),
  131. M.ReLU(),
  132. )
  133. net = ConvBNReLU()
  134. net.eval()
  135. @jit.trace(symbolic=True)
  136. def fun(data):
  137. return net(data)
  138. data = np.random.random([1, 3, 224, 224]).astype(np.float32)
  139. fun.trace(data)
  140. with mkstemp() as out:
  141. fun.dump(out, optimize_for_inference=True)
  142. cg, _, outputs = mgb.load_comp_graph_from_file(out)
  143. (out,) = outputs
  144. inputs = mgb.cgtools.get_inputs(out)
  145. assert len(inputs) == 2 and (
  146. mgb.cgtools.get_type(inputs[0]) == "MultipleDeviceTensorHolder"
  147. and mgb.cgtools.get_type(inputs[1]) == "ConvolutionForward"
  148. )
  149. # Simply verify the options passed down
  150. def test_sublinear():
  151. config = SublinearMemoryConfig(genetic_nr_iter=10)
  152. @jit.trace(symbolic=True, sublinear_memory_config=config)
  153. def f(x):
  154. return x + x
  155. f([0.0])

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台