|
|
@@ -243,7 +243,8 @@ auto apply_on_var_node( |
|
|
|
const VarNodeArray& inputs) { |
|
|
|
auto&& matmul = static_cast<const MatrixMul&>(def); |
|
|
|
mgb_assert(inputs.size() == 2); |
|
|
|
return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param()); |
|
|
|
return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(), |
|
|
|
matmul.policy()); |
|
|
|
} |
|
|
|
OP_TRAIT_REG(MatrixMul, MatrixMul) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
@@ -256,7 +257,8 @@ auto apply_on_var_node( |
|
|
|
const VarNodeArray& inputs) { |
|
|
|
auto&& matmul = static_cast<const BatchedMatrixMul&>(def); |
|
|
|
mgb_assert(inputs.size() == 2); |
|
|
|
return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param()); |
|
|
|
return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(), |
|
|
|
matmul.policy()); |
|
|
|
} |
|
|
|
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
@@ -428,7 +430,7 @@ auto apply_on_var_node( |
|
|
|
return opr::AssertEqual::make(inputs[0],inputs[1],op.param()); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
OP_TRAIT_REG(AssertEqual, AssertEqual) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
.fallback(); |
|
|
|