You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dump_with_testcase_mge.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import argparse
  10. import os
  11. import re
  12. import struct
  13. import cv2
  14. import numpy as np
  15. import megengine as mge
  16. import megengine.core._imperative_rt as rt
  17. import megengine.core.tensor.megbrain_graph as G
  18. from megengine.utils import comp_graph_tools as cgtools
  19. from megengine.core.ops import builtin
  20. from megengine.core.tensor.core import apply
  21. from megengine.core.tensor.megbrain_graph import VarNode
  22. from megengine.core.tensor.raw_tensor import as_raw_tensor
  23. logger = mge.get_logger(__name__)
  24. def auto_reformat_image(args, path, data, dst_shape):
  25. """reformat image to target shape
  26. :param data: image data as numpy array
  27. :param dst_shape: target shape
  28. """
  29. dim3_format = False # required input format does not contain batch
  30. hwc_format = False # required input format is NHWC
  31. if not dst_shape: # input tensor shape is not predefined
  32. if len(data.shape) == 2:
  33. chl = 1
  34. h = data.shape[0]
  35. w = data.shape[1]
  36. else:
  37. assert len(data.shape) == 3, "Input image must be of dimension 2 or 3"
  38. h, w, chl = data.shape
  39. dst_shape = (1, chl, h, w)
  40. if len(dst_shape) == 3:
  41. dst_shape = (1,) + dst_shape
  42. dim3_format = True
  43. assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape)
  44. chl = dst_shape[1]
  45. if chl in [1, 3]:
  46. n, c, h, w = dst_shape
  47. dst_shape = (n, h, w, c)
  48. else:
  49. chl = dst_shape[3]
  50. assert chl in [1, 3], "can not infer input format from shape: {}".format(
  51. dst_shape
  52. )
  53. hwc_format = True
  54. # dst_shape has now been normalized to NHWC format
  55. if args.resize_input:
  56. h, w = dst_shape[1:3]
  57. data = cv2.resize(data, (w, h))
  58. logger.info("input {} resized to {}".format(path, data.shape))
  59. if chl == 1:
  60. data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)
  61. data = data[:, :, np.newaxis]
  62. assert data.ndim == 3
  63. data = data[np.newaxis]
  64. # data normalized to NHWC format
  65. if not hwc_format:
  66. data = np.transpose(data, (0, 3, 1, 2))
  67. if dim3_format:
  68. data = np.squeeze(data, 0)
  69. return data
  70. def read_input_data(args, dst_shape, dtype, path, repeat):
  71. def check_shape_equal(dst_shape, data_shape):
  72. if len(dst_shape):
  73. assert len(data_shape) == len(
  74. dst_shape
  75. ), "input/data shapes mismatch: {} vs {}".format(dst_shape, data_shape)
  76. if data_shape[1:] != dst_shape[1:]:
  77. logger.warning(
  78. "dst_shape is {}; data_shape is {}".format(dst_shape, data_shape)
  79. )
  80. if path.startswith("#"):
  81. assert not args.resize_input
  82. assert not args.input_transform
  83. spec = path
  84. m = re.match(r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec)
  85. assert m, "bad spec {}".format(spec)
  86. rng_min = float(m.group(1))
  87. rng_max = float(m.group(2))
  88. if m.group(3):
  89. shape_str = m.group(3)
  90. try:
  91. shape = shape_str[1:].split(",")
  92. if shape[-1].strip() == "...":
  93. shape = shape[:-1]
  94. shape.extend(list(dst_shape[len(shape) :]))
  95. data_shape = tuple(map(int, shape))
  96. except ValueError as e:
  97. raise ValueError("bad spec {}: {}".format(spec, e.args))
  98. else:
  99. data_shape = dst_shape
  100. check_shape_equal(dst_shape, data_shape)
  101. return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype)
  102. # try to load image
  103. data = cv2.imread(path, cv2.IMREAD_COLOR)
  104. if data is None:
  105. assert not args.resize_input
  106. data = np.load(path)
  107. assert isinstance(data, np.ndarray)
  108. else:
  109. # load image succeeds, so we expect input format is image format
  110. data = auto_reformat_image(args, path, data, dst_shape)
  111. data = np.repeat(data, repeat, axis=0)
  112. if repeat > 1:
  113. logger.info(
  114. "repeat input for {} times, data shape is {}".format(repeat, data.shape)
  115. )
  116. check_shape_equal(dst_shape, data.shape)
  117. if args.input_transform:
  118. data = eval(args.input_transform, {"data": data, "np": np})
  119. return data
  120. def gen_one_testcase(args, inputs, spec):
  121. paths = spec.split(";")
  122. if len(paths) != len(inputs):
  123. if len(paths) == 1 and paths[0].startswith("#"):
  124. paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()]
  125. assert len(paths) == len(inputs), "required inputs: {}; data paths: {}".format(
  126. inputs.keys(), paths
  127. )
  128. if len(paths) == 1 and ":" not in paths[0]:
  129. paths[0] = next(iter(inputs.keys())) + ":" + paths[0]
  130. ret = {}
  131. for path in paths:
  132. var, path = path.split(":")
  133. if args.repeat:
  134. repeat = args.repeat
  135. else:
  136. repeat = 1
  137. ret[var] = read_input_data(
  138. args, inputs[var].shape, inputs[var].dtype, path, repeat
  139. )
  140. return ret
  141. def make_feeds(args):
  142. cg_rt, _, outputs = G.load_graph(args.input)
  143. inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy")
  144. inputs = {i.name: i for i in inputs}
  145. if not args.no_assert:
  146. replace_varmap = {}
  147. inp_map = {}
  148. # replace var use InputNode
  149. for name, var in inputs.items():
  150. inp = G.InputNode(
  151. device="xpux", dtype=var.dtype, shape=var.shape, graph=cg_rt
  152. )
  153. replace_varmap[var] = inp.outputs[0]
  154. inp_map[name] = inp
  155. new = cgtools.replace_vars(outputs, replace_varmap)
  156. if isinstance(new, rt.VarNode):
  157. new = list(new)
  158. output_nodes = [G.OutputNode(var) for var in new]
  159. func = cg_rt.compile([node.outputs[0] for node in output_nodes])
  160. def make_dev_tensor(value, dtype=None, device=None):
  161. return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor()
  162. def calculate(*args, **kwargs):
  163. output_val = []
  164. # set inputs value
  165. for name, var in inputs.items():
  166. val = kwargs.pop(name, None)
  167. assert val is not None, "miss input name{}".format(name)
  168. dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux")
  169. inp_map[name].set_value(dev_tensor)
  170. func.execute()
  171. for res in output_nodes:
  172. output_val.append(res.get_value().numpy())
  173. return output_val
  174. def expect_name(var):
  175. return "{}:expect".format(var.name)
  176. testcases = []
  177. np.set_printoptions(precision=2, threshold=4, suppress=True)
  178. data_list = []
  179. for item in args.data:
  180. if item.startswith("@"):
  181. with open(item[1:], "r") as f:
  182. data_list.extend([line.rstrip() for line in f if line.rstrip() != ""])
  183. else:
  184. data_list.append(item)
  185. for inp_spec in data_list:
  186. cur_testcase = gen_one_testcase(args, inputs, inp_spec)
  187. assert len(cur_testcase) == len(
  188. inputs
  189. ), "required inputs: {}; given data: {}".format(
  190. inputs.keys(), cur_testcase.keys()
  191. )
  192. if not args.no_assert:
  193. outputs_get = calculate(**cur_testcase)
  194. for var, val in zip(outputs, outputs_get):
  195. cur_testcase[expect_name(var)] = val
  196. logger.info(
  197. "generate test groundtruth: var={} shape={} range=({}, {})"
  198. " mean={} var={}".format(
  199. var, val.shape, val.min(), val.max(), np.mean(val), np.var(val)
  200. )
  201. )
  202. testcases.append(cur_testcase)
  203. logger.info(
  204. "add testcase: \n {}".format(
  205. "\n ".join(
  206. "{}: shape={} dtype={} range=({:.2f},{:.2f}) "
  207. "mean={:.2f} sd={:.2f}".format(
  208. k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v)
  209. )
  210. for k, v in sorted(cur_testcase.items())
  211. )
  212. )
  213. )
  214. if not args.no_assert:
  215. def expect_shp(var):
  216. ret = var.shape
  217. if ret:
  218. return ret
  219. return testcases[0][expect_name(var)].shape
  220. def assert_equal(expect, real, **kwargs):
  221. op = builtin.AssertEqual(**kwargs)
  222. (res,) = apply(op, expect, real)
  223. return res
  224. verbose = not args.silent
  225. outputs_new = []
  226. for i in outputs:
  227. device = rt.CompNode("xpux")
  228. dtype = i.dtype
  229. name = expect_name(i)
  230. shape = expect_shp(i)
  231. # make expect output as one input of model.
  232. expect_get = rt.make_h2d(cg_rt, device, dtype, shape, name)
  233. # insert assert opr to check expect and real.
  234. outputs_new.append(
  235. assert_equal(
  236. G.VarNode(expect_get),
  237. G.VarNode(i),
  238. verbose=verbose,
  239. maxerr=args.maxerr,
  240. )
  241. )
  242. inputs[expect_name(i)] = expect_get
  243. outputs = outputs_new
  244. return {"outputs": outputs, "testcases": testcases}
  245. def optimize_for_inference(args, outputs):
  246. args_map = {
  247. "enable_io16xc32": "f16_io_f32_comp",
  248. "enable_ioc16": "f16_io_comp",
  249. "enable_hwcd4": "use_nhwcd4",
  250. "enable_nchw4": "use_nchw4",
  251. "enable_nchw88": "use_nchw88",
  252. "enable_nchw44": "use_nchw44",
  253. "enable_nchw44_dot": "use_nchw44_dot",
  254. "enable_nchw32": "use_nchw32",
  255. "enable_chwn4": "use_chwn4",
  256. "enable_fuse_conv_bias_nonlinearity": "fuse_conv_bias_nonlinearity",
  257. "enable_fuse_conv_bias_with_z": "fuse_conv_bias_with_z",
  258. }
  259. kwargs = {}
  260. for k, v in args_map.items():
  261. if getattr(args, k):
  262. assert (
  263. args.optimize_for_inference
  264. ), "optimize_for_inference should be set when {} is given".format(k)
  265. kwargs[v] = True
  266. if args.optimize_for_inference:
  267. outputs = [i._node for i in G.optimize_for_inference(outputs, **kwargs)]
  268. return outputs
  269. def main():
  270. parser = argparse.ArgumentParser(
  271. description="Pack computing graph, input values and expected output "
  272. "values into one file for checking correctness. README.md gives more "
  273. "details on the usage",
  274. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  275. )
  276. parser.add_argument("input", help="MegEngine dumped model file")
  277. parser.add_argument("-o", "--output", help="output file", required=True)
  278. parser.add_argument(
  279. "-d",
  280. "--data",
  281. default=[],
  282. action="append",
  283. required=True,
  284. help="Given input test data when input file is a network, "
  285. "and current network output would be used as groundtruth. "
  286. "The format is var0:file0;var1:file1... to specify data files for "
  287. "input vars. It can also be #rand(min,max,shape...) for generating "
  288. "random input data, for example, #rand(0,255), "
  289. "#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means "
  290. "the remaining part of the original shape. "
  291. "If the shape is not specified, the shape of "
  292. "corresponding input tensors in the network will be used. "
  293. "If there is only one input var, its name can be omitted. "
  294. "Each data file can either be an image which can be loaded by opencv, "
  295. "or a pickled numpy.ndarray. "
  296. "This option can be given multiple times to add multiple testcases. "
  297. " *NOTE* "
  298. "If you start the data with the letter @, the rest should be a "
  299. "filename, and each line in the file should be a single datum in "
  300. "the format described above. ",
  301. )
  302. parser.add_argument(
  303. "--repeat",
  304. type=int,
  305. default=1,
  306. help="Specify how many times the input image is repeated. "
  307. "Useful when running benchmark for batch size other than one. "
  308. "Have no effect on randomly generated input data.",
  309. )
  310. parser.add_argument(
  311. "--silent",
  312. action="store_true",
  313. help="set verbose to False in asserti_equal opr",
  314. )
  315. parser.add_argument(
  316. "--optimize-for-inference",
  317. action="store_false",
  318. help="enbale optimization for inference",
  319. )
  320. parser.add_argument(
  321. "--no-assert",
  322. action="store_true",
  323. help="do not insert assert_equal opr to check result; "
  324. "this option is useful for benchmarking",
  325. )
  326. parser.add_argument(
  327. "--maxerr",
  328. type=float,
  329. default=1e-4,
  330. help="max error for assert_equal check during runtime",
  331. )
  332. parser.add_argument(
  333. "--resize-input",
  334. action="store_true",
  335. help="resize input image to fit input var shape",
  336. )
  337. parser.add_argument(
  338. "--input-transform",
  339. help="a python expression to transform the input data. "
  340. "Example: data / np.std(data)",
  341. )
  342. parser.add_argument(
  343. "--discard-var-name",
  344. action="store_true",
  345. help="discard variable and param names in the " "generated output",
  346. )
  347. parser.add_argument(
  348. "--output-strip-info", action="store_true", help="output code strip information"
  349. )
  350. parser.add_argument(
  351. "--enable-io16xc32",
  352. action="store_true",
  353. help="transform the mode to float16 io float32 compute",
  354. )
  355. parser.add_argument(
  356. "--enable-ioc16",
  357. action="store_true",
  358. help="transform the dtype of the model to float16 io " "and compute",
  359. )
  360. parser.add_argument(
  361. "--enable-fuse-conv-bias-nonlinearity",
  362. action="store_true",
  363. help="fuse convolution bias and nonlinearity opr to a "
  364. "conv_bias opr and compute",
  365. )
  366. parser.add_argument(
  367. "--enable-hwcd4",
  368. action="store_true",
  369. help="transform the model format from NCHW to NHWCD4 "
  370. "for inference; you may need to disable CUDA and set "
  371. "MGB_USE_MEGDNN_DBG=2",
  372. )
  373. parser.add_argument(
  374. "--enable-nchw4",
  375. action="store_true",
  376. help="transform the model format from NCHW to NCHW4 " "for inference",
  377. )
  378. parser.add_argument(
  379. "--enable-nchw88",
  380. action="store_true",
  381. help="transform the model format from NCHW to NCHW88 " "for inference",
  382. )
  383. parser.add_argument(
  384. "--enable-nchw44",
  385. action="store_true",
  386. help="transform the model format from NCHW to NCHW44 " "for inference",
  387. )
  388. parser.add_argument(
  389. "--enable-nchw44-dot",
  390. action="store_true",
  391. help="transform the model format from NCHW to NCHW44_DOT "
  392. "for optimizing armv8.2 dot in inference",
  393. )
  394. parser.add_argument(
  395. "--enable-nchw32",
  396. action="store_true",
  397. help="transform the model format from NCHW4 to NCHW32 "
  398. "for inference on nvidia TensoCore",
  399. )
  400. parser.add_argument(
  401. "--enable-chwn4",
  402. action="store_true",
  403. help="transform the model format to CHWN4 "
  404. "for inference, mainly used for nvidia tensorcore",
  405. )
  406. parser.add_argument(
  407. "--enable-fuse-conv-bias-with-z",
  408. action="store_true",
  409. help="fuse conv_bias with z input for inference on "
  410. "nvidia GPU (this optimization pass will result in mismatch "
  411. "of the precision of output of training and inference)",
  412. )
  413. args = parser.parse_args()
  414. feeds = make_feeds(args)
  415. assert isinstance(feeds, dict) and feeds["testcases"], "testcases can not be empty"
  416. output_mgbvars = feeds["outputs"]
  417. output_mgbvars = optimize_for_inference(args, output_mgbvars)
  418. inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy")
  419. inputs = sorted((i.name, i.dtype) for i in inputs)
  420. if args.discard_var_name:
  421. sereg_kwargs = dict(keep_var_name=0, keep_param_name=False)
  422. else:
  423. sereg_kwargs = dict(keep_var_name=2, keep_param_name=True)
  424. strip_info_file = args.output + ".json" if args.output_strip_info else None
  425. with open(args.output, "wb") as fout:
  426. fout.write(b"mgbtest0")
  427. fout.write(struct.pack("I", len(feeds["testcases"])))
  428. if isinstance(output_mgbvars, dict):
  429. wrap_output_vars = dict([(i, VarNode(j)) for i, j in output_mgbvars])
  430. else:
  431. wrap_output_vars = [VarNode(i) for i in output_mgbvars]
  432. dump_content, stat = G.dump_graph(
  433. wrap_output_vars,
  434. append_json=True,
  435. strip_info_file=strip_info_file,
  436. **sereg_kwargs
  437. )
  438. fout.write(dump_content)
  439. logger.info(
  440. "graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB".format(
  441. stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024
  442. )
  443. )
  444. def make_dev_tensor(value, dtype=None, device=None):
  445. return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor()
  446. for testcase in feeds["testcases"]:
  447. assert isinstance(testcase, dict)
  448. cg = G.Graph()
  449. output_mgbvars = []
  450. for name, dtype in inputs:
  451. output_mgbvars.append(
  452. cg.make_const(
  453. make_dev_tensor(testcase.pop(name), dtype=dtype, device="cpux")
  454. )
  455. )
  456. assert not testcase, "extra inputs provided in testcase: {}".format(
  457. testcase.keys()
  458. )
  459. with open(args.output, "ab") as fout:
  460. dump_content, _ = G.dump_graph(
  461. output_mgbvars, strip_info_file=strip_info_file, append_json=True
  462. )
  463. fout.write(dump_content)
  464. if __name__ == "__main__":
  465. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台