|
|
@@ -32,11 +32,13 @@ class Linear(QuantizedModule): |
|
|
|
inp_scale = mgb.dtype.get_scale(inp.dtype) |
|
|
|
w_scale = mgb.dtype.get_scale(self.weight.dtype) |
|
|
|
bias_dtype = mgb.dtype.qint32(inp_scale * w_scale) |
|
|
|
return F.linear( |
|
|
|
ret = F.linear( |
|
|
|
inp, |
|
|
|
self.weight, |
|
|
|
None if self.bias is None else self.bias.astype(bias_dtype), |
|
|
|
).astype(self.output_dtype) |
|
|
|
) |
|
|
|
ret = ret if self.output_dtype is None else ret.astype(self.output_dtype) |
|
|
|
return ret |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def from_qat_module(cls, qat_module: QAT.Linear): |
|
|
|