|
- import types
- from functools import partial
-
- from .. import functional as F
- from .. import module as M
- from ..utils.module_utils import set_module_mode_safe
-
-
- def get_norm_mod_value(weight, norm_value):
- weight = weight.reshape(-1)
- norm = F.norm(weight)
- scale = norm_value / norm
- round_log = F.floor(F.log(scale) / F.log(2))
- rounded_scale = 2 ** round_log
- return rounded_scale.detach()
-
-
- def get_scaled_model(model, scale_submodel, input_shape=None):
- submodule_list = None
- scale_value = None
- accumulated_scale = 1.0
-
- def scale_calc(mod_calc_func):
- def calcfun(self, inp, weight, bias):
- scaled_weight = weight
- scaled_bias = bias
- if self.training:
- scaled_weight = (
- weight * self.weight_scale if weight is not None else None
- )
- scaled_bias = bias * self.bias_scale if bias is not None else None
- return mod_calc_func(inp, scaled_weight, scaled_bias)
-
- return calcfun
-
- def scale_module_structure(
- scale_list: list = None, scale_value: tuple = None,
- ):
- nonlocal accumulated_scale
- for i in range(len(scale_list)):
- key, mod = scale_list[i]
- w_scale_value = scale_value[1]
- if scale_value[0] is not "CONST":
- w_scale_value = get_norm_mod_value(mod.weight, scale_value[1])
-
- accumulated_scale *= w_scale_value
-
- mod.weight_scale = w_scale_value
- mod.bias_scale = accumulated_scale
-
- if isinstance(mod, M.conv.Conv2d):
- mod.calc_conv = types.MethodType(scale_calc(mod.calc_conv), mod)
- else:
- mod._calc_linear = types.MethodType(scale_calc(mod._calc_linear), mod)
-
- def forward_hook(submodel, inputs, outpus, modelname=""):
- nonlocal submodule_list
- nonlocal scale_value
- nonlocal accumulated_scale
- if modelname in scale_submodel:
- scale_value = scale_submodel[modelname]
- if isinstance(submodel, (M.conv.Conv2d, M.linear.Linear)):
- scale_module_structure([(modelname, submodel)], scale_value)
- else:
- submodule_list = []
-
- if isinstance(submodel, (M.conv.Conv2d, M.linear.Linear)) and (
- submodule_list is not None
- ):
- submodule_list.append((modelname, submodel))
-
- if isinstance(submodel, M.batchnorm.BatchNorm2d) and (
- submodule_list is not None
- ):
- scale_module_structure(submodule_list, scale_value)
- submodule_list = None
- scale_value = None
- accumulated_scale = 1.0
-
- if input_shape is None:
- raise ValueError("input_shape is required for calculating scale value")
-
- input = F.zeros(input_shape)
-
- hooks = []
- for modelname, submodel in model.named_modules():
- hooks.append(
- submodel.register_forward_pre_hook(
- partial(forward_hook, modelname=modelname, outpus=None)
- )
- )
-
- with set_module_mode_safe(model, training=False) as model:
- model(input)
-
- for hook in hooks:
- hook.remove()
-
- return model
|