GitOrigin-RevId: 144ff547d1
tags/v1.4.0-rc1
@@ -136,16 +136,16 @@ public: | |||||
uint32_t type = INVALID_ALGO_TYPE; | uint32_t type = INVALID_ALGO_TYPE; | ||||
//! serialized param of the algo type | //! serialized param of the algo type | ||||
std::string param; | std::string param; | ||||
//! algorithm name | |||||
std::string name; | |||||
bool valid() const { return type != INVALID_ALGO_TYPE; } | bool valid() const { return type != INVALID_ALGO_TYPE; } | ||||
void reset() { type = INVALID_ALGO_TYPE; } | void reset() { type = INVALID_ALGO_TYPE; } | ||||
bool operator==(const Desc& rhs) const { | bool operator==(const Desc& rhs) const { | ||||
return handle_type == rhs.handle_type && type == rhs.type && | return handle_type == rhs.handle_type && type == rhs.type && | ||||
param == rhs.param; | |||||
param == rhs.param && name == rhs.name; | |||||
} | } | ||||
} desc; | } desc; | ||||
//! algorithm name | |||||
std::string name; | |||||
Attribute attribute; | Attribute attribute; | ||||
bool valid() const { return desc.valid(); } | bool valid() const { return desc.valid(); } | ||||
void reset() { desc.reset(); } | void reset() { desc.reset(); } | ||||
@@ -178,12 +178,12 @@ public: | |||||
static std::string attribute_str(const Attribute& attr); | static std::string attribute_str(const Attribute& attr); | ||||
Handle::HandleType handle_type() const { return m_handle_type; } | Handle::HandleType handle_type() const { return m_handle_type; } | ||||
Info::Desc desc() const { return {handle_type(), type(), param(), name()}; } | |||||
Info info() const { | Info info() const { | ||||
return {{handle_type(), type(), param()}, name(), attribute()}; | |||||
return {desc(), attribute()}; | |||||
} | } | ||||
Info::Desc desc() const { return {handle_type(), type(), param()}; } | |||||
template <typename T> | template <typename T> | ||||
static void serialize_write_pod(const T& val, std::string& result) { | static void serialize_write_pod(const T& val, std::string& result) { | ||||
static_assert(std::is_trivially_copyable<T>::value, | static_assert(std::is_trivially_copyable<T>::value, | ||||
@@ -116,8 +116,10 @@ struct hash<megdnn::detail::Algorithm::Info::Desc> { | |||||
const megdnn::detail::Algorithm::Info::Desc& desc) const { | const megdnn::detail::Algorithm::Info::Desc& desc) const { | ||||
return megdnn::hash_combine<size_t>( | return megdnn::hash_combine<size_t>( | ||||
megdnn::hash_combine<size_t>( | megdnn::hash_combine<size_t>( | ||||
std::hash<std::string>()(desc.param), | |||||
std::hash<uint32_t>()(desc.type)), | |||||
std::hash<std::string>()(desc.name), | |||||
megdnn::hash_combine<size_t>( | |||||
std::hash<std::string>()(desc.param), | |||||
std::hash<uint32_t>()(desc.type))), | |||||
std::hash<uint32_t>()(static_cast<uint32_t>(desc.handle_type))); | std::hash<uint32_t>()(static_cast<uint32_t>(desc.handle_type))); | ||||
} | } | ||||
}; | }; | ||||
@@ -439,12 +439,6 @@ public: | |||||
TensorLayout& dst_pg, TensorLayout& bias_pg); | TensorLayout& dst_pg, TensorLayout& bias_pg); | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | ||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_impl->name(), ret); | |||||
return ret; | |||||
} | |||||
private: | private: | ||||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | ||||
AlgoBase* m_impl; | AlgoBase* m_impl; | ||||
@@ -237,12 +237,6 @@ public: | |||||
} | } | ||||
return ret; | return ret; | ||||
} | } | ||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_impl->name(), ret); | |||||
return ret; | |||||
} | |||||
}; | }; | ||||
class ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm final | class ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm final | ||||
@@ -222,12 +222,6 @@ public: | |||||
} | } | ||||
return ret; | return ret; | ||||
} | } | ||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_impl->name(), ret); | |||||
return ret; | |||||
} | |||||
}; | }; | ||||
class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj { | class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj { | ||||
@@ -174,14 +174,8 @@ public: | |||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | ||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_impl->name(), ret); | |||||
return ret; | |||||
} | |||||
}; | }; | ||||
class Convolution3DBackwardDataImpl::AlgoPack : NonCopyableObj { | class Convolution3DBackwardDataImpl::AlgoPack : NonCopyableObj { | ||||
// defined in cudnn.cpp | // defined in cudnn.cpp | ||||
void fill_cudnn_algos(); | void fill_cudnn_algos(); | ||||
@@ -183,11 +183,6 @@ public: | |||||
TensorLayout& diff_pg); | TensorLayout& diff_pg); | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | ||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_impl->name(), ret); | |||||
return ret; | |||||
} | |||||
}; | }; | ||||
class Convolution3DBackwardFilterImpl::AlgoPack : NonCopyableObj { | class Convolution3DBackwardFilterImpl::AlgoPack : NonCopyableObj { | ||||
@@ -135,11 +135,6 @@ public: | |||||
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | ||||
TensorLayout& dst_pg); | TensorLayout& dst_pg); | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | ||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_impl->name(), ret); | |||||
return ret; | |||||
} | |||||
}; | }; | ||||
class Convolution3DForwardImpl::AlgoCUDNN final : public AlgoBase { | class Convolution3DForwardImpl::AlgoCUDNN final : public AlgoBase { | ||||
@@ -65,11 +65,6 @@ public: | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | ||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_F32) | MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_F32) | ||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_matmul_algo->name(), ret); | |||||
return ret; | |||||
} | |||||
private: | private: | ||||
MatrixMulImpl::AlgoBase* m_matmul_algo; | MatrixMulImpl::AlgoBase* m_matmul_algo; | ||||
@@ -101,11 +96,6 @@ public: | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | ||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_4X4_F32) | MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_4X4_F32) | ||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_matmul_algo->name(), ret); | |||||
return ret; | |||||
} | |||||
private: | private: | ||||
MatrixMulImpl::AlgoBase* m_matmul_algo; | MatrixMulImpl::AlgoBase* m_matmul_algo; | ||||
@@ -137,11 +127,6 @@ public: | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | ||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_QS8) | MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_QS8) | ||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_matmul_algo->name(), ret); | |||||
return ret; | |||||
} | |||||
private: | private: | ||||
MatrixMulImpl::AlgoBase* m_matmul_algo; | MatrixMulImpl::AlgoBase* m_matmul_algo; | ||||
@@ -173,11 +158,6 @@ public: | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | ||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8) | MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8) | ||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_matmul_algo->name(), ret); | |||||
return ret; | |||||
} | |||||
private: | private: | ||||
MatrixMulImpl::AlgoBase* m_matmul_algo; | MatrixMulImpl::AlgoBase* m_matmul_algo; | ||||
@@ -157,7 +157,6 @@ using BiasMode = ConvBiasForward::BiasMode; | |||||
} \ | } \ | ||||
std::string param() const override { \ | std::string param() const override { \ | ||||
std::string ret; \ | std::string ret; \ | ||||
serialize_write_pod(m_matmul_algo->name(), ret); \ | |||||
serialize_write_pod(m_tile_size, ret); \ | serialize_write_pod(m_tile_size, ret); \ | ||||
return ret; \ | return ret; \ | ||||
} \ | } \ | ||||
@@ -62,10 +62,9 @@ public: | |||||
return {m_matmul_algo->matmul_description().algo_type.data_type, | return {m_matmul_algo->matmul_description().algo_type.data_type, | ||||
AlgoCategory::IM2COL}; | AlgoCategory::IM2COL}; | ||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8) | |||||
MEGDNN_DECL_ALGO_TYPE(FB_CONV1x1) | |||||
std::string param() const override { | std::string param() const override { | ||||
std::string ret; | std::string ret; | ||||
serialize_write_pod(m_matmul_algo->name(), ret); | |||||
serialize_write_pod(m_oc_block_size, ret); | serialize_write_pod(m_oc_block_size, ret); | ||||
return ret; | return ret; | ||||
} | } | ||||
@@ -74,7 +74,6 @@ public: | |||||
std::string param() const override { | std::string param() const override { | ||||
std::string ret; | std::string ret; | ||||
serialize_write_pod(m_matmul_algo->name(), ret); | |||||
serialize_write_pod(m_ohw_tile_size, ret); | serialize_write_pod(m_ohw_tile_size, ret); | ||||
return ret; | return ret; | ||||
} | } | ||||
@@ -155,12 +155,6 @@ public: | |||||
//! select matmul to the highest preference | //! select matmul to the highest preference | ||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_algorithm->name(), ret); | |||||
return ret; | |||||
} | |||||
static ConvBiasImpl::NCBKernSizeParam init_conv_bias_param( | static ConvBiasImpl::NCBKernSizeParam init_conv_bias_param( | ||||
const NCBKernSizeParam& param); | const NCBKernSizeParam& param); | ||||
@@ -380,13 +380,13 @@ float algo_benchmark(Benchmarker<Opr, T>& benchmark, TensorLayoutArray layouts, | |||||
float min_used = std::numeric_limits<float>::max(); | float min_used = std::numeric_limits<float>::max(); | ||||
bool execed = false; | bool execed = false; | ||||
for (auto i : algos) { | for (auto i : algos) { | ||||
if (std::regex_match(i.name, | |||||
if (std::regex_match(i.desc.name, | |||||
std::regex("(" + algo_base + ")(.*)"))) { | std::regex("(" + algo_base + ")(.*)"))) { | ||||
opr->execution_policy().algo = i.desc; | opr->execution_policy().algo = i.desc; | ||||
auto used = benchmark.exec(layouts); | auto used = benchmark.exec(layouts); | ||||
min_used = std::min(min_used, used); | min_used = std::min(min_used, used); | ||||
printf("run algo: %s used: %f ms min_used: %f ms\n", i.name.c_str(), | |||||
used, min_used); | |||||
printf("run algo: %s used: %f ms min_used: %f ms\n", | |||||
i.desc.name.c_str(), used, min_used); | |||||
execed = true; | execed = true; | ||||
} | } | ||||
} | } | ||||
@@ -482,7 +482,7 @@ public: | |||||
AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info( | AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info( | ||||
opr.get(), layouts)) { | opr.get(), layouts)) { | ||||
if (std::regex_match( | if (std::regex_match( | ||||
algo_info.name, | |||||
algo_info.desc.name, | |||||
std::regex("(" + policy_name.name + ")(.*)"))) { | std::regex("(" + policy_name.name + ")(.*)"))) { | ||||
ret.algo = algo_info.desc; | ret.algo = algo_info.desc; | ||||
} else { | } else { | ||||
@@ -495,7 +495,7 @@ public: | |||||
if (sub_items.size() != policy_name.sub_policy_names.size()) { | if (sub_items.size() != policy_name.sub_policy_names.size()) { | ||||
printf("Invalid sub_policy_names in %s, expected %zu but got " | printf("Invalid sub_policy_names in %s, expected %zu but got " | ||||
"%zu\n", | "%zu\n", | ||||
algo_info.name.c_str(), sub_items.size(), | |||||
algo_info.desc.name.c_str(), sub_items.size(), | |||||
policy_name.sub_policy_names.size()); | policy_name.sub_policy_names.size()); | ||||
return {}; | return {}; | ||||
} | } | ||||
@@ -528,7 +528,7 @@ public: | |||||
auto algo = | auto algo = | ||||
OprAlgoProxy::get_algorithm_info_heuristic(opr, layouts); | OprAlgoProxy::get_algorithm_info_heuristic(opr, layouts); | ||||
ASSERT_STREQ(opr->get_algorithm_from_desc(m_policy.algo)->name(), | ASSERT_STREQ(opr->get_algorithm_from_desc(m_policy.algo)->name(), | ||||
algo.name.c_str()); | |||||
algo.desc.name.c_str()); | |||||
} else { | } else { | ||||
opr->execution_policy() = m_policy; | opr->execution_policy() = m_policy; | ||||
} | } | ||||
@@ -629,11 +629,10 @@ Checker<Convolution> checker(handle); | |||||
out_type = inp_type; | out_type = inp_type; | ||||
} | } | ||||
checker | |||||
.set_dtype(0, inp_type) | |||||
.set_dtype(1, inp_type) | |||||
.set_dtype(2, out_type) | |||||
.set_param(param); | |||||
checker.set_dtype(0, inp_type) | |||||
.set_dtype(1, inp_type) | |||||
.set_dtype(2, out_type) | |||||
.set_param(param); | |||||
auto opr = checker.opr(); | auto opr = checker.opr(); | ||||
opr->param() = param; | opr->param() = param; | ||||
std::string param_str; | std::string param_str; | ||||
@@ -642,7 +641,8 @@ Checker<Convolution> checker(handle); | |||||
oly.dtype = out_type; | oly.dtype = out_type; | ||||
opr->deduce_layout(ily, fly, oly); | opr->deduce_layout(ily, fly, oly); | ||||
int channel_start = 1; | int channel_start = 1; | ||||
if (format) channel_start = 3; | |||||
if (format) | |||||
channel_start = 3; | |||||
float scale = 1.0f / sqrt(fshp[channel_start] * FH * FW); | float scale = 1.0f / sqrt(fshp[channel_start] * FH * FW); | ||||
UniformFloatRNG rng(scale, 2 * scale); | UniformFloatRNG rng(scale, 2 * scale); | ||||
checker.set_rng(0, &rng).set_rng(1, &rng); | checker.set_rng(0, &rng).set_rng(1, &rng); | ||||
@@ -653,11 +653,11 @@ Checker<Convolution> checker(handle); | |||||
construct_sub_execution_policy_heuristic<ConvolutionForward>( | construct_sub_execution_policy_heuristic<ConvolutionForward>( | ||||
opr->execution_policy(), {ily, fly, oly}, param_str, | opr->execution_policy(), {ily, fly, oly}, param_str, | ||||
opr->handle()); | opr->handle()); | ||||
checker | |||||
.set_epsilon(eps_getter(dtype == 1, 0, algo.name.c_str())) | |||||
.execs({ishp, fshp, {}}); | |||||
checker.set_epsilon( | |||||
eps_getter(dtype == 1, 0, algo.desc.name.c_str())) | |||||
.execs({ishp, fshp, {}}); | |||||
opr->execution_policy() = {}; | opr->execution_policy() = {}; | ||||
ASSERT_TRUE(checker.prev_succ()) << errmsg(algo.name.c_str()); | |||||
ASSERT_TRUE(checker.prev_succ()) << errmsg(algo.desc.name.c_str()); | |||||
} | } | ||||
if (test_backward) { | if (test_backward) { | ||||
@@ -671,7 +671,7 @@ Checker<Convolution> checker(handle); | |||||
opr->param() = param; | opr->param() = param; | ||||
std::string param_str; | std::string param_str; | ||||
Algorithm::serialize_write_pod(opr->param(), param_str); | Algorithm::serialize_write_pod(opr->param(), param_str); | ||||
for (auto algo: opr->get_all_algorithms_info(fly, oly, ily)) { | |||||
for (auto algo : opr->get_all_algorithms_info(fly, oly, ily)) { | |||||
used_algos_bwd_data.insert(algo.desc); | used_algos_bwd_data.insert(algo.desc); | ||||
opr->execution_policy().algo = algo.desc; | opr->execution_policy().algo = algo.desc; | ||||
construct_sub_execution_policy_heuristic< | construct_sub_execution_policy_heuristic< | ||||
@@ -679,26 +679,26 @@ Checker<Convolution> checker(handle); | |||||
{fly, oly, ily}, param_str, | {fly, oly, ily}, param_str, | ||||
opr->handle()); | opr->handle()); | ||||
checker_bwd_data | checker_bwd_data | ||||
.set_epsilon(eps_getter(dtype == 1, 1, algo.name.c_str())) | |||||
.execl({fly, oly, ily}); | |||||
.set_epsilon(eps_getter(dtype == 1, 1, | |||||
algo.desc.name.c_str())) | |||||
.execl({fly, oly, ily}); | |||||
opr->execution_policy() = {}; | opr->execution_policy() = {}; | ||||
ASSERT_TRUE(checker_bwd_data.prev_succ()) << | |||||
errmsg(algo.name.c_str()); | |||||
ASSERT_TRUE(checker_bwd_data.prev_succ()) | |||||
<< errmsg(algo.desc.name.c_str()); | |||||
} | } | ||||
} | } | ||||
if (test_backward) { | if (test_backward) { | ||||
// backward filter | // backward filter | ||||
checker_bwd_filter | |||||
.set_dtype(0, inp_type) | |||||
.set_dtype(1, out_type) | |||||
.set_dtype(2, inp_type) | |||||
.set_param(param); | |||||
checker_bwd_filter.set_dtype(0, inp_type) | |||||
.set_dtype(1, out_type) | |||||
.set_dtype(2, inp_type) | |||||
.set_param(param); | |||||
auto opr = checker_bwd_filter.opr(); | auto opr = checker_bwd_filter.opr(); | ||||
opr->param() = param; | opr->param() = param; | ||||
std::string param_str; | std::string param_str; | ||||
Algorithm::serialize_write_pod(opr->param(), param_str); | Algorithm::serialize_write_pod(opr->param(), param_str); | ||||
for (auto algo: opr->get_all_algorithms_info(ily, oly, fly)) { | |||||
for (auto algo : opr->get_all_algorithms_info(ily, oly, fly)) { | |||||
used_algos_bwd_flt.insert(algo.desc); | used_algos_bwd_flt.insert(algo.desc); | ||||
opr->execution_policy().algo = algo.desc; | opr->execution_policy().algo = algo.desc; | ||||
construct_sub_execution_policy_heuristic< | construct_sub_execution_policy_heuristic< | ||||
@@ -706,11 +706,12 @@ Checker<Convolution> checker(handle); | |||||
{ily, oly, fly}, param_str, | {ily, oly, fly}, param_str, | ||||
opr->handle()); | opr->handle()); | ||||
checker_bwd_filter | checker_bwd_filter | ||||
.set_epsilon(eps_getter(dtype == 1, 2, algo.name.c_str())) | |||||
.execl({ily, oly, fly}); | |||||
.set_epsilon(eps_getter(dtype == 1, 2, | |||||
algo.desc.name.c_str())) | |||||
.execl({ily, oly, fly}); | |||||
opr->execution_policy() = {}; | opr->execution_policy() = {}; | ||||
ASSERT_TRUE(checker_bwd_filter.prev_succ()) << | |||||
errmsg(algo.name.c_str()); | |||||
ASSERT_TRUE(checker_bwd_filter.prev_succ()) | |||||
<< errmsg(algo.desc.name.c_str()); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -400,7 +400,7 @@ struct OprProxyProfilingBase | |||||
megcoreSynchronize(opr->handle()->megcore_computing_handle()); | megcoreSynchronize(opr->handle()->megcore_computing_handle()); | ||||
timer.stop(); | timer.stop(); | ||||
megdnn_log("%.3fms %s", timer.get_time_in_us() / 1e3, | megdnn_log("%.3fms %s", timer.get_time_in_us() / 1e3, | ||||
algo.name.c_str()); | |||||
algo.desc.name.c_str()); | |||||
if (min_time > timer.get_time_in_us()) { | if (min_time > timer.get_time_in_us()) { | ||||
min_time = timer.get_time_in_us(); | min_time = timer.get_time_in_us(); | ||||
best_algo = algo.desc; | best_algo = algo.desc; | ||||
@@ -522,7 +522,7 @@ struct OprWeightPreprocessProxyImpl : public OprProxyProfilingBase<Opr> { | |||||
megcoreSynchronize(opr->handle()->megcore_computing_handle()); | megcoreSynchronize(opr->handle()->megcore_computing_handle()); | ||||
timer.stop(); | timer.stop(); | ||||
printf("%.3fms %s\n", timer.get_time_in_us() / 1e3, | printf("%.3fms %s\n", timer.get_time_in_us() / 1e3, | ||||
algo.name.c_str()); | |||||
algo.desc.name.c_str()); | |||||
if (min_time > timer.get_time_in_us()) { | if (min_time > timer.get_time_in_us()) { | ||||
min_time = timer.get_time_in_us(); | min_time = timer.get_time_in_us(); | ||||
Base::target_execution_policy.algo = algo.desc; | Base::target_execution_policy.algo = algo.desc; | ||||
@@ -88,7 +88,7 @@ void test_multibatchsize( | |||||
A_tensor.layout(), B_tensor.layout(), | A_tensor.layout(), B_tensor.layout(), | ||||
C_tensor.layout())) { | C_tensor.layout())) { | ||||
if (std::regex_match( | if (std::regex_match( | ||||
i.name.c_str(), | |||||
i.desc.name.c_str(), | |||||
std::regex("(" + std::string(algo) + ")(.*)"))) { | std::regex("(" + std::string(algo) + ")(.*)"))) { | ||||
opr_reference->execution_policy().algo = i.desc; | opr_reference->execution_policy().algo = i.desc; | ||||
break; | break; | ||||
@@ -117,7 +117,7 @@ void test_multibatchsize( | |||||
A_tensor_prime.layout(), B_tensor.layout(), | A_tensor_prime.layout(), B_tensor.layout(), | ||||
C_tensor_batch.layout())) { | C_tensor_batch.layout())) { | ||||
if (std::regex_match( | if (std::regex_match( | ||||
i.name.c_str(), | |||||
i.desc.name.c_str(), | |||||
std::regex("(" + std::string(algo) + ")(.*)"))) { | std::regex("(" + std::string(algo) + ")(.*)"))) { | ||||
opr_reference->execution_policy().algo = i.desc; | opr_reference->execution_policy().algo = i.desc; | ||||
break; | break; | ||||
@@ -318,7 +318,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||||
Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst; | Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst; | ||||
std::string msg = ssprintf("profiling %s algorithm %s %s", | std::string msg = ssprintf("profiling %s algorithm %s %s", | ||||
ctx.mgb_opr()->dyn_typeinfo()->name, | ctx.mgb_opr()->dyn_typeinfo()->name, | ||||
algo.name.c_str(), layouts_str.c_str()); | |||||
algo.desc.name.c_str(), layouts_str.c_str()); | |||||
ImplExecutionPolicy policy; | ImplExecutionPolicy policy; | ||||
policy.algo = algo.desc; | policy.algo = algo.desc; | ||||
ctx.construct_execution_policy(selected_strategy, policy); | ctx.construct_execution_policy(selected_strategy, policy); | ||||
@@ -327,12 +327,12 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||||
} | } | ||||
auto palgo = ctx.megdnn_opr()->get_algorithm_from_desc(policy.algo); | auto palgo = ctx.megdnn_opr()->get_algorithm_from_desc(policy.algo); | ||||
if (!(palgo->contain_attribute_all(target_attr.first) && | if (!(palgo->contain_attribute_all(target_attr.first) && | ||||
!palgo->contain_attribute_any(target_attr.second))) { | |||||
!palgo->contain_attribute_any(target_attr.second))) { | |||||
mgb_log_debug( | mgb_log_debug( | ||||
"skip algo %s with attribute(%s), which is not match the " | "skip algo %s with attribute(%s), which is not match the " | ||||
"profile strategy required contain attribute(%s) and not " | "profile strategy required contain attribute(%s) and not " | ||||
"contain attribute(%s).", | "contain attribute(%s).", | ||||
algo.name.c_str(), | |||||
algo.desc.name.c_str(), | |||||
Algorithm::attribute_str(palgo->attribute()).c_str(), | Algorithm::attribute_str(palgo->attribute()).c_str(), | ||||
Algorithm::attribute_str(target_attr.first).c_str(), | Algorithm::attribute_str(target_attr.first).c_str(), | ||||
Algorithm::attribute_str(target_attr.second).c_str()); | Algorithm::attribute_str(target_attr.second).c_str()); | ||||
@@ -552,8 +552,8 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||||
auto&& prof = rst.val(); | auto&& prof = rst.val(); | ||||
std::unordered_map<std::string, ImplAlgo> algo_map; | std::unordered_map<std::string, ImplAlgo> algo_map; | ||||
for (auto i : get_all_candidates()) { | for (auto i : get_all_candidates()) { | ||||
auto ins = algo_map.emplace(i.name.c_str(), i); | |||||
mgb_assert(ins.second, "duplicated algo name: %s", i.name.c_str()); | |||||
auto ins = algo_map.emplace(i.desc.name.c_str(), i); | |||||
mgb_assert(ins.second, "duplicated algo name: %s", i.desc.name.c_str()); | |||||
} | } | ||||
if (prof.empty()) | if (prof.empty()) | ||||
@@ -41,8 +41,11 @@ std::string serialize_policy(const megdnn::ExecutionPolicy& policy) { | |||||
megdnn::Algorithm::serialize_write_pod(policy.algo.handle_type, ret); | megdnn::Algorithm::serialize_write_pod(policy.algo.handle_type, ret); | ||||
megdnn::Algorithm::serialize_write_pod(policy.algo.type, ret); | megdnn::Algorithm::serialize_write_pod(policy.algo.type, ret); | ||||
uint32_t param_size = policy.algo.param.size(); | uint32_t param_size = policy.algo.param.size(); | ||||
uint32_t name_size = policy.algo.name.size(); | |||||
megdnn::Algorithm::serialize_write_pod<uint32_t>(param_size, ret); | megdnn::Algorithm::serialize_write_pod<uint32_t>(param_size, ret); | ||||
megdnn::Algorithm::serialize_write_pod<uint32_t>(name_size, ret); | |||||
ret += policy.algo.param; | ret += policy.algo.param; | ||||
ret += policy.algo.name; | |||||
//! serialize sub_policy | //! serialize sub_policy | ||||
uint32_t size = policy.sub_policy.size(); | uint32_t size = policy.sub_policy.size(); | ||||
@@ -64,11 +67,17 @@ megdnn::ExecutionPolicy deserialize_policy(const char* buf, uint32_t size, | |||||
cb(ret.algo.type, uint32_t); | cb(ret.algo.type, uint32_t); | ||||
uint32_t param_size = 0; | uint32_t param_size = 0; | ||||
uint32_t name_size = 0; | |||||
cb(param_size, uint32_t); | cb(param_size, uint32_t); | ||||
cb(name_size, uint32_t); | |||||
if (param_size > 0) { | if (param_size > 0) { | ||||
ret.algo.param = std::string(buf + offset, param_size); | ret.algo.param = std::string(buf + offset, param_size); | ||||
offset += param_size; | offset += param_size; | ||||
} | } | ||||
if (name_size > 0) { | |||||
ret.algo.name = std::string(buf + offset, name_size); | |||||
offset += name_size; | |||||
} | |||||
uint32_t nr_policy = 0; | uint32_t nr_policy = 0; | ||||
cb(nr_policy, uint32_t); | cb(nr_policy, uint32_t); | ||||