|
|
@@ -40,7 +40,9 @@ class _ConvBnActivation2d(Conv2d): |
|
|
|
) |
|
|
|
weight = w_fold.astype(qat_module.get_weight_dtype()) |
|
|
|
qconv.weight = Parameter(weight.numpy(), name=qat_module.conv.weight.name) |
|
|
|
qconv.bias = Parameter(b_fold.numpy(), name=qat_module.conv.bias.name) |
|
|
|
qconv.bias = Parameter(b_fold.numpy()) |
|
|
|
if qat_module.conv.bias is not None: |
|
|
|
qconv.bias.name = qat_module.conv.bias.name |
|
|
|
return qconv |
|
|
|
|
|
|
|
|
|
|
|