|
|
@@ -44,7 +44,7 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor |
|
|
|
ret = mgb.opr.matrix_mul(inp, weight, transposeB=True) |
|
|
|
ret = ret.reshape(orig_shape[:-1], weight.shape[0]) |
|
|
|
if bias is not None: |
|
|
|
ret += bias |
|
|
|
ret += bias.reshape(1, bias.shape[0]) |
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|