|
|
@@ -34,6 +34,7 @@ class _ConvBnActivation2d(Conv2d): |
|
|
|
groups: int = 1, |
|
|
|
conv_mode: str = "CROSS_CORRELATION", |
|
|
|
compute_mode: str = "DEFAULT", |
|
|
|
dtype=None, |
|
|
|
): |
|
|
|
super().__init__( |
|
|
|
in_channels, |
|
|
@@ -47,11 +48,7 @@ class _ConvBnActivation2d(Conv2d): |
|
|
|
conv_mode, |
|
|
|
compute_mode, |
|
|
|
) |
|
|
|
self.scale = 1.0 |
|
|
|
self.zero_point = 0.0 |
|
|
|
self.output_dtype = mgb.dtype.qint8(self.scale) |
|
|
|
self.weight = self.weight.astype(self.output_dtype) |
|
|
|
self.bias = self.bias.astype(mgb.dtype.qint32(self.scale)) |
|
|
|
self.output_dtype = dtype |
|
|
|
|
|
|
|
def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"): |
|
|
|
inp_scale = mgb.dtype.get_scale(inp.dtype) |
|
|
@@ -87,6 +84,7 @@ class ConvBnRelu2d(_ConvBnActivation2d): |
|
|
|
|
|
|
|
|
|
|
|
def to_quantized(quantized_class, float_module): |
|
|
|
output_dtype = float_module.act_observer.get_dtype() |
|
|
|
qconv = quantized_class( |
|
|
|
float_module.conv.in_channels, |
|
|
|
float_module.conv.out_channels, |
|
|
@@ -95,15 +93,14 @@ def to_quantized(quantized_class, float_module): |
|
|
|
float_module.conv.padding, |
|
|
|
float_module.conv.dilation, |
|
|
|
float_module.conv.groups, |
|
|
|
dtype=output_dtype, |
|
|
|
) |
|
|
|
w_fold, b_fold = float_module.fold_weight_bias( |
|
|
|
float_module.bn.running_mean, float_module.bn.running_var |
|
|
|
) |
|
|
|
weight = w_fold.astype(float_module.weight_observer.get_dtype()) |
|
|
|
qconv.output_dtype = float_module.act_observer.get_dtype() |
|
|
|
qconv.weight = Parameter(weight.numpy()) |
|
|
|
qconv.bias = Parameter(b_fold.numpy()) |
|
|
|
qconv.scale, qconv.zero_point = float_module.act_observer.get_qparams() |
|
|
|
|
|
|
|
return qconv |
|
|
|
|
|
|
|