Browse Source

feat(mgb): adapt to imperative runtime

GitOrigin-RevId: 3bccc17b62
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
e9a7be46b1
3 changed files with 9 additions and 5 deletions
  1. +2
    -0
      imperative/python/megengine/functional/nn.py
  2. +5
    -3
      imperative/src/impl/ops/specializations.cpp
  3. +2
    -2
      src/core/include/megbrain/ir/ops.td

+ 2
- 0
imperative/python/megengine/functional/nn.py View File

@@ -1106,6 +1106,7 @@ def matmul(
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=get_conv_execution_strategy(),
)
else:
op = builtin.MatrixMul(
@@ -1113,6 +1114,7 @@ def matmul(
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=get_conv_execution_strategy(),
)

(result,) = apply(op, inp1, inp2)


+ 5
- 3
imperative/src/impl/ops/specializations.cpp View File

@@ -243,7 +243,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& matmul = static_cast<const MatrixMul&>(def);
mgb_assert(inputs.size() == 2);
return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param());
return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(),
matmul.policy());
}
OP_TRAIT_REG(MatrixMul, MatrixMul)
.apply_on_var_node(apply_on_var_node)
@@ -256,7 +257,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
mgb_assert(inputs.size() == 2);
return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param());
return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(),
matmul.policy());
}
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
.apply_on_var_node(apply_on_var_node)
@@ -428,7 +430,7 @@ auto apply_on_var_node(
return opr::AssertEqual::make(inputs[0],inputs[1],op.param());

}
OP_TRAIT_REG(AssertEqual, AssertEqual)
.apply_on_var_node(apply_on_var_node)
.fallback();


+ 2
- 2
src/core/include/megbrain/ir/ops.td View File

@@ -34,9 +34,9 @@ def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> {
let results = (outs AnyType);
}

def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam]>;
def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;

def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam]>;
def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;

def Dot: MgbHashableOp<"Dot", [EmptyParam]>;



Loading…
Cancel
Save