You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

benchmark_op.py 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import time
  10. import numpy as np
  11. import megengine as mge
  12. import megengine.module as MM
  13. import megengine.functional as MF
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as TF
  17. from tabulate import tabulate
  18. module_cache = {
  19. "conv2d": (MM.Conv2d(32, 32, 3, 1, 0), nn.Conv2d(32, 32, 3, 1, 0).cuda()),
  20. "dw_conv2d": (MM.Conv2d(32, 32, 3, 1, 0, groups=32), nn.Conv2d(32, 32, 3, 1, 0, groups=32).cuda()),
  21. "conv3d": (MM.Conv3d(32, 32, 3, 1, 0), nn.Conv3d(32, 32, 3, 1, 0).cuda()),
  22. "ConvTranspose2d": (MM.ConvTranspose2d(32, 32, 3, 1, 0), nn.ConvTranspose2d(32, 32, 3, 1, 0).cuda()),
  23. "BatchNorm2d": (MM.BatchNorm2d(64), nn.BatchNorm2d(64).cuda()),
  24. "Linear": (MM.Linear(1000, 1000), nn.Linear(1000, 1000).cuda()),
  25. }
  26. test_cases = [
  27. # (mge op, torch op, small inps, large inps, unpack_inps, rep)
  28. ("adaptive_avg_pool2d", lambda x: MF.adaptive_avg_pool2d(x, (7,7)), lambda x: TF.adaptive_avg_pool2d(x, (7,7)), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000),
  29. ("adaptive_max_pool2d", lambda x: MF.adaptive_max_pool2d(x, (7,7)), lambda x: TF.adaptive_max_pool2d(x, (7,7)), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000),
  30. ("argsort", MF.argsort, torch.argsort, [(1000,)], [(1000, 1000),], True, 1000),
  31. ("avg_pool2d", lambda x: MF.avg_pool2d(x, 2), lambda x: TF.avg_pool2d(x, 2), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000),
  32. ("broadcast", lambda x: MF.broadcast_to(x, (5,) + x.shape), lambda x: torch.broadcast_to(x, (5,)+x.shape), [(100,100)], [(64, 512, 16, 16)], True, 1000),
  33. ("batchedmatmul", MF.matmul, torch.matmul, [(8, 64, 32), (8, 32, 64)], [(8, 2048, 512), (8, 512, 2048)], True, 1000),
  34. ("batchnrom2d", lambda x: module_cache["BatchNorm2d"][0](x), lambda x: module_cache["BatchNorm2d"][1](x), [(2, 64, 16, 16)], [(64, 64, 128, 128)], True, 1000),
  35. ("concat", MF.concat, torch.cat, [(20, 100), (50, 100), (30, 100)], [(64, 512, 16, 16), (64, 512, 16, 16), (64, 512, 16, 16)], False, 1000),
  36. ("conv2d", lambda x: module_cache["conv2d"][0](x), lambda x: module_cache["conv2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000),
  37. ("conv3d", lambda x: module_cache["conv3d"][0](x), lambda x: module_cache["conv3d"][1](x), [(2, 32, 8, 8, 8)], [(32, 32, 16, 16, 16)], True, 1000),
  38. ("convTranspose2d", lambda x: module_cache["ConvTranspose2d"][0](x), lambda x: module_cache["ConvTranspose2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000),
  39. ("dropout", lambda x: MF.dropout(x, 0.5), TF.dropout, [(100,100)], [(64, 512, 16, 16)], True, 1000),
  40. ("dw_conv2d", lambda x: module_cache["dw_conv2d"][0](x), lambda x: module_cache["dw_conv2d"][1](x), [(2, 32, 16, 16)], [(32, 32, 128, 128)], True, 1000),
  41. ("elemwise.unary", MF.log, torch.log, [(100,100)], [(64, 512, 16, 16)], True, 1000),
  42. ("elemwise.binary", MF.add, torch.add, [(100,100), (100,100)], [(64, 512, 16, 16), (64, 512, 16, 16)], True, 1000),
  43. ("expand_dims", lambda x: MF.expand_dims(x, 0), lambda x: torch.unsqueeze(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000),
  44. ("gelu", MF.gelu, TF.gelu, [(100,100)], [(64, 512, 16, 16)], True, 1000),
  45. ("hswish", MF.hswish, TF.hardswish, [(100,100)], [(64, 512, 16, 16)], True, 1000),
  46. ("hsigmoid", MF.hsigmoid, TF.hardsigmoid, [(100,100)], [(64, 512, 16, 16)], True, 1000),
  47. ("isinf", MF.isinf, torch.isinf, [(100,100)], [(64, 512, 16, 16)], True, 1000),
  48. ("indeixngMultiAxisVec", lambda x: x[[1,3,5], [1,3,5], [1,3,5], [1,3,5]], lambda x: x[[1,3,5], [1,3,5], [1,3,5], [1,3,5]], [(10,10,10,10)], [(64, 512, 16, 16)], True, 1000),
  49. ("logsigmoid", MF.logsigmoid, TF.logsigmoid, [(100,100)], [(64, 512, 16, 16)], True, 1000),
  50. ("leaky_relu", lambda x: MF.leaky_relu(x, 0.5), lambda x: TF.leaky_relu(x, 0.5), [(100,100)], [(64, 512, 16, 16)], True, 1000),
  51. ("linear", lambda x: module_cache["Linear"][0](x), lambda x: module_cache["Linear"][1](x), [(10, 1000)], [(64, 128, 1000)], True, 1000),
  52. ("matinv", MF.matinv, torch.inverse, [(10,10)], [(30, 30)], True, 1000),
  53. ("matmul", MF.matmul, torch.matmul, [(64,32), (32, 64)], [(2048, 1024), (1024, 2048)], True, 1000),
  54. ("max_pool2d", lambda x: MF.max_pool2d(x, 2), lambda x: TF.max_pool2d(x, 2), [(2, 32, 16, 16)], [(64, 512, 16, 16)], True, 1000),
  55. ("normal", lambda x: mge.random.normal(0,1, x.shape), lambda x: torch.randn(x.shape, device="cuda"), [(100,100)], [(64, 512, 16, 16)], True, 1000),
  56. ("prelu", MF.prelu, TF.prelu, [(100,100), (1,)], [(64, 512, 16, 16), (1,)], True, 1000),
  57. ("reduce.max", lambda x: MF.max(x, 0), lambda x: torch.max(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000),
  58. ("reduce.mean", lambda x: MF.mean(x, 0), lambda x: torch.mean(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000),
  59. ("reduce.mean", lambda x: MF.mean(x, 0), lambda x: torch.mean(x, 0), [(100,100)], [(64, 512, 16, 16)], True, 1000),
  60. ("relu", MF.relu, TF.relu, [(100,100)], [(64, 512, 16, 16)], True, 1000),
  61. ("relu6", MF.relu6, TF.relu6, [(100,100)], [(64, 512, 16, 16)], True, 1000),
  62. ("repeat", lambda x: MF.repeat(x, 5), lambda x: torch.repeat_interleave(x, 5), [(100,100)], [(64, 512, 16, 16)], True, 1000),
  63. ("silu", MF.silu, TF.silu, [(100,100)], [(64, 512, 16, 16)], True, 1000),
  64. ("split", lambda x: MF.split(x, 5), lambda x: torch.split(x, 5), [(100,100)], [(64, 512, 16, 16)], True, 1000),
  65. ("sigmoid", MF.sigmoid, TF.sigmoid, [(100,100)], [(64, 512, 16, 16)], True, 1000),
  66. ("softmax", lambda x: MF.softmax(x, axis=1), lambda x: TF.softmax(x, dim=1), [(100,100)], [(64, 512, 16, 16)], True, 1000),
  67. ("softplus", MF.softplus, TF.softplus, [(100,100)], [(64, 512, 16, 16)], True, 1000),
  68. ("squeeze", lambda x: MF.squeeze(x, 0), lambda x: torch.squeeze(x, 0), [(1, 100,100)], [(1, 64, 512, 16, 16)], True, 1000),
  69. ("stack", MF.stack, torch.stack, [(100,100), (100,100)], [(64, 512, 16, 16), (64, 512, 16, 16)], False, 10000),
  70. ("subtensor", lambda x: x[0:20, 10:60], lambda x: x[0:20, 10:60], [(100,100)], [(64, 512, 16, 16)], True, 1000),
  71. ("topk", lambda x: MF.topk(x, 10), lambda x: torch.topk(x, 10), [(100,100)], [(1000, 1000)], True, 1000),
  72. ("tile", lambda x: MF.tile(x, (2,)*len(x.shape)), lambda x: torch.tile(x, (2,)*len(x.shape)), [(100,100)], [(64, 512, 16, 16)], True, 1000),
  73. ("transpose", lambda x: MF.transpose(x, list(range(len(x.shape)))[::-1]), lambda x: torch.permute(x, list(range(len(x.shape)))[::-1]), [(100,100)], [(64, 512, 16, 16)], True, 1000),
  74. ("where", lambda x: MF.where(x > 0.5, x, x), lambda x: torch.where(x > 0.5, x, x), [(100,100)], [(64, 512, 16, 16)], True, 1000),
  75. ("uniform", lambda x: mge.random.uniform(0,1, x.shape), lambda x: torch.rand(x.shape, device="cuda"), [(100,100)], [(64, 512, 16, 16)], True, 1000),
  76. ]
  77. def perf_func(func, inps, reps, unpack_inps, is_mge):
  78. if is_mge:
  79. mge._full_sync()
  80. tik = time.time()
  81. for _ in range(reps):
  82. if unpack_inps:
  83. out = func(*inps)
  84. else:
  85. out = func(inps)
  86. mge._full_sync()
  87. else:
  88. torch.cuda.synchronize()
  89. with torch.no_grad():
  90. tik = time.time()
  91. for _ in range(reps):
  92. if unpack_inps:
  93. out = func(*inps)
  94. else:
  95. out = func(inps)
  96. torch.cuda.synchronize()
  97. return time.time() - tik
  98. def get_avg_time(func, inps, reps, unpack_inps, is_mge):
  99. # warm up
  100. for _ in range(2):
  101. t = perf_func(func, inps, reps, unpack_inps, is_mge)
  102. times = []
  103. for _ in range(5):
  104. t = perf_func(func, inps, reps, unpack_inps, is_mge)
  105. times.append(t)
  106. return np.mean(times)
  107. def get_perf_results(mge_func, torch_func, shapes, unpack_inps, reps):
  108. inps = [
  109. np.random.randn(*shape) for shape in shapes
  110. ]
  111. inps_mge = [mge.tensor(inp, dtype="float32") for inp in inps]
  112. avg_time_mge = get_avg_time(mge_func, inps_mge, reps, unpack_inps, True)
  113. inps_torch = [torch.Tensor(inp).type(torch.float).cuda() for inp in inps]
  114. avg_time_torch = get_avg_time(torch_func, inps_torch, reps, unpack_inps, False)
  115. return avg_time_mge, avg_time_torch
  116. if __name__ == "__main__":
  117. header = ["opr_name", "time(mge/pytorch; small input)", "time(mge/pytorch; large input)"]
  118. table = []
  119. for case in test_cases:
  120. assert len(case) == 7
  121. name, mge_func, torch_func, small_shapes, large_shapes, unpack_inps, reps = case
  122. data = []
  123. data.append(name)
  124. print("========== op: {}".format(name))
  125. avg_time_mge, avg_time_torch = get_perf_results(mge_func, torch_func, small_shapes, unpack_inps, reps)
  126. print("mge time: {}".format(avg_time_mge))
  127. print("torch time: {}".format(avg_time_torch))
  128. data.append("{:.2f}".format(avg_time_mge / avg_time_torch))
  129. avg_time_mge, avg_time_torch = get_perf_results(mge_func, torch_func, large_shapes, unpack_inps, reps)
  130. print("mge time: {}".format(avg_time_mge))
  131. print("torch time: {}".format(avg_time_torch))
  132. data.append("{:.2f}".format(avg_time_mge / avg_time_torch))
  133. table.append(data)
  134. print(tabulate(table, header, tablefmt="github"))