|
|
@@ -98,6 +98,7 @@ size_t MatrixMul::get_workspace_size_bytes( |
|
|
|
param ^= 1; |
|
|
|
}; |
|
|
|
MGB_TRY { |
|
|
|
megdnn_opr()->execution_policy() = {}; |
|
|
|
a = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out}, |
|
|
|
megdnn_opr(), this); |
|
|
|
//! Here we just want to save the execution policy got from setup_algo, |
|
|
@@ -106,24 +107,28 @@ size_t MatrixMul::get_workspace_size_bytes( |
|
|
|
const_cast<MatrixMul*>(this) |
|
|
|
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = |
|
|
|
megdnn_opr()->execution_policy(); |
|
|
|
megdnn_opr()->execution_policy() = {}; |
|
|
|
transpose(i0, tparam.transposeA); |
|
|
|
b = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out}, |
|
|
|
megdnn_opr(), this); |
|
|
|
const_cast<MatrixMul*>(this) |
|
|
|
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = |
|
|
|
megdnn_opr()->execution_policy(); |
|
|
|
megdnn_opr()->execution_policy() = {}; |
|
|
|
transpose(i1, tparam.transposeB); |
|
|
|
c = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out}, |
|
|
|
megdnn_opr(), this); |
|
|
|
const_cast<MatrixMul*>(this) |
|
|
|
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = |
|
|
|
megdnn_opr()->execution_policy(); |
|
|
|
megdnn_opr()->execution_policy() = {}; |
|
|
|
transpose(i0, tparam.transposeA); |
|
|
|
d = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out}, |
|
|
|
megdnn_opr(), this); |
|
|
|
const_cast<MatrixMul*>(this) |
|
|
|
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = |
|
|
|
megdnn_opr()->execution_policy(); |
|
|
|
megdnn_opr()->execution_policy() = {}; |
|
|
|
} |
|
|
|
MGB_FINALLY({ tparam = this->param(); }); |
|
|
|
return std::max(std::max(a, b), std::max(c, d)); |
|
|
@@ -252,29 +257,34 @@ size_t BatchedMatrixMul::get_workspace_size_bytes( |
|
|
|
param ^= 1; |
|
|
|
}; |
|
|
|
MGB_TRY { |
|
|
|
megdnn_opr()->execution_policy() = {}; |
|
|
|
a = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo( |
|
|
|
{i0, i1, out}, megdnn_opr(), this); |
|
|
|
const_cast<BatchedMatrixMul*>(this) |
|
|
|
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = |
|
|
|
megdnn_opr()->execution_policy(); |
|
|
|
megdnn_opr()->execution_policy() = {}; |
|
|
|
transpose(i0, tparam.transposeA); |
|
|
|
b = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo( |
|
|
|
{i0, i1, out}, megdnn_opr(), this); |
|
|
|
const_cast<BatchedMatrixMul*>(this) |
|
|
|
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = |
|
|
|
megdnn_opr()->execution_policy(); |
|
|
|
megdnn_opr()->execution_policy() = {}; |
|
|
|
transpose(i1, tparam.transposeB); |
|
|
|
c = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo( |
|
|
|
{i0, i1, out}, megdnn_opr(), this); |
|
|
|
const_cast<BatchedMatrixMul*>(this) |
|
|
|
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = |
|
|
|
megdnn_opr()->execution_policy(); |
|
|
|
megdnn_opr()->execution_policy() = {}; |
|
|
|
transpose(i0, tparam.transposeA); |
|
|
|
d = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo( |
|
|
|
{i0, i1, out}, megdnn_opr(), this); |
|
|
|
const_cast<BatchedMatrixMul*>(this) |
|
|
|
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = |
|
|
|
megdnn_opr()->execution_policy(); |
|
|
|
megdnn_opr()->execution_policy() = {}; |
|
|
|
} |
|
|
|
MGB_FINALLY({ tparam = this->param(); }); |
|
|
|
return std::max(std::max(a, b), std::max(c, d)); |
|
|
|