You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 61 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import contextlib
  11. import functools
  12. import itertools
  13. import json
  14. import os
  15. import pickle
  16. import re
  17. import struct
  18. from typing import Any
  19. import cv2
  20. import numpy as np
  21. from megengine.logger import get_logger
  22. from .. import tensor
  23. from ..core import _imperative_rt as rt
  24. from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata
  25. from ..core._imperative_rt.core2 import Tensor as RawTensor
  26. from ..core._imperative_rt.core2 import (
  27. TensorWeakRef,
  28. apply,
  29. set_tracing,
  30. skip_tracing,
  31. unset_tracing,
  32. )
  33. from ..core._imperative_rt.ops import (
  34. AssertEqual,
  35. CollectiveComm,
  36. ExternOpr,
  37. RemoteRecv,
  38. RemoteSend,
  39. )
  40. from ..core._trace_option import set_symbolic_shape
  41. from ..core._wrap import as_device
  42. from ..core.ops.builtin import BatchNorm, OpDef
  43. from ..core.tensor import megbrain_graph as G
  44. from ..core.tensor.utils import setscalar
  45. from ..utils import comp_graph_tools as cgtools
  46. from ..utils.naming import AutoNaming
  47. from ..utils.profiler import is_profiling
  48. from .dtr_config import DTRConfig
  49. from .graph_opt_config import GraphOptimizationConfig
  50. from .sublinear_memory_config import SublinearMemoryConfig
  51. logger = get_logger(__name__)
  52. def _input_node_use_static_shape():
  53. return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None
  54. class TraceMismatchError(RuntimeError):
  55. pass
  56. active_trace = None
  57. def is_tracing():
  58. if active_trace is None:
  59. return False
  60. else:
  61. return not skip_tracing
  62. @contextlib.contextmanager
  63. def exclude_from_trace():
  64. global skip_tracing
  65. if skip_tracing or (active_trace is None):
  66. yield
  67. return
  68. try:
  69. skip_tracing = True
  70. unset_tracing()
  71. if active_trace is not None:
  72. active_trace._begin_excluded_region()
  73. yield
  74. finally:
  75. skip_tracing = False
  76. set_tracing()
  77. class TensorInfo:
  78. __slots__ = (
  79. # collected attributes
  80. "name",
  81. "external",
  82. "data_read",
  83. "shape_read",
  84. "value_read",
  85. "exported",
  86. "device",
  87. "dtype",
  88. "shape",
  89. "is_const",
  90. "bound_data",
  91. # resources for execution
  92. "varnode",
  93. "data_setter",
  94. "shape_reader",
  95. "value_reader",
  96. "data_reader",
  97. )
  98. def __init__(self):
  99. self.name = None
  100. self.exported = None
  101. self.data_read = None
  102. self.shape_read = None
  103. self.value_read = None
  104. self.bound_data = None
  105. self.data_setter = None
  106. self.shape_reader = None
  107. self.value_reader = None
  108. self.data_reader = None
  109. _io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv}
  110. class trace:
  111. """Wraps a callable and provide:
  112. * tracing via :meth:`.trace` and :meth:`.dump`
  113. * accelerated evalutaion via :meth:`.__call__`
  114. Args:
  115. function: the function will be traced.
  116. symbolic: whether to apply symbolic execution for tracing. Default: False
  117. capture_as_const: capture global vars or closures as const value. Default: False
  118. record_only: if True, won't run even if call the function. Default: False
  119. sublinear_memory_config: configuration for sublinear memory optimization.
  120. If not None, it enables sublinear memory optimization with given setting.
  121. profiling: whether to profile compiled trace. Default: False
  122. opt_level: optimization level for compiling trace. Default: 2
  123. graph_opt_config: configuration for graph optimization. Default: None
  124. symbolic_shape: whether to use symbolic shape for tracing. Default: True
  125. """
  126. def __new__(cls, *args, **kwargs):
  127. if not args:
  128. return functools.partial(cls, **kwargs)
  129. return super().__new__(cls)
  130. def __init__(
  131. self,
  132. function,
  133. symbolic=False,
  134. capture_as_const=False,
  135. record_only=False,
  136. sublinear_memory_config: SublinearMemoryConfig = None,
  137. dtr_config: DTRConfig = None,
  138. profiling: bool = False,
  139. opt_level: int = 2,
  140. graph_opt_config: GraphOptimizationConfig = None,
  141. symbolic_shape: bool = True,
  142. ):
  143. self.__wrapped__ = function
  144. self._symbolic = symbolic or record_only
  145. self._capture_as_const = capture_as_const or record_only
  146. self._record_only = record_only
  147. self._sublinear_memory_config = sublinear_memory_config
  148. self._dtr_config = dtr_config
  149. self._profiling = profiling
  150. self._profiler = None
  151. self._profiler2 = None
  152. self._graph_opt_level = opt_level
  153. self._graph_opt_config = graph_opt_config
  154. self._symbolic_shape = symbolic_shape
  155. self._output_handles = set()
  156. self._reset()
  157. def _reset(self):
  158. self._untraced = True
  159. self._tinfo = [] # handle -> TensorInfo
  160. self._seq = []
  161. self._pc = 0
  162. self._graph = None
  163. self._need_reset_nodes = None
  164. self._lazy_eval_graph = None
  165. self._lazy_eval_tensors = set()
  166. self._lazy_eval_links = None
  167. self._active_tensors = set()
  168. self._tensor_remaps = None
  169. self._inputs_to_restore = None
  170. self._arg_bindings = None
  171. self._kwarg_bindings = None
  172. self._output_bindings = None
  173. self._output_names = None
  174. def _new_handle(self):
  175. handle = len(self._tinfo)
  176. info = TensorInfo()
  177. self._tinfo.append(info)
  178. return handle, info
  179. def _apply_op(self, op, args):
  180. assert not self._untraced
  181. # check against trace
  182. if self._pc >= len(self._seq):
  183. raise TraceMismatchError("trace should end here, but more op observed")
  184. record = self._seq[self._pc]
  185. op_, ihandles, ohandles = record
  186. if (isinstance(op_, str) and op_ == "Const") or (op != op_):
  187. raise TraceMismatchError("op different from last time")
  188. if len(ihandles) != len(args):
  189. raise TraceMismatchError("op input size different from last time")
  190. # check all inputs of crrent op
  191. for h, x in zip(ihandles, args):
  192. info = self._tinfo[h]
  193. if info.external:
  194. if (
  195. x._compiled_info is not None
  196. and not self._tinfo[x._mixin_handle].exported
  197. ):
  198. raise TraceMismatchError(
  199. "failed to capture: input was an external tensor "
  200. "last time, got an internal tensor this time"
  201. )
  202. if info.bound_data:
  203. if x._compiled_info is not None:
  204. raise TraceMismatchError(
  205. "const capture violated: was an external tensor "
  206. "last time, got an internal tensor this time"
  207. )
  208. if x._handle != info.bound_data._handle:
  209. if not np.array_equal(x.numpy(), info.bound_data.numpy()):
  210. raise TraceMismatchError(
  211. "const capture violated: got "
  212. "a different tensor this time"
  213. )
  214. else:
  215. if info.dtype != x.dtype:
  216. raise TraceMismatchError(
  217. "failed to capture: different dtype from last time"
  218. )
  219. if info.device != x.device:
  220. raise TraceMismatchError(
  221. "failed to capture: different device from last time"
  222. )
  223. info.data_setter.set_value(x._dev_tensor())
  224. else:
  225. if x._mixin_handle == -1:
  226. if x._handle not in self._tensor_remaps:
  227. raise TraceMismatchError(
  228. "unexpected capture: trying to use an external tensor as "
  229. "input, but that input was an internal tensor last time"
  230. )
  231. else:
  232. x._mixin_handle = self._tensor_remaps[
  233. x._handle
  234. ]._CompiledTensorProxy__handle
  235. if x._mixin_handle != h:
  236. raise TraceMismatchError(
  237. "mis-wiring: input edge to an data flow "
  238. "graph node is different from last time"
  239. )
  240. self._pc += 1
  241. outputs = []
  242. for h in ohandles:
  243. info = self._tinfo[h]
  244. # generate output tensor and create compied info
  245. y = RawTensor(info.varnode)
  246. y._compiled_info = CompiledTensorProxy(h)
  247. y._mixin_handle = h
  248. outputs += [y]
  249. self._active_tensors.add(TensorWeakRef(y))
  250. self._output_handles.update(ohandles)
  251. return outputs
  252. def _apply_const(self, value, dtype, device):
  253. assert not self._untraced
  254. # check against trace
  255. if self._pc >= len(self._seq):
  256. raise TraceMismatchError("trace should end here, but more op observed")
  257. record = self._seq[self._pc]
  258. op_, ihandles, ohandles = record
  259. # Const op is represented by a str
  260. assert isinstance(op_, str) and op_ == "Const"
  261. expected = self._tinfo[ohandles[0]].bound_data.numpy()
  262. shape = value.shape
  263. if shape != expected.shape or dtype != expected.dtype:
  264. eq = False
  265. elif shape == ():
  266. eq = expected.item() == value.item()
  267. elif shape == (1,):
  268. eq = expected[0] == value[0]
  269. else:
  270. eq = np.all(value == expected)
  271. if not eq:
  272. raise TraceMismatchError(
  273. "const tensor violated: got a different tensor this time"
  274. )
  275. self._pc += 1
  276. (h,) = ohandles
  277. outputs = [self._tinfo[h].bound_data]
  278. return outputs
  279. # run in first step, record information for trace
  280. def _record_op(self, op, inputs, outputs):
  281. if skip_tracing:
  282. for x in inputs:
  283. h = getattr(x, "_mixin_handle", -1)
  284. if h >= 0:
  285. self._tinfo[h].data = True
  286. return
  287. ihandles = []
  288. for x in inputs:
  289. h = getattr(x, "_mixin_handle", -1)
  290. if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
  291. h, info = self._new_handle()
  292. name = AutoNaming.gen_name(x)
  293. info.name = name
  294. info.external = True
  295. info.device = x.device
  296. info.dtype = x.dtype
  297. info.shape = x.shape
  298. if self._capture_as_const:
  299. info.bound_data = RawTensor(
  300. x.numpy(), x.dtype, x.device, False, name
  301. )
  302. ihandles.append(h)
  303. ohandles = []
  304. for x in outputs:
  305. h, info = self._new_handle()
  306. ohandles.append(h)
  307. info.external = False
  308. x._mixin_handle = h
  309. x._recording = True
  310. x._trace_mixin_info = info
  311. self._active_tensors.add(TensorWeakRef(x))
  312. if self._symbolic:
  313. self._lazy_eval_tensors.add(TensorWeakRef(x))
  314. self._seq.append((op, tuple(ihandles), tuple(ohandles)))
  315. def _record_const(self, outputs):
  316. if skip_tracing:
  317. (x,) = outputs
  318. h = getattr(x, "_mixin_handle", -1)
  319. if h >= 0:
  320. self._tinfo[h].data_read = True
  321. return
  322. (x,) = outputs
  323. h, info = self._new_handle()
  324. ohandles = [h]
  325. info.external = True
  326. info.device = x.device
  327. info.dtype = x.dtype
  328. info.shape = x.shape
  329. info.bound_data = x
  330. info.is_const = True
  331. x._mixin_handle = h
  332. x._recording = True
  333. x._trace_mixin_info = info
  334. if self._symbolic:
  335. self._lazy_eval_tensors.add(TensorWeakRef(x))
  336. self._seq.append(("Const", tuple(), tuple(ohandles)))
  337. def _set_active(self, active: bool):
  338. global active_trace
  339. if active:
  340. if active_trace:
  341. raise NotImplementedError("sorry, not implemented: nested trace")
  342. active_trace = self
  343. else:
  344. assert active_trace is self
  345. active_trace = None
  346. def _init_trace(self, symbolic: bool):
  347. if symbolic:
  348. self._lazy_eval_graph = G.Graph()
  349. self._apply_graph_options(self._lazy_eval_graph)
  350. self._lazy_eval_links = ()
  351. def _take_escaped_tensors(self):
  352. escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors))
  353. self._active_tensors.clear()
  354. return escaped_tensors
  355. def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
  356. lazy_eval_tensors = [x() for x in lazy_eval_tensors]
  357. lazy_eval_tensors = [x for x in lazy_eval_tensors if x is not None]
  358. readers = [G.OutputNode(x._varnode).outputs[0] for x in lazy_eval_tensors]
  359. self._apply_graph_options(lazy_eval_graph)
  360. lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
  361. lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers])
  362. lazy_eval_graph.compile(*lazy_eval_links, *readers)
  363. self._execute_graph(lazy_eval_graph)
  364. lazy_eval_graph.wait()
  365. for r, x in zip(readers, lazy_eval_tensors):
  366. # get values from lazy_eval_graph and assign to lazy_eval tensor
  367. x._handle = RawTensor(r.op.get_value())._handle
  368. x._reset_varnode()
  369. @contextlib.contextmanager
  370. def _setup(self):
  371. interrupted = False
  372. def do_enter():
  373. set_tracing()
  374. self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
  375. self._set_active(True)
  376. if self._untraced:
  377. self._init_trace(self._symbolic)
  378. else:
  379. if self._graph is None:
  380. self._compile()
  381. self._execute_graph(self._graph)
  382. def do_finalize():
  383. escaped_tensors = self._take_escaped_tensors()
  384. if self._untraced:
  385. if self._record_only:
  386. self._lazy_eval_graph = None
  387. self._lazy_eval_tensors = None
  388. self._lazy_eval_links = None
  389. else:
  390. for x in escaped_tensors:
  391. if x():
  392. info = self._tinfo[x()._mixin_handle]
  393. info.data_read = True
  394. x()._mixin_handle = -1
  395. x()._recording = False
  396. if self._inputs_to_restore:
  397. for x in self._inputs_to_restore:
  398. x._mixin_handle = -1
  399. x._recording = False
  400. if self._symbolic and (
  401. self._lazy_eval_tensors or self._lazy_eval_links
  402. ):
  403. # eval lazy eval tensors
  404. self._lazy_eval(
  405. self._lazy_eval_graph,
  406. self._lazy_eval_tensors,
  407. self._lazy_eval_links,
  408. )
  409. self._lazy_eval_graph = None
  410. self._lazy_eval_tensors = None
  411. self._lazy_eval_links = None
  412. self._untraced = False
  413. else:
  414. # compiled_tensor leaks
  415. if self._pc == len(self._seq):
  416. for x in escaped_tensors:
  417. try:
  418. x().__init__(RawTensor(x()._dev_tensor()))
  419. except RuntimeError:
  420. # TraceMismatchError thrown in do_exit
  421. pass
  422. self._graph.wait()
  423. self._reset_exec_env()
  424. # reset status
  425. self._pc = 0
  426. self._tensor_remaps = None
  427. self._set_active(False)
  428. set_symbolic_shape(self._save_symbolic_shape)
  429. unset_tracing()
  430. def do_exit():
  431. unset_tracing()
  432. if not self._untraced and self._pc != len(self._seq):
  433. raise TraceMismatchError("premature end")
  434. if not self._symbolic or not self._untraced:
  435. # reset output tensors
  436. for x in self._active_tensors.copy():
  437. strong_x = x()
  438. if strong_x is not None:
  439. strong_x._dev_tensor()
  440. strong_x._reset_varnode()
  441. strong_x._mixin_handle = -1
  442. strong_x._recording = False
  443. strong_x._trace_mixin_info = None
  444. try:
  445. do_enter()
  446. yield
  447. do_exit()
  448. except:
  449. interrupted = True
  450. raise
  451. finally:
  452. do_finalize()
  453. if interrupted:
  454. self._reset()
  455. def _begin_excluded_region(self):
  456. if self._capture_as_const:
  457. raise RuntimeError(
  458. "exclude_from_trace cannot be used with capture_as_const"
  459. )
  460. if self._untraced:
  461. # conditionally reading a compiled tensor in excluded region
  462. # is permitted, so we have to assume every tensor might be read
  463. for x in self._active_tensors:
  464. strong_x = x()
  465. if strong_x:
  466. info = self._tinfo[strong_x._mixin_handle]
  467. info.exported = True
  468. info.data_read = True
  469. else:
  470. for x in self._active_tensors:
  471. strong_x = x()
  472. if strong_x:
  473. strong_x._dev_tensor()
  474. def _apply_graph_options(self, graph):
  475. graph.options.no_force_inplace = True
  476. graph.options.seq_opt.enable_seq_comp_node_opt = False
  477. graph.options.graph_opt_level = self._graph_opt_level
  478. if self._dtr_config is not None:
  479. graph.options.enable_dtr_memory_opt = True
  480. graph.options.dtr_config.eviction_threshold = (
  481. self._dtr_config.eviction_threshold
  482. )
  483. graph.options.dtr_config.evictee_minimum_size = (
  484. self._dtr_config.evictee_minimum_size
  485. )
  486. graph.options.dtr_config.recomp_memory_factor = (
  487. self._dtr_config.recomp_memory_factor
  488. )
  489. graph.options.dtr_config.recomp_time_factor = (
  490. self._dtr_config.recomp_time_factor
  491. )
  492. # graph optimization
  493. if self._graph_opt_config is not None:
  494. mapping = {None: 0, False: 1, True: 2}
  495. jit_config = graph.options.graph_opt.jit_config
  496. jit_config.fuse_dimshuffle = mapping[
  497. self._graph_opt_config.jit_fuse_dimshuffle
  498. ]
  499. jit_config.fuse_reduce = mapping[self._graph_opt_config.jit_fuse_reduce]
  500. # sublinear
  501. if self._sublinear_memory_config is not None:
  502. graph.options.enable_sublinear_memory_opt = True
  503. sublinear_config = graph.options.sublinear_mem_config
  504. sublinear_config.lb_memory_mb = self._sublinear_memory_config.lb_memory_mb
  505. sublinear_config.genetic_nr_iter = (
  506. self._sublinear_memory_config.genetic_nr_iter
  507. )
  508. sublinear_config.genetic_pool_size = (
  509. self._sublinear_memory_config.genetic_pool_size
  510. )
  511. sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
  512. sublinear_config.num_worker = self._sublinear_memory_config.num_worker
  513. # profile
  514. if self._profiling:
  515. self._profiler = GraphProfiler(graph)
  516. self._profiler2 = None
  517. if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")):
  518. graph.options.var_sanity_check_first_run = False
  519. def _execute_graph(self, graph: G.Graph, *args):
  520. if is_profiling() and (self._profiler2 is None):
  521. self._profiler2 = GraphProfiler2(graph)
  522. elif not is_profiling() and (self._profiler2 is not None):
  523. self._profiler2 = None
  524. graph.execute(*args)
  525. def _compile(self):
  526. graph = self._graph = G.Graph()
  527. graph.options.async_exec_level = 0b100
  528. self._apply_graph_options(graph)
  529. need_reset_nodes = self._need_reset_nodes = []
  530. # links enforce ordering of I/O nodes
  531. in_out_links = ()
  532. io_links = ()
  533. readers = []
  534. if self._capture_as_const:
  535. for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
  536. info = self._tinfo[h]
  537. opnode = info.data_setter = G.InputNode(
  538. device=info.device,
  539. dtype=info.dtype,
  540. shape=info.shape or (1,),
  541. graph=graph,
  542. use_static_shape=_input_node_use_static_shape(),
  543. )
  544. need_reset_nodes.append(opnode)
  545. info.varnode = opnode.outputs[0]
  546. in_out_links += opnode.outputs[1:]
  547. for op, ihandles, ohandles in self._seq:
  548. if isinstance(op, str) and op == "Const":
  549. assert len(ihandles) == 0
  550. (h,) = ohandles
  551. info = self._tinfo[h]
  552. if not hasattr(info, "varnode"):
  553. assert info.external
  554. assert info.bound_data
  555. info.varnode = graph.make_const(
  556. info.bound_data.numpy(),
  557. info.bound_data.dtype,
  558. info.bound_data.device,
  559. )
  560. continue
  561. require_links = type(op) in _io_op_types
  562. ivars = []
  563. for i, h in enumerate(ihandles):
  564. info = self._tinfo[h]
  565. if not hasattr(info, "varnode"):
  566. assert info.external
  567. if info.bound_data:
  568. if getattr(info, "is_const", False):
  569. info.varnode = graph.make_const(
  570. info.bound_data.numpy(),
  571. info.bound_data.dtype,
  572. info.bound_data.device,
  573. )
  574. else:
  575. info.varnode = graph.make_const(
  576. info.bound_data._dev_tensor()
  577. # info.bound_data.numpy()
  578. )
  579. else:
  580. opnode = info.data_setter = G.InputNode(
  581. *in_out_links,
  582. device=info.device,
  583. dtype=info.dtype,
  584. shape=info.shape or (1,),
  585. graph=graph,
  586. use_static_shape=_input_node_use_static_shape(),
  587. )
  588. need_reset_nodes.append(opnode)
  589. info.varnode, *in_out_links = opnode.outputs
  590. if require_links and i == 0 and len(io_links) > 0:
  591. opnode = G.VirtualDepNode(
  592. [info.varnode, *io_links], str(io_links[0].device)
  593. )
  594. info.varnode = opnode.outputs[0]
  595. io_links = (info.varnode,)
  596. ivars.append(info.varnode)
  597. ovars = G.apply_normal_varnode(op, *ivars)
  598. if require_links and len(ovars) > 0:
  599. io_links = (ovars[0],)
  600. assert len(ovars) == len(ohandles)
  601. for h, v in zip(ohandles, ovars):
  602. info = self._tinfo[h]
  603. info.varnode = v
  604. def add_reader(opnode):
  605. nonlocal in_out_links
  606. need_reset_nodes.append(opnode)
  607. readers.append(opnode.outputs[0])
  608. in_out_links = opnode.outputs
  609. if info.data_read:
  610. # Shape can be obtained from data so doesn't need its own
  611. # output node. On the other hand, value is read separately
  612. # to leverage eager h2d copy
  613. info.shape_read = False
  614. opnode = info.data_reader = G.OutputNode(v, *in_out_links)
  615. add_reader(opnode)
  616. if info.value_read:
  617. opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
  618. add_reader(opnode)
  619. if info.shape_read:
  620. opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
  621. add_reader(opnode)
  622. graph.options.graph_opt_level = self._graph_opt_level
  623. graph._set_priority_to_id([*readers, *in_out_links, *io_links])
  624. graph.compile(*readers, *in_out_links, *io_links)
  625. def _reset_exec_env(self):
  626. for opnode in self._need_reset_nodes:
  627. opnode.reset()
  628. def __call__(self, *args, **kwargs):
  629. with self._setup():
  630. if self._capture_as_const:
  631. self._process_inputs(*args, **kwargs)
  632. outputs = self.__wrapped__(*args, **kwargs)
  633. if self._capture_as_const:
  634. self._process_outputs(outputs)
  635. return outputs
  636. def _make_feed(
  637. self,
  638. graph,
  639. outputs,
  640. input_data,
  641. repeat,
  642. silent,
  643. no_assert,
  644. maxerr,
  645. resize_input,
  646. input_transform,
  647. ):
  648. def auto_reformat_image(path, data, dst_shape):
  649. """reformat image to target shape
  650. :param data: image data as numpy array
  651. :param dst_shape: target shape
  652. """
  653. dim3_format = False # required input format does not contain batch
  654. hwc_format = False # required input format is NHWC
  655. if not dst_shape: # input tensor shape is not predefined
  656. if len(data.shape) == 2:
  657. chl = 1
  658. h = data.shape[0]
  659. w = data.shape[1]
  660. else:
  661. assert (
  662. len(data.shape) == 3
  663. ), "Input image must be of dimension 2 or 3"
  664. h, w, chl = data.shape
  665. dst_shape = (1, chl, h, w)
  666. if len(dst_shape) == 3:
  667. dst_shape = (1,) + dst_shape
  668. dim3_format = True
  669. assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape)
  670. chl = dst_shape[1]
  671. if chl in [1, 3]:
  672. n, c, h, w = dst_shape
  673. dst_shape = (n, h, w, c)
  674. else:
  675. chl = dst_shape[3]
  676. assert chl in [
  677. 1,
  678. 3,
  679. ], "can not infer input format from shape: {}".format(dst_shape)
  680. hwc_format = True
  681. # dst_shape has now been normalized to NHWC format
  682. if resize_input:
  683. h, w = dst_shape[1:3]
  684. data = cv2.resize(data, (w, h))
  685. logger.info("input {} resized to {}".format(path, data.shape))
  686. if chl == 1:
  687. data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)
  688. data = data[:, :, np.newaxis]
  689. assert data.ndim == 3
  690. data = data[np.newaxis]
  691. # data normalized to NHWC format
  692. if not hwc_format:
  693. data = np.transpose(data, (0, 3, 1, 2))
  694. if dim3_format:
  695. data = np.squeeze(data, 0)
  696. return data
  697. def read_input_data(dst_shape, dtype, path):
  698. def check_shape_equal(dst_shape, data_shape):
  699. if len(dst_shape):
  700. assert len(data_shape) == len(
  701. dst_shape
  702. ), "input/data shapes mismatch: {} vs {}".format(
  703. dst_shape, data_shape
  704. )
  705. if data_shape[1:] != dst_shape[1:]:
  706. logger.warning(
  707. "dst_shape is {}; data_shape is {}".format(
  708. dst_shape, data_shape
  709. )
  710. )
  711. if path.startswith("#"):
  712. assert not resize_input
  713. assert not input_transform
  714. spec = path
  715. m = re.match(
  716. r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec
  717. )
  718. assert m, "bad spec {}".format(spec)
  719. rng_min = float(m.group(1))
  720. rng_max = float(m.group(2))
  721. if m.group(3):
  722. shape_str = m.group(3)
  723. try:
  724. shape = shape_str[1:].split(",")
  725. if shape[-1].strip() == "...":
  726. shape = shape[:-1]
  727. shape.extend(list(dst_shape[len(shape) :]))
  728. data_shape = tuple(map(int, shape))
  729. except ValueError as e:
  730. raise ValueError("bad spec {}: {}".format(spec, e.args))
  731. else:
  732. data_shape = dst_shape
  733. check_shape_equal(dst_shape, data_shape)
  734. return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype)
  735. # try to load image
  736. data = cv2.imread(path, cv2.IMREAD_COLOR)
  737. if data is None:
  738. assert not resize_input
  739. data = np.load(path)
  740. assert isinstance(data, np.ndarray)
  741. else:
  742. # load image succeeds, so we expect input format is image format
  743. data = auto_reformat_image(path, data, dst_shape)
  744. data = np.repeat(data, repeat, axis=0)
  745. if repeat > 1:
  746. logger.info(
  747. "repeat input for {} times, data shape is {}".format(
  748. repeat, data.shape
  749. )
  750. )
  751. check_shape_equal(dst_shape, data.shape)
  752. if input_transform:
  753. data = eval(input_transform, {"data": data, "np": np})
  754. return data
  755. def gen_one_testcase(inputs, spec):
  756. paths = spec.split(";")
  757. if len(paths) != len(inputs):
  758. if len(paths) == 1 and paths[0].startswith("#"):
  759. paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()]
  760. assert len(paths) == len(
  761. inputs
  762. ), "required inputs: {}; data paths: {}".format(inputs.keys(), paths)
  763. if len(paths) == 1 and ":" not in paths[0]:
  764. paths[0] = next(iter(inputs.keys())) + ":" + paths[0]
  765. ret = {}
  766. for path in paths:
  767. var, path = path.split(":")
  768. ret[var] = read_input_data(inputs[var].shape, inputs[var].dtype, path)
  769. return ret
  770. inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy")
  771. inputs = {i.name: i for i in inputs}
  772. if not no_assert:
  773. replace_varmap = {}
  774. inp_map = {}
  775. # replace var use InputNode
  776. for name, var in inputs.items():
  777. inp = G.InputNode(
  778. device="xpux", dtype=var.dtype, shape=var.shape, graph=graph
  779. )
  780. replace_varmap[var] = inp.outputs[0]._node
  781. inp_map[name] = inp
  782. new = cgtools.replace_vars(outputs, replace_varmap)
  783. if isinstance(new, rt.VarNode):
  784. new = list(new)
  785. output_nodes = [G.OutputNode(var) for var in new]
  786. func = graph.compile(*[node.outputs[0]._node for node in output_nodes])
  787. def make_dev_tensor(value, dtype=None, device=None):
  788. return tensor(value, dtype=dtype, device=device)._dev_tensor()
  789. def calculate(*args, **kwargs):
  790. output_val = []
  791. # set inputs value
  792. for name, var in inputs.items():
  793. val = kwargs.pop(name, None)
  794. assert val is not None, "miss input name{}".format(name)
  795. dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux")
  796. inp_map[name].set_value(dev_tensor)
  797. func.execute()
  798. for res in output_nodes:
  799. output_val.append(res.get_value().numpy())
  800. return output_val
  801. def expect_name(var):
  802. return "{}:expect".format(var.name)
  803. testcases = []
  804. np.set_printoptions(precision=2, threshold=4, suppress=True)
  805. data_list = []
  806. for item in input_data:
  807. if item.startswith("@"):
  808. with open(item[1:], "r") as f:
  809. data_list.extend(
  810. [line.rstrip() for line in f if line.rstrip() != ""]
  811. )
  812. else:
  813. data_list.append(item)
  814. for inp_spec in data_list:
  815. cur_testcase = gen_one_testcase(inputs, inp_spec)
  816. assert len(cur_testcase) == len(
  817. inputs
  818. ), "required inputs: {}; given data: {}".format(
  819. inputs.keys(), cur_testcase.keys()
  820. )
  821. if not no_assert:
  822. outputs_get = calculate(**cur_testcase)
  823. for var, val in zip(outputs, outputs_get):
  824. cur_testcase[expect_name(var)] = val
  825. logger.info(
  826. "generate test groundtruth: var={} shape={} range=({}, {})"
  827. " mean={} var={}".format(
  828. var,
  829. val.shape,
  830. val.min(),
  831. val.max(),
  832. np.mean(val),
  833. np.var(val),
  834. )
  835. )
  836. testcases.append(cur_testcase)
  837. logger.info(
  838. "add testcase: \n {}".format(
  839. "\n ".join(
  840. "{}: shape={} dtype={} range=({:.2f},{:.2f}) "
  841. "mean={:.2f} sd={:.2f}".format(
  842. k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v)
  843. )
  844. for k, v in sorted(cur_testcase.items())
  845. )
  846. )
  847. )
  848. if not no_assert:
  849. def expect_shp(var):
  850. ret = var.shape
  851. if ret:
  852. return ret
  853. return testcases[0][expect_name(var)].shape
  854. def assert_equal(expect, real, **kwargs):
  855. op = AssertEqual(**kwargs)
  856. (res,) = G.apply_normal_varnode(op, expect, real)
  857. return res._node
  858. verbose = not silent
  859. outputs_new = []
  860. for i in outputs:
  861. device = rt.CompNode("xpux")
  862. dtype = i.dtype
  863. name = expect_name(i)
  864. shape = expect_shp(i)
  865. # make expect output as one input of model.
  866. expect_get = rt.make_h2d(graph, device, dtype, shape, name)
  867. # insert assert opr to check expect and real.
  868. outputs_new.append(
  869. assert_equal(expect_get, i, verbose=verbose, maxerr=maxerr,)
  870. )
  871. inputs[expect_name(i)] = expect_get
  872. outputs = outputs_new
  873. return {"outputs": outputs, "testcases": testcases}
  874. def dump(
  875. self,
  876. file,
  877. *,
  878. arg_names=None,
  879. output_names=None,
  880. append=False,
  881. keep_var_name: int = 1,
  882. keep_opr_name: bool = False,
  883. keep_param_name: bool = False,
  884. keep_opr_priority: bool = False,
  885. strip_info_file=None,
  886. append_json=False,
  887. optimize_for_inference=True,
  888. user_info: Any = None,
  889. enable_metadata: bool = True,
  890. input_data=None,
  891. repeat=1,
  892. silent=False,
  893. no_assert=False,
  894. maxerr=1e-4,
  895. resize_input=False,
  896. input_transform=None,
  897. dump_format: str = None,
  898. **kwargs
  899. ):
  900. r"""Serializes trace to file system.
  901. Args:
  902. file: output file, could be file object or filename.
  903. arg_names: names of the input tensors in the traced function.
  904. output_names: names of the output tensors in the traced function,
  905. use the default name if not specified.
  906. append: whether output is appended to ``file``.
  907. Only works when ``file`` is str.
  908. keep_var_name: level for keeping variable names:
  909. * 0: none of the names are kept
  910. * 1: (default)keep names of output vars
  911. * 2: keep names of all (output and internal) vars
  912. keep_opr_name: whether to keep operator names.
  913. keep_param_name: whether to keep param names, so param values can be
  914. easily manipulated after loading model
  915. keep_opr_priority: whether to keep priority setting for operators
  916. strip_info_file: a string for path or a file handler. if is not None,
  917. then the dump information for code strip would be written to ``strip_info_file``
  918. append_json: will be check when `strip_info_file` is not None. if set
  919. true, the information for code strip will be append to strip_info_file.
  920. if set false, will rewrite strip_info_file
  921. optimize_for_inference: enbale optmizations,
  922. will skip all optimize options if this is False. Default: True
  923. user_info: any type object, which will be pickled to bytes.
  924. enable_metadata: whether to save metadata into output file.
  925. input_data: input test data and current network output would be used as groundtruth.
  926. The format is "var0:file0;var1:file1..." to specify data files for input vars.
  927. It can also be "#rand(min,max,shape...)" for generating random input data, for
  928. example, "#rand(0,255)", "#rand(0,255,1,3,224,224)" or "#rand(0, 255, 1, ...)"
  929. where `...` means the remaining part of the original shape. If the shape is not
  930. specified, the shape of corresponding input tensors in the network will be used.
  931. If there is only one input var, its name can be omitted. Each data file can either
  932. be an image which can be loaded by opencv, or a pickled numpy.ndarray. This option
  933. can be given multiple times to add multiple testcases. If you start the data
  934. with the letter @, the rest should be a filename, and each line in the file should
  935. be a single datum in the format described above. *NOTE* If `input_data` is not None,
  936. you can only use load-and-run to run the output file.
  937. repeat: how many times the input image is repeated. Useful when running benchmark for
  938. batch size other than one. Have no effect on randomly generated input data.
  939. silent: whether set verbose to False in assert_equal opr.
  940. no_assert: whether insert assert_equal opr to check result; this option is useful for
  941. benchmarking.
  942. maxerr: max error for assert_equal check during runtime.
  943. resize_input: whether resize input image to fit input var shape.
  944. input_transform: a python expression to transform the input data.
  945. Example: data / np.std(data)
  946. dump_format: using different dump formats.
  947. Keyword Arguments:
  948. * enable_io16xc32 --
  949. whether to use float16 for I/O between oprs and use
  950. float32 as internal computation precision. Note the output var would be
  951. changed to float16.
  952. * enable_ioc16 --
  953. whether to use float16 for both I/O and computation
  954. precision.
  955. * enable_hwcd4 --
  956. whether to use NHWCD4 data layout. This is faster on some
  957. OpenCL backend.
  958. * enable_nchw88 --
  959. whether to use NCHW88 data layout, currently
  960. used in X86 AVX backend.
  961. * enable_nchw44 --
  962. whether to use NCHW44 data layout, currently
  963. used in arm backend.
  964. * enable_nchw44_dot --
  965. whether to use NCHW44_dot data layout, currently
  966. used in armv8.2+dotprod backend.
  967. * enable_nchw4 --
  968. whether to use NCHW4 data layout, currently
  969. used in nvidia backend(based on cudnn).
  970. * enable_nchw32 --
  971. whether to use NCHW32 data layout, currently
  972. used in nvidia backend with tensorcore(based on cudnn).
  973. * enable_chwn4 --
  974. whether to use CHWN4 data layout, currently
  975. used in nvidia backend with tensorcore.
  976. * enable_nchw64 --
  977. whether to use NCHW64 data layout, used for fast int4
  978. support on Nvidia GPU.
  979. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  980. into one opr.
  981. * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
  982. input for inference on nvidia backend(this optimization pass will
  983. result in mismatch of the precision of output of training and
  984. inference)
  985. * enable_fuse_preprocess: whether to fuse astype\pad_channel\dimshuffle and
  986. etc opr
  987. """
  988. if not self._capture_as_const:
  989. raise ValueError(
  990. "you must specify capture_as_const=True at __init__ to use dump"
  991. )
  992. if self._untraced and len(self._seq) == 0:
  993. raise RuntimeError("should do record first before dump")
  994. if self._output_names and output_names:
  995. raise TypeError(
  996. "cannot specify output_names when output is already in dict format"
  997. )
  998. if output_names and not isinstance(output_names, collections.abc.Sequence):
  999. output_names = (output_names,)
  1000. if output_names and len(output_names) != len(self._output_bindings):
  1001. raise ValueError(
  1002. "wrong number of output_names, should be {} values".format(
  1003. len(self._output_bindings)
  1004. )
  1005. )
  1006. without_arg_names = arg_names is None
  1007. if without_arg_names:
  1008. arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
  1009. if arg_names and not isinstance(arg_names, collections.abc.Sequence):
  1010. arg_names = (arg_names,)
  1011. if arg_names and len(arg_names) != len(self._arg_bindings):
  1012. raise ValueError(
  1013. "wrong number of arg_names, should be {} values".format(
  1014. len(self._arg_bindings)
  1015. )
  1016. )
  1017. output_names = output_names or self._output_names
  1018. def dumped_device(info):
  1019. device_name = info.device.logical_name
  1020. if device_name[:3] in ("cpu", "gpu", "xpu"):
  1021. return as_device("xpux")
  1022. return info.device
  1023. h2v = {}
  1024. graph = G.Graph()
  1025. # apply graph_opt_level in dump
  1026. if self._graph_opt_level is not None:
  1027. graph.options.graph_opt_level = self._graph_opt_level
  1028. for i, h in enumerate(self._arg_bindings):
  1029. info = self._tinfo[h]
  1030. h2v[h] = graph.make_h2d(
  1031. dtype=info.dtype,
  1032. device=dumped_device(info),
  1033. shape=info.shape or (1,),
  1034. name=info.name if without_arg_names and info.name else arg_names[i],
  1035. )
  1036. for k, h in self._kwarg_bindings.items():
  1037. info = self._tinfo[h]
  1038. h2v[h] = graph.make_h2d(
  1039. dtype=info.dtype,
  1040. device=dumped_device(info),
  1041. shape=info.shape or (1,),
  1042. name=k,
  1043. )
  1044. for op, ihandles, ohandles in self._seq:
  1045. if isinstance(op, str) and op == "Const":
  1046. assert len(ihandles) == 0
  1047. (h,) = ohandles
  1048. info = self._tinfo[h]
  1049. if h not in h2v:
  1050. assert info.external
  1051. assert info.bound_data
  1052. h2v[h] = graph.make_const(
  1053. info.bound_data.numpy(),
  1054. dtype=info.dtype,
  1055. device=dumped_device(info),
  1056. name=info.name,
  1057. )
  1058. continue
  1059. ivars = []
  1060. for h in ihandles:
  1061. info = self._tinfo[h]
  1062. if h not in h2v:
  1063. assert info.external
  1064. assert info.bound_data
  1065. h2v[h] = graph.make_const(
  1066. info.bound_data.numpy(),
  1067. dtype=info.dtype,
  1068. device=dumped_device(info),
  1069. name=info.name,
  1070. )
  1071. ivars.append(h2v[h])
  1072. if isinstance(op, BatchNorm):
  1073. assert (
  1074. op.fwd_mode == BatchNorm.FwdMode.INFERENCE
  1075. ), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?"
  1076. ovars = G.apply_normal_varnode(op, *ivars)
  1077. AutoNaming.record_opnode(ovars[0].op)
  1078. assert len(ovars) == len(ohandles)
  1079. h2v.update(zip(ohandles, ovars))
  1080. for i in ohandles:
  1081. name = AutoNaming.get_var_name(i)
  1082. if name is not None:
  1083. h2v[i].name = name
  1084. AutoNaming.remove_duplicate_names()
  1085. dest_vars = []
  1086. for i, h in enumerate(self._output_bindings):
  1087. v = h2v[h]
  1088. if output_names:
  1089. v.name = output_names[i]
  1090. dest_vars.append(v)
  1091. dest_vars = [i._node for i in dest_vars]
  1092. if input_data is not None:
  1093. feeds = self._make_feed(
  1094. graph,
  1095. dest_vars,
  1096. input_data,
  1097. repeat,
  1098. silent,
  1099. no_assert,
  1100. maxerr,
  1101. resize_input,
  1102. input_transform,
  1103. )
  1104. assert (
  1105. isinstance(feeds, dict) and feeds["testcases"]
  1106. ), "testcases can not be empty"
  1107. dest_vars = feeds["outputs"]
  1108. if optimize_for_inference:
  1109. dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs)
  1110. dest_vars = [i._node for i in dest_vars]
  1111. metadata = SerializationMetadata()
  1112. if enable_metadata:
  1113. metadata.user_info = pickle.dumps(user_info)
  1114. metadata.is_valid = True
  1115. metadata.graph_modified = False
  1116. if optimize_for_inference:
  1117. metadata.optimize_options = optimize_options
  1118. if isinstance(file, str):
  1119. permission = "wb" if append == False else "ab"
  1120. file = open(file, permission)
  1121. if keep_opr_priority:
  1122. graph._set_priority_to_id(dest_vars)
  1123. if input_data is not None:
  1124. file.write(b"mgbtest0")
  1125. file.write(struct.pack("I", len(feeds["testcases"])))
  1126. dump_content, dump_info = G.dump_graph(
  1127. dest_vars,
  1128. keep_var_name=keep_var_name,
  1129. keep_opr_name=keep_opr_name,
  1130. keep_param_name=keep_param_name,
  1131. keep_opr_priority=keep_opr_priority,
  1132. strip_info_file=strip_info_file,
  1133. append_json=append_json,
  1134. metadata=metadata,
  1135. dump_format=dump_format,
  1136. )
  1137. file.write(dump_content)
  1138. if input_data is not None:
  1139. inputs = cgtools.get_dep_vars(dest_vars, "Host2DeviceCopy")
  1140. inputs = sorted((i.name, i.dtype) for i in inputs)
  1141. def make_dev_tensor(value, dtype=None, device=None):
  1142. return tensor(value, dtype=dtype, device=device)._dev_tensor()
  1143. for testcase in feeds["testcases"]:
  1144. assert isinstance(testcase, dict)
  1145. cg = G.Graph()
  1146. output_mgbvars = []
  1147. for name, dtype in inputs:
  1148. output_mgbvars.append(
  1149. cg.make_const(
  1150. make_dev_tensor(
  1151. testcase.pop(name), dtype=dtype, device="cpux"
  1152. )
  1153. )
  1154. )
  1155. assert not testcase, "extra inputs provided in testcase: {}".format(
  1156. testcase.keys()
  1157. )
  1158. dump_content, _ = G.dump_graph(
  1159. output_mgbvars, strip_info_file=strip_info_file, append_json=True,
  1160. )
  1161. file.write(dump_content)
  1162. return dump_info
  1163. def _process_inputs(self, *args, **kwargs):
  1164. if self._untraced:
  1165. self._inputs_to_restore = []
  1166. def record_input(x):
  1167. if x is None:
  1168. return
  1169. h, info = self._new_handle()
  1170. info.external = False
  1171. info.name = x.c_name
  1172. info.device = x.device
  1173. info.dtype = x.dtype
  1174. info.shape = x.numpy().shape
  1175. x._mixin_handle = h
  1176. x._recording = True
  1177. x._trace_mixin_info = info
  1178. self._inputs_to_restore.append(x)
  1179. return h
  1180. self._arg_bindings = []
  1181. for i, x in enumerate(args):
  1182. if not isinstance(x, RawTensor):
  1183. raise TypeError(
  1184. "positional arguments should all be tensor "
  1185. "but args[%d] cannot be recognized as one" % i
  1186. )
  1187. self._arg_bindings.append(record_input(x))
  1188. self._kwarg_bindings = {}
  1189. for k, x in kwargs.items():
  1190. if isinstance(x, RawTensor):
  1191. self._kwarg_bindings[k] = record_input(x)
  1192. else:
  1193. if len(args) != len(self._arg_bindings):
  1194. raise TraceMismatchError("positional argument length mismatch")
  1195. self._tensor_remaps = {}
  1196. for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
  1197. if not isinstance(x, RawTensor):
  1198. raise TypeError(
  1199. "positional arguments should all be tensor "
  1200. "but args[%d] cannot be recognized as one" % i
  1201. )
  1202. info = self._tinfo[h]
  1203. if x.dtype != info.dtype:
  1204. raise TypeError("args[%d].dtype different from last time" % i)
  1205. if x.device != info.device:
  1206. raise TypeError("args[%d].device different from last time" % i)
  1207. info.data_setter.set_value(x._dev_tensor())
  1208. self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
  1209. kwargs_tensors = {}
  1210. for k, x in kwargs.items():
  1211. if isinstance(x, RawTensor):
  1212. kwargs_tensors[k] = x
  1213. if set(kwargs_tensors) != set(self._kwarg_bindings):
  1214. too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
  1215. too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
  1216. if too_many:
  1217. raise TraceMismatchError(
  1218. "keyword arguments found to be tensor this time "
  1219. "but were non-tensor previously: %s" % " ".join(too_many)
  1220. )
  1221. if too_few:
  1222. raise TraceMismatchError(
  1223. "keyword arguments found to be non-tensor this time "
  1224. "but were tensor previously: %s" % " ".join(too_few)
  1225. )
  1226. for k, h in self._kwarg_bindings.items():
  1227. x = kwargs_tensors[k]
  1228. info = self._tinfo[h]
  1229. if x.dtype != info.dtype:
  1230. raise TypeError("kwargs[%s].dtype different from last time" % k)
  1231. if x.device != info.device:
  1232. raise TypeError("kwargs[%s].device different from last time" % k)
  1233. info.data_setter.set_value(x._dev_tensor())
  1234. self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
  1235. def _process_outputs(self, outputs):
  1236. output_names = None
  1237. if isinstance(outputs, collections.abc.Mapping):
  1238. output_names, outputs = zip(*sorted(outputs.items()))
  1239. elif not isinstance(outputs, collections.abc.Sequence):
  1240. outputs = (outputs,)
  1241. if not self._untraced:
  1242. if output_names != self._output_names:
  1243. too_many = set(output_names) - set(self._output_names)
  1244. too_few = set(self._output_names) - set(output_names)
  1245. if too_many:
  1246. raise TraceMismatchError(
  1247. "output has more keys than last time: %s" % " ".join(too_many)
  1248. )
  1249. if too_few:
  1250. raise TraceMismatchError(
  1251. "output has less keys than last time: %s" % " ".join(too_few)
  1252. )
  1253. if len(outputs) != len(self._output_bindings):
  1254. raise TraceMismatchError("output size differs from last time")
  1255. else:
  1256. self._output_names = output_names
  1257. self._output_bindings = []
  1258. for i, x in enumerate(outputs):
  1259. if not isinstance(x, RawTensor):
  1260. raise TypeError("every item of return value should be tensor")
  1261. if self._untraced:
  1262. h = x._mixin_handle
  1263. if h < 0:
  1264. raise RuntimeError("output is not computed from inputs")
  1265. self._output_bindings.append(h)
  1266. else:
  1267. h = x._mixin_handle
  1268. if h not in self._output_handles:
  1269. raise RuntimeError("output is not computed from inputs")
  1270. if h != self._output_bindings[i]:
  1271. raise TraceMismatchError(
  1272. "retval[%s] is a different tensor than last time"
  1273. % (output_names and output_names[i] or i)
  1274. )
  1275. def get_profile(self):
  1276. r"""Get profiling result for compiled trace.
  1277. Return:
  1278. a json compatible object.
  1279. """
  1280. if not self._profiler:
  1281. raise RuntimeError("trace is not set with profiling=True")
  1282. return json.loads(self._profiler.get())
  1283. class CompiledTensorProxy:
  1284. r"""Duck-typed RawTensor"""
  1285. def __init__(self, handle):
  1286. self.__handle = handle
  1287. self._isscalar = False
  1288. self.__info = active_trace._tinfo[handle]
  1289. self.__shape = None
  1290. self.__data = None
  1291. self.__value = None
  1292. @property
  1293. def dtype(self):
  1294. return self.__info.varnode.dtype
  1295. @property
  1296. def device(self):
  1297. return self.__info.varnode.device
  1298. @property
  1299. def shape(self):
  1300. if self._isscalar:
  1301. return ()
  1302. if self.__shape is None:
  1303. if self.__info.shape_read:
  1304. self.__shape = self.__info.shape_reader.get_value().shape
  1305. elif self.__info.data_read:
  1306. self.__shape = self._dev_tensor().shape
  1307. else:
  1308. # c++ will throw TraceReadError
  1309. return None
  1310. return self.__shape
  1311. def numpy(self):
  1312. if self.__value is None:
  1313. if self.__info.value_read:
  1314. self.__value = self.__info.value_reader.get_value()
  1315. elif self.__info.data_read:
  1316. self.__value = self._dev_tensor().numpy()
  1317. else:
  1318. # c++ will throw TraceReadError
  1319. return None
  1320. # c++ side will handle scalar case
  1321. return self.__value
  1322. def _dev_tensor(self):
  1323. if self.__data is None:
  1324. if not self.__info.data_read:
  1325. # c++ will throw TraceReadError
  1326. return None
  1327. self.__data = self.__info.data_reader.get_value()
  1328. return self.__data
  1329. def __del__(self):
  1330. if self.__info.shape_read and self.__shape is not None:
  1331. self.__info.shape_reader.drop_value()
  1332. if self.__info.value_read and self.__value is not None:
  1333. self.__info.value_reader.drop_value()
  1334. if self.__info.data_read and self.__data is not None:
  1335. self.__info.data_reader.drop_value()
  1336. def apply_symbolic_mode(op: OpDef, *args: RawTensor):
  1337. graph = active_trace._lazy_eval_graph
  1338. ivars = []
  1339. for x in args:
  1340. var = getattr(x, "_varnode", None)
  1341. if var:
  1342. ivars.append(var)
  1343. else:
  1344. data_setter = G.InputNode(
  1345. device=x.device,
  1346. dtype=x.dtype,
  1347. shape=x.numpy().shape or (1,),
  1348. graph=graph,
  1349. use_static_shape=True,
  1350. )
  1351. var = data_setter.outputs[0]
  1352. ivars.append(var)
  1353. data_setter.set_value(x._dev_tensor())
  1354. require_links = type(op) in _io_op_types
  1355. if require_links and active_trace._lazy_eval_links:
  1356. assert len(ivars) > 0, "op should has at least one input"
  1357. opnode = G.VirtualDepNode(
  1358. [ivars[0], *active_trace._lazy_eval_links],
  1359. str(active_trace._lazy_eval_links[0].device),
  1360. )
  1361. ivars[0] = opnode.outputs[0]
  1362. active_trace._lazy_eval_links = (ivars[0],)
  1363. ovars = G.apply_normal_varnode(op, *ivars)
  1364. outputs = [RawTensor(o) for o in ovars]
  1365. if require_links:
  1366. active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),)
  1367. return outputs
  1368. def apply_const_symbolic_mode(value, dtype, device, name):
  1369. graph = active_trace._lazy_eval_graph
  1370. # don't need to unset tracing
  1371. # because varnode construction will ignore tracing flag
  1372. ret = RawTensor(graph.make_const(value, dtype=dtype, device=device, name=name))
  1373. if np.array(value).ndim == 0:
  1374. setscalar(ret)
  1375. return (ret,)
  1376. def apply_compiled_mode(op: OpDef, *args: RawTensor):
  1377. if skip_tracing:
  1378. args = [
  1379. RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  1380. for x in args
  1381. ]
  1382. unset_tracing()
  1383. ret = apply(op, *args)
  1384. set_tracing()
  1385. return ret
  1386. return active_trace._apply_op(op, args)
  1387. def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name):
  1388. if skip_tracing:
  1389. unset_tracing()
  1390. ret = RawTensor(value, dtype, device, False, name)
  1391. set_tracing()
  1392. return ret
  1393. return active_trace._apply_const(value, dtype, device)
  1394. def apply_with_tracing(op: OpDef, *args: RawTensor):
  1395. if active_trace._graph:
  1396. # if member _graph exits, then is_compiled
  1397. return apply_compiled_mode(op, *args)
  1398. if hasattr(op, "scope"):
  1399. op.scope = AutoNaming.get_scope()
  1400. if active_trace._symbolic:
  1401. outputs = apply_symbolic_mode(op, *args)
  1402. else:
  1403. unset_tracing()
  1404. outputs = apply(op, *args)
  1405. set_tracing()
  1406. active_trace._record_op(op, args, outputs)
  1407. return list(outputs)
  1408. def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name):
  1409. if active_trace._graph:
  1410. return apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name)
  1411. if active_trace._symbolic:
  1412. outputs = apply_const_symbolic_mode(value, dtype, device, name)
  1413. else:
  1414. unset_tracing()
  1415. outputs = RawTensor(value, dtype, device, False, name)
  1416. if np.array(value).ndim == 0:
  1417. setscalar(outputs)
  1418. outputs = (outputs,)
  1419. set_tracing()
  1420. active_trace._record_const(outputs)
  1421. return list(outputs)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台