GitOrigin-RevId: d24f73193e
tags/v1.3.1
@@ -384,8 +384,15 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, | |||||
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) | MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) | ||||
if (ctx.owner_graph()->options().no_profiling_on_shape_change) { | if (ctx.owner_graph()->options().no_profiling_on_shape_change) { | ||||
auto policy = ctx.megdnn_opr()->execution_policy(); | auto policy = ctx.megdnn_opr()->execution_policy(); | ||||
if (policy.algo.valid()) | |||||
if (policy.algo.valid()){ | |||||
return policy; | return policy; | ||||
} | |||||
if (!algo_usable_on_shape_change<Opr>()) { | |||||
mgb_log_warn( | |||||
"choose algo by heuristic, which may cause performance " | |||||
"regression."); | |||||
return ctx.choose_by_heuristic(selected_strategy); | |||||
} | |||||
} | } | ||||
if (enable_update) { | if (enable_update) { | ||||
@@ -89,6 +89,13 @@ constexpr bool opr_contain_bias() { | |||||
return std::is_same<Opr, megdnn::ConvBias>::value; | return std::is_same<Opr, megdnn::ConvBias>::value; | ||||
} | } | ||||
//! matmul and batchedMatrixMul may not be usable once shape changed | |||||
template <typename Opr> | |||||
constexpr bool algo_usable_on_shape_change() { | |||||
return !(std::is_same<Opr, megdnn::MatrixMul>::value || | |||||
std::is_same<Opr, megdnn::BatchedMatrixMul>::value); | |||||
} | |||||
template <typename Opr, bool has_prep> | template <typename Opr, bool has_prep> | ||||
struct PreprocessFilterImpl { | struct PreprocessFilterImpl { | ||||
using T = union {}; | using T = union {}; | ||||
@@ -885,5 +885,44 @@ TEST(TestOprBlas, SingularValueDecompositionZeroGrad) { | |||||
run_svd_empty_grad_test<1, 1, 1>(); | run_svd_empty_grad_test<1, 1, 1>(); | ||||
} | } | ||||
#if MGB_ENABLE_FASTRUN | |||||
TEST(TestOprBlas, MatrixMulExePolicy) { | |||||
using Param = opr::MatrixMul::Param; | |||||
Param param; | |||||
using Policy = opr::MatrixMul::ExecutionPolicy; | |||||
using S = Policy::Strategy; | |||||
Policy policy; | |||||
policy.strategy = S::PROFILE; | |||||
auto cn = CompNode::load("cpux"); | |||||
int nr_get = 0; | |||||
auto on_get = [&nr_get](const std::string&, const void*, size_t, | |||||
const void*, size_t) { ++nr_get; }; | |||||
PersistentCacheHook cache_hook{on_get}; | |||||
auto graph = ComputingGraph::make(); | |||||
HostTensorGenerator<> gen; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::Host2DeviceCopy::make(*graph, gen(shp), cn).rename(name); | |||||
}; | |||||
auto a = mkvar("a", {20, 50}); | |||||
auto b = mkvar("b", {50, 40}); | |||||
auto matmul = opr::MatrixMul::make(a, b, param, policy, {}); | |||||
HostTensorND host_y; | |||||
graph->options().no_profiling_on_shape_change = true; | |||||
auto func = graph->compile({make_callback_copy(matmul, host_y)}); | |||||
func->execute(); | |||||
ASSERT_EQ(nr_get, 0); | |||||
graph->options().no_profiling_on_shape_change = false; | |||||
func = graph->compile({make_callback_copy(matmul, host_y)}); | |||||
func->execute(); | |||||
ASSERT_GT(nr_get, 0); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
// | // |