Browse Source

fix(mgb): fix matmul model compat in flatbuffer

GitOrigin-RevId: 2effad8e5f
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
d62dabe5f8
2 changed files with 7 additions and 27 deletions
  1. +2
    -19
      src/opr/impl/blas.oprdecl
  2. +5
    -8
      src/opr/impl/blas.sereg.h

+ 2
- 19
src/opr/impl/blas.oprdecl View File

@@ -14,28 +14,11 @@ decl_opr('BatchedMatrixMul',
'performed and output shape is (n, a, c)')

decl_opr('MatrixMul',
pyname='matrix_mul_v2',
inputs=['opr0', 'opr1'],
params='MatrixMul',
desc='matrix multiplication',
version=2, has_out_dtype=True)

decl_opr('BatchedMatrixMul',
pyname='batched_matrix_mul_v2',
inputs=['opr0', 'opr1'],
params='MatrixMul',
desc='batched matrix multiplication: input shapes should be '
'(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are '
'False); then :math:`n` independent matrix multiplications would be '
'performed and output shape is (n, a, c)',
version=2, has_out_dtype=True)

decl_opr('MatrixMul',
inputs=['opr0', 'opr1'],
params=[('param', 'MatrixMul'),
('execution_polity', 'ExecutionPolicy')],
desc='matrix multiplication',
version=3, has_out_dtype=True)
version=2, has_out_dtype=True)

decl_opr('BatchedMatrixMul',
inputs=['opr0', 'opr1'],
@@ -45,7 +28,7 @@ decl_opr('BatchedMatrixMul',
'(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are '
'False); then :math:`n` independent matrix multiplications would be '
'performed and output shape is (n, a, c)',
version=3, has_out_dtype=True)
version=2, has_out_dtype=True)

decl_opr('Dot',
inputs=['opr0', 'opr1'],


+ 5
- 8
src/opr/impl/blas.sereg.h View File

@@ -51,7 +51,6 @@ struct MatrixMulLoadDumpImpl {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param<megdnn::param::MatrixMul>(opr.param());
ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy());
}

static VarNode* make(const cg::VarNodeArray& inputs,
@@ -68,9 +67,7 @@ struct MatrixMulLoadDumpImpl {
const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto param = ctx.read_param<megdnn::param::MatrixMul>();
auto execution_policy =
ctx.read_param<megdnn::param::ExecutionPolicy>();
return make(inputs, param, execution_policy, config)->owner_opr();
return make(inputs, param, {}, config)->owner_opr();
}
};

@@ -90,10 +87,10 @@ struct OprLoadDumpImpl<opr::BatchedMatrixMul, 2>

namespace opr {

using MatrixMulV3 = MatrixMul;
using BatchedMatrixMulV3 = BatchedMatrixMul;
MGB_SEREG_OPR(MatrixMulV3, 2);
MGB_SEREG_OPR(BatchedMatrixMulV3, 2);
using MatrixMulV2 = MatrixMul;
using BatchedMatrixMulV2 = BatchedMatrixMul;
MGB_SEREG_OPR(MatrixMulV2, 2);
MGB_SEREG_OPR(BatchedMatrixMulV2, 2);
MGB_SEREG_OPR(Dot, 2);
MGB_SEREG_OPR(MatrixInverse, 1);
MGB_SEREG_OPR(SVD, 1);


Loading…
Cancel
Save