You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817
  1. # -*- coding: utf-8 -*-
  2. import collections
  3. import contextlib
  4. import functools
  5. import itertools
  6. import json
  7. import os
  8. import pickle
  9. import re
  10. import struct
  11. import sys
  12. from typing import Any
  13. import cv2
  14. import numpy as np
  15. from .. import tensor
  16. from ..core import _imperative_rt as rt
  17. from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata
  18. from ..core._imperative_rt.core2 import Tensor as RawTensor
  19. from ..core._imperative_rt.core2 import Trace, TraceError, name_tensor # skip_tracing,
  20. from ..core._imperative_rt.graph import _set_priority_to_id
  21. from ..core._imperative_rt.ops import (
  22. AssertEqual,
  23. CollectiveComm,
  24. ExternOpr,
  25. RemoteRecv,
  26. RemoteSend,
  27. set_jit_enabled,
  28. )
  29. from ..core._trace_option import set_symbolic_shape
  30. from ..core.tensor import megbrain_graph as G
  31. from ..logger import get_logger
  32. from ..utils import comp_graph_tools as cgtools
  33. from ..utils.naming import AutoNaming
  34. from ..utils.profiler import is_profiling
  35. from .dtr_config import DTRConfig
  36. from .graph_opt_config import GraphOptimizationConfig
  37. from .sublinear_memory_config import SublinearMemoryConfig
  38. logger = get_logger(__name__)
  39. def _input_node_use_static_shape():
  40. return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None
  41. active_trace = None
  42. skip_tracing = False
  43. def is_tracing():
  44. if active_trace is None:
  45. return False
  46. else:
  47. return not skip_tracing
  48. @contextlib.contextmanager
  49. def exclude_from_trace():
  50. global skip_tracing
  51. if skip_tracing or (active_trace is None):
  52. yield
  53. return
  54. try:
  55. skip_tracing = True
  56. if active_trace is not None:
  57. active_trace._begin_excluded_region()
  58. yield
  59. if active_trace is not None:
  60. active_trace._end_excluded_region()
  61. finally:
  62. skip_tracing = False
  63. def array_comparator(lhs, rhs):
  64. return np.all(lhs == rhs)
  65. class trace:
  66. """Wraps a callable and provide:
  67. * tracing via :meth:`.trace` and :meth:`.dump`
  68. * accelerated evalutaion via :meth:`.__call__`
  69. Args:
  70. function: the function will be traced.
  71. symbolic: whether to apply symbolic execution for tracing. Default: False
  72. capture_as_const: capture global vars or closures as const value. Default: False
  73. record_only: if True, won't run even if call the function. Default: False
  74. sublinear_memory_config: configuration for sublinear memory optimization.
  75. If not None, it enables sublinear memory optimization with given setting.
  76. profiling: whether to profile compiled trace. Default: False
  77. opt_level: optimization level for compiling trace. Default: 2
  78. graph_opt_config: configuration for graph optimization. Default: None
  79. symbolic_shape: whether to use symbolic shape for tracing. Default: True
  80. """
  81. def __new__(cls, *args, **kwargs):
  82. if not args:
  83. return functools.partial(cls, **kwargs)
  84. return super().__new__(cls)
  85. def __init__(
  86. self,
  87. function,
  88. symbolic=False,
  89. capture_as_const=False,
  90. record_only=False,
  91. sublinear_memory_config: SublinearMemoryConfig = None,
  92. dtr_config: DTRConfig = None,
  93. profiling: bool = False,
  94. opt_level: int = 2,
  95. graph_opt_config: GraphOptimizationConfig = None,
  96. symbolic_shape: bool = True,
  97. ):
  98. self.__wrapped__ = function
  99. self._capture_as_const = capture_as_const or record_only
  100. self._arg_bindings = None
  101. self._kwarg_bindings = None
  102. self._output_bindings = None
  103. self._symbolic_shape = symbolic_shape
  104. self._graph_options = {
  105. "no_force_inplace": True,
  106. "graph_opt_level": opt_level,
  107. "seq_opt.enable_seq_comp_node_opt": False,
  108. }
  109. # prevent cyclic reference
  110. graph_options = self._graph_options
  111. if dtr_config is not None:
  112. graph_options["enable_dtr_memory_opt"] = True
  113. graph_options[
  114. "dtr_config.eviction_threshold"
  115. ] = dtr_config.eviction_threshold
  116. graph_options[
  117. "dtr_config.evictee_minimum_size"
  118. ] = dtr_config.evictee_minimum_size
  119. graph_options[
  120. "dtr_config.recomp_memory_factor"
  121. ] = dtr_config.recomp_memory_factor
  122. graph_options[
  123. "dtr_config.recomp_time_factor"
  124. ] = dtr_config.recomp_time_factor
  125. if graph_opt_config is not None:
  126. mapping = {None: 0, False: 1, True: 2}
  127. graph_options["graph_opt.jit_config.fuse_dimshuffle"] = mapping[
  128. graph_opt_config.jit_fuse_dimshuffle
  129. ]
  130. graph_options["graph_opt.jit_config.fuse_reduce"] = mapping[
  131. graph_opt_config.jit_fuse_reduce
  132. ]
  133. if sublinear_memory_config is not None:
  134. graph_options["enable_sublinear_memory_opt"] = True
  135. graph_options[
  136. "sublinear_mem_config.lb_memory_mb"
  137. ] = sublinear_memory_config.lb_memory_mb
  138. graph_options[
  139. "sublinear_mem_config.genetic_nr_iter"
  140. ] = sublinear_memory_config.genetic_nr_iter
  141. graph_options[
  142. "sublinear_mem_config.genetic_pool_size"
  143. ] = sublinear_memory_config.genetic_pool_size
  144. graph_options[
  145. "sublinear_mem_config.thresh_nr_try"
  146. ] = sublinear_memory_config.thresh_nr_try
  147. graph_options[
  148. "sublinear_mem_config.num_worker"
  149. ] = sublinear_memory_config.num_worker
  150. if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")):
  151. graph_options["var_sanity_check_first_run"] = False
  152. def apply_options(options):
  153. for k, v in graph_options.items():
  154. words = k.split(".")
  155. suboptions = options
  156. for word in words[:-1]:
  157. suboptions = getattr(suboptions, word)
  158. setattr(suboptions, words[-1], v)
  159. self._trace = Trace()
  160. self._trace.symbolic = symbolic or record_only
  161. self._trace.capture_as_const = capture_as_const or record_only
  162. self._trace.no_exec = record_only
  163. self._trace.options_visitor = apply_options
  164. self._trace.profile = profiling
  165. self._trace.array_comparator = array_comparator
  166. self._trace.record_input_shapes = _input_node_use_static_shape()
  167. def __call__(self, *args, **kwargs):
  168. global active_trace
  169. symbolic_shape = None
  170. outputs = None
  171. try:
  172. active_trace = self
  173. self._trace.enter()
  174. if self._capture_as_const:
  175. self._process_inputs(*args, **kwargs)
  176. symbolic_shape = set_symbolic_shape(self._symbolic_shape)
  177. outputs = self.__wrapped__(*args, **kwargs)
  178. finally:
  179. handling_exc = sys.exc_info() != (None,) * 3
  180. active_trace = None
  181. if symbolic_shape is not None:
  182. symbolic_shape = set_symbolic_shape(symbolic_shape)
  183. assert symbolic_shape == self._symbolic_shape
  184. if self._capture_as_const and (outputs is not None):
  185. self._process_outputs(outputs)
  186. try:
  187. # may raise TraceError
  188. self._trace.exit()
  189. except TraceError:
  190. if not handling_exc:
  191. raise
  192. return outputs
  193. def _process_inputs(self, *args, **kwargs):
  194. for i, arg in enumerate(args):
  195. name_tensor("arg_{}".format(i), arg)
  196. # TODO: mark kwargs in order
  197. for k, kwarg in kwargs.items():
  198. if isinstance(kwarg, RawTensor):
  199. name_tensor("kwarg_{}".format(k), kwarg)
  200. if self._arg_bindings is None:
  201. self._arg_bindings = [
  202. ("arg_{}".format(i), arg._tuple_shape) for i, arg in enumerate(args)
  203. ]
  204. if self._kwarg_bindings is None:
  205. self._kwarg_bindings = {
  206. "kwarg_{}".format(k): (k, kwarg._tuple_shape)
  207. for k, kwarg in kwargs.items()
  208. if isinstance(kwarg, RawTensor)
  209. }
  210. def _process_outputs(self, outputs):
  211. if isinstance(outputs, RawTensor):
  212. outputs = [outputs]
  213. if isinstance(outputs, collections.abc.Mapping):
  214. output_names, outputs = zip(*sorted(outputs.items()))
  215. else:
  216. # output_names = ["output_{}".format(i) for i in range(len(outputs))]
  217. output_names = None
  218. self._output_names = output_names
  219. for i, output in enumerate(outputs):
  220. name_tensor("output_{}".format(i), output)
  221. if self._output_bindings is None:
  222. self._output_bindings = ["output_{}".format(i) for i in range(len(outputs))]
  223. def _begin_excluded_region(self):
  224. self._trace.begin_excluded_region()
  225. def _end_excluded_region(self):
  226. self._trace.end_excluded_region()
  227. def _make_feed(
  228. self,
  229. graph,
  230. outputs,
  231. input_data,
  232. repeat,
  233. silent,
  234. no_assert,
  235. maxerr,
  236. resize_input,
  237. input_transform,
  238. ):
  239. def auto_reformat_image(path, data, dst_shape):
  240. """reformat image to target shape
  241. :param data: image data as numpy array
  242. :param dst_shape: target shape
  243. """
  244. dim3_format = False # required input format does not contain batch
  245. hwc_format = False # required input format is NHWC
  246. if not dst_shape: # input tensor shape is not predefined
  247. if len(data.shape) == 2:
  248. chl = 1
  249. h = data.shape[0]
  250. w = data.shape[1]
  251. else:
  252. assert (
  253. len(data.shape) == 3
  254. ), "Input image must be of dimension 2 or 3"
  255. h, w, chl = data.shape
  256. dst_shape = (1, chl, h, w)
  257. if len(dst_shape) == 3:
  258. dst_shape = (1,) + dst_shape
  259. dim3_format = True
  260. assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape)
  261. chl = dst_shape[1]
  262. if chl in [1, 3]:
  263. n, c, h, w = dst_shape
  264. dst_shape = (n, h, w, c)
  265. else:
  266. chl = dst_shape[3]
  267. assert chl in [
  268. 1,
  269. 3,
  270. ], "can not infer input format from shape: {}".format(dst_shape)
  271. hwc_format = True
  272. # dst_shape has now been normalized to NHWC format
  273. if resize_input:
  274. h, w = dst_shape[1:3]
  275. data = cv2.resize(data, (w, h))
  276. logger.info("input {} resized to {}".format(path, data.shape))
  277. if chl == 1:
  278. data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)
  279. data = data[:, :, np.newaxis]
  280. assert data.ndim == 3
  281. data = data[np.newaxis]
  282. # data normalized to NHWC format
  283. if not hwc_format:
  284. data = np.transpose(data, (0, 3, 1, 2))
  285. if dim3_format:
  286. data = np.squeeze(data, 0)
  287. return data
  288. def read_input_data(dst_shape, dtype, path):
  289. def check_shape_equal(dst_shape, data_shape):
  290. if len(dst_shape):
  291. assert len(data_shape) == len(
  292. dst_shape
  293. ), "input/data shapes mismatch: {} vs {}".format(
  294. dst_shape, data_shape
  295. )
  296. if data_shape[1:] != dst_shape[1:]:
  297. logger.warning(
  298. "dst_shape is {}; data_shape is {}".format(
  299. dst_shape, data_shape
  300. )
  301. )
  302. if path.startswith("#"):
  303. assert not resize_input
  304. assert not input_transform
  305. spec = path
  306. m = re.match(
  307. r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec
  308. )
  309. assert m, "bad spec {}".format(spec)
  310. rng_min = float(m.group(1))
  311. rng_max = float(m.group(2))
  312. if m.group(3):
  313. shape_str = m.group(3)
  314. try:
  315. shape = shape_str[1:].split(",")
  316. if shape[-1].strip() == "...":
  317. shape = shape[:-1]
  318. shape.extend(list(dst_shape[len(shape) :]))
  319. data_shape = tuple(map(int, shape))
  320. except ValueError as e:
  321. raise ValueError("bad spec {}: {}".format(spec, e.args))
  322. else:
  323. data_shape = dst_shape
  324. check_shape_equal(dst_shape, data_shape)
  325. return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype)
  326. # try to load image
  327. data = cv2.imread(path, cv2.IMREAD_COLOR)
  328. if data is None:
  329. assert not resize_input
  330. data = np.load(path)
  331. assert isinstance(data, np.ndarray)
  332. else:
  333. # load image succeeds, so we expect input format is image format
  334. data = auto_reformat_image(path, data, dst_shape)
  335. data = np.repeat(data, repeat, axis=0)
  336. if repeat > 1:
  337. logger.info(
  338. "repeat input for {} times, data shape is {}".format(
  339. repeat, data.shape
  340. )
  341. )
  342. check_shape_equal(dst_shape, data.shape)
  343. if input_transform:
  344. data = eval(input_transform, {"data": data, "np": np})
  345. return data
  346. def gen_one_testcase(inputs, spec):
  347. paths = spec.split(";")
  348. if len(paths) != len(inputs):
  349. if len(paths) == 1 and paths[0].startswith("#"):
  350. paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()]
  351. assert len(paths) == len(
  352. inputs
  353. ), "required inputs: {}; data paths: {}".format(inputs.keys(), paths)
  354. if len(paths) == 1 and ":" not in paths[0]:
  355. paths[0] = next(iter(inputs.keys())) + ":" + paths[0]
  356. ret = {}
  357. for path in paths:
  358. var, path = path.split(":")
  359. ret[var] = read_input_data(inputs[var].shape, inputs[var].dtype, path)
  360. return ret
  361. inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy")
  362. inputs = {i.name: i for i in inputs}
  363. if not no_assert:
  364. replace_varmap = {}
  365. inp_map = {}
  366. # replace var use InputNode
  367. for name, var in inputs.items():
  368. inp = G.InputNode(
  369. device="xpux", dtype=var.dtype, shape=var.shape, graph=graph
  370. )
  371. replace_varmap[var] = inp.outputs[0]._node
  372. inp_map[name] = inp
  373. new = cgtools.replace_vars(outputs, replace_varmap)
  374. if isinstance(new, rt.VarNode):
  375. new = list(new)
  376. output_nodes = [G.OutputNode(var) for var in new]
  377. func = graph.compile(*[node.outputs[0]._node for node in output_nodes])
  378. def make_dev_tensor(value, dtype=None, device=None):
  379. return tensor(value, dtype=dtype, device=device)._dev_tensor()
  380. def calculate(*args, **kwargs):
  381. output_val = []
  382. # set inputs value
  383. for name, var in inputs.items():
  384. val = kwargs.pop(name, None)
  385. assert val is not None, "miss input name{}".format(name)
  386. dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux")
  387. inp_map[name].set_value(dev_tensor)
  388. func.execute()
  389. for res in output_nodes:
  390. output_val.append(res.get_value().numpy())
  391. return output_val
  392. def expect_name(var):
  393. return "{}:expect".format(var.name)
  394. testcases = []
  395. np.set_printoptions(precision=2, threshold=4, suppress=True)
  396. data_list = []
  397. for item in input_data:
  398. if item.startswith("@"):
  399. with open(item[1:], "r") as f:
  400. data_list.extend(
  401. [line.rstrip() for line in f if line.rstrip() != ""]
  402. )
  403. else:
  404. data_list.append(item)
  405. for inp_spec in data_list:
  406. cur_testcase = gen_one_testcase(inputs, inp_spec)
  407. assert len(cur_testcase) == len(
  408. inputs
  409. ), "required inputs: {}; given data: {}".format(
  410. inputs.keys(), cur_testcase.keys()
  411. )
  412. if not no_assert:
  413. outputs_get = calculate(**cur_testcase)
  414. for var, val in zip(outputs, outputs_get):
  415. cur_testcase[expect_name(var)] = val
  416. logger.info(
  417. "generate test groundtruth: var={} shape={} range=({}, {})"
  418. " mean={} var={}".format(
  419. var,
  420. val.shape,
  421. val.min(),
  422. val.max(),
  423. np.mean(val),
  424. np.var(val),
  425. )
  426. )
  427. testcases.append(cur_testcase)
  428. logger.info(
  429. "add testcase: \n {}".format(
  430. "\n ".join(
  431. "{}: shape={} dtype={} range=({:.2f},{:.2f}) "
  432. "mean={:.2f} sd={:.2f}".format(
  433. k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v)
  434. )
  435. for k, v in sorted(cur_testcase.items())
  436. )
  437. )
  438. )
  439. if not no_assert:
  440. def expect_shp(var):
  441. ret = var.shape
  442. if ret:
  443. return ret
  444. return testcases[0][expect_name(var)].shape
  445. def assert_equal(expect, real, **kwargs):
  446. op = AssertEqual(**kwargs)
  447. (res,) = G.apply_normal_varnode(op, expect, real)
  448. return res._node
  449. verbose = not silent
  450. outputs_new = []
  451. for i in outputs:
  452. device = rt.CompNode("xpux")
  453. dtype = i.dtype
  454. name = expect_name(i)
  455. shape = expect_shp(i)
  456. # make expect output as one input of model.
  457. expect_get = rt.make_h2d(graph, device, dtype, shape, name)
  458. # insert assert opr to check expect and real.
  459. outputs_new.append(
  460. assert_equal(expect_get, i, verbose=verbose, maxerr=maxerr,)
  461. )
  462. inputs[expect_name(i)] = expect_get
  463. outputs = outputs_new
  464. return {"outputs": outputs, "testcases": testcases}
  465. def dump(
  466. self,
  467. file,
  468. *,
  469. arg_names=None,
  470. output_names=None,
  471. append=False,
  472. keep_var_name: int = 1,
  473. keep_opr_name: bool = False,
  474. keep_param_name: bool = False,
  475. keep_opr_priority: bool = False,
  476. no_change_graph: bool = False,
  477. strip_info_file=None,
  478. append_json=False,
  479. optimize_for_inference=True,
  480. user_info: Any = None,
  481. enable_metadata: bool = True,
  482. input_data=None,
  483. repeat=1,
  484. silent=False,
  485. no_assert=False,
  486. maxerr=1e-4,
  487. resize_input=False,
  488. input_transform=None,
  489. dump_format: str = None,
  490. model_version: int = 2,
  491. **kwargs
  492. ):
  493. r"""Serializes trace to file system.
  494. Args:
  495. file: output file, could be file object or filename.
  496. arg_names: names of the input tensors in the traced function.
  497. output_names: names of the output tensors in the traced function,
  498. use the default name if not specified.
  499. append: whether output is appended to ``file``.
  500. Only works when ``file`` is str.
  501. keep_var_name: level for keeping variable names:
  502. * 0: none of the names are kept
  503. * 1: (default)keep names of output vars
  504. * 2: keep names of all (output and internal) vars
  505. keep_opr_name: whether to keep operator names.
  506. keep_param_name: whether to keep param names, so param values can be
  507. easily manipulated after loading model
  508. keep_opr_priority: whether to keep priority setting for operators
  509. no_change_graph: whether to change the compute graph when dump, for
  510. model compatibility, some operators will convert to its compatible
  511. format in this version.
  512. * if set False, some operators maybe convert to other operator for
  513. compatibility, all operators will ensure compatibility.
  514. * if set True, no operator will change in the graph when dump.
  515. strip_info_file: a string for path or a file handler. if is not None,
  516. then the dump information for code strip would be written to ``strip_info_file``
  517. append_json: will be check when `strip_info_file` is not None. if set
  518. true, the information for code strip will be append to strip_info_file.
  519. if set false, will rewrite strip_info_file
  520. optimize_for_inference: enbale optmizations,
  521. will skip all optimize options if this is False. Default: True
  522. user_info: any type object, which will be pickled to bytes.
  523. enable_metadata: whether to save metadata into output file.
  524. input_data: input test data and current network output would be used as groundtruth.
  525. The format is "var0:file0;var1:file1..." to specify data files for input vars.
  526. It can also be "#rand(min,max,shape...)" for generating random input data, for
  527. example, "#rand(0,255)", "#rand(0,255,1,3,224,224)" or "#rand(0, 255, 1, ...)"
  528. where `...` means the remaining part of the original shape. If the shape is not
  529. specified, the shape of corresponding input tensors in the network will be used.
  530. If there is only one input var, its name can be omitted. Each data file can either
  531. be an image which can be loaded by opencv, or a pickled numpy.ndarray. This option
  532. can be given multiple times to add multiple testcases. If you start the data
  533. with the letter @, the rest should be a filename, and each line in the file should
  534. be a single datum in the format described above. *NOTE* If `input_data` is not None,
  535. you can only use load-and-run to run the output file.
  536. repeat: how many times the input image is repeated. Useful when running benchmark for
  537. batch size other than one. Have no effect on randomly generated input data.
  538. silent: whether set verbose to False in assert_equal opr.
  539. no_assert: whether insert assert_equal opr to check result; this option is useful for
  540. benchmarking.
  541. maxerr: max error for assert_equal check during runtime.
  542. resize_input: whether resize input image to fit input var shape.
  543. input_transform: a python expression to transform the input data.
  544. Example: data / np.std(data)
  545. dump_format: using different dump formats. the open source MegEngine
  546. defaults to the FBS_V2 format, there are two format FBS_V2 and FBS to choose,
  547. internal MegEngine have an other choice of internal proprietary formats
  548. model_version: the model version of FBS_V2, begin with version 2, this
  549. works only when dump format is FBS_V2.
  550. Keyword Arguments:
  551. * enable_io16xc32 --
  552. whether to use float16 for I/O between oprs and use
  553. float32 as internal computation precision. Note the output var would be
  554. changed to float16.
  555. * enable_ioc16 --
  556. whether to use float16 for both I/O and computation
  557. precision.
  558. * enable_hwcd4 --
  559. whether to use NHWCD4 data layout. This is faster on some
  560. OpenCL backend.
  561. * enable_nchw88 --
  562. whether to use NCHW88 data layout, currently
  563. used in X86 AVX backend.
  564. * enable_nchw44 --
  565. whether to use NCHW44 data layout, currently
  566. used in arm backend.
  567. * enable_nchw44_dot --
  568. whether to use NCHW44_dot data layout, currently
  569. used in armv8.2+dotprod backend.
  570. * enable_nchw4 --
  571. whether to use NCHW4 data layout, currently
  572. used in nvidia backend(based on cudnn).
  573. * enable_nchw32 --
  574. whether to use NCHW32 data layout, currently
  575. used in nvidia backend with tensorcore(based on cudnn).
  576. * enable_chwn4 --
  577. whether to use CHWN4 data layout, currently
  578. used in nvidia backend with tensorcore.
  579. * enable_nchw64 --
  580. whether to use NCHW64 data layout, used for fast int4
  581. support on Nvidia GPU.
  582. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  583. into one opr.
  584. * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
  585. input for inference on nvidia backend(this optimization pass will
  586. result in mismatch of the precision of output of training and
  587. inference)
  588. * enable_fuse_preprocess: whether to fuse astype\pad_channel\dimshuffle and
  589. etc opr
  590. """
  591. if not self._capture_as_const:
  592. raise ValueError(
  593. "you must specify capture_as_const=True at __init__ to use dump"
  594. )
  595. if self._output_names and output_names:
  596. raise TypeError(
  597. "cannot specify output_names when output is already in dict format"
  598. )
  599. if output_names and isinstance(output_names, str):
  600. output_names = (output_names,)
  601. if output_names and len(output_names) != len(self._output_bindings):
  602. raise ValueError(
  603. "wrong number of output_names, should be {} values".format(
  604. len(self._output_bindings)
  605. )
  606. )
  607. prefer_input_names = arg_names is not None
  608. if arg_names is None:
  609. arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
  610. if isinstance(arg_names, str):
  611. arg_names = (arg_names,)
  612. arg_names = [arg_name if arg_name is not None else "" for arg_name in arg_names]
  613. if arg_names and len(arg_names) != len(self._arg_bindings):
  614. raise ValueError(
  615. "wrong number of arg_names, should be {} values".format(
  616. len(self._arg_bindings)
  617. )
  618. )
  619. output_names = output_names or self._output_names
  620. if output_names is None:
  621. output_names = [""] * len(self._output_bindings)
  622. # output_names = ["output_{}".format(i) for i in range(len(self._output_bindings))]
  623. input_bindings = []
  624. def normalize_shape(shape):
  625. return (1,) if shape == () else shape
  626. for arg_name, (arg_id, arg_shape) in zip(arg_names, self._arg_bindings):
  627. input_bindings.append((arg_id, arg_name, normalize_shape(arg_shape)))
  628. for kwarg_id, (kwarg_name, kwarg_shape) in self._kwarg_bindings.items():
  629. input_bindings.append((kwarg_id, kwarg_name, normalize_shape(kwarg_shape)))
  630. graph = G.Graph()
  631. jit_enabled = set_jit_enabled(False)
  632. dest_vars = self._trace.dump(
  633. graph,
  634. input_bindings,
  635. [*zip(self._output_bindings, output_names)],
  636. prefer_input_names,
  637. )
  638. set_jit_enabled(jit_enabled)
  639. # dest_vars = [i._node for i in dest_vars]
  640. if input_data is not None:
  641. feeds = self._make_feed(
  642. graph,
  643. dest_vars,
  644. input_data,
  645. repeat,
  646. silent,
  647. no_assert,
  648. maxerr,
  649. resize_input,
  650. input_transform,
  651. )
  652. assert (
  653. isinstance(feeds, dict) and feeds["testcases"]
  654. ), "testcases can not be empty"
  655. dest_vars = feeds["outputs"]
  656. if optimize_for_inference:
  657. dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs)
  658. dest_vars = [i._node for i in dest_vars]
  659. metadata = SerializationMetadata()
  660. if enable_metadata:
  661. metadata.user_info = pickle.dumps(user_info)
  662. metadata.is_valid = True
  663. metadata.graph_modified = False
  664. if optimize_for_inference:
  665. metadata.optimize_options = optimize_options
  666. if isinstance(file, str):
  667. permission = "wb" if append == False else "ab"
  668. file = open(file, permission)
  669. if keep_opr_priority:
  670. _set_priority_to_id(dest_vars)
  671. if input_data is not None:
  672. file.write(b"mgbtest0")
  673. file.write(struct.pack("I", len(feeds["testcases"])))
  674. dump_content, dump_info = G.dump_graph(
  675. dest_vars,
  676. keep_var_name=keep_var_name,
  677. keep_opr_name=keep_opr_name,
  678. keep_param_name=keep_param_name,
  679. keep_opr_priority=keep_opr_priority,
  680. no_change_graph=no_change_graph,
  681. strip_info_file=strip_info_file,
  682. append_json=append_json,
  683. metadata=metadata,
  684. dump_format=dump_format,
  685. model_version=model_version,
  686. )
  687. file.write(dump_content)
  688. if input_data is not None:
  689. inputs = cgtools.get_dep_vars(dest_vars, "Host2DeviceCopy")
  690. inputs = sorted((i.name, i.dtype) for i in inputs)
  691. def make_dev_tensor(value, dtype=None, device=None):
  692. return tensor(value, dtype=dtype, device=device)._dev_tensor()
  693. for testcase in feeds["testcases"]:
  694. assert isinstance(testcase, dict)
  695. cg = G.Graph()
  696. output_mgbvars = []
  697. for name, dtype in inputs:
  698. output_mgbvars.append(
  699. cg.make_const(
  700. make_dev_tensor(
  701. testcase.pop(name), dtype=dtype, device="cpux"
  702. )
  703. )
  704. )
  705. assert not testcase, "extra inputs provided in testcase: {}".format(
  706. testcase.keys()
  707. )
  708. dump_content, _ = G.dump_graph(
  709. output_mgbvars, strip_info_file=strip_info_file, append_json=True,
  710. )
  711. file.write(dump_content)
  712. return dump_info
  713. def get_profile(self):
  714. return json.loads(self._trace.get_profile())