GitOrigin-RevId: 2effad8e5f
tags/v1.3.0
@@ -14,28 +14,11 @@ decl_opr('BatchedMatrixMul', | |||||
'performed and output shape is (n, a, c)') | 'performed and output shape is (n, a, c)') | ||||
decl_opr('MatrixMul', | decl_opr('MatrixMul', | ||||
pyname='matrix_mul_v2', | |||||
inputs=['opr0', 'opr1'], | |||||
params='MatrixMul', | |||||
desc='matrix multiplication', | |||||
version=2, has_out_dtype=True) | |||||
decl_opr('BatchedMatrixMul', | |||||
pyname='batched_matrix_mul_v2', | |||||
inputs=['opr0', 'opr1'], | |||||
params='MatrixMul', | |||||
desc='batched matrix multiplication: input shapes should be ' | |||||
'(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are ' | |||||
'False); then :math:`n` independent matrix multiplications would be ' | |||||
'performed and output shape is (n, a, c)', | |||||
version=2, has_out_dtype=True) | |||||
decl_opr('MatrixMul', | |||||
inputs=['opr0', 'opr1'], | inputs=['opr0', 'opr1'], | ||||
params=[('param', 'MatrixMul'), | params=[('param', 'MatrixMul'), | ||||
('execution_polity', 'ExecutionPolicy')], | ('execution_polity', 'ExecutionPolicy')], | ||||
desc='matrix multiplication', | desc='matrix multiplication', | ||||
version=3, has_out_dtype=True) | |||||
version=2, has_out_dtype=True) | |||||
decl_opr('BatchedMatrixMul', | decl_opr('BatchedMatrixMul', | ||||
inputs=['opr0', 'opr1'], | inputs=['opr0', 'opr1'], | ||||
@@ -45,7 +28,7 @@ decl_opr('BatchedMatrixMul', | |||||
'(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are ' | '(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are ' | ||||
'False); then :math:`n` independent matrix multiplications would be ' | 'False); then :math:`n` independent matrix multiplications would be ' | ||||
'performed and output shape is (n, a, c)', | 'performed and output shape is (n, a, c)', | ||||
version=3, has_out_dtype=True) | |||||
version=2, has_out_dtype=True) | |||||
decl_opr('Dot', | decl_opr('Dot', | ||||
inputs=['opr0', 'opr1'], | inputs=['opr0', 'opr1'], | ||||
@@ -51,7 +51,6 @@ struct MatrixMulLoadDumpImpl { | |||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | ||||
auto&& opr = opr_.cast_final_safe<Opr>(); | auto&& opr = opr_.cast_final_safe<Opr>(); | ||||
ctx.write_param<megdnn::param::MatrixMul>(opr.param()); | ctx.write_param<megdnn::param::MatrixMul>(opr.param()); | ||||
ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy()); | |||||
} | } | ||||
static VarNode* make(const cg::VarNodeArray& inputs, | static VarNode* make(const cg::VarNodeArray& inputs, | ||||
@@ -68,9 +67,7 @@ struct MatrixMulLoadDumpImpl { | |||||
const cg::VarNodeArray& inputs, | const cg::VarNodeArray& inputs, | ||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
auto param = ctx.read_param<megdnn::param::MatrixMul>(); | auto param = ctx.read_param<megdnn::param::MatrixMul>(); | ||||
auto execution_policy = | |||||
ctx.read_param<megdnn::param::ExecutionPolicy>(); | |||||
return make(inputs, param, execution_policy, config)->owner_opr(); | |||||
return make(inputs, param, {}, config)->owner_opr(); | |||||
} | } | ||||
}; | }; | ||||
@@ -90,10 +87,10 @@ struct OprLoadDumpImpl<opr::BatchedMatrixMul, 2> | |||||
namespace opr { | namespace opr { | ||||
using MatrixMulV3 = MatrixMul; | |||||
using BatchedMatrixMulV3 = BatchedMatrixMul; | |||||
MGB_SEREG_OPR(MatrixMulV3, 2); | |||||
MGB_SEREG_OPR(BatchedMatrixMulV3, 2); | |||||
using MatrixMulV2 = MatrixMul; | |||||
using BatchedMatrixMulV2 = BatchedMatrixMul; | |||||
MGB_SEREG_OPR(MatrixMulV2, 2); | |||||
MGB_SEREG_OPR(BatchedMatrixMulV2, 2); | |||||
MGB_SEREG_OPR(Dot, 2); | MGB_SEREG_OPR(Dot, 2); | ||||
MGB_SEREG_OPR(MatrixInverse, 1); | MGB_SEREG_OPR(MatrixInverse, 1); | ||||
MGB_SEREG_OPR(SVD, 1); | MGB_SEREG_OPR(SVD, 1); | ||||