|
|
@@ -14,28 +14,11 @@ decl_opr('BatchedMatrixMul', |
|
|
|
'performed and output shape is (n, a, c)') |
|
|
|
|
|
|
|
decl_opr('MatrixMul', |
|
|
|
pyname='matrix_mul_v2', |
|
|
|
inputs=['opr0', 'opr1'], |
|
|
|
params='MatrixMul', |
|
|
|
desc='matrix multiplication', |
|
|
|
version=2, has_out_dtype=True) |
|
|
|
|
|
|
|
decl_opr('BatchedMatrixMul', |
|
|
|
pyname='batched_matrix_mul_v2', |
|
|
|
inputs=['opr0', 'opr1'], |
|
|
|
params='MatrixMul', |
|
|
|
desc='batched matrix multiplication: input shapes should be ' |
|
|
|
'(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are ' |
|
|
|
'False); then :math:`n` independent matrix multiplications would be ' |
|
|
|
'performed and output shape is (n, a, c)', |
|
|
|
version=2, has_out_dtype=True) |
|
|
|
|
|
|
|
decl_opr('MatrixMul', |
|
|
|
inputs=['opr0', 'opr1'], |
|
|
|
params=[('param', 'MatrixMul'), |
|
|
|
('execution_polity', 'ExecutionPolicy')], |
|
|
|
desc='matrix multiplication', |
|
|
|
version=3, has_out_dtype=True) |
|
|
|
version=2, has_out_dtype=True) |
|
|
|
|
|
|
|
decl_opr('BatchedMatrixMul', |
|
|
|
inputs=['opr0', 'opr1'], |
|
|
@@ -45,7 +28,7 @@ decl_opr('BatchedMatrixMul', |
|
|
|
'(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are ' |
|
|
|
'False); then :math:`n` independent matrix multiplications would be ' |
|
|
|
'performed and output shape is (n, a, c)', |
|
|
|
version=3, has_out_dtype=True) |
|
|
|
version=2, has_out_dtype=True) |
|
|
|
|
|
|
|
decl_opr('Dot', |
|
|
|
inputs=['opr0', 'opr1'], |
|
|
|