You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tracing.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import inspect
  10. import io
  11. import itertools
  12. import random
  13. from tempfile import mkstemp
  14. import numpy as np
  15. import pytest
  16. import megengine.core.tensor.megbrain_graph as G
  17. import megengine.functional as F
  18. import megengine.optimizer as optim
  19. import megengine.utils.comp_graph_tools as cgtools
  20. from megengine import Parameter, tensor
  21. from megengine.autodiff import GradManager
  22. from megengine.core._trace_option import set_symbolic_shape
  23. from megengine.core.ops import builtin as ops
  24. from megengine.core.ops.builtin import Elemwise
  25. from megengine.core.tensor.utils import isscalar
  26. from megengine.functional import exp, log
  27. from megengine.jit import GraphOptimizationConfig, TraceError, exclude_from_trace, trace
  28. from megengine.module import Module
  29. from megengine.random import normal, uniform
  30. from megengine.utils.naming import AutoNaming
  31. @pytest.mark.parametrize("trace_mode", [False, True])
  32. @pytest.mark.parametrize("return_mode", ["Value", "Tuple", "List", "Dict"])
  33. def test_trace(trace_mode, return_mode):
  34. @trace(symbolic=trace_mode)
  35. def f(x):
  36. if return_mode == "Tuple":
  37. return (-x,)
  38. elif return_mode == "List":
  39. return [-x]
  40. elif return_mode == "Dict":
  41. return {"neg": -x}
  42. else:
  43. return -x
  44. def get_numpy(y):
  45. if return_mode == "Tuple" or return_mode == "List":
  46. return y[0].numpy()
  47. elif return_mode == "Dict":
  48. return y["neg"].numpy()
  49. return y.numpy()
  50. x = tensor([1])
  51. y = get_numpy(f(x))
  52. for i in range(3):
  53. np.testing.assert_equal(get_numpy(f(x)), y)
  54. def test_output_copy_trace():
  55. class Simple(Module):
  56. def __init__(self):
  57. super().__init__()
  58. self.a = Parameter([1.0], dtype=np.float32)
  59. def forward(self, x):
  60. x = x * self.a
  61. # will result into a copy of output in grad
  62. x = F.exp(x)
  63. return x
  64. ys = {False: [], True: []}
  65. for symbolic in [False, True]:
  66. net = Simple()
  67. gm = GradManager().attach(net.parameters())
  68. opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
  69. data = tensor(np.arange(4).reshape(2, 2), dtype="float32")
  70. @trace(symbolic=symbolic)
  71. def train_func(d):
  72. with gm:
  73. loss = net(d)
  74. gm.backward(loss)
  75. opt.step().clear_grad()
  76. return loss
  77. for i in range(3):
  78. y = train_func(data).numpy()
  79. ys[symbolic].append(y)
  80. for i in range(3):
  81. np.testing.assert_equal(ys[False][i], ys[True][i])
  82. @pytest.mark.parametrize("trace_mode", [False, True])
  83. def test_tensor_detach(trace_mode):
  84. @trace(symbolic=True)
  85. def f(x):
  86. y = x.detach() ** 2
  87. z = y.detach() + 1
  88. return z.detach()
  89. x = tensor([1, 2, 3, 4])
  90. for _ in range(3):
  91. f(x).numpy()
  92. @pytest.mark.parametrize("trace_mode", [False, True])
  93. def test_exclude_from_trace(trace_mode):
  94. @trace(symbolic=trace_mode)
  95. def f(x):
  96. x = -x
  97. with exclude_from_trace():
  98. if i % 2:
  99. x = -x
  100. x = -x
  101. return x
  102. x = tensor([1])
  103. for i in range(3):
  104. y = f(x).numpy()
  105. np.testing.assert_equal(f(x).numpy(), y)
  106. @pytest.mark.parametrize("trace_mode", [False, True])
  107. def test_elemwise_fuse(trace_mode):
  108. # explicitly declare opt_level as 2
  109. @trace(symbolic=trace_mode, opt_level=2)
  110. def f(a, b):
  111. base = 0
  112. c = b - a
  113. _, idx = F.topk(c, 3)
  114. # internally, biased_idx will be idx as gopt will ignore the addition
  115. biased_idx = base + idx
  116. return biased_idx
  117. a = tensor(np.ones((7, 2)), dtype=np.int32)
  118. b = tensor(2 * np.ones((7, 2)), dtype=np.float32)
  119. for i in range(3):
  120. y = f(a, b)
  121. y.numpy()
  122. @pytest.mark.parametrize("trace_mode", [False, True])
  123. def test_elemwise_fuse_in_grad(trace_mode):
  124. w = Parameter(np.ones([4, 6]), dtype="float32")
  125. gm = GradManager().attach(w)
  126. opt = optim.SGD([w], lr=0.01, momentum=0.9, weight_decay=5e-4)
  127. # explicitly declare opt_level as 2
  128. @trace(symbolic=trace_mode, opt_level=2)
  129. def f():
  130. with gm:
  131. wm = F.sum(w ** 2, axis=1) ** 0.5
  132. loss = wm.mean()
  133. gm.backward(loss)
  134. opt.step().clear_grad()
  135. return loss
  136. for i in range(3):
  137. y = f()
  138. y.numpy()
  139. def test_print_in_trace():
  140. for symbolic in [False]: # cannot read value in symbolic mode
  141. @trace(symbolic=symbolic)
  142. def f(x):
  143. nonlocal buf
  144. x = -x
  145. buf = x.numpy()
  146. x = -x
  147. return x
  148. buf = None
  149. x = tensor([1])
  150. for i in range(3):
  151. y = f(x).numpy()
  152. z = buf
  153. buf = None
  154. np.testing.assert_equal(f(x).numpy(), y)
  155. np.testing.assert_equal(z, buf)
  156. @pytest.mark.parametrize(
  157. "dump_format",
  158. [
  159. "FBS",
  160. ],
  161. )
  162. def test_dump(dump_format):
  163. @trace(symbolic=True, capture_as_const=True)
  164. def f(a, b):
  165. return a + b
  166. # prevent from remaining scope from exception test
  167. AutoNaming.clear()
  168. a = tensor([2])
  169. b = tensor([4])
  170. y = f(a, b).numpy()
  171. for i in range(3):
  172. np.testing.assert_equal(f(a, b).numpy(), y)
  173. file = io.BytesIO()
  174. dump_info = f.dump(file, dump_format=dump_format)
  175. assert dump_info.nr_opr == 3
  176. np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"])
  177. np.testing.assert_equal(dump_info.outputs, ["ADD"])
  178. file.seek(0)
  179. infer_cg = cgtools.GraphInference(file)
  180. result = list((infer_cg.run(a, b)).values())[0]
  181. np.testing.assert_equal(result[0], y)
  182. def test_capture_dump():
  183. a = tensor([2])
  184. @trace(symbolic=True, capture_as_const=True)
  185. def f(x):
  186. return x * a
  187. x = tensor([3])
  188. y = f(x).numpy()
  189. for i in range(3):
  190. np.testing.assert_equal(f(x).numpy(), y)
  191. file = io.BytesIO()
  192. f.dump(file)
  193. file.seek(0)
  194. infer_cg = cgtools.GraphInference(file)
  195. result = list((infer_cg.run(x)).values())[0]
  196. np.testing.assert_equal(result[0], y)
  197. def test_dump_volatile():
  198. p = tensor([2])
  199. @trace(symbolic=True, capture_as_const=True)
  200. def f(x):
  201. return x * p
  202. x = tensor([3])
  203. y = f(x).numpy()
  204. for i in range(3):
  205. np.testing.assert_equal(f(x).numpy(), y)
  206. file = io.BytesIO()
  207. f.dump(file, optimize_for_inference=False)
  208. file.seek(0)
  209. (out,) = G.load_graph(file).output_vars_list
  210. assert (
  211. cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1])
  212. == "ImmutableTensor"
  213. )
  214. def test_dump_backward_graph():
  215. x0 = tensor(np.random.randn(3, 4))
  216. x1 = tensor(np.random.randn(3, 4))
  217. gm = GradManager().attach(x0)
  218. @trace(symbolic=True, capture_as_const=True)
  219. def f(x0, x1):
  220. with gm:
  221. y = x0 * x1
  222. gm.backward(y, F.ones_like(y))
  223. dx0 = x0.grad
  224. return y, dx0
  225. y, dx0 = f(x0, x1)
  226. np.testing.assert_equal(dx0.numpy(), x1)
  227. file = io.BytesIO()
  228. f.dump(file, optimize_for_inference=False)
  229. file.seek(0)
  230. infer_cg = cgtools.GraphInference(file)
  231. results = list((infer_cg.run(x0, x1)).values())
  232. np.testing.assert_equal(results[0], y)
  233. np.testing.assert_equal(results[1], dx0)
  234. def test_dump_with_testcase():
  235. @trace(symbolic=True, capture_as_const=True)
  236. def f(x):
  237. return exp(x)
  238. f(tensor(1.0))
  239. file = io.BytesIO()
  240. f.dump(file, input_data=["#rand(0, 255, 1)"])
  241. @pytest.mark.parametrize("trace_mode", [False, True])
  242. def test_trace_profiler(trace_mode):
  243. @trace(symbolic=trace_mode, profiling=True)
  244. def f(x):
  245. return -x
  246. x = tensor([1])
  247. y = f(x).numpy()
  248. f(x)
  249. f(x) # XXX: has to run twice
  250. out = f.get_profile()
  251. assert out.get("profiler")
  252. def test_goptions():
  253. @trace(symbolic=True, opt_level=0, capture_as_const=True)
  254. def f(x):
  255. # directly return x / x will not trigger gopt
  256. # since there's no way to tell the two x are the same
  257. y = 2.0 * x
  258. return y / y
  259. @trace(symbolic=True, opt_level=1, capture_as_const=True)
  260. def g(x):
  261. y = 2.0 * x
  262. return y / y
  263. d = tensor(0.0)
  264. assert not np.isfinite(f(d).numpy())
  265. np.testing.assert_equal(g(d).numpy().item(), 1.0)
  266. def test_goptions_log_sum_exp():
  267. @trace(symbolic=True, opt_level=0, capture_as_const=True)
  268. def f(x, y):
  269. return log(exp(x) + exp(y))
  270. @trace(symbolic=True, opt_level=1, capture_as_const=True)
  271. def g(x, y):
  272. return log(exp(x) + exp(y))
  273. val = 1.0e4
  274. d = tensor(val)
  275. o = tensor(0.0)
  276. assert not np.isfinite(f(d, o).numpy())
  277. np.testing.assert_almost_equal(g(d, o), val)
  278. def test_goptions_log_exp():
  279. @trace(symbolic=True, opt_level=0, capture_as_const=True)
  280. def f(x):
  281. return log(exp(x))
  282. @trace(symbolic=True, opt_level=1, capture_as_const=True)
  283. def g(x):
  284. return log(exp(x))
  285. f(tensor(1.0))
  286. _, out = mkstemp()
  287. f.dump(out, optimize_for_inference=False)
  288. outputs = G.load_graph(out).output_vars_list
  289. oprs_1 = cgtools.get_oprs_seq(outputs)
  290. g(tensor(1.0))
  291. g.dump(out, optimize_for_inference=False)
  292. outputs = G.load_graph(out).output_vars_list
  293. oprs_2 = cgtools.get_oprs_seq(outputs)
  294. assert len(oprs_1) - len(oprs_2) == 2
  295. def test_optimize_for_inference():
  296. @trace(symbolic=True, capture_as_const=True)
  297. def f(x):
  298. return exp(x)
  299. _, out = mkstemp()
  300. f(tensor(5.0))
  301. f.dump(out, enable_io16xc32=True)
  302. res = G.load_graph(out)
  303. computing_input = res.output_vars_list[0].owner.inputs[0]
  304. assert computing_input.dtype == np.float16
  305. def test_optimize_for_inference_broadcast():
  306. a = tensor(np.ones(1, dtype=np.float32))
  307. @trace(capture_as_const=True, symbolic_shape=True)
  308. def f():
  309. return a._broadcast(tensor([1, 10], dtype=np.int32))
  310. f()
  311. f.dump(io.BytesIO())
  312. def test_trace_cvt_bool():
  313. x = tensor([0], dtype=np.int32)
  314. @trace(symbolic=True)
  315. def f(x):
  316. a = x.shape
  317. b = a[0]
  318. assert isscalar(b)
  319. return b == 0
  320. for i in range(3):
  321. np.testing.assert_equal(f(x).numpy(), False)
  322. @pytest.mark.parametrize("trace_mode", [False, True])
  323. def test_trace_reshape(trace_mode):
  324. x1 = tensor(np.random.randn(2, 10, 10))
  325. x2 = tensor(np.random.randn(4, 10, 10))
  326. x3 = tensor(np.random.randn(8, 10, 10))
  327. @trace(symbolic=trace_mode, capture_as_const=True)
  328. def f(x):
  329. y = x.reshape(x.shape[0], 100)
  330. return y
  331. f(x1)
  332. f(x2)
  333. f(x3)
  334. def test_trace_topk():
  335. x = tensor([5, 2, 7, 1, 0, 3, 2])
  336. @trace(symbolic=True)
  337. def f(x):
  338. y = F.topk(x, 3)
  339. np.testing.assert_equal(y[0].shape.numpy(), np.array([3,]))
  340. return y
  341. for i in range(3):
  342. f(x)
  343. def test_trace_warp_perspective():
  344. inp_shape = (1, 1, 4, 4)
  345. x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  346. M_shape = (1, 3, 3)
  347. M = tensor(
  348. np.array(
  349. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  350. ).reshape(M_shape)
  351. )
  352. @trace(symbolic=True)
  353. def f(x, M):
  354. out = F.vision.warp_perspective(x, M, (2, 2))
  355. np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2]))
  356. return out
  357. for i in range(3):
  358. f(x, M)
  359. @pytest.mark.parametrize(
  360. "normal_expr, mismatch_expr, reason",
  361. [
  362. ("a + b + c", "a + b - c", "operator mismatch"),
  363. ("a + b + 1", "a + b + 2", "tensors not equals"),
  364. ("((a + b), (b + c))[0]", "a + b", "mismature end"),
  365. ("a + b + c", "c + (a + b)", "expect internal node, got external"),
  366. ("c + (a + b)", "a + b + c", "expect external node, got internal"),
  367. ("a + b + c", "a + b + c + c", "too many instructions"),
  368. ("((a + b), (b + c))[1]", "((a + b), (b + c))[0]", "data unreadable"),
  369. ("((a + b), (b + c))[1] + a", "((a + b), (b + c))[0] + a", "input id mismatch"),
  370. ],
  371. )
  372. def test_trace_mismatch(normal_expr, mismatch_expr, reason):
  373. a = tensor([1, 2, 3, 4])
  374. b = tensor([5, 6, 7, 8])
  375. c = tensor([9, 0, 1, 2])
  376. mismatch = False
  377. @trace(symbolic=True)
  378. def fn(a, b, c):
  379. if not mismatch:
  380. result = eval(normal_expr)
  381. else:
  382. result = eval(mismatch_expr)
  383. return result
  384. for i in range(20):
  385. try:
  386. d = fn(a, b, c)
  387. except TraceError as e:
  388. assert mismatch
  389. assert str(e) == "trace error because {}".format(reason)
  390. except:
  391. pytest.fail("unexpected trace error")
  392. else:
  393. assert not mismatch
  394. np.testing.assert_equal(d.numpy(), eval(normal_expr).numpy())
  395. mismatch = random.random() > 0.8
  396. def test_exception_in_trace():
  397. a = tensor([1, 2, 3, 4])
  398. b = tensor([5, 6, 7, 8])
  399. c = tensor([9, 0, 1, 2])
  400. mismatch = False
  401. exc = Exception()
  402. @trace(symbolic=True)
  403. def fn(a, b, c):
  404. result = a + b
  405. if not mismatch:
  406. result += c
  407. else:
  408. raise exc
  409. return result
  410. for i in range(20):
  411. try:
  412. d = fn(a, b, c)
  413. except TraceError as e:
  414. pytest.fail("unexpected trace error")
  415. except Exception as e:
  416. assert mismatch
  417. assert e is exc
  418. else:
  419. assert not mismatch
  420. np.testing.assert_equal(d.numpy(), (a + b + c).numpy())
  421. mismatch = random.random() > 0.8
  422. def test_graph_error():
  423. a = tensor(np.arange(8).reshape((2, 4)))
  424. b = tensor(np.arange(8).reshape((2, 4)))
  425. @trace(symbolic=True)
  426. def fn(a, b):
  427. return a + b
  428. fn(a, b)
  429. with pytest.raises(RuntimeError):
  430. fn(a, b.transpose())
  431. fn(a, b)
  432. @pytest.mark.parametrize("trace_mode", [False, True])
  433. def test_trace_broadcast(trace_mode):
  434. x1 = tensor(np.random.randn(3, 1, 1))
  435. x2 = tensor(np.random.randn(1, 4, 1))
  436. x3 = tensor(np.random.randn(1, 1, 5))
  437. @trace(symbolic=trace_mode, capture_as_const=True)
  438. def f(x):
  439. y = F.broadcast_to(x, (3, 4, 5))
  440. return y
  441. f(x1)
  442. f(x2)
  443. f(x3)
  444. def test_trace_nms():
  445. def make_inputs(n):
  446. boxes = np.zeros((n, 4))
  447. boxes[:, :2] = np.random.rand(n, 2) * 100
  448. boxes[:, 2:] = np.random.rand(n, 2) * 100 + 100
  449. scores = np.random.rand(n)
  450. return tensor(boxes), tensor(scores)
  451. @trace(symbolic=False)
  452. def f(boxes, scores):
  453. # with tracing, max_output must be specified
  454. results = F.vision.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20)
  455. # without tracing, max output can be inferred inside nms
  456. with exclude_from_trace():
  457. _ = F.vision.nms(boxes, scores=scores, iou_thresh=0.5)
  458. return results
  459. f(*make_inputs(10))
  460. f(*make_inputs(20))
  461. f(*make_inputs(30))
  462. def test_trace_valid_broadcast():
  463. x1 = tensor(np.random.randn(1, 1))
  464. x2 = tensor(np.random.randn(1, 2))
  465. shape = (tensor([2]), tensor([2]))
  466. @trace(symbolic=False)
  467. def f(x, shape):
  468. y = F.broadcast_to(x, shape)
  469. return y
  470. f(x1, shape)
  471. f(x2, shape)
  472. @pytest.mark.parametrize("trace_mode", [False, True])
  473. def test_clip(trace_mode):
  474. x = tensor(np.random.randn(10, 10))
  475. @trace(symbolic=trace_mode)
  476. def f(x, lower, upper):
  477. y = F.clip(x, lower, upper)
  478. return y
  479. for i in range(3):
  480. f(x, tensor([0]), tensor([1]))
  481. for i in range(3):
  482. f(x, tensor([5]), tensor([4]))
  483. # test returning noncontiguous tensor from trace
  484. def test_slice():
  485. @trace
  486. def f(x):
  487. return x[:, 1::2]
  488. x = F.arange(8).reshape(2, 4)
  489. f(x)
  490. y = f(x)
  491. np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2])
  492. y + y
  493. @pytest.mark.parametrize("shape_mode", [False, True])
  494. def test_random(shape_mode):
  495. def run_test(op):
  496. @trace(symbolic=True, symbolic_shape=shape_mode)
  497. def f():
  498. out = op(size=[10, 10])
  499. out_shape = out.shape
  500. assert out_shape is not None
  501. if not isinstance(out_shape, tuple):
  502. assert out.shape.numpy() is not None
  503. return out
  504. for _ in range(3):
  505. f()
  506. run_test(uniform)
  507. run_test(normal)
  508. @pytest.mark.parametrize("shape_mode", [False, True])
  509. def test_trace_advance_indexing(shape_mode):
  510. funcs = [
  511. lambda x, i: x[i],
  512. lambda x, i, j: x[i, j],
  513. lambda x, i, j: x[i, :, j, ...],
  514. lambda x, start, end: x[start:end],
  515. lambda x, start, end: x[:, 0, start:end, ..., 1],
  516. lambda x, vec: x[vec],
  517. lambda x, vec: x[vec, ..., 0, 1:3],
  518. lambda x, vec: x[vec, vec[0], vec[1]],
  519. # lambda x, i, start, end, vec: x[i, ..., :, vec, start:end], # FIXME
  520. lambda x, mask: x[mask],
  521. ]
  522. inputs = {
  523. "x": np.random.randn(5, 5, 5, 5, 5).astype("float32"),
  524. "i": 4,
  525. "j": 2,
  526. "start": 1,
  527. "end": 3,
  528. "vec": [1, 2, 3],
  529. "mask": np.random.randn(5, 5, 5, 5, 5) >= 0,
  530. }
  531. for f in funcs:
  532. sig = inspect.signature(f)
  533. param_names = list(sig._parameters.keys())
  534. params = {}
  535. params_np = {}
  536. f_traced = trace(f, symbolic=False, symbolic_shape=shape_mode)
  537. for name in param_names:
  538. params[name] = tensor(inputs[name])
  539. params_np[name] = inputs[name]
  540. expected = f(**params_np)
  541. result_imperative = f(**params)
  542. np.testing.assert_equal(expected, result_imperative.numpy())
  543. for _ in range(3):
  544. result_trace = f_traced(**params)
  545. np.testing.assert_equal(expected, result_trace.numpy())
  546. @pytest.mark.require_ngpu(1) # nvrtc backend
  547. def test_trace_jit_config():
  548. def run(fuse_dimshuffle, fuse_reduce):
  549. config = GraphOptimizationConfig()
  550. config.jit_fuse_dimshuffle = fuse_dimshuffle
  551. config.jit_fuse_reduce = fuse_reduce
  552. # set opt_level = 1 to avoid fusing dimshuffle and reduce at the same time
  553. @trace(opt_level=1, graph_opt_config=config)
  554. def func(x):
  555. return x + 1
  556. x = tensor(2)
  557. y = func(x)
  558. y = func(x)
  559. # func._compile()
  560. options = func._trace.options
  561. mapping = {None: 0, False: 1, True: 2}
  562. assert options.graph_opt.jit == 0
  563. assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle]
  564. assert options.graph_opt.jit_config.fuse_reduce == mapping[fuse_reduce]
  565. for fuse_dimshuffle in [None, False, True]:
  566. for fuse_reduce in [None, False, True]:
  567. run(fuse_dimshuffle, fuse_reduce)