You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tracing.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import io
  10. from tempfile import mkstemp
  11. import numpy as np
  12. import pytest
  13. import megengine.core.tensor.megbrain_graph as G
  14. import megengine.functional as F
  15. from megengine import cgtools, tensor
  16. from megengine.core._trace_option import set_tensor_shape
  17. from megengine.core.ops import builtin as ops
  18. from megengine.core.tensor.core import apply
  19. from megengine.core.tensor.raw_tensor import as_raw_tensor
  20. from megengine.functional import exp, log
  21. from megengine.jit import exclude_from_trace, trace
  22. def test_trace():
  23. for symbolic in [False, True]:
  24. @trace(symbolic=symbolic)
  25. def f(x):
  26. op = ops.Elemwise(mode="negate")
  27. (y,) = apply(op, x)
  28. return y
  29. x = as_raw_tensor([1]).numpy()
  30. y = f.__wrapped__(as_raw_tensor(x)).numpy()
  31. for i in range(3):
  32. np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
  33. def test_exclude_from_trace():
  34. for symbolic in [False, True]:
  35. @trace(symbolic=symbolic)
  36. def f(x):
  37. neg = ops.Elemwise(mode="negate")
  38. (x,) = apply(neg, x)
  39. with exclude_from_trace():
  40. if i % 2:
  41. (x,) = apply(neg, x)
  42. (x,) = apply(neg, x)
  43. return x
  44. x = as_raw_tensor([1]).numpy()
  45. for i in range(3):
  46. y = f.__wrapped__(as_raw_tensor(x)).numpy()
  47. np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
  48. def test_print_in_trace():
  49. for symbolic in [False]: # cannot read value in symbolic mode
  50. @trace(symbolic=symbolic)
  51. def f(x):
  52. nonlocal buf
  53. neg = ops.Elemwise(mode="negate")
  54. (x,) = apply(neg, x)
  55. buf = x.numpy()
  56. (x,) = apply(neg, x)
  57. return x
  58. buf = None
  59. x = as_raw_tensor([1]).numpy()
  60. for i in range(3):
  61. y = f.__wrapped__(as_raw_tensor(x)).numpy()
  62. z = buf
  63. buf = None
  64. np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
  65. np.testing.assert_equal(z, buf)
  66. def test_dump():
  67. @trace(symbolic=True, capture_as_const=True)
  68. def f(a, b):
  69. op = ops.Elemwise(mode="add")
  70. (y,) = apply(op, a, b)
  71. return y
  72. a = as_raw_tensor([2]).numpy()
  73. b = as_raw_tensor([4]).numpy()
  74. y = f.__wrapped__(as_raw_tensor(a), as_raw_tensor(b)).numpy()
  75. for i in range(3):
  76. np.testing.assert_equal(f(as_raw_tensor(a), as_raw_tensor(b)).numpy(), y)
  77. file = io.BytesIO()
  78. dump_info = f.dump(file)
  79. assert dump_info.nr_opr == 3
  80. np.testing.assert_equal(dump_info.inputs, ["h2d[0]", "h2d[2]"])
  81. np.testing.assert_equal(dump_info.outputs, ["ADD(h2d[0],h2d[2])[4]"])
  82. file.seek(0)
  83. result = cgtools.load_and_inference(file, [a, b])
  84. np.testing.assert_equal(result[0], y)
  85. def test_capture_dump():
  86. a = as_raw_tensor([2])
  87. @trace(symbolic=True, capture_as_const=True)
  88. def f(x):
  89. op = ops.Elemwise(mode="mul")
  90. (y,) = apply(op, x, a)
  91. return y
  92. x = as_raw_tensor([3]).numpy()
  93. y = f.__wrapped__(as_raw_tensor(x)).numpy()
  94. for i in range(3):
  95. np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
  96. file = io.BytesIO()
  97. f.dump(file)
  98. file.seek(0)
  99. result = cgtools.load_and_inference(file, [x])
  100. np.testing.assert_equal(result[0], y)
  101. def test_dump_volatile():
  102. p = as_raw_tensor([2])
  103. @trace(symbolic=True, capture_as_const=True)
  104. def f(x):
  105. op = ops.Elemwise(mode="mul")
  106. (y,) = apply(op, x, p)
  107. return y
  108. x = as_raw_tensor([3]).numpy()
  109. y = f.__wrapped__(as_raw_tensor(x)).numpy()
  110. for i in range(3):
  111. np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
  112. file = io.BytesIO()
  113. f.dump(file, optimize_for_inference=False)
  114. file.seek(0)
  115. cg, _, outputs = G.load_graph(file)
  116. (out,) = outputs
  117. assert (
  118. cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1])
  119. == "ImmutableTensor"
  120. )
  121. def test_trace_profiler():
  122. for symbolic in [False, True]:
  123. @trace(symbolic=symbolic, profiling=True)
  124. def f(x):
  125. op = ops.Elemwise(mode="negate")
  126. (y,) = apply(op, x)
  127. return y
  128. x = as_raw_tensor([1]).numpy()
  129. y = f.__wrapped__(as_raw_tensor(x)).numpy()
  130. f(as_raw_tensor(x))
  131. f(as_raw_tensor(x)) # XXX: has to run twice
  132. out = f.get_profile()
  133. assert out.get("profiler")
  134. @pytest.mark.skip(reason="could not disable opt_level")
  135. def test_goptions_log_exp():
  136. @trace(symbolic=True, opt_level=0, capture_as_const=True)
  137. def f(x):
  138. return log(exp(x))
  139. @trace(symbolic=True, opt_level=1, capture_as_const=True)
  140. def g(x):
  141. return log(exp(x))
  142. f(tensor(1.0))
  143. _, out = mkstemp()
  144. f.dump(out, optimize_for_inference=False)
  145. *_, outputs = G.load_graph(out)
  146. oprs_1 = cgtools.get_oprs_seq(outputs)
  147. g(tensor(1.0))
  148. g.dump(out, optimize_for_inference=False)
  149. *_, outputs = G.load_graph(out)
  150. oprs_2 = cgtools.get_oprs_seq(outputs)
  151. assert len(oprs_1) - len(oprs_2) == 2
  152. @pytest.mark.skip(reason="could not disable opt_level")
  153. def test_goptions_log_sum_exp():
  154. @trace(symbolic=True, opt_level=0, capture_as_const=True)
  155. def f(x, y):
  156. return log(exp(x) + exp(y))
  157. @trace(symbolic=True, opt_level=1, capture_as_const=True)
  158. def g(x, y):
  159. return log(exp(x) + exp(y))
  160. f(tensor(1.0), tensor(2.0))
  161. _, out = mkstemp()
  162. f.dump(out, optimize_for_inference=False)
  163. *_, outputs = G.load_graph(out)
  164. oprs_1 = cgtools.get_oprs_seq(outputs)
  165. g(tensor(1.0), tensor(2.0))
  166. g.dump(out, optimize_for_inference=False)
  167. *_, outputs = G.load_graph(out)
  168. oprs_2 = cgtools.get_oprs_seq(outputs)
  169. assert len(oprs_1) - len(oprs_2) == 2
  170. def test_optimize_for_inference():
  171. @trace(symbolic=True, capture_as_const=True)
  172. def f(x):
  173. return exp(x)
  174. _, out = mkstemp()
  175. f(tensor(5.0))
  176. f.dump(out, enable_io16xc32=True)
  177. res = G.load_graph(out)
  178. computing_input = res.output_vars_list[0].owner.inputs[0]
  179. assert computing_input.dtype == np.float16
  180. def test_optimize_for_inference_broadcast():
  181. a = tensor(np.ones(1, dtype=np.float32))
  182. @trace(capture_as_const=True, tensor_shape=True)
  183. def f():
  184. (b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32))
  185. return b
  186. f()
  187. f.dump(io.BytesIO())
  188. def test_trace_cvt_bool():
  189. set_tensor_shape(True)
  190. x = tensor([0], dtype=np.int32)
  191. @trace(symbolic=True)
  192. def f(x):
  193. return x.shape[0] == 0
  194. for i in range(3):
  195. np.testing.assert_equal(f(x).numpy()[0], False)
  196. def test_trace_reshape():
  197. for symbolic in [False, True]:
  198. set_tensor_shape(True)
  199. x1 = tensor(np.random.randn(2, 10, 10))
  200. x2 = tensor(np.random.randn(4, 10, 10))
  201. x3 = tensor(np.random.randn(8, 10, 10))
  202. @trace(symbolic=symbolic, capture_as_const=True)
  203. def f(x):
  204. y = x.reshape(x.shape[0], 100)
  205. return y
  206. f(x1)
  207. f(x2)
  208. f(x3)
  209. def test_trace_topk():
  210. x = tensor([5, 2, 7, 1, 0, 3, 2])
  211. @trace(symbolic=True)
  212. def f(x):
  213. y = F.topk(x, 3)
  214. np.testing.assert_equal(y[0].shape.numpy(), np.array([3,]))
  215. return y
  216. for i in range(3):
  217. f(x)
  218. def test_trace_warp_perspective():
  219. inp_shape = (1, 1, 4, 4)
  220. x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  221. M_shape = (1, 3, 3)
  222. M = tensor(
  223. np.array(
  224. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  225. ).reshape(M_shape)
  226. )
  227. @trace(symbolic=True)
  228. def f(x, M):
  229. out = F.warp_perspective(x, M, (2, 2))
  230. np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2]))
  231. return out
  232. for i in range(1):
  233. f(x, M)
  234. def test_raise_on_trace():
  235. step_count = 0
  236. catch_count = 0
  237. bad_step = 10
  238. class CatchMe(Exception):
  239. pass
  240. a = tensor([1, 2, 3, 4])
  241. b = tensor([5, 6, 7, 8])
  242. c = tensor([9, 0, 1, 2])
  243. @trace
  244. def add_abc(a, b, c):
  245. print("Hello")
  246. ps = a + b
  247. result = ps + c
  248. if step_count == bad_step:
  249. raise CatchMe("catch me")
  250. return result
  251. for i in range(100):
  252. try:
  253. d = add_abc(a, b, c)
  254. except CatchMe as e:
  255. catch_count += 1
  256. else:
  257. np.testing.assert_equal(d.numpy(), (a + b + c).numpy())
  258. step_count += 1
  259. assert catch_count == 1
  260. def test_trace_broadcast():
  261. for symbolic in [False, True]:
  262. set_tensor_shape(True)
  263. x1 = tensor(np.random.randn(3, 1, 1))
  264. x2 = tensor(np.random.randn(1, 4, 1))
  265. x3 = tensor(np.random.randn(1, 1, 5))
  266. @trace(symbolic=symbolic, capture_as_const=True)
  267. def f(x):
  268. y = F.broadcast_to(x, (3, 4, 5))
  269. return y
  270. f(x1)
  271. f(x2)
  272. f(x3)
  273. def test_trace_nms():
  274. def make_inputs(n):
  275. boxes = np.zeros((n, 4))
  276. boxes[:, :2] = np.random.rand(n, 2) * 100
  277. boxes[:, 2:] = np.random.rand(n, 2) * 100 + 100
  278. scores = np.random.rand(n)
  279. return tensor(boxes), tensor(scores)
  280. @trace(symbolic=False)
  281. def f(boxes, scores):
  282. results = F.nn.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20)
  283. with exclude_from_trace():
  284. _ = F.nn.nms(boxes, scores=scores, iou_thresh=0.5)
  285. return results
  286. f(*make_inputs(10))
  287. f(*make_inputs(20))
  288. f(*make_inputs(30))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台