You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

weight_scaler.py 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import types
  2. from functools import partial
  3. import megengine.functional as F
  4. import megengine.module as M
  5. from megengine.functional.tensor import zeros
  6. from megengine.utils.module_utils import set_module_mode_safe
  7. def get_norm_mod_value(weight, norm_value):
  8. weight = weight.reshape(-1)
  9. norm = F.norm(weight)
  10. scale = norm_value / norm
  11. round_log = F.floor(F.log(scale) / F.log(2))
  12. rounded_scale = 2 ** round_log
  13. return rounded_scale.detach()
  14. def get_scaled_model(model, scale_submodel, input_shape=None):
  15. submodule_list = None
  16. scale_value = None
  17. accumulated_scale = 1.0
  18. def scale_calc(mod_calc_func):
  19. def calcfun(self, inp, weight, bias):
  20. scaled_weight = weight
  21. scaled_bias = bias
  22. if self.training:
  23. scaled_weight = (
  24. weight * self.weight_scale if weight is not None else None
  25. )
  26. scaled_bias = bias * self.bias_scale if bias is not None else None
  27. return mod_calc_func(inp, scaled_weight, scaled_bias)
  28. return calcfun
  29. def scale_module_structure(
  30. scale_list: list = None, scale_value: tuple = None,
  31. ):
  32. nonlocal accumulated_scale
  33. for i in range(len(scale_list)):
  34. key, mod = scale_list[i]
  35. w_scale_value = scale_value[1]
  36. if scale_value[0] is not "CONST":
  37. w_scale_value = get_norm_mod_value(mod.weight, scale_value[1])
  38. accumulated_scale *= w_scale_value
  39. mod.weight_scale = w_scale_value
  40. mod.bias_scale = accumulated_scale
  41. if isinstance(mod, M.conv.Conv2d):
  42. mod.calc_conv = types.MethodType(scale_calc(mod.calc_conv), mod)
  43. else:
  44. mod._calc_linear = types.MethodType(scale_calc(mod._calc_linear), mod)
  45. def forward_hook(submodel, inputs, outpus, modelname=""):
  46. nonlocal submodule_list
  47. nonlocal scale_value
  48. nonlocal accumulated_scale
  49. if modelname in scale_submodel:
  50. scale_value = scale_submodel[modelname]
  51. if isinstance(submodel, (M.conv.Conv2d, M.linear.Linear)):
  52. scale_module_structure([(modelname, submodel)], scale_value)
  53. else:
  54. submodule_list = []
  55. if isinstance(submodel, (M.conv.Conv2d, M.linear.Linear)) and (
  56. submodule_list is not None
  57. ):
  58. submodule_list.append((modelname, submodel))
  59. if isinstance(submodel, M.batchnorm.BatchNorm2d) and (
  60. submodule_list is not None
  61. ):
  62. scale_module_structure(submodule_list, scale_value)
  63. submodule_list = None
  64. scale_value = None
  65. accumulated_scale = 1.0
  66. if input_shape is None:
  67. raise ValueError("input_shape is required for calculating scale value")
  68. input = zeros(input_shape)
  69. hooks = []
  70. for modelname, submodel in model.named_modules():
  71. hooks.append(
  72. submodel.register_forward_pre_hook(
  73. partial(forward_hook, modelname=modelname, outpus=None)
  74. )
  75. )
  76. with set_module_mode_safe(model, training=False) as model:
  77. model(input)
  78. for hook in hooks:
  79. hook.remove()
  80. return model