You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

benchmark_op.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import time
  10. import numpy as np
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as TF
  14. from tabulate import tabulate
  15. import megengine as mge
  16. import megengine.functional as MF
  17. import megengine.module as MM
  18. module_cache = {
  19. "conv2d": (MM.Conv2d(32, 32, 3, 1, 0), nn.Conv2d(32, 32, 3, 1, 0).cuda()),
  20. "dw_conv2d": (
  21. MM.Conv2d(32, 32, 3, 1, 0, groups=32),
  22. nn.Conv2d(32, 32, 3, 1, 0, groups=32).cuda(),
  23. ),
  24. "conv3d": (MM.Conv3d(32, 32, 3, 1, 0), nn.Conv3d(32, 32, 3, 1, 0).cuda()),
  25. "ConvTranspose2d": (
  26. MM.ConvTranspose2d(32, 32, 3, 1, 0),
  27. nn.ConvTranspose2d(32, 32, 3, 1, 0).cuda(),
  28. ),
  29. "BatchNorm2d": (MM.BatchNorm2d(64), nn.BatchNorm2d(64).cuda()),
  30. "Linear": (MM.Linear(1000, 1000), nn.Linear(1000, 1000).cuda()),
  31. }
  32. test_cases = [
  33. # (mge op, torch op, small inps, large inps, unpack_inps, rep)
  34. (
  35. "adaptive_avg_pool2d",
  36. lambda x: MF.adaptive_avg_pool2d(x, (7, 7)),
  37. lambda x: TF.adaptive_avg_pool2d(x, (7, 7)),
  38. [(2, 32, 16, 16)],
  39. [(64, 512, 16, 16)],
  40. True,
  41. 1000,
  42. ),
  43. (
  44. "adaptive_max_pool2d",
  45. lambda x: MF.adaptive_max_pool2d(x, (7, 7)),
  46. lambda x: TF.adaptive_max_pool2d(x, (7, 7)),
  47. [(2, 32, 16, 16)],
  48. [(64, 512, 16, 16)],
  49. True,
  50. 1000,
  51. ),
  52. ("argsort", MF.argsort, torch.argsort, [(1000,)], [(1000, 1000),], True, 1000),
  53. (
  54. "avg_pool2d",
  55. lambda x: MF.avg_pool2d(x, 2),
  56. lambda x: TF.avg_pool2d(x, 2),
  57. [(2, 32, 16, 16)],
  58. [(64, 512, 16, 16)],
  59. True,
  60. 1000,
  61. ),
  62. (
  63. "broadcast",
  64. lambda x: MF.broadcast_to(x, (5,) + x.shape),
  65. lambda x: torch.broadcast_to(x, (5,) + x.shape),
  66. [(100, 100)],
  67. [(64, 512, 16, 16)],
  68. True,
  69. 1000,
  70. ),
  71. (
  72. "batchedmatmul",
  73. MF.matmul,
  74. torch.matmul,
  75. [(8, 64, 32), (8, 32, 64)],
  76. [(8, 2048, 512), (8, 512, 2048)],
  77. True,
  78. 1000,
  79. ),
  80. (
  81. "batchnrom2d",
  82. lambda x: module_cache["BatchNorm2d"][0](x),
  83. lambda x: module_cache["BatchNorm2d"][1](x),
  84. [(2, 64, 16, 16)],
  85. [(64, 64, 128, 128)],
  86. True,
  87. 1000,
  88. ),
  89. (
  90. "concat",
  91. MF.concat,
  92. torch.cat,
  93. [(20, 100), (50, 100), (30, 100)],
  94. [(64, 512, 16, 16), (64, 512, 16, 16), (64, 512, 16, 16)],
  95. False,
  96. 1000,
  97. ),
  98. (
  99. "conv2d",
  100. lambda x: module_cache["conv2d"][0](x),
  101. lambda x: module_cache["conv2d"][1](x),
  102. [(2, 32, 16, 16)],
  103. [(32, 32, 128, 128)],
  104. True,
  105. 1000,
  106. ),
  107. (
  108. "conv3d",
  109. lambda x: module_cache["conv3d"][0](x),
  110. lambda x: module_cache["conv3d"][1](x),
  111. [(2, 32, 8, 8, 8)],
  112. [(32, 32, 16, 16, 16)],
  113. True,
  114. 1000,
  115. ),
  116. (
  117. "convTranspose2d",
  118. lambda x: module_cache["ConvTranspose2d"][0](x),
  119. lambda x: module_cache["ConvTranspose2d"][1](x),
  120. [(2, 32, 16, 16)],
  121. [(32, 32, 128, 128)],
  122. True,
  123. 1000,
  124. ),
  125. (
  126. "dropout",
  127. lambda x: MF.dropout(x, 0.5),
  128. TF.dropout,
  129. [(100, 100)],
  130. [(64, 512, 16, 16)],
  131. True,
  132. 1000,
  133. ),
  134. (
  135. "dw_conv2d",
  136. lambda x: module_cache["dw_conv2d"][0](x),
  137. lambda x: module_cache["dw_conv2d"][1](x),
  138. [(2, 32, 16, 16)],
  139. [(32, 32, 128, 128)],
  140. True,
  141. 1000,
  142. ),
  143. (
  144. "elemwise.unary",
  145. MF.log,
  146. torch.log,
  147. [(100, 100)],
  148. [(64, 512, 16, 16)],
  149. True,
  150. 1000,
  151. ),
  152. (
  153. "elemwise.binary",
  154. MF.add,
  155. torch.add,
  156. [(100, 100), (100, 100)],
  157. [(64, 512, 16, 16), (64, 512, 16, 16)],
  158. True,
  159. 1000,
  160. ),
  161. (
  162. "expand_dims",
  163. lambda x: MF.expand_dims(x, 0),
  164. lambda x: torch.unsqueeze(x, 0),
  165. [(100, 100)],
  166. [(64, 512, 16, 16)],
  167. True,
  168. 1000,
  169. ),
  170. ("gelu", MF.gelu, TF.gelu, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
  171. ("hswish", MF.hswish, TF.hardswish, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
  172. (
  173. "hsigmoid",
  174. MF.hsigmoid,
  175. TF.hardsigmoid,
  176. [(100, 100)],
  177. [(64, 512, 16, 16)],
  178. True,
  179. 1000,
  180. ),
  181. ("isinf", MF.isinf, torch.isinf, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
  182. (
  183. "indeixngMultiAxisVec",
  184. lambda x: x[[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]],
  185. lambda x: x[[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]],
  186. [(10, 10, 10, 10)],
  187. [(64, 512, 16, 16)],
  188. True,
  189. 1000,
  190. ),
  191. (
  192. "logsigmoid",
  193. MF.logsigmoid,
  194. TF.logsigmoid,
  195. [(100, 100)],
  196. [(64, 512, 16, 16)],
  197. True,
  198. 1000,
  199. ),
  200. (
  201. "leaky_relu",
  202. lambda x: MF.leaky_relu(x, 0.5),
  203. lambda x: TF.leaky_relu(x, 0.5),
  204. [(100, 100)],
  205. [(64, 512, 16, 16)],
  206. True,
  207. 1000,
  208. ),
  209. (
  210. "linear",
  211. lambda x: module_cache["Linear"][0](x),
  212. lambda x: module_cache["Linear"][1](x),
  213. [(10, 1000)],
  214. [(64, 128, 1000)],
  215. True,
  216. 1000,
  217. ),
  218. ("matinv", MF.matinv, torch.inverse, [(10, 10)], [(30, 30)], True, 1000),
  219. (
  220. "matmul",
  221. MF.matmul,
  222. torch.matmul,
  223. [(64, 32), (32, 64)],
  224. [(2048, 1024), (1024, 2048)],
  225. True,
  226. 1000,
  227. ),
  228. (
  229. "max_pool2d",
  230. lambda x: MF.max_pool2d(x, 2),
  231. lambda x: TF.max_pool2d(x, 2),
  232. [(2, 32, 16, 16)],
  233. [(64, 512, 16, 16)],
  234. True,
  235. 1000,
  236. ),
  237. (
  238. "normal",
  239. lambda x: mge.random.normal(0, 1, x.shape),
  240. lambda x: torch.randn(x.shape, device="cuda"),
  241. [(100, 100)],
  242. [(64, 512, 16, 16)],
  243. True,
  244. 1000,
  245. ),
  246. (
  247. "prelu",
  248. MF.prelu,
  249. TF.prelu,
  250. [(100, 100), (1,)],
  251. [(64, 512, 16, 16), (1,)],
  252. True,
  253. 1000,
  254. ),
  255. (
  256. "reduce.max",
  257. lambda x: MF.max(x, 0),
  258. lambda x: torch.max(x, 0),
  259. [(100, 100)],
  260. [(64, 512, 16, 16)],
  261. True,
  262. 1000,
  263. ),
  264. (
  265. "reduce.mean",
  266. lambda x: MF.mean(x, 0),
  267. lambda x: torch.mean(x, 0),
  268. [(100, 100)],
  269. [(64, 512, 16, 16)],
  270. True,
  271. 1000,
  272. ),
  273. (
  274. "reduce.mean",
  275. lambda x: MF.mean(x, 0),
  276. lambda x: torch.mean(x, 0),
  277. [(100, 100)],
  278. [(64, 512, 16, 16)],
  279. True,
  280. 1000,
  281. ),
  282. ("relu", MF.relu, TF.relu, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
  283. ("relu6", MF.relu6, TF.relu6, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
  284. (
  285. "repeat",
  286. lambda x: MF.repeat(x, 5),
  287. lambda x: torch.repeat_interleave(x, 5),
  288. [(100, 100)],
  289. [(64, 512, 16, 16)],
  290. True,
  291. 1000,
  292. ),
  293. ("silu", MF.silu, TF.silu, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
  294. (
  295. "split",
  296. lambda x: MF.split(x, 5),
  297. lambda x: torch.split(x, 5),
  298. [(100, 100)],
  299. [(64, 512, 16, 16)],
  300. True,
  301. 1000,
  302. ),
  303. ("sigmoid", MF.sigmoid, TF.sigmoid, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
  304. (
  305. "softmax",
  306. lambda x: MF.softmax(x, axis=1),
  307. lambda x: TF.softmax(x, dim=1),
  308. [(100, 100)],
  309. [(64, 512, 16, 16)],
  310. True,
  311. 1000,
  312. ),
  313. (
  314. "softplus",
  315. MF.softplus,
  316. TF.softplus,
  317. [(100, 100)],
  318. [(64, 512, 16, 16)],
  319. True,
  320. 1000,
  321. ),
  322. (
  323. "squeeze",
  324. lambda x: MF.squeeze(x, 0),
  325. lambda x: torch.squeeze(x, 0),
  326. [(1, 100, 100)],
  327. [(1, 64, 512, 16, 16)],
  328. True,
  329. 1000,
  330. ),
  331. (
  332. "stack",
  333. MF.stack,
  334. torch.stack,
  335. [(100, 100), (100, 100)],
  336. [(64, 512, 16, 16), (64, 512, 16, 16)],
  337. False,
  338. 10000,
  339. ),
  340. (
  341. "subtensor",
  342. lambda x: x[0:20, 10:60],
  343. lambda x: x[0:20, 10:60],
  344. [(100, 100)],
  345. [(64, 512, 16, 16)],
  346. True,
  347. 1000,
  348. ),
  349. (
  350. "topk",
  351. lambda x: MF.topk(x, 10),
  352. lambda x: torch.topk(x, 10),
  353. [(100, 100)],
  354. [(1000, 1000)],
  355. True,
  356. 1000,
  357. ),
  358. (
  359. "tile",
  360. lambda x: MF.tile(x, (2,) * len(x.shape)),
  361. lambda x: torch.tile(x, (2,) * len(x.shape)),
  362. [(100, 100)],
  363. [(64, 512, 16, 16)],
  364. True,
  365. 1000,
  366. ),
  367. (
  368. "transpose",
  369. lambda x: MF.transpose(x, list(range(len(x.shape)))[::-1]),
  370. lambda x: torch.permute(x, list(range(len(x.shape)))[::-1]),
  371. [(100, 100)],
  372. [(64, 512, 16, 16)],
  373. True,
  374. 1000,
  375. ),
  376. (
  377. "where",
  378. lambda x: MF.where(x > 0.5, x, x),
  379. lambda x: torch.where(x > 0.5, x, x),
  380. [(100, 100)],
  381. [(64, 512, 16, 16)],
  382. True,
  383. 1000,
  384. ),
  385. (
  386. "uniform",
  387. lambda x: mge.random.uniform(0, 1, x.shape),
  388. lambda x: torch.rand(x.shape, device="cuda"),
  389. [(100, 100)],
  390. [(64, 512, 16, 16)],
  391. True,
  392. 1000,
  393. ),
  394. ]
  395. def perf_func(func, inps, reps, unpack_inps, is_mge):
  396. if is_mge:
  397. mge.sync()
  398. tik = time.time()
  399. for _ in range(reps):
  400. if unpack_inps:
  401. out = func(*inps)
  402. else:
  403. out = func(inps)
  404. mge.sync()
  405. else:
  406. torch.cuda.synchronize()
  407. with torch.no_grad():
  408. tik = time.time()
  409. for _ in range(reps):
  410. if unpack_inps:
  411. out = func(*inps)
  412. else:
  413. out = func(inps)
  414. torch.cuda.synchronize()
  415. return time.time() - tik
  416. def get_avg_time(func, inps, reps, unpack_inps, is_mge):
  417. # warm up
  418. for _ in range(2):
  419. t = perf_func(func, inps, reps, unpack_inps, is_mge)
  420. times = []
  421. for _ in range(5):
  422. t = perf_func(func, inps, reps, unpack_inps, is_mge)
  423. times.append(t)
  424. return np.mean(times)
  425. def get_perf_results(mge_func, torch_func, shapes, unpack_inps, reps):
  426. inps = [np.random.randn(*shape) for shape in shapes]
  427. inps_mge = [mge.tensor(inp, dtype="float32") for inp in inps]
  428. avg_time_mge = get_avg_time(mge_func, inps_mge, reps, unpack_inps, True)
  429. inps_torch = [torch.Tensor(inp).type(torch.float).cuda() for inp in inps]
  430. avg_time_torch = get_avg_time(torch_func, inps_torch, reps, unpack_inps, False)
  431. return avg_time_mge, avg_time_torch
  432. if __name__ == "__main__":
  433. header = [
  434. "opr_name",
  435. "time(mge/pytorch; small input)",
  436. "time(mge/pytorch; large input)",
  437. ]
  438. table = []
  439. for case in test_cases:
  440. assert len(case) == 7
  441. name, mge_func, torch_func, small_shapes, large_shapes, unpack_inps, reps = case
  442. data = []
  443. data.append(name)
  444. print("========== op: {}".format(name))
  445. avg_time_mge, avg_time_torch = get_perf_results(
  446. mge_func, torch_func, small_shapes, unpack_inps, reps
  447. )
  448. print("mge time: {}".format(avg_time_mge))
  449. print("torch time: {}".format(avg_time_torch))
  450. data.append("{:.2f}".format(avg_time_mge / avg_time_torch))
  451. avg_time_mge, avg_time_torch = get_perf_results(
  452. mge_func, torch_func, large_shapes, unpack_inps, reps
  453. )
  454. print("mge time: {}".format(avg_time_mge))
  455. print("torch time: {}".format(avg_time_torch))
  456. data.append("{:.2f}".format(avg_time_mge / avg_time_torch))
  457. table.append(data)
  458. print(tabulate(table, header, tablefmt="github"))