You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_jit.py 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import contextlib
  10. import os
  11. import tempfile
  12. import numpy as np
  13. import pytest
  14. import megengine as mge
  15. import megengine._internal as mgb
  16. import megengine.module as M
  17. from megengine import jit, tensor
  18. from megengine.core.tensor import Tensor
  19. from megengine.test import assertTensorClose
  20. @contextlib.contextmanager
  21. def mkstemp():
  22. fd, path = tempfile.mkstemp()
  23. try:
  24. os.close(fd)
  25. yield path
  26. finally:
  27. os.remove(path)
  28. def load_and_compile(fpath):
  29. cg, _, outputs = mgb.load_comp_graph_from_file(fpath)
  30. inputs = mgb.cgtools.get_dep_vars(outputs, "Host2DeviceCopy")
  31. inputs = sorted(inputs, key=lambda i: i.name)
  32. outputs = list(map(mgb.copy_output, outputs))
  33. if len(outputs) == 1:
  34. (outputs,) = outputs
  35. return cg.compile(inputs, outputs)
  36. def test_symbolic():
  37. @jit.trace(symbolic=False)
  38. def f(x):
  39. return Tensor(mgb.opr.assert_equal(x._symvar, x._symvar + 1))
  40. with pytest.raises(mgb.exc.MegBrainError):
  41. f.trace(0)
  42. @jit.trace(symbolic=True)
  43. def f(x):
  44. return Tensor(mgb.opr.assert_equal(x._symvar, x._symvar + 1))
  45. f.trace(0)
  46. def test_dump():
  47. @jit.trace(symbolic=True)
  48. def f(x, y):
  49. return x * y
  50. f.trace(0, 0)
  51. with mkstemp() as out:
  52. f.dump(out)
  53. g = load_and_compile(out)
  54. np.testing.assert_allclose(g([1, 2, 3], [1, 2, 3]), [1, 4, 9])
  55. def test_goptions():
  56. @jit.trace(symbolic=True, opt_level=0)
  57. def f(x):
  58. return x / x
  59. @jit.trace(symbolic=True, opt_level=1)
  60. def g(x):
  61. return x / x
  62. out = f([0.0]).numpy()
  63. # out is nan
  64. if out == out:
  65. raise
  66. # with gopt, x / x returns 1
  67. out = g([0.0]).numpy()
  68. assert out == 1
  69. def test_json_prof():
  70. @jit.trace(symbolic=True, profiling=True)
  71. def f(x):
  72. return x * x
  73. f([0.0])
  74. out = f.get_profile()
  75. assert out.get("profiler")
  76. def test_capture_dump():
  77. p = tensor(7)
  78. @jit.trace(symbolic=True)
  79. def f(x):
  80. return x * p
  81. f.trace(0)
  82. with mkstemp() as out:
  83. f.dump(out)
  84. g = load_and_compile(out)
  85. np.testing.assert_allclose(g([1, 2, 3]), [7, 14, 21])
  86. def test_dump_volatile():
  87. p = tensor(7)
  88. @jit.trace(symbolic=True)
  89. def f(x):
  90. return x * p
  91. f.trace(0)
  92. with mkstemp() as out:
  93. f.dump(out)
  94. cg, _, outputs = mgb.load_comp_graph_from_file(out)
  95. (out,) = outputs
  96. assert mgb.cgtools.get_type(mgb.cgtools.get_inputs(out)[1]) == "SharedDeviceTensor"
  97. def test_shape_tracing():
  98. for symbolic in [False, True]:
  99. @jit.trace(symbolic=symbolic)
  100. def f(x):
  101. a, b = x.shape
  102. return a * b
  103. assert f(np.zeros([4, 3], dtype="float32")).item() == 12
  104. assert f(np.zeros([6, 4], dtype="float32")).item() == 24
  105. def test_shape_infer():
  106. @jit.trace(symbolic=True)
  107. def f(x):
  108. a, b = x.shape
  109. return sum(x[i] for i in range(a))
  110. x = np.random.randn(3, 10).astype("float32")
  111. assertTensorClose(f(x), x.sum(0))
  112. x = np.random.randn(4, 10).astype("float32")
  113. assertTensorClose(f(x), x[:3].sum(0))
  114. def test_dump_bn_fused():
  115. class ConvBNReLU(M.Sequential):
  116. def __init__(self):
  117. super(ConvBNReLU, self).__init__(
  118. M.Conv2d(3, 4, 3, 1, 1, groups=1, bias=False),
  119. M.BatchNorm2d(4),
  120. M.ReLU(),
  121. )
  122. net = ConvBNReLU()
  123. net.eval()
  124. @jit.trace(symbolic=True)
  125. def fun(data):
  126. return net(data)
  127. data = np.random.random([1, 3, 224, 224]).astype(np.float32)
  128. fun.trace(data)
  129. with mkstemp() as out:
  130. fun.dump(out, optimize_for_inference=True)
  131. cg, _, outputs = mgb.load_comp_graph_from_file(out)
  132. (out,) = outputs
  133. inputs = mgb.cgtools.get_inputs(out)
  134. assert len(inputs) == 2 and (
  135. mgb.cgtools.get_type(inputs[0]) == "MultipleDeviceTensorHolder"
  136. and mgb.cgtools.get_type(inputs[1]) == "ConvolutionForward"
  137. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)