You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dump_with_testcase_mge.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. # -*- coding: utf-8 -*-
  2. import argparse
  3. import os
  4. import re
  5. import struct
  6. import cv2
  7. import numpy as np
  8. import megengine as mge
  9. import megengine.core._imperative_rt as rt
  10. import megengine.core.tensor.megbrain_graph as G
  11. from megengine import tensor
  12. from megengine.core.ops import builtin
  13. from megengine.utils import comp_graph_tools as cgtools
  14. logger = mge.get_logger(__name__)
  15. def auto_reformat_image(args, path, data, dst_shape):
  16. """reformat image to target shape
  17. :param data: image data as numpy array
  18. :param dst_shape: target shape
  19. """
  20. dim3_format = False # required input format does not contain batch
  21. hwc_format = False # required input format is NHWC
  22. if not dst_shape: # input tensor shape is not predefined
  23. if len(data.shape) == 2:
  24. chl = 1
  25. h = data.shape[0]
  26. w = data.shape[1]
  27. else:
  28. assert len(data.shape) == 3, "Input image must be of dimension 2 or 3"
  29. h, w, chl = data.shape
  30. dst_shape = (1, chl, h, w)
  31. if len(dst_shape) == 3:
  32. dst_shape = (1,) + dst_shape
  33. dim3_format = True
  34. assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape)
  35. chl = dst_shape[1]
  36. if chl in [1, 3]:
  37. n, c, h, w = dst_shape
  38. dst_shape = (n, h, w, c)
  39. else:
  40. chl = dst_shape[3]
  41. assert chl in [1, 3], "can not infer input format from shape: {}".format(
  42. dst_shape
  43. )
  44. hwc_format = True
  45. # dst_shape has now been normalized to NHWC format
  46. if args.resize_input:
  47. h, w = dst_shape[1:3]
  48. data = cv2.resize(data, (w, h))
  49. logger.info("input {} resized to {}".format(path, data.shape))
  50. if chl == 1:
  51. data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)
  52. data = data[:, :, np.newaxis]
  53. assert data.ndim == 3
  54. data = data[np.newaxis]
  55. # data normalized to NHWC format
  56. if not hwc_format:
  57. data = np.transpose(data, (0, 3, 1, 2))
  58. if dim3_format:
  59. data = np.squeeze(data, 0)
  60. return data
  61. def read_input_data(args, dst_shape, dtype, path, repeat):
  62. def check_shape_equal(dst_shape, data_shape):
  63. if len(dst_shape):
  64. assert len(data_shape) == len(
  65. dst_shape
  66. ), "input/data shapes mismatch: {} vs {}".format(dst_shape, data_shape)
  67. if data_shape[1:] != dst_shape[1:]:
  68. logger.warning(
  69. "dst_shape is {}; data_shape is {}".format(dst_shape, data_shape)
  70. )
  71. if path.startswith("#"):
  72. assert not args.resize_input
  73. assert not args.input_transform
  74. spec = path
  75. m = re.match(r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec)
  76. assert m, "bad spec {}".format(spec)
  77. rng_min = float(m.group(1))
  78. rng_max = float(m.group(2))
  79. if m.group(3):
  80. shape_str = m.group(3)
  81. try:
  82. shape = shape_str[1:].split(",")
  83. if shape[-1].strip() == "...":
  84. shape = shape[:-1]
  85. shape.extend(list(dst_shape[len(shape) :]))
  86. data_shape = tuple(map(int, shape))
  87. except ValueError as e:
  88. raise ValueError("bad spec {}: {}".format(spec, e.args))
  89. else:
  90. data_shape = dst_shape
  91. check_shape_equal(dst_shape, data_shape)
  92. return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype)
  93. # try to load image
  94. data = cv2.imread(path, cv2.IMREAD_COLOR)
  95. if data is None:
  96. assert not args.resize_input
  97. data = np.load(path)
  98. assert isinstance(data, np.ndarray)
  99. else:
  100. # load image succeeds, so we expect input format is image format
  101. data = auto_reformat_image(args, path, data, dst_shape)
  102. data = np.repeat(data, repeat, axis=0)
  103. if repeat > 1:
  104. logger.info(
  105. "repeat input for {} times, data shape is {}".format(repeat, data.shape)
  106. )
  107. check_shape_equal(dst_shape, data.shape)
  108. if args.input_transform:
  109. data = eval(args.input_transform, {"data": data, "np": np})
  110. return data
  111. def gen_one_testcase(args, inputs, spec):
  112. paths = spec.split(";")
  113. if len(paths) != len(inputs):
  114. if len(paths) == 1 and paths[0].startswith("#"):
  115. paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()]
  116. assert len(paths) == len(inputs), "required inputs: {}; data paths: {}".format(
  117. inputs.keys(), paths
  118. )
  119. if len(paths) == 1 and ":" not in paths[0]:
  120. paths[0] = next(iter(inputs.keys())) + ":" + paths[0]
  121. ret = {}
  122. for path in paths:
  123. var, path = path.split(":")
  124. if args.repeat:
  125. repeat = args.repeat
  126. else:
  127. repeat = 1
  128. ret[var] = read_input_data(
  129. args, inputs[var].shape, inputs[var].dtype, path, repeat
  130. )
  131. return ret
  132. def make_feeds(args):
  133. ret = G.load_graph(args.input)
  134. cg_rt, outputs = ret.graph, ret.output_vars_list
  135. inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy")
  136. inputs = {i.name: i for i in inputs}
  137. if not args.no_assert:
  138. replace_varmap = {}
  139. inp_map = {}
  140. # replace var use InputNode
  141. for name, var in inputs.items():
  142. inp = G.InputNode(
  143. device="xpux", dtype=var.dtype, shape=var.shape, graph=cg_rt
  144. )
  145. replace_varmap[var] = inp.outputs[0]
  146. inp_map[name] = inp
  147. new = cgtools.replace_vars(outputs, replace_varmap)
  148. if isinstance(new, rt.VarNode):
  149. new = list(new)
  150. output_nodes = [G.OutputNode(var) for var in new]
  151. func = cg_rt.compile([node.outputs[0] for node in output_nodes])
  152. def make_dev_tensor(value, dtype=None, device=None):
  153. return tensor(value, dtype=dtype, device=device)._dev_tensor()
  154. def calculate(*args, **kwargs):
  155. output_val = []
  156. # set inputs value
  157. for name, var in inputs.items():
  158. val = kwargs.pop(name, None)
  159. assert val is not None, "miss input name{}".format(name)
  160. dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux")
  161. inp_map[name].set_value(dev_tensor)
  162. func.execute()
  163. for res in output_nodes:
  164. output_val.append(res.get_value().numpy())
  165. return output_val
  166. def expect_name(var):
  167. return "{}:expect".format(var.name)
  168. testcases = []
  169. np.set_printoptions(precision=2, threshold=4, suppress=True)
  170. data_list = []
  171. for item in args.data:
  172. if item.startswith("@"):
  173. with open(item[1:], "r") as f:
  174. data_list.extend([line.rstrip() for line in f if line.rstrip() != ""])
  175. else:
  176. data_list.append(item)
  177. for inp_spec in data_list:
  178. cur_testcase = gen_one_testcase(args, inputs, inp_spec)
  179. assert len(cur_testcase) == len(
  180. inputs
  181. ), "required inputs: {}; given data: {}".format(
  182. inputs.keys(), cur_testcase.keys()
  183. )
  184. if not args.no_assert:
  185. outputs_get = calculate(**cur_testcase)
  186. for var, val in zip(outputs, outputs_get):
  187. cur_testcase[expect_name(var)] = val
  188. logger.info(
  189. "generate test groundtruth: var={} shape={} range=({}, {})"
  190. " mean={} var={}".format(
  191. var, val.shape, val.min(), val.max(), np.mean(val), np.var(val)
  192. )
  193. )
  194. testcases.append(cur_testcase)
  195. logger.info(
  196. "add testcase: \n {}".format(
  197. "\n ".join(
  198. "{}: shape={} dtype={} range=({:.2f},{:.2f}) "
  199. "mean={:.2f} sd={:.2f}".format(
  200. k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v)
  201. )
  202. for k, v in sorted(cur_testcase.items())
  203. )
  204. )
  205. )
  206. if not args.no_assert:
  207. def expect_shp(var):
  208. ret = var.shape
  209. if ret:
  210. return ret
  211. return testcases[0][expect_name(var)].shape
  212. def assert_equal(expect, real, **kwargs):
  213. op = builtin.AssertEqual(**kwargs)
  214. (res,) = G.apply_normal_varnode(op, expect, real)
  215. return res
  216. verbose = not args.silent
  217. outputs_new = []
  218. for i in outputs:
  219. device = rt.CompNode("xpux")
  220. dtype = i.dtype
  221. name = expect_name(i)
  222. shape = expect_shp(i)
  223. # make expect output as one input of model.
  224. expect_get = rt.make_h2d(cg_rt, device, dtype, shape, name)
  225. # insert assert opr to check expect and real.
  226. outputs_new.append(
  227. assert_equal(expect_get, i, verbose=verbose, maxerr=args.maxerr,)
  228. )
  229. inputs[expect_name(i)] = expect_get
  230. outputs = outputs_new
  231. return {"outputs": outputs, "testcases": testcases}
  232. def optimize_for_inference(args, outputs):
  233. args_list = [
  234. "enable_io16xc32",
  235. "enable_ioc16",
  236. "enable_hwcd4",
  237. "enable_nchw4",
  238. "enable_nchw88",
  239. "enable_nchw44",
  240. "enable_nchw44_dot",
  241. "enable_nchw32",
  242. "enable_chwn4",
  243. "enable_fuse_conv_bias_nonlinearity",
  244. "enable_fuse_conv_bias_with_z",
  245. "enable_fuse_preprocess",
  246. ]
  247. kwargs = {}
  248. for k in args_list:
  249. if getattr(args, k):
  250. kwargs[k] = True
  251. if args.optimize_for_inference:
  252. outputs = G.optimize_for_inference(outputs, **kwargs)
  253. return outputs
  254. def main():
  255. parser = argparse.ArgumentParser(
  256. description="Pack computing graph, input values and expected output "
  257. "values into one file for checking correctness. README.md gives more "
  258. "details on the usage",
  259. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  260. )
  261. parser.add_argument("input", help="MegEngine dumped model file")
  262. parser.add_argument("-o", "--output", help="output file", required=True)
  263. parser.add_argument(
  264. "-d",
  265. "--data",
  266. default=[],
  267. action="append",
  268. required=True,
  269. help="Given input test data when input file is a network, "
  270. "and current network output would be used as groundtruth. "
  271. "The format is var0:file0;var1:file1... to specify data files for "
  272. "input vars. It can also be #rand(min,max,shape...) for generating "
  273. "random input data, for example, #rand(0,255), "
  274. "#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means "
  275. "the remaining part of the original shape. "
  276. "If the shape is not specified, the shape of "
  277. "corresponding input tensors in the network will be used. "
  278. "If there is only one input var, its name can be omitted. "
  279. "Each data file can either be an image which can be loaded by opencv, "
  280. "or a pickled numpy.ndarray. "
  281. "This option can be given multiple times to add multiple testcases. "
  282. " *NOTE* "
  283. "If you start the data with the letter @, the rest should be a "
  284. "filename, and each line in the file should be a single datum in "
  285. "the format described above. ",
  286. )
  287. parser.add_argument(
  288. "--repeat",
  289. type=int,
  290. default=1,
  291. help="Specify how many times the input image is repeated. "
  292. "Useful when running benchmark for batch size other than one. "
  293. "Have no effect on randomly generated input data.",
  294. )
  295. parser.add_argument(
  296. "--silent",
  297. action="store_true",
  298. help="set verbose to False in asserti_equal opr",
  299. )
  300. parser.add_argument(
  301. "--optimize-for-inference",
  302. action="store_true",
  303. help="enable optimization for inference",
  304. )
  305. parser.add_argument(
  306. "--no-assert",
  307. action="store_true",
  308. help="do not insert assert_equal opr to check result; "
  309. "this option is useful for benchmarking",
  310. )
  311. parser.add_argument(
  312. "--maxerr",
  313. type=float,
  314. default=1e-4,
  315. help="max error for assert_equal check during runtime",
  316. )
  317. parser.add_argument(
  318. "--resize-input",
  319. action="store_true",
  320. help="resize input image to fit input var shape",
  321. )
  322. parser.add_argument(
  323. "--input-transform",
  324. help="a python expression to transform the input data. "
  325. "Example: data / np.std(data)",
  326. )
  327. parser.add_argument(
  328. "--discard-var-name",
  329. action="store_true",
  330. help="discard variable and param names in the " "generated output",
  331. )
  332. parser.add_argument(
  333. "--output-strip-info", action="store_true", help="output code strip information"
  334. )
  335. parser.add_argument(
  336. "--enable-io16xc32",
  337. action="store_true",
  338. help="transform the mode to float16 io float32 compute",
  339. )
  340. parser.add_argument(
  341. "--enable-ioc16",
  342. action="store_true",
  343. help="transform the dtype of the model to float16 io " "and compute",
  344. )
  345. parser.add_argument(
  346. "--enable-fuse-conv-bias-nonlinearity",
  347. action="store_true",
  348. help="fuse convolution bias and nonlinearity opr to a "
  349. "conv_bias opr and compute",
  350. )
  351. parser.add_argument(
  352. "--enable-hwcd4",
  353. action="store_true",
  354. help="transform the model format from NCHW to NHWCD4 "
  355. "for inference; you may need to disable CUDA and set "
  356. "MGB_USE_MEGDNN_DBG=2",
  357. )
  358. parser.add_argument(
  359. "--enable-nchw4",
  360. action="store_true",
  361. help="transform the model format from NCHW to NCHW4 " "for inference",
  362. )
  363. parser.add_argument(
  364. "--enable-nchw88",
  365. action="store_true",
  366. help="transform the model format from NCHW to NCHW88 " "for inference",
  367. )
  368. parser.add_argument(
  369. "--enable-nchw44",
  370. action="store_true",
  371. help="transform the model format from NCHW to NCHW44 " "for inference",
  372. )
  373. parser.add_argument(
  374. "--enable-nchw44-dot",
  375. action="store_true",
  376. help="transform the model format from NCHW to NCHW44_DOT "
  377. "for optimizing armv8.2 dot in inference",
  378. )
  379. parser.add_argument(
  380. "--enable-nchw32",
  381. action="store_true",
  382. help="transform the model format from NCHW4 to NCHW32 "
  383. "for inference on nvidia TensoCore",
  384. )
  385. parser.add_argument(
  386. "--enable-chwn4",
  387. action="store_true",
  388. help="transform the model format to CHWN4 "
  389. "for inference, mainly used for nvidia tensorcore",
  390. )
  391. parser.add_argument(
  392. "--enable-fuse-conv-bias-with-z",
  393. action="store_true",
  394. help="fuse conv_bias with z input for inference on "
  395. "nvidia GPU (this optimization pass will result in mismatch "
  396. "of the precision of output of training and inference)",
  397. )
  398. parser.add_argument(
  399. "--enable-fuse-preprocess",
  400. action="store_true",
  401. help="fuse astype\pad_channel\dimshuffle and etc opr " "from h2d opr",
  402. )
  403. args = parser.parse_args()
  404. feeds = make_feeds(args)
  405. assert isinstance(feeds, dict) and feeds["testcases"], "testcases can not be empty"
  406. output_mgbvars = feeds["outputs"]
  407. output_mgbvars = optimize_for_inference(args, output_mgbvars)
  408. inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy")
  409. inputs = sorted((i.name, i.dtype) for i in inputs)
  410. if args.discard_var_name:
  411. sereg_kwargs = dict(keep_var_name=0, keep_param_name=False)
  412. else:
  413. sereg_kwargs = dict(keep_var_name=2, keep_param_name=True)
  414. strip_info_file = args.output + ".json" if args.output_strip_info else None
  415. with open(args.output, "wb") as fout:
  416. fout.write(b"mgbtest0")
  417. fout.write(struct.pack("I", len(feeds["testcases"])))
  418. dump_content, stat = G.dump_graph(
  419. output_mgbvars,
  420. append_json=True,
  421. strip_info_file=strip_info_file,
  422. **sereg_kwargs,
  423. )
  424. fout.write(dump_content)
  425. logger.info(
  426. "graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB".format(
  427. stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024
  428. )
  429. )
  430. def make_dev_tensor(value, dtype=None, device=None):
  431. return tensor(value, dtype=dtype, device=device)._dev_tensor()
  432. for testcase in feeds["testcases"]:
  433. assert isinstance(testcase, dict)
  434. cg = G.Graph()
  435. output_mgbvars = []
  436. for name, dtype in inputs:
  437. output_mgbvars.append(
  438. cg.make_const(
  439. make_dev_tensor(testcase.pop(name), dtype=dtype, device="cpux")
  440. )
  441. )
  442. assert not testcase, "extra inputs provided in testcase: {}".format(
  443. testcase.keys()
  444. )
  445. with open(args.output, "ab") as fout:
  446. dump_content, _ = G.dump_graph(
  447. output_mgbvars, strip_info_file=strip_info_file, append_json=True
  448. )
  449. fout.write(dump_content)
  450. if __name__ == "__main__":
  451. main()