You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

flops_table.py 2.8 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT License.
  3. # Written by Hao Du and Houwen Peng
  4. # email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
  5. import torch
  6. from ptflops import get_model_complexity_info
  7. class FlopsEst(object):
  8. def __init__(self, model, input_shape=(2, 3, 224, 224), device='cpu'):
  9. self.block_num = len(model.blocks)
  10. self.choice_num = len(model.blocks[0])
  11. self.flops_dict = {}
  12. self.params_dict = {}
  13. if device == 'cpu':
  14. model = model.cpu()
  15. else:
  16. model = model.cuda()
  17. self.params_fixed = 0
  18. self.flops_fixed = 0
  19. input = torch.randn(input_shape)
  20. flops, params = get_model_complexity_info(
  21. model.conv_stem, (3, 224, 224), as_strings=False, print_per_layer_stat=False)
  22. self.params_fixed += params / 1e6
  23. self.flops_fixed += flops / 1e6
  24. input = model.conv_stem(input)
  25. for block_id, block in enumerate(model.blocks):
  26. self.flops_dict[block_id] = {}
  27. self.params_dict[block_id] = {}
  28. for module_id, module in enumerate(block):
  29. flops, params = get_model_complexity_info(module, tuple(
  30. input.shape[1:]), as_strings=False, print_per_layer_stat=False)
  31. # Flops(M)
  32. self.flops_dict[block_id][module_id] = flops / 1e6
  33. # Params(M)
  34. self.params_dict[block_id][module_id] = params / 1e6
  35. input = module(input)
  36. # conv_last
  37. flops, params = get_model_complexity_info(model.global_pool, tuple(
  38. input.shape[1:]), as_strings=False, print_per_layer_stat=False)
  39. self.params_fixed += params / 1e6
  40. self.flops_fixed += flops / 1e6
  41. input = model.global_pool(input)
  42. # globalpool
  43. flops, params = get_model_complexity_info(model.conv_head, tuple(
  44. input.shape[1:]), as_strings=False, print_per_layer_stat=False)
  45. self.params_fixed += params / 1e6
  46. self.flops_fixed += flops / 1e6
  47. # return params (M)
  48. def get_params(self, arch):
  49. params = 0
  50. for block_id, block in enumerate(arch):
  51. if block == -1:
  52. continue
  53. params += self.params_dict[block_id][block]
  54. return params + self.params_fixed
  55. # return flops (M)
  56. def get_flops(self, arch):
  57. flops = 0
  58. for block_id, block in enumerate(arch):
  59. if block == 'LayerChoice1' or block_id == 'LayerChoice23':
  60. continue
  61. for idx, choice in enumerate(arch[block]):
  62. flops += self.flops_dict[block_id][idx] * (1 if choice else 0)
  63. return flops + self.flops_fixed

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能