You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

weight_scaler.py 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import types
  2. from functools import partial
  3. from .. import functional as F
  4. from .. import module as M
  5. from ..utils.module_utils import set_module_mode_safe
  6. def get_norm_mod_value(weight, norm_value):
  7. weight = weight.reshape(-1)
  8. norm = F.norm(weight)
  9. scale = norm_value / norm
  10. round_log = F.floor(F.log(scale) / F.log(2))
  11. rounded_scale = 2 ** round_log
  12. return rounded_scale.detach()
  13. def get_scaled_model(model, scale_submodel, input_shape=None):
  14. submodule_list = None
  15. scale_value = None
  16. accumulated_scale = 1.0
  17. def scale_calc(mod_calc_func):
  18. def calcfun(self, inp, weight, bias):
  19. scaled_weight = weight
  20. scaled_bias = bias
  21. if self.training:
  22. scaled_weight = (
  23. weight * self.weight_scale if weight is not None else None
  24. )
  25. scaled_bias = bias * self.bias_scale if bias is not None else None
  26. return mod_calc_func(inp, scaled_weight, scaled_bias)
  27. return calcfun
  28. def scale_module_structure(
  29. scale_list: list = None, scale_value: tuple = None,
  30. ):
  31. nonlocal accumulated_scale
  32. for i in range(len(scale_list)):
  33. key, mod = scale_list[i]
  34. w_scale_value = scale_value[1]
  35. if scale_value[0] is not "CONST":
  36. w_scale_value = get_norm_mod_value(mod.weight, scale_value[1])
  37. accumulated_scale *= w_scale_value
  38. mod.weight_scale = w_scale_value
  39. mod.bias_scale = accumulated_scale
  40. if isinstance(mod, M.conv.Conv2d):
  41. mod.calc_conv = types.MethodType(scale_calc(mod.calc_conv), mod)
  42. else:
  43. mod._calc_linear = types.MethodType(scale_calc(mod._calc_linear), mod)
  44. def forward_hook(submodel, inputs, outpus, modelname=""):
  45. nonlocal submodule_list
  46. nonlocal scale_value
  47. nonlocal accumulated_scale
  48. if modelname in scale_submodel:
  49. scale_value = scale_submodel[modelname]
  50. if isinstance(submodel, (M.conv.Conv2d, M.linear.Linear)):
  51. scale_module_structure([(modelname, submodel)], scale_value)
  52. else:
  53. submodule_list = []
  54. if isinstance(submodel, (M.conv.Conv2d, M.linear.Linear)) and (
  55. submodule_list is not None
  56. ):
  57. submodule_list.append((modelname, submodel))
  58. if isinstance(submodel, M.batchnorm.BatchNorm2d) and (
  59. submodule_list is not None
  60. ):
  61. scale_module_structure(submodule_list, scale_value)
  62. submodule_list = None
  63. scale_value = None
  64. accumulated_scale = 1.0
  65. if input_shape is None:
  66. raise ValueError("input_shape is required for calculating scale value")
  67. input = F.zeros(input_shape)
  68. hooks = []
  69. for modelname, submodel in model.named_modules():
  70. hooks.append(
  71. submodel.register_forward_pre_hook(
  72. partial(forward_hook, modelname=modelname, outpus=None)
  73. )
  74. )
  75. with set_module_mode_safe(model, training=False) as model:
  76. model(input)
  77. for hook in hooks:
  78. hook.remove()
  79. return model