You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

benchmark_op.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # -*- coding: utf-8 -*-
  2. import time
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as TF
  7. from tabulate import tabulate
  8. import megengine as mge
  9. import megengine.functional as MF
  10. import megengine.module as MM
  11. module_cache = {
  12. "conv2d": (MM.Conv2d(32, 32, 3, 1, 0), nn.Conv2d(32, 32, 3, 1, 0).cuda()),
  13. "dw_conv2d": (
  14. MM.Conv2d(32, 32, 3, 1, 0, groups=32),
  15. nn.Conv2d(32, 32, 3, 1, 0, groups=32).cuda(),
  16. ),
  17. "conv3d": (MM.Conv3d(32, 32, 3, 1, 0), nn.Conv3d(32, 32, 3, 1, 0).cuda()),
  18. "ConvTranspose2d": (
  19. MM.ConvTranspose2d(32, 32, 3, 1, 0),
  20. nn.ConvTranspose2d(32, 32, 3, 1, 0).cuda(),
  21. ),
  22. "BatchNorm2d": (MM.BatchNorm2d(64), nn.BatchNorm2d(64).cuda()),
  23. "Linear": (MM.Linear(1000, 1000), nn.Linear(1000, 1000).cuda()),
  24. }
  25. test_cases = [
  26. # (mge op, torch op, small inps, large inps, unpack_inps, rep)
  27. (
  28. "adaptive_avg_pool2d",
  29. lambda x: MF.adaptive_avg_pool2d(x, (7, 7)),
  30. lambda x: TF.adaptive_avg_pool2d(x, (7, 7)),
  31. [(2, 32, 16, 16)],
  32. [(64, 512, 16, 16)],
  33. True,
  34. 1000,
  35. ),
  36. (
  37. "adaptive_max_pool2d",
  38. lambda x: MF.adaptive_max_pool2d(x, (7, 7)),
  39. lambda x: TF.adaptive_max_pool2d(x, (7, 7)),
  40. [(2, 32, 16, 16)],
  41. [(64, 512, 16, 16)],
  42. True,
  43. 1000,
  44. ),
  45. ("argsort", MF.argsort, torch.argsort, [(1000,)], [(1000, 1000),], True, 1000),
  46. (
  47. "avg_pool2d",
  48. lambda x: MF.avg_pool2d(x, 2),
  49. lambda x: TF.avg_pool2d(x, 2),
  50. [(2, 32, 16, 16)],
  51. [(64, 512, 16, 16)],
  52. True,
  53. 1000,
  54. ),
  55. (
  56. "broadcast",
  57. lambda x: MF.broadcast_to(x, (5,) + x.shape),
  58. lambda x: torch.broadcast_to(x, (5,) + x.shape),
  59. [(100, 100)],
  60. [(64, 512, 16, 16)],
  61. True,
  62. 1000,
  63. ),
  64. (
  65. "batchedmatmul",
  66. MF.matmul,
  67. torch.matmul,
  68. [(8, 64, 32), (8, 32, 64)],
  69. [(8, 2048, 512), (8, 512, 2048)],
  70. True,
  71. 1000,
  72. ),
  73. (
  74. "batchnrom2d",
  75. lambda x: module_cache["BatchNorm2d"][0](x),
  76. lambda x: module_cache["BatchNorm2d"][1](x),
  77. [(2, 64, 16, 16)],
  78. [(64, 64, 128, 128)],
  79. True,
  80. 1000,
  81. ),
  82. (
  83. "concat",
  84. MF.concat,
  85. torch.cat,
  86. [(20, 100), (50, 100), (30, 100)],
  87. [(64, 512, 16, 16), (64, 512, 16, 16), (64, 512, 16, 16)],
  88. False,
  89. 1000,
  90. ),
  91. (
  92. "conv2d",
  93. lambda x: module_cache["conv2d"][0](x),
  94. lambda x: module_cache["conv2d"][1](x),
  95. [(2, 32, 16, 16)],
  96. [(32, 32, 128, 128)],
  97. True,
  98. 1000,
  99. ),
  100. (
  101. "conv3d",
  102. lambda x: module_cache["conv3d"][0](x),
  103. lambda x: module_cache["conv3d"][1](x),
  104. [(2, 32, 8, 8, 8)],
  105. [(32, 32, 16, 16, 16)],
  106. True,
  107. 1000,
  108. ),
  109. (
  110. "convTranspose2d",
  111. lambda x: module_cache["ConvTranspose2d"][0](x),
  112. lambda x: module_cache["ConvTranspose2d"][1](x),
  113. [(2, 32, 16, 16)],
  114. [(32, 32, 128, 128)],
  115. True,
  116. 1000,
  117. ),
  118. (
  119. "dropout",
  120. lambda x: MF.dropout(x, 0.5),
  121. TF.dropout,
  122. [(100, 100)],
  123. [(64, 512, 16, 16)],
  124. True,
  125. 1000,
  126. ),
  127. (
  128. "dw_conv2d",
  129. lambda x: module_cache["dw_conv2d"][0](x),
  130. lambda x: module_cache["dw_conv2d"][1](x),
  131. [(2, 32, 16, 16)],
  132. [(32, 32, 128, 128)],
  133. True,
  134. 1000,
  135. ),
  136. (
  137. "elemwise.unary",
  138. MF.log,
  139. torch.log,
  140. [(100, 100)],
  141. [(64, 512, 16, 16)],
  142. True,
  143. 1000,
  144. ),
  145. (
  146. "elemwise.binary",
  147. MF.add,
  148. torch.add,
  149. [(100, 100), (100, 100)],
  150. [(64, 512, 16, 16), (64, 512, 16, 16)],
  151. True,
  152. 1000,
  153. ),
  154. (
  155. "expand_dims",
  156. lambda x: MF.expand_dims(x, 0),
  157. lambda x: torch.unsqueeze(x, 0),
  158. [(100, 100)],
  159. [(64, 512, 16, 16)],
  160. True,
  161. 1000,
  162. ),
  163. ("gelu", MF.gelu, TF.gelu, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
  164. ("hswish", MF.hswish, TF.hardswish, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
  165. (
  166. "hsigmoid",
  167. MF.hsigmoid,
  168. TF.hardsigmoid,
  169. [(100, 100)],
  170. [(64, 512, 16, 16)],
  171. True,
  172. 1000,
  173. ),
  174. ("isinf", MF.isinf, torch.isinf, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
  175. (
  176. "indeixngMultiAxisVec",
  177. lambda x: x[[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]],
  178. lambda x: x[[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]],
  179. [(10, 10, 10, 10)],
  180. [(64, 512, 16, 16)],
  181. True,
  182. 1000,
  183. ),
  184. (
  185. "logsigmoid",
  186. MF.logsigmoid,
  187. TF.logsigmoid,
  188. [(100, 100)],
  189. [(64, 512, 16, 16)],
  190. True,
  191. 1000,
  192. ),
  193. (
  194. "leaky_relu",
  195. lambda x: MF.leaky_relu(x, 0.5),
  196. lambda x: TF.leaky_relu(x, 0.5),
  197. [(100, 100)],
  198. [(64, 512, 16, 16)],
  199. True,
  200. 1000,
  201. ),
  202. (
  203. "linear",
  204. lambda x: module_cache["Linear"][0](x),
  205. lambda x: module_cache["Linear"][1](x),
  206. [(10, 1000)],
  207. [(64, 128, 1000)],
  208. True,
  209. 1000,
  210. ),
  211. ("matinv", MF.matinv, torch.inverse, [(10, 10)], [(30, 30)], True, 1000),
  212. (
  213. "matmul",
  214. MF.matmul,
  215. torch.matmul,
  216. [(64, 32), (32, 64)],
  217. [(2048, 1024), (1024, 2048)],
  218. True,
  219. 1000,
  220. ),
  221. (
  222. "max_pool2d",
  223. lambda x: MF.max_pool2d(x, 2),
  224. lambda x: TF.max_pool2d(x, 2),
  225. [(2, 32, 16, 16)],
  226. [(64, 512, 16, 16)],
  227. True,
  228. 1000,
  229. ),
  230. (
  231. "normal",
  232. lambda x: mge.random.normal(0, 1, x.shape),
  233. lambda x: torch.randn(x.shape, device="cuda"),
  234. [(100, 100)],
  235. [(64, 512, 16, 16)],
  236. True,
  237. 1000,
  238. ),
  239. (
  240. "prelu",
  241. MF.prelu,
  242. TF.prelu,
  243. [(100, 100), (1,)],
  244. [(64, 512, 16, 16), (1,)],
  245. True,
  246. 1000,
  247. ),
  248. (
  249. "reduce.max",
  250. lambda x: MF.max(x, 0),
  251. lambda x: torch.max(x, 0),
  252. [(100, 100)],
  253. [(64, 512, 16, 16)],
  254. True,
  255. 1000,
  256. ),
  257. (
  258. "reduce.mean",
  259. lambda x: MF.mean(x, 0),
  260. lambda x: torch.mean(x, 0),
  261. [(100, 100)],
  262. [(64, 512, 16, 16)],
  263. True,
  264. 1000,
  265. ),
  266. (
  267. "reduce.mean",
  268. lambda x: MF.mean(x, 0),
  269. lambda x: torch.mean(x, 0),
  270. [(100, 100)],
  271. [(64, 512, 16, 16)],
  272. True,
  273. 1000,
  274. ),
  275. ("relu", MF.relu, TF.relu, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
  276. ("relu6", MF.relu6, TF.relu6, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
  277. (
  278. "repeat",
  279. lambda x: MF.repeat(x, 5),
  280. lambda x: torch.repeat_interleave(x, 5),
  281. [(100, 100)],
  282. [(64, 512, 16, 16)],
  283. True,
  284. 1000,
  285. ),
  286. ("silu", MF.silu, TF.silu, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
  287. (
  288. "split",
  289. lambda x: MF.split(x, 5),
  290. lambda x: torch.split(x, 5),
  291. [(100, 100)],
  292. [(64, 512, 16, 16)],
  293. True,
  294. 1000,
  295. ),
  296. ("sigmoid", MF.sigmoid, TF.sigmoid, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
  297. (
  298. "softmax",
  299. lambda x: MF.softmax(x, axis=1),
  300. lambda x: TF.softmax(x, dim=1),
  301. [(100, 100)],
  302. [(64, 512, 16, 16)],
  303. True,
  304. 1000,
  305. ),
  306. (
  307. "softplus",
  308. MF.softplus,
  309. TF.softplus,
  310. [(100, 100)],
  311. [(64, 512, 16, 16)],
  312. True,
  313. 1000,
  314. ),
  315. (
  316. "squeeze",
  317. lambda x: MF.squeeze(x, 0),
  318. lambda x: torch.squeeze(x, 0),
  319. [(1, 100, 100)],
  320. [(1, 64, 512, 16, 16)],
  321. True,
  322. 1000,
  323. ),
  324. (
  325. "stack",
  326. MF.stack,
  327. torch.stack,
  328. [(100, 100), (100, 100)],
  329. [(64, 512, 16, 16), (64, 512, 16, 16)],
  330. False,
  331. 10000,
  332. ),
  333. (
  334. "subtensor",
  335. lambda x: x[0:20, 10:60],
  336. lambda x: x[0:20, 10:60],
  337. [(100, 100)],
  338. [(64, 512, 16, 16)],
  339. True,
  340. 1000,
  341. ),
  342. (
  343. "topk",
  344. lambda x: MF.topk(x, 10),
  345. lambda x: torch.topk(x, 10),
  346. [(100, 100)],
  347. [(1000, 1000)],
  348. True,
  349. 1000,
  350. ),
  351. (
  352. "tile",
  353. lambda x: MF.tile(x, (2,) * len(x.shape)),
  354. lambda x: torch.tile(x, (2,) * len(x.shape)),
  355. [(100, 100)],
  356. [(64, 512, 16, 16)],
  357. True,
  358. 1000,
  359. ),
  360. (
  361. "transpose",
  362. lambda x: MF.transpose(x, list(range(len(x.shape)))[::-1]),
  363. lambda x: torch.permute(x, list(range(len(x.shape)))[::-1]),
  364. [(100, 100)],
  365. [(64, 512, 16, 16)],
  366. True,
  367. 1000,
  368. ),
  369. (
  370. "where",
  371. lambda x: MF.where(x > 0.5, x, x),
  372. lambda x: torch.where(x > 0.5, x, x),
  373. [(100, 100)],
  374. [(64, 512, 16, 16)],
  375. True,
  376. 1000,
  377. ),
  378. (
  379. "uniform",
  380. lambda x: mge.random.uniform(0, 1, x.shape),
  381. lambda x: torch.rand(x.shape, device="cuda"),
  382. [(100, 100)],
  383. [(64, 512, 16, 16)],
  384. True,
  385. 1000,
  386. ),
  387. ]
  388. def perf_func(func, inps, reps, unpack_inps, is_mge):
  389. if is_mge:
  390. mge._full_sync()
  391. tik = time.time()
  392. for _ in range(reps):
  393. if unpack_inps:
  394. out = func(*inps)
  395. else:
  396. out = func(inps)
  397. mge._full_sync()
  398. else:
  399. torch.cuda.synchronize()
  400. with torch.no_grad():
  401. tik = time.time()
  402. for _ in range(reps):
  403. if unpack_inps:
  404. out = func(*inps)
  405. else:
  406. out = func(inps)
  407. torch.cuda.synchronize()
  408. return time.time() - tik
  409. def get_avg_time(func, inps, reps, unpack_inps, is_mge):
  410. # warm up
  411. for _ in range(2):
  412. t = perf_func(func, inps, reps, unpack_inps, is_mge)
  413. times = []
  414. for _ in range(5):
  415. t = perf_func(func, inps, reps, unpack_inps, is_mge)
  416. times.append(t)
  417. return np.mean(times)
  418. def get_perf_results(mge_func, torch_func, shapes, unpack_inps, reps):
  419. inps = [np.random.randn(*shape) for shape in shapes]
  420. inps_mge = [mge.tensor(inp, dtype="float32") for inp in inps]
  421. avg_time_mge = get_avg_time(mge_func, inps_mge, reps, unpack_inps, True)
  422. inps_torch = [torch.Tensor(inp).type(torch.float).cuda() for inp in inps]
  423. avg_time_torch = get_avg_time(torch_func, inps_torch, reps, unpack_inps, False)
  424. return avg_time_mge, avg_time_torch
  425. if __name__ == "__main__":
  426. header = [
  427. "opr_name",
  428. "time(mge/pytorch; small input)",
  429. "time(mge/pytorch; large input)",
  430. ]
  431. table = []
  432. for case in test_cases:
  433. assert len(case) == 7
  434. name, mge_func, torch_func, small_shapes, large_shapes, unpack_inps, reps = case
  435. data = []
  436. data.append(name)
  437. print("========== op: {}".format(name))
  438. avg_time_mge, avg_time_torch = get_perf_results(
  439. mge_func, torch_func, small_shapes, unpack_inps, reps
  440. )
  441. print("mge time: {}".format(avg_time_mge))
  442. print("torch time: {}".format(avg_time_torch))
  443. data.append("{:.2f}".format(avg_time_mge / avg_time_torch))
  444. avg_time_mge, avg_time_torch = get_perf_results(
  445. mge_func, torch_func, large_shapes, unpack_inps, reps
  446. )
  447. print("mge time: {}".format(avg_time_mge))
  448. print("torch time: {}".format(avg_time_torch))
  449. data.append("{:.2f}".format(avg_time_mge / avg_time_torch))
  450. table.append(data)
  451. print(tabulate(table, header, tablefmt="github"))