You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

builder_util.py 9.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import re
  2. import math
  3. import torch.nn as nn
  4. from copy import deepcopy
  5. from timm.utils import *
  6. from timm.models.layers.activations import Swish
  7. from timm.models.layers import CondConv2d, get_condconv_initializer
  8. def parse_ksize(ss):
  9. if ss.isdigit():
  10. return int(ss)
  11. else:
  12. return [int(k) for k in ss.split('.')]
  13. def decode_arch_def(
  14. arch_def,
  15. depth_multiplier=1.0,
  16. depth_trunc='ceil',
  17. experts_multiplier=1):
  18. arch_args = []
  19. for stack_idx, block_strings in enumerate(arch_def):
  20. assert isinstance(block_strings, list)
  21. stack_args = []
  22. repeats = []
  23. for block_str in block_strings:
  24. assert isinstance(block_str, str)
  25. ba, rep = decode_block_str(block_str)
  26. if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
  27. ba['num_experts'] *= experts_multiplier
  28. stack_args.append(ba)
  29. repeats.append(rep)
  30. arch_args.append(
  31. scale_stage_depth(
  32. stack_args,
  33. repeats,
  34. depth_multiplier,
  35. depth_trunc))
  36. return arch_args
  37. def modify_block_args(block_args, kernel_size, exp_ratio):
  38. block_type = block_args['block_type']
  39. if block_type == 'cn':
  40. block_args['kernel_size'] = kernel_size
  41. elif block_type == 'er':
  42. block_args['exp_kernel_size'] = kernel_size
  43. else:
  44. block_args['dw_kernel_size'] = kernel_size
  45. if block_type == 'ir' or block_type == 'er':
  46. block_args['exp_ratio'] = exp_ratio
  47. return block_args
  48. def decode_block_str(block_str):
  49. """ Decode block definition string
  50. Gets a list of block arg (dicts) through a string notation of arguments.
  51. E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
  52. All args can exist in any order with the exception of the leading string which
  53. is assumed to indicate the block type.
  54. leading string - block type (
  55. ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
  56. r - number of repeat blocks,
  57. k - kernel size,
  58. s - strides (1-9),
  59. e - expansion ratio,
  60. c - output channels,
  61. se - squeeze/excitation ratio
  62. n - activation fn ('re', 'r6', 'hs', or 'sw')
  63. Args:
  64. block_str: a string representation of block arguments.
  65. Returns:
  66. A list of block args (dicts)
  67. Raises:
  68. ValueError: if the string def not properly specified (TODO)
  69. """
  70. assert isinstance(block_str, str)
  71. ops = block_str.split('_')
  72. block_type = ops[0] # take the block type off the front
  73. ops = ops[1:]
  74. options = {}
  75. noskip = False
  76. for op in ops:
  77. # string options being checked on individual basis, combine if they
  78. # grow
  79. if op == 'noskip':
  80. noskip = True
  81. elif op.startswith('n'):
  82. # activation fn
  83. key = op[0]
  84. v = op[1:]
  85. if v == 're':
  86. value = nn.ReLU
  87. elif v == 'r6':
  88. value = nn.ReLU6
  89. elif v == 'sw':
  90. value = Swish
  91. else:
  92. continue
  93. options[key] = value
  94. else:
  95. # all numeric options
  96. splits = re.split(r'(\d.*)', op)
  97. if len(splits) >= 2:
  98. key, value = splits[:2]
  99. options[key] = value
  100. # if act_layer is None, the model default (passed to model init) will be
  101. # used
  102. act_layer = options['n'] if 'n' in options else None
  103. exp_kernel_size = parse_ksize(options['a']) if 'a' in options else 1
  104. pw_kernel_size = parse_ksize(options['p']) if 'p' in options else 1
  105. # FIXME hack to deal with in_chs issue in TPU def
  106. fake_in_chs = int(options['fc']) if 'fc' in options else 0
  107. num_repeat = int(options['r'])
  108. # each type of block has different valid arguments, fill accordingly
  109. if block_type == 'ir':
  110. block_args = dict(
  111. block_type=block_type,
  112. dw_kernel_size=parse_ksize(options['k']),
  113. exp_kernel_size=exp_kernel_size,
  114. pw_kernel_size=pw_kernel_size,
  115. out_chs=int(options['c']),
  116. exp_ratio=float(options['e']),
  117. se_ratio=float(options['se']) if 'se' in options else None,
  118. stride=int(options['s']),
  119. act_layer=act_layer,
  120. noskip=noskip,
  121. )
  122. if 'cc' in options:
  123. block_args['num_experts'] = int(options['cc'])
  124. elif block_type == 'ds' or block_type == 'dsa':
  125. block_args = dict(
  126. block_type=block_type,
  127. dw_kernel_size=parse_ksize(options['k']),
  128. pw_kernel_size=pw_kernel_size,
  129. out_chs=int(options['c']),
  130. se_ratio=float(options['se']) if 'se' in options else None,
  131. stride=int(options['s']),
  132. act_layer=act_layer,
  133. pw_act=block_type == 'dsa',
  134. noskip=block_type == 'dsa' or noskip,
  135. )
  136. elif block_type == 'cn':
  137. block_args = dict(
  138. block_type=block_type,
  139. kernel_size=int(options['k']),
  140. out_chs=int(options['c']),
  141. stride=int(options['s']),
  142. act_layer=act_layer,
  143. )
  144. else:
  145. assert False, 'Unknown block type (%s)' % block_type
  146. return block_args, num_repeat
  147. def scale_stage_depth(
  148. stack_args,
  149. repeats,
  150. depth_multiplier=1.0,
  151. depth_trunc='ceil'):
  152. """ Per-stage depth scaling
  153. Scales the block repeats in each stage. This depth scaling impl maintains
  154. compatibility with the EfficientNet scaling method, while allowing sensible
  155. scaling for other models that may have multiple block arg definitions in each stage.
  156. """
  157. # We scale the total repeat count for each stage, there may be multiple
  158. # block arg defs per stage so we need to sum.
  159. num_repeat = sum(repeats)
  160. if depth_trunc == 'round':
  161. # Truncating to int by rounding allows stages with few repeats to remain
  162. # proportionally smaller for longer. This is a good choice when stage definitions
  163. # include single repeat stages that we'd prefer to keep that way as
  164. # long as possible
  165. num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
  166. else:
  167. # The default for EfficientNet truncates repeats to int via 'ceil'.
  168. # Any multiplier > 1.0 will result in an increased depth for every
  169. # stage.
  170. num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
  171. # Proportionally distribute repeat count scaling to each block definition in the stage.
  172. # Allocation is done in reverse as it results in the first block being less likely to be scaled.
  173. # The first block makes less sense to repeat in most of the arch
  174. # definitions.
  175. repeats_scaled = []
  176. for r in repeats[::-1]:
  177. rs = max(1, round((r / num_repeat * num_repeat_scaled)))
  178. repeats_scaled.append(rs)
  179. num_repeat -= r
  180. num_repeat_scaled -= rs
  181. repeats_scaled = repeats_scaled[::-1]
  182. # Apply the calculated scaling to each block arg in the stage
  183. sa_scaled = []
  184. for ba, rep in zip(stack_args, repeats_scaled):
  185. sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
  186. return sa_scaled
  187. def init_weight_goog(m, n='', fix_group_fanout=True, last_bn=None):
  188. """ Weight initialization as per Tensorflow official implementations.
  189. Args:
  190. m (nn.Module): module to init
  191. n (str): module name
  192. fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
  193. Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
  194. * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
  195. * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
  196. """
  197. if isinstance(m, CondConv2d):
  198. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  199. if fix_group_fanout:
  200. fan_out //= m.groups
  201. init_weight_fn = get_condconv_initializer(lambda w: w.data.normal_(
  202. 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
  203. init_weight_fn(m.weight)
  204. if m.bias is not None:
  205. m.bias.data.zero_()
  206. elif isinstance(m, nn.Conv2d):
  207. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  208. if fix_group_fanout:
  209. fan_out //= m.groups
  210. m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
  211. if m.bias is not None:
  212. m.bias.data.zero_()
  213. elif isinstance(m, nn.BatchNorm2d):
  214. if n in last_bn:
  215. m.weight.data.zero_()
  216. m.bias.data.zero_()
  217. else:
  218. m.weight.data.fill_(1.0)
  219. m.bias.data.zero_()
  220. m.weight.data.fill_(1.0)
  221. m.bias.data.zero_()
  222. elif isinstance(m, nn.Linear):
  223. fan_out = m.weight.size(0) # fan-out
  224. fan_in = 0
  225. if 'routing_fn' in n:
  226. fan_in = m.weight.size(1)
  227. init_range = 1.0 / math.sqrt(fan_in + fan_out)
  228. m.weight.data.uniform_(-init_range, init_range)
  229. m.bias.data.zero_()
  230. def efficientnet_init_weights(
  231. model: nn.Module,
  232. init_fn=None,
  233. zero_gamma=False):
  234. last_bn = []
  235. if zero_gamma:
  236. prev_n = ''
  237. for n, m in model.named_modules():
  238. if isinstance(m, nn.BatchNorm2d):
  239. if ''.join(prev_n.split('.')[:-1]) != ''.join(n.split('.')[:-1]):
  240. last_bn.append(prev_n)
  241. prev_n = n
  242. last_bn.append(prev_n)
  243. init_fn = init_fn or init_weight_goog
  244. for n, m in model.named_modules():
  245. init_fn(m, n, last_bn=last_bn)
  246. init_fn(m, n, last_bn=last_bn)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能