|
@@ -31,7 +31,7 @@ class Linear(QuantizedModule): |
|
|
inp_scale = dtype.get_scale(inp.dtype) |
|
|
inp_scale = dtype.get_scale(inp.dtype) |
|
|
w_scale = dtype.get_scale(self.weight.dtype) |
|
|
w_scale = dtype.get_scale(self.weight.dtype) |
|
|
bias_dtype = dtype.qint32(inp_scale * w_scale) |
|
|
bias_dtype = dtype.qint32(inp_scale * w_scale) |
|
|
return F.linear( |
|
|
|
|
|
|
|
|
return F.nn.linear( |
|
|
inp, |
|
|
inp, |
|
|
self.weight, |
|
|
self.weight, |
|
|
None if self.bias is None else self.bias.astype(bias_dtype), |
|
|
None if self.bias is None else self.bias.astype(bias_dtype), |
|
|