You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

load_network_and_run.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import argparse
  9. import logging
  10. import time
  11. from collections import OrderedDict
  12. import numpy as np
  13. import megengine as mge
  14. from megengine.core.tensor import megbrain_graph as G
  15. from megengine.device import get_device_count, set_default_device
  16. from megengine.functional.debug_param import set_execution_strategy
  17. from megengine.logger import enable_debug_log, get_logger, set_log_file
  18. from megengine.utils import comp_graph_tools as tools
  19. logger = get_logger(__name__)
  20. def make_data_given_desc(args, inputs, shape0_multiply=1):
  21. if args.load_input_data:
  22. logger.info("load data from {}".format(args.load_input_data))
  23. data = mge.load(args.load_input_data)
  24. data_names = [inp.name for inp in inputs]
  25. if isinstance(data, np.ndarray):
  26. assert len(data_names) == 1, (
  27. "data is given as a single numpy array, so there should be "
  28. "exactly one input in the graph; got: {}".format(data_names)
  29. )
  30. data = {data_names[0]: data}
  31. assert isinstance(data, dict)
  32. for v in data.values():
  33. assert isinstance(
  34. v, np.ndarray
  35. ), "data should provide ndarray; got {} instead".format(v)
  36. if args.batchsize:
  37. for k, v in list(data.items()):
  38. assert (
  39. args.batchsize % v.shape[0] == 0
  40. ), "current batch size must divide given batch size: {} {}".format(
  41. args.batchsize, v.shape[0]
  42. )
  43. data[k] = np.repeat(v, args.batchsize // v.shape[0], axis=0)
  44. return data
  45. def iter_inpdesc(desc):
  46. if not desc:
  47. return
  48. for pair in desc.split(";"):
  49. name, value = pair.split(":")
  50. if name not in data_shapes:
  51. logger.warning("rng name {} not in data provider".format(name))
  52. yield name, value
  53. rng = np.random.RandomState(args.seed)
  54. data_shapes = OrderedDict((inp.name, list(inp.shape)) for inp in inputs)
  55. data_dtypes = OrderedDict((inp.name, inp.dtype) for inp in inputs)
  56. for name, shape in iter_inpdesc(args.input_desc):
  57. data_shapes[name] = list(map(int, shape.split(",")))
  58. if args.batchsize:
  59. for i in data_shapes.values():
  60. i[0] = args.batchsize
  61. data_rngs = dict(iter_inpdesc(args.rng))
  62. result = OrderedDict()
  63. for name, shape in data_shapes.items():
  64. shape[0] *= shape0_multiply
  65. rng_expr = data_rngs.get(name)
  66. if rng_expr:
  67. value = eval("rng.{}".format(rng_expr).format(shape), {"rng": rng})
  68. else:
  69. value = rng.uniform(size=shape)
  70. value = np.ascontiguousarray(value, dtype=data_dtypes[name])
  71. assert value.shape == tuple(shape)
  72. result[name] = value
  73. return result
  74. def get_execution_strategy(args):
  75. if not args.fast_run:
  76. logger.warning("--fast-run not enabled; execution may be slow")
  77. strategy = "HEURISTIC"
  78. else:
  79. logger.warning("--fast-run enabled; compile may be slow")
  80. strategy = "PROFILE"
  81. if args.reproducible:
  82. strategy += "_REPRODUCIBLE"
  83. return strategy
  84. def get_opt_kwargs(args):
  85. args_list = [
  86. "enable_io16xc32",
  87. "enable_ioc16",
  88. "enable_hwcd4",
  89. "enable_nchw4",
  90. "enable_nchw88",
  91. "enable_nchw44",
  92. "enable_nchw44_dot",
  93. "enable_nchw32",
  94. "enable_chwn4",
  95. "enable_fuse_conv_bias_nonlinearity",
  96. "enable_fuse_conv_bias_with_z",
  97. ]
  98. kwargs = {}
  99. for k in args_list:
  100. if getattr(args, k):
  101. kwargs[k] = True
  102. return kwargs
  103. def run_model(args, graph, inputs, outputs, data):
  104. # must use level0 to avoid unintended opr modification
  105. graph.options.graph_opt_level = 0
  106. logger.info("input tensors: ")
  107. for k, v in data.items():
  108. logger.info(" {}: {}".format(k, v.shape))
  109. G.modify_opr_algo_strategy_inplace(outputs, get_execution_strategy(args))
  110. if args.optimize_for_inference:
  111. opt_kwargs = get_opt_kwargs(args)
  112. outputs = G.optimize_for_inference(outputs, **opt_kwargs)
  113. # embed inputs must be on the last, to avoid const fold
  114. if args.embed_input:
  115. outputs, inp_dict = tools.embed_inputs(outputs, data.values(), inputs=inputs)
  116. else:
  117. outputs, inp_dict = tools.convert_inputs(outputs, inputs=inputs)
  118. if args.dump_cpp_model:
  119. dump_content, _ = G.dump_graph(outputs, keep_var_name=2)
  120. with open(args.dump_cpp_model, "wb") as file:
  121. file.write(dump_content)
  122. logger.info("C++ model written to {}".format(args.dump_cpp_model))
  123. outputs, output_dict = tools.convert_outputs(outputs)
  124. if args.profile:
  125. profiler = tools.GraphProfiler(graph)
  126. func = graph.compile(outputs)
  127. def run():
  128. if not args.embed_input:
  129. for key in inp_dict:
  130. inp_dict[key].set_value(mge.Tensor(data[key])._dev_tensor())
  131. func.execute()
  132. func.wait()
  133. return [oup_node.get_value().numpy() for oup_node in output_dict.values()]
  134. if args.warm_up:
  135. logger.info("warming up")
  136. run()
  137. total_time = 0
  138. for i in range(args.iter):
  139. logger.info("iter {}".format(i))
  140. start_time = time.time()
  141. retval = run()
  142. cur_time = time.time() - start_time
  143. total_time += cur_time
  144. avg_speed = (i + 1) / total_time
  145. if "data" in data:
  146. avg_speed *= data["data"].shape[0]
  147. avg_speed_txt = "{:.3f}sample/s".format(avg_speed)
  148. else:
  149. avg_speed_txt = "{:.3f}batch/s".format(avg_speed)
  150. msg = (
  151. "iter {}: duration={:.4f}({:.4f})s average={:.4f}s "
  152. "avg_speed={} time={:.4f}s"
  153. ).format(
  154. i,
  155. cur_time,
  156. func.get_prev_exec_time(),
  157. total_time / (i + 1),
  158. avg_speed_txt,
  159. total_time,
  160. )
  161. if args.calc_output_rms:
  162. rms = []
  163. for v in retval:
  164. rms.append("{:.3g}".format(float(((v ** 2).mean()) ** 0.5)))
  165. msg += " output_rms=[{}]".format(", ".join(rms))
  166. if logger.level > logging.INFO:
  167. print(msg)
  168. else:
  169. logger.info(msg)
  170. if args.focused_nvprof:
  171. if get_device_count("gpu") < 1:
  172. logger.warning(
  173. "No cuda device detected. ``focused_nvprof`` will be ignored."
  174. )
  175. else:
  176. try:
  177. import pycuda.driver as D
  178. D.start_profiler()
  179. func.execute()
  180. func.wait()
  181. D.stop_profiler()
  182. except ImportError:
  183. logger.error("`focused_nvprof need pycuda`", exc_info=True)
  184. if args.profile:
  185. with open(args.profile, "w") as fout:
  186. fout.write(profiler.get())
  187. return avg_speed
  188. def main():
  189. parser = argparse.ArgumentParser(
  190. description="load a network and run inference on random data",
  191. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  192. )
  193. parser.add_argument("net")
  194. parser.add_argument(
  195. "--device", "-d", help="set defult device, like 'gpux' or 'cpux'"
  196. )
  197. parser.add_argument(
  198. "--calc-output-rms",
  199. action="store_true",
  200. help="compute RMS of outputs; useful for comparing computing results",
  201. )
  202. parser.add_argument(
  203. "--output-name",
  204. nargs="*",
  205. help="Specify output name. This option can be"
  206. " specified multiple times. We will look for opr/var"
  207. " in the graph",
  208. )
  209. parser.add_argument(
  210. "--load-input-data",
  211. help="load input data from pickle file; it should be"
  212. " a numpy array or a dict of numpy array",
  213. )
  214. parser.add_argument("--profile", help="profiler output file")
  215. parser.add_argument(
  216. "--fast-run",
  217. action="store_true",
  218. help="enable fast running by profiling conv algorithms during compiling.",
  219. )
  220. parser.add_argument(
  221. "--reproducible", action="store_true", help="use reproducible kernels"
  222. )
  223. parser.add_argument(
  224. "--input-desc",
  225. help="specifiy input names and shapes manually in"
  226. " format: <name>:<shape>[;<name>:<shape>, ...], where"
  227. " name is a string and shape is a comma separated"
  228. ' string. e.g., "data:128,1,28,28,label:128".'
  229. " different input tensor are separated by semicolon.",
  230. )
  231. parser.add_argument(
  232. "--batchsize",
  233. type=int,
  234. help="change batchsize; the first dimension of each"
  235. " input is assumed to be batch size",
  236. )
  237. parser.add_argument(
  238. "--warm-up",
  239. action="store_true",
  240. help="warm up model before do timing " " for better estimation",
  241. )
  242. parser.add_argument(
  243. "--verbose",
  244. "-v",
  245. action="store_true",
  246. help="verbose output, logging in debug mode",
  247. )
  248. parser.add_argument(
  249. "--iter", type=int, default=1, help="number of iters to run the model"
  250. )
  251. parser.add_argument("--log", help="give a file path to duplicate log to")
  252. parser.add_argument(
  253. "--seed",
  254. type=int,
  255. default=0,
  256. help="seed for random number generator for input data",
  257. )
  258. parser.add_argument(
  259. "--rng",
  260. help="special RNG options to generate input data in"
  261. " format: <name>:func[;<name>:func, ...] where name is"
  262. " a string and func is a python expression containing"
  263. ' "{}" for the size param, e.g. '
  264. ' "label:randint(low=0,high=1000,size={})"',
  265. )
  266. parser.add_argument(
  267. "--focused-nvprof",
  268. action="store_true",
  269. help="only profile last iter for `nvprof --profile-from-start off`",
  270. )
  271. parser.add_argument(
  272. "--optimize-for-inference",
  273. action="store_true",
  274. help="optimize model for inference",
  275. )
  276. parser.add_argument(
  277. "--enable-io16xc32",
  278. action="store_true",
  279. help="transform the mode to float16 io float32 compute",
  280. )
  281. parser.add_argument(
  282. "--enable-ioc16",
  283. action="store_true",
  284. help="transform the dtype of the model to float16 io and compute",
  285. )
  286. parser.add_argument(
  287. "--enable-hwcd4",
  288. action="store_true",
  289. help="transform the model format from NCHW to NHWCD4 for inference",
  290. )
  291. parser.add_argument(
  292. "--enable-nchw4",
  293. action="store_true",
  294. help="transform the model format from NCHW to NCHW4 for inference",
  295. )
  296. parser.add_argument(
  297. "--enable-nchw88",
  298. action="store_true",
  299. help="transform the model format from NCHW to NCHW88 for inference",
  300. )
  301. parser.add_argument(
  302. "--enable-nchw44",
  303. action="store_true",
  304. help="transform the model format from NCHW to NCHW44 for inference",
  305. )
  306. parser.add_argument(
  307. "--enable-nchw44-dot",
  308. action="store_true",
  309. help="transform the model format from NCHW to NCHW44_DOT "
  310. "for optimizing armv8.2 dot in inference",
  311. )
  312. parser.add_argument(
  313. "--enable-chwn4",
  314. action="store_true",
  315. help="transform the model format to CHWN4 "
  316. "for inference, mainly used for nvidia tensorcore",
  317. )
  318. parser.add_argument(
  319. "--enable-nchw32",
  320. action="store_true",
  321. help="transform the model format from NCHW4 to NCHW32 "
  322. "for inference on nvidia TensoCore",
  323. )
  324. parser.add_argument(
  325. "--enable-fuse-conv-bias-nonlinearity",
  326. action="store_true",
  327. help="fuse convolution bias and nonlinearity opr to a "
  328. "conv_bias opr and compute",
  329. )
  330. parser.add_argument(
  331. "--enable-fuse-conv-bias-with-z",
  332. action="store_true",
  333. help="fuse conv_bias with z input for inference on "
  334. "nvidia GPU (this optimization pass will result in mismatch "
  335. "of the precision of output of training and inference)",
  336. )
  337. parser.add_argument(
  338. "--dump-cpp-model",
  339. help="write a C++ model that can be loaded by "
  340. "megbrain/sdk/load-and-run; "
  341. "this implies --embed-input",
  342. )
  343. parser.add_argument(
  344. "--embed-input",
  345. action="store_true",
  346. help="embed input data as SharedDeviceTensor in model, "
  347. "to remove memory copy for inputs",
  348. )
  349. args = parser.parse_args()
  350. if args.verbose:
  351. enable_debug_log()
  352. if args.log:
  353. set_log_file(args.log)
  354. if args.device:
  355. set_default_device(args.device)
  356. if args.dump_cpp_model:
  357. args.embed_input = True
  358. logger.info("loading model ...")
  359. graph, _, output_vars = G.load_graph(args.net)
  360. input_vars = tools.get_dep_vars(output_vars, "Host2DeviceCopy")
  361. if args.output_name is not None:
  362. output_vars = tools.find_vars_by_name(output_vars, args.output_name)
  363. data = make_data_given_desc(args, input_vars)
  364. run_model(args, graph, input_vars, output_vars, data)
  365. if __name__ == "__main__":
  366. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台