Browse Source

fix(mgb): fix execution_policy set of matmul

GitOrigin-RevId: 90f539b0be
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
169fa53d54
2 changed files with 10 additions and 1 deletions
  1. +10
    -0
      src/opr/impl/blas.cpp
  2. +0
    -1
      src/opr/impl/search_policy/algo_chooser.cpp

+ 10
- 0
src/opr/impl/blas.cpp View File

@@ -98,6 +98,7 @@ size_t MatrixMul::get_workspace_size_bytes(
param ^= 1;
};
MGB_TRY {
megdnn_opr()->execution_policy() = {};
a = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
//! Here we just want to save the execution policy got from setup_algo,
@@ -106,24 +107,28 @@ size_t MatrixMul::get_workspace_size_bytes(
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i0, tparam.transposeA);
b = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i1, tparam.transposeB);
c = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i0, tparam.transposeA);
d = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
}
MGB_FINALLY({ tparam = this->param(); });
return std::max(std::max(a, b), std::max(c, d));
@@ -252,29 +257,34 @@ size_t BatchedMatrixMul::get_workspace_size_bytes(
param ^= 1;
};
MGB_TRY {
megdnn_opr()->execution_policy() = {};
a = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i0, tparam.transposeA);
b = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i1, tparam.transposeB);
c = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i0, tparam.transposeA);
d = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
}
MGB_FINALLY({ tparam = this->param(); });
return std::max(std::max(a, b), std::max(c, d));


+ 0
- 1
src/opr/impl/search_policy/algo_chooser.cpp View File

@@ -266,7 +266,6 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const {
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
opr->owner_graph(), opr->comp_node(),
opr->execution_policy().workspace_limit);
m_megdnn_opr->execution_policy() = {};
return APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
args..., workspace_limit, reproducible),
m_layouts);


Loading…
Cancel
Save