You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tracing.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import inspect
  10. import io
  11. import itertools
  12. from tempfile import mkstemp
  13. import numpy as np
  14. import pytest
  15. import megengine.core.tensor.megbrain_graph as G
  16. import megengine.functional as F
  17. import megengine.optimizer as optim
  18. import megengine.utils.comp_graph_tools as cgtools
  19. from megengine import Parameter, tensor
  20. from megengine.autodiff import GradManager
  21. from megengine.core._trace_option import set_symbolic_shape
  22. from megengine.core.ops import builtin as ops
  23. from megengine.core.ops.builtin import Elemwise
  24. from megengine.core.tensor.utils import isscalar
  25. from megengine.functional import exp, log
  26. from megengine.jit import exclude_from_trace, trace
  27. from megengine.module import Module
  28. from megengine.random import normal, uniform
  29. from megengine.utils.naming import auto_naming
  30. @pytest.mark.parametrize("trace_mode", [False, True])
  31. @pytest.mark.parametrize("return_mode", ["Value", "Tuple", "List", "Dict"])
  32. def test_trace(trace_mode, return_mode):
  33. @trace(symbolic=trace_mode)
  34. def f(x):
  35. if return_mode == "Tuple":
  36. return (-x,)
  37. elif return_mode == "List":
  38. return [-x]
  39. elif return_mode == "Dict":
  40. return {"neg": -x}
  41. else:
  42. return -x
  43. def get_numpy(y):
  44. if return_mode == "Tuple" or return_mode == "List":
  45. return y[0].numpy()
  46. elif return_mode == "Dict":
  47. return y["neg"].numpy()
  48. return y.numpy()
  49. x = tensor([1])
  50. y = get_numpy(f(x))
  51. for i in range(3):
  52. np.testing.assert_equal(get_numpy(f(x)), y)
  53. def test_output_copy_trace():
  54. class Simple(Module):
  55. def __init__(self):
  56. super().__init__()
  57. self.a = Parameter([1.0], dtype=np.float32)
  58. def forward(self, x):
  59. x = x * self.a
  60. # will result into a copy of output in grad
  61. x = F.exp(x)
  62. return x
  63. ys = {False: [], True: []}
  64. for symbolic in [False, True]:
  65. net = Simple()
  66. gm = GradManager().attach(net.parameters())
  67. opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
  68. data = tensor(np.arange(4).reshape(2, 2), dtype="float32")
  69. @trace(symbolic=symbolic)
  70. def train_func(d):
  71. with gm:
  72. loss = net(d)
  73. gm.backward(loss)
  74. opt.step().clear_grad()
  75. return loss
  76. for i in range(3):
  77. y = train_func(data).numpy()
  78. ys[symbolic].append(y)
  79. for i in range(3):
  80. np.testing.assert_equal(ys[False][i], ys[True][i])
  81. @pytest.mark.parametrize("trace_mode", [False, True])
  82. def test_exclude_from_trace(trace_mode):
  83. @trace(symbolic=trace_mode)
  84. def f(x):
  85. x = -x
  86. with exclude_from_trace():
  87. if i % 2:
  88. x = -x
  89. x = -x
  90. return x
  91. x = tensor([1])
  92. for i in range(3):
  93. y = f(x).numpy()
  94. np.testing.assert_equal(f(x).numpy(), y)
  95. def test_print_in_trace():
  96. for symbolic in [False]: # cannot read value in symbolic mode
  97. @trace(symbolic=symbolic)
  98. def f(x):
  99. nonlocal buf
  100. x = -x
  101. buf = x.numpy()
  102. x = -x
  103. return x
  104. buf = None
  105. x = tensor([1])
  106. for i in range(3):
  107. y = f(x).numpy()
  108. z = buf
  109. buf = None
  110. np.testing.assert_equal(f(x).numpy(), y)
  111. np.testing.assert_equal(z, buf)
  112. def test_dump():
  113. @trace(symbolic=True, capture_as_const=True)
  114. def f(a, b):
  115. return a + b
  116. # prevent from remaining scope from exception test
  117. auto_naming.clear()
  118. a = tensor([2])
  119. b = tensor([4])
  120. y = f(a, b).numpy()
  121. for i in range(3):
  122. np.testing.assert_equal(f(a, b).numpy(), y)
  123. file = io.BytesIO()
  124. dump_info = f.dump(file)
  125. assert dump_info.nr_opr == 3
  126. np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"])
  127. np.testing.assert_equal(dump_info.outputs, ["ADD"])
  128. file.seek(0)
  129. infer_cg = cgtools.GraphInference(file)
  130. result = list((infer_cg.run(a, b)).values())[0]
  131. np.testing.assert_equal(result[0], y)
  132. def test_capture_dump():
  133. a = tensor([2])
  134. @trace(symbolic=True, capture_as_const=True)
  135. def f(x):
  136. return x * a
  137. x = tensor([3])
  138. y = f(x).numpy()
  139. for i in range(3):
  140. np.testing.assert_equal(f(x).numpy(), y)
  141. file = io.BytesIO()
  142. f.dump(file)
  143. file.seek(0)
  144. infer_cg = cgtools.GraphInference(file)
  145. result = list((infer_cg.run(x)).values())[0]
  146. np.testing.assert_equal(result[0], y)
  147. def test_dump_volatile():
  148. p = tensor([2])
  149. @trace(symbolic=True, capture_as_const=True)
  150. def f(x):
  151. return x * p
  152. x = tensor([3])
  153. y = f(x).numpy()
  154. for i in range(3):
  155. np.testing.assert_equal(f(x).numpy(), y)
  156. file = io.BytesIO()
  157. f.dump(file, optimize_for_inference=False)
  158. file.seek(0)
  159. cg, _, outputs = G.load_graph(file)
  160. (out,) = outputs
  161. assert (
  162. cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1])
  163. == "ImmutableTensor"
  164. )
  165. @pytest.mark.parametrize("trace_mode", [False, True])
  166. def test_trace_profiler(trace_mode):
  167. @trace(symbolic=trace_mode, profiling=True)
  168. def f(x):
  169. return -x
  170. x = tensor([1])
  171. y = f(x).numpy()
  172. f(x)
  173. f(x) # XXX: has to run twice
  174. out = f.get_profile()
  175. assert out.get("profiler")
  176. @pytest.mark.skip(reason="force opt_level=0 when building graph")
  177. def test_goptions():
  178. @trace(symbolic=True, opt_level=0, capture_as_const=True)
  179. def f(x):
  180. # directly return x / x will not trigger gopt
  181. # since there's no way to tell the two x are the same
  182. y = 2.0 * x
  183. return y / y
  184. @trace(symbolic=True, opt_level=1, capture_as_const=True)
  185. def g(x):
  186. y = 2.0 * x
  187. return y / y
  188. d = tensor(0.0)
  189. assert not np.isfinite(f(d).numpy())
  190. np.testing.assert_equal(g(d).numpy().item(), 1.0)
  191. @pytest.mark.skip(reason="force opt_level=0 when building graph")
  192. def test_goptions_log_sum_exp():
  193. @trace(symbolic=True, opt_level=0, capture_as_const=True)
  194. def f(x, y):
  195. return log(exp(x) + exp(y))
  196. @trace(symbolic=True, opt_level=1, capture_as_const=True)
  197. def g(x, y):
  198. return log(exp(x) + exp(y))
  199. val = 1.0e4
  200. d = tensor(val)
  201. o = tensor(0.0)
  202. assert not np.isfinite(f(d, o).numpy())
  203. np.testing.assert_almost_equal(g(d, o), val)
  204. def test_goptions_log_exp():
  205. @trace(symbolic=True, opt_level=0, capture_as_const=True)
  206. def f(x):
  207. return log(exp(x))
  208. @trace(symbolic=True, opt_level=1, capture_as_const=True)
  209. def g(x):
  210. return log(exp(x))
  211. f(tensor(1.0))
  212. _, out = mkstemp()
  213. f.dump(out, optimize_for_inference=False)
  214. *_, outputs = G.load_graph(out)
  215. oprs_1 = cgtools.get_oprs_seq(outputs)
  216. g(tensor(1.0))
  217. g.dump(out, optimize_for_inference=False)
  218. *_, outputs = G.load_graph(out)
  219. oprs_2 = cgtools.get_oprs_seq(outputs)
  220. assert len(oprs_1) - len(oprs_2) == 2
  221. def test_optimize_for_inference():
  222. @trace(symbolic=True, capture_as_const=True)
  223. def f(x):
  224. return exp(x)
  225. _, out = mkstemp()
  226. f(tensor(5.0))
  227. f.dump(out, enable_io16xc32=True)
  228. res = G.load_graph(out)
  229. computing_input = res.output_vars_list[0].owner.inputs[0]
  230. assert computing_input.dtype == np.float16
  231. def test_optimize_for_inference_broadcast():
  232. a = tensor(np.ones(1, dtype=np.float32))
  233. @trace(capture_as_const=True, symbolic_shape=True)
  234. def f():
  235. return a._broadcast(tensor([1, 10], dtype=np.int32))
  236. f()
  237. f.dump(io.BytesIO())
  238. def test_trace_cvt_bool():
  239. x = tensor([0], dtype=np.int32)
  240. @trace(symbolic=True)
  241. def f(x):
  242. a = x.shape
  243. b = a[0]
  244. assert isscalar(b)
  245. return b == 0
  246. for i in range(3):
  247. np.testing.assert_equal(f(x).numpy(), False)
  248. @pytest.mark.parametrize("trace_mode", [False, True])
  249. def test_trace_reshape(trace_mode):
  250. x1 = tensor(np.random.randn(2, 10, 10))
  251. x2 = tensor(np.random.randn(4, 10, 10))
  252. x3 = tensor(np.random.randn(8, 10, 10))
  253. @trace(symbolic=trace_mode, capture_as_const=True)
  254. def f(x):
  255. y = x.reshape(x.shape[0], 100)
  256. return y
  257. f(x1)
  258. f(x2)
  259. f(x3)
  260. def test_trace_topk():
  261. x = tensor([5, 2, 7, 1, 0, 3, 2])
  262. @trace(symbolic=True)
  263. def f(x):
  264. y = F.topk(x, 3)
  265. np.testing.assert_equal(y[0].shape.numpy(), np.array([3,]))
  266. return y
  267. for i in range(3):
  268. f(x)
  269. def test_trace_warp_perspective():
  270. inp_shape = (1, 1, 4, 4)
  271. x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  272. M_shape = (1, 3, 3)
  273. M = tensor(
  274. np.array(
  275. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  276. ).reshape(M_shape)
  277. )
  278. @trace(symbolic=True)
  279. def f(x, M):
  280. out = F.vision.warp_perspective(x, M, (2, 2))
  281. np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2]))
  282. return out
  283. for i in range(3):
  284. f(x, M)
  285. def test_raise_on_trace():
  286. step_count = 0
  287. catch_count = 0
  288. bad_step = 10
  289. class CatchMe(Exception):
  290. pass
  291. a = tensor([1, 2, 3, 4])
  292. b = tensor([5, 6, 7, 8])
  293. c = tensor([9, 0, 1, 2])
  294. @trace
  295. def add_abc(a, b, c):
  296. ps = a + b
  297. result = ps + c
  298. if step_count == bad_step:
  299. raise CatchMe("catch me")
  300. return result
  301. for i in range(100):
  302. try:
  303. d = add_abc(a, b, c)
  304. except CatchMe as e:
  305. catch_count += 1
  306. else:
  307. np.testing.assert_equal(d.numpy(), (a + b + c).numpy())
  308. step_count += 1
  309. assert catch_count == 1
  310. @pytest.mark.parametrize("trace_mode", [False, True])
  311. def test_trace_broadcast(trace_mode):
  312. x1 = tensor(np.random.randn(3, 1, 1))
  313. x2 = tensor(np.random.randn(1, 4, 1))
  314. x3 = tensor(np.random.randn(1, 1, 5))
  315. @trace(symbolic=trace_mode, capture_as_const=True)
  316. def f(x):
  317. y = F.broadcast_to(x, (3, 4, 5))
  318. return y
  319. f(x1)
  320. f(x2)
  321. f(x3)
  322. def test_trace_nms():
  323. def make_inputs(n):
  324. boxes = np.zeros((n, 4))
  325. boxes[:, :2] = np.random.rand(n, 2) * 100
  326. boxes[:, 2:] = np.random.rand(n, 2) * 100 + 100
  327. scores = np.random.rand(n)
  328. return tensor(boxes), tensor(scores)
  329. @trace(symbolic=False)
  330. def f(boxes, scores):
  331. # with tracing, max_output must be specified
  332. results = F.vision.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20)
  333. # without tracing, max output can be inferred inside nms
  334. with exclude_from_trace():
  335. _ = F.vision.nms(boxes, scores=scores, iou_thresh=0.5)
  336. return results
  337. f(*make_inputs(10))
  338. f(*make_inputs(20))
  339. f(*make_inputs(30))
  340. def test_trace_valid_broadcast():
  341. x1 = tensor(np.random.randn(1, 1))
  342. x2 = tensor(np.random.randn(1, 2))
  343. shape = (tensor([2]), tensor([2]))
  344. @trace(symbolic=False)
  345. def f(x, shape):
  346. y = F.broadcast_to(x, shape)
  347. return y
  348. f(x1, shape)
  349. f(x2, shape)
  350. def test_clip():
  351. x = tensor(np.random.randn(10, 10))
  352. @trace(symbolic=True)
  353. def f(x, lower, upper):
  354. y = F.clip(x, lower, upper)
  355. return y
  356. for i in range(3):
  357. f(x, tensor([0]), tensor([1]))
  358. # test returning noncontiguous tensor from trace
  359. def test_slice():
  360. @trace
  361. def f(x):
  362. return x[:, 1::2]
  363. x = F.arange(8).reshape(2, 4)
  364. f(x)
  365. y = f(x)
  366. np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2])
  367. y + y
  368. @pytest.mark.parametrize("shape_mode", [False, True])
  369. def test_random(shape_mode):
  370. def run_test(op):
  371. @trace(symbolic=True, symbolic_shape=shape_mode)
  372. def f():
  373. out = op(size=[10, 10])
  374. out_shape = out.shape
  375. assert out_shape is not None
  376. if not isinstance(out_shape, tuple):
  377. assert out.shape.numpy() is not None
  378. return out
  379. for _ in range(3):
  380. f()
  381. run_test(uniform)
  382. run_test(normal)
  383. @pytest.mark.parametrize("shape_mode", [False, True])
  384. def test_trace_advance_indexing(shape_mode):
  385. funcs = [
  386. lambda x, i: x[i],
  387. # lambda x, i, j: x[i, j], # FIXME
  388. lambda x, i, j: x[i, :, j, ...],
  389. # lambda x, start, end: x[start:end], # FIXME
  390. lambda x, start, end: x[:, 0, start:end, ..., 1],
  391. lambda x, vec: x[vec],
  392. lambda x, vec: x[vec, ..., 0, 1:3],
  393. lambda x, vec: x[vec, vec[0], vec[1]],
  394. # lambda x, i, start, end, vec: x[i, ..., :, vec, start:end], # FIXME
  395. lambda x, mask: x[mask],
  396. ]
  397. inputs = {
  398. "x": np.random.randn(5, 5, 5, 5, 5).astype("float32"),
  399. "i": 0,
  400. "j": 2,
  401. "start": 1,
  402. "end": 3,
  403. "vec": [1, 2, 3],
  404. "mask": np.random.randn(5, 5, 5, 5, 5) >= 0,
  405. }
  406. for f in funcs:
  407. sig = inspect.signature(f)
  408. param_names = list(sig._parameters.keys())
  409. params = {}
  410. params_np = {}
  411. f_traced = trace(f, symbolic=False, symbolic_shape=shape_mode)
  412. for name in param_names:
  413. params[name] = tensor(inputs[name])
  414. params_np[name] = inputs[name]
  415. expected = f(**params_np)
  416. result_imperative = f(**params)
  417. np.testing.assert_equal(expected, result_imperative.numpy())
  418. for _ in range(3):
  419. result_trace = f_traced(**params)
  420. np.testing.assert_equal(expected, result_trace.numpy())

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台