GitOrigin-RevId: b7a1dc62d8
release-1.10
@@ -51,15 +51,6 @@ PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack; | |||||
size_t PoolingImpl::get_workspace_in_bytes( | size_t PoolingImpl::get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& dst) { | const TensorLayout& src, const TensorLayout& dst) { | ||||
TensorLayoutArray layouts{src, dst}; | |||||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||||
layouts.data(), layouts.size(), | |||||
&this->param(), sizeof(this->param())}; | |||||
auto rst = AlgorithmCache::instance().get(key); | |||||
if (rst.policy.algo.valid()) { | |||||
return rst.workspace; | |||||
} | |||||
auto param = make_pooling_kern_szie_param(this, src, dst); | auto param = make_pooling_kern_szie_param(this, src, dst); | ||||
auto algo = get_algorithm(this, src, dst); | auto algo = get_algorithm(this, src, dst); | ||||
if (!is_fallback_algo(algo)) { | if (!is_fallback_algo(algo)) { | ||||
@@ -13,14 +13,6 @@ namespace megdnn { | |||||
template <class Opr, typename... Args> | template <class Opr, typename... Args> | ||||
size_t get_dnn_workspace(Opr* opr, Args&&... args) { | size_t get_dnn_workspace(Opr* opr, Args&&... args) { | ||||
TensorLayoutArray layouts{{args...}}; | |||||
AlgorithmCache::Key key{opr->handle(), opr->get_opr_type(), layouts.data(), | |||||
layouts.size(), &opr->param(), sizeof(opr->param())}; | |||||
auto rst = AlgorithmCache::instance().get(key); | |||||
if (rst.policy.algo.valid()) { | |||||
return rst.workspace; | |||||
} | |||||
typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...); | typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...); | ||||
return get_algorithm(opr, std::forward<Args>(args)...) | return get_algorithm(opr, std::forward<Args>(args)...) | ||||
->get_workspace_in_bytes(size_args); | ->get_workspace_in_bytes(size_args); | ||||
@@ -32,6 +24,7 @@ size_t get_dnn_workspace(Opr* opr, Args&&... args) { | |||||
template <class Opr, typename... Args> | template <class Opr, typename... Args> | ||||
typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | ||||
typename Opr::AlgorithmDesc ret; | typename Opr::AlgorithmDesc ret; | ||||
// first check self configured algorithm | |||||
auto set = opr->execution_policy().algo; | auto set = opr->execution_policy().algo; | ||||
if (set.valid()) { | if (set.valid()) { | ||||
ret = set; | ret = set; | ||||
@@ -40,10 +33,12 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||||
AlgorithmCache::Key key{opr->handle(), opr->get_opr_type(), | AlgorithmCache::Key key{opr->handle(), opr->get_opr_type(), | ||||
layouts.data(), layouts.size(), | layouts.data(), layouts.size(), | ||||
&opr->param(), sizeof(opr->param())}; | &opr->param(), sizeof(opr->param())}; | ||||
// then get from global algorithm cache | |||||
auto rst = AlgorithmCache::instance().get(key); | auto rst = AlgorithmCache::instance().get(key); | ||||
if (rst.policy.algo.valid()) { | if (rst.policy.algo.valid()) { | ||||
ret = rst.policy.algo; | ret = rst.policy.algo; | ||||
} else { | } else { | ||||
// finally get pre-defined heuristic algorithm | |||||
ret = opr->get_algorithm_info_heuristic( | ret = opr->get_algorithm_info_heuristic( | ||||
std::forward<Args>(args)..., | std::forward<Args>(args)..., | ||||
std::numeric_limits<size_t>::max(), AlgoAttribute::DEFAULT, | std::numeric_limits<size_t>::max(), AlgoAttribute::DEFAULT, | ||||
@@ -44,14 +44,6 @@ WorkspaceBundle BatchConvBiasForwardImpl::get_workspace_bundle( | |||||
size_t BatchConvBiasForwardImpl::get_workspace_in_bytes( | size_t BatchConvBiasForwardImpl::get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias, | const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias, | ||||
const TensorLayout& z, const TensorLayout& dst) { | const TensorLayout& z, const TensorLayout& dst) { | ||||
TensorLayoutArray layouts{src, flt, bias, z, dst}; | |||||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||||
layouts.data(), layouts.size(), | |||||
&this->param(), sizeof(this->param())}; | |||||
auto rst = AlgorithmCache::instance().get(key); | |||||
if (rst.policy.algo.valid()) { | |||||
return rst.workspace; | |||||
} | |||||
return get_workspace_bundle(nullptr, src, flt, bias, z, dst).total_size_in_bytes(); | return get_workspace_bundle(nullptr, src, flt, bias, z, dst).total_size_in_bytes(); | ||||
} | } | ||||
@@ -187,15 +187,6 @@ void forward_bias<dt_quint4, dt_qint4, dt_qint32, dt_qint32>( | |||||
size_t ConvBiasForwardImpl::get_workspace_in_bytes( | size_t ConvBiasForwardImpl::get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias, | const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias, | ||||
const TensorLayout& z, const TensorLayout& dst, const PreprocessedFilter*) { | const TensorLayout& z, const TensorLayout& dst, const PreprocessedFilter*) { | ||||
TensorLayoutArray layouts{src, flt, bias, z, dst}; | |||||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||||
layouts.data(), layouts.size(), | |||||
&this->param(), sizeof(this->param())}; | |||||
auto rst = AlgorithmCache::instance().get(key); | |||||
if (rst.policy.algo.valid()) { | |||||
return rst.workspace; | |||||
} | |||||
size_t float_workspace_size = 0; | size_t float_workspace_size = 0; | ||||
if (z.ndim > 0 && z.dtype.category() != DTypeCategory::FLOAT) { | if (z.ndim > 0 && z.dtype.category() != DTypeCategory::FLOAT) { | ||||
@@ -66,15 +66,6 @@ void ConvolutionForwardImpl::exec( | |||||
size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( | size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad) { | const TensorLayout& grad) { | ||||
TensorLayoutArray layouts{filter, diff, grad}; | |||||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||||
layouts.data(), layouts.size(), | |||||
&this->param(), sizeof(this->param())}; | |||||
auto rst = AlgorithmCache::instance().get(key); | |||||
if (rst.policy.algo.valid()) { | |||||
return rst.workspace; | |||||
} | |||||
size_t workspace_size = 0; | size_t workspace_size = 0; | ||||
auto flt_dt = filter.dtype.enumv(); | auto flt_dt = filter.dtype.enumv(); | ||||
auto grad_dt = grad.dtype.enumv(); | auto grad_dt = grad.dtype.enumv(); | ||||
@@ -178,15 +169,6 @@ size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { | const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { | ||||
size_t workspace_size = 0; | size_t workspace_size = 0; | ||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
TensorLayoutArray layouts{src, diff, grad}; | |||||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||||
layouts.data(), layouts.size(), | |||||
&this->param(), sizeof(this->param())}; | |||||
auto rst = AlgorithmCache::instance().get(key); | |||||
if (rst.policy.algo.valid()) { | |||||
return rst.workspace; | |||||
} | |||||
auto src_dt = src.dtype.enumv(); | auto src_dt = src.dtype.enumv(); | ||||
auto grad_dt = grad.dtype.enumv(); | auto grad_dt = grad.dtype.enumv(); | ||||
auto diff_dt = diff.dtype.enumv(); | auto diff_dt = diff.dtype.enumv(); | ||||
@@ -397,14 +397,6 @@ WorkspaceBundle PoolingForwardImpl::get_workspace_bundle( | |||||
size_t PoolingForwardImpl::get_workspace_in_bytes( | size_t PoolingForwardImpl::get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& dst) { | const TensorLayout& src, const TensorLayout& dst) { | ||||
TensorLayoutArray layouts{src, dst}; | |||||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||||
layouts.data(), layouts.size(), | |||||
&this->param(), sizeof(this->param())}; | |||||
auto rst = AlgorithmCache::instance().get(key); | |||||
if (rst.policy.algo.valid()) { | |||||
return rst.workspace; | |||||
} | |||||
return get_workspace_bundle(nullptr, src, dst).total_size_in_bytes(); | return get_workspace_bundle(nullptr, src, dst).total_size_in_bytes(); | ||||
} | } | ||||
namespace { | namespace { | ||||
@@ -649,14 +641,6 @@ WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( | |||||
size_t PoolingBackwardImpl::get_workspace_in_bytes( | size_t PoolingBackwardImpl::get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, | ||||
const TensorLayout& grad) { | const TensorLayout& grad) { | ||||
TensorLayoutArray layouts{src, dst, diff, grad}; | |||||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||||
layouts.data(), layouts.size(), | |||||
&this->param(), sizeof(this->param())}; | |||||
auto rst = AlgorithmCache::instance().get(key); | |||||
if (rst.policy.algo.valid()) { | |||||
return rst.workspace; | |||||
} | |||||
return get_workspace_bundle(nullptr, src, dst, diff, grad).total_size_in_bytes(); | return get_workspace_bundle(nullptr, src, dst, diff, grad).total_size_in_bytes(); | ||||
} | } | ||||
@@ -104,15 +104,6 @@ std::vector<ConvolutionForwardImpl::Algorithm*> ConvolutionForwardImpl:: | |||||
size_t ConvolutionForwardImpl::get_workspace_in_bytes( | size_t ConvolutionForwardImpl::get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, | ||||
const PreprocessedFilter*) { | const PreprocessedFilter*) { | ||||
TensorLayoutArray layouts{src, filter, dst}; | |||||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||||
layouts.data(), layouts.size(), | |||||
&this->param(), sizeof(this->param())}; | |||||
auto rst = AlgorithmCache::instance().get(key); | |||||
if (rst.policy.algo.valid()) { | |||||
return rst.workspace; | |||||
} | |||||
AlgoBase::SizeArgs args(this, src, filter, dst); | AlgoBase::SizeArgs args(this, src, filter, dst); | ||||
return get_algorithm(this, src, filter, dst)->get_workspace_in_bytes(args); | return get_algorithm(this, src, filter, dst)->get_workspace_in_bytes(args); | ||||
} | } | ||||
@@ -198,15 +189,6 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl:: | |||||
size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( | size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad) { | const TensorLayout& grad) { | ||||
TensorLayoutArray layouts{filter, diff, grad}; | |||||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||||
layouts.data(), layouts.size(), | |||||
&this->param(), sizeof(this->param())}; | |||||
auto rst = AlgorithmCache::instance().get(key); | |||||
if (rst.policy.algo.valid()) { | |||||
return rst.workspace; | |||||
} | |||||
AlgoBase::SizeArgs args(this, filter, diff, grad); | AlgoBase::SizeArgs args(this, filter, diff, grad); | ||||
return get_algorithm(this, filter, diff, grad)->get_workspace_in_bytes(args); | return get_algorithm(this, filter, diff, grad)->get_workspace_in_bytes(args); | ||||
} | } | ||||
@@ -282,15 +264,6 @@ ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl:: | |||||
size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( | size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { | const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { | ||||
TensorLayoutArray layouts{src, diff, grad}; | |||||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||||
layouts.data(), layouts.size(), | |||||
&this->param(), sizeof(this->param())}; | |||||
auto rst = AlgorithmCache::instance().get(key); | |||||
if (rst.policy.algo.valid()) { | |||||
return rst.workspace; | |||||
} | |||||
AlgoBase::SizeArgs args(this, src, diff, grad); | AlgoBase::SizeArgs args(this, src, diff, grad); | ||||
return get_algorithm(this, src, diff, grad)->get_workspace_in_bytes(args); | return get_algorithm(this, src, diff, grad)->get_workspace_in_bytes(args); | ||||
} | } | ||||
@@ -35,15 +35,6 @@ WorkspaceBundle megdnn::x86::get_bundle( | |||||
size_t PoolingImpl::get_workspace_in_bytes( | size_t PoolingImpl::get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& dst) { | const TensorLayout& src, const TensorLayout& dst) { | ||||
TensorLayoutArray layouts{src, dst}; | |||||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||||
layouts.data(), layouts.size(), | |||||
&this->param(), sizeof(this->param())}; | |||||
auto rst = AlgorithmCache::instance().get(key); | |||||
if (rst.policy.algo.valid()) { | |||||
return rst.workspace; | |||||
} | |||||
auto algo = get_algorithm(this, src, dst); | auto algo = get_algorithm(this, src, dst); | ||||
if (!is_fallback_algo(algo)) { | if (!is_fallback_algo(algo)) { | ||||
if (is_supported(SIMDType::SSE) && src.dtype == dtype::Float32() && | if (is_supported(SIMDType::SSE) && src.dtype == dtype::Float32() && | ||||
@@ -351,7 +351,7 @@ class TimedFuncInvokerImpl final : public TimedFuncInvoker { | |||||
} else { | } else { | ||||
CHECK_SYS_ERR(cur_recv); | CHECK_SYS_ERR(cur_recv); | ||||
} | } | ||||
mgb_assert(cur_recv > 0); | |||||
mgb_assert(cur_recv >= 0); | |||||
dest += cur_recv; | dest += cur_recv; | ||||
size -= cur_recv; | size -= cur_recv; | ||||
} | } | ||||
@@ -950,10 +950,12 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile( | |||||
algo.desc.name.c_str(), layouts_str.c_str()); | algo.desc.name.c_str(), layouts_str.c_str()); | ||||
timer.reset(); | timer.reset(); | ||||
MGB_TRY { cur_rst = profile_single_algo(policy, cur_timeout); } | MGB_TRY { cur_rst = profile_single_algo(policy, cur_timeout); } | ||||
// megbrain catched exception | |||||
MGB_CATCH(std::exception & exc, { | MGB_CATCH(std::exception & exc, { | ||||
mgb_log_warn("caught exception during %s: %s", msg.c_str(), exc.what()); | |||||
mgb_log_debug("caught exception during %s: %s", msg.c_str(), exc.what()); | |||||
continue; | continue; | ||||
}) | }) | ||||
// megbrain uncatched exception | |||||
MGB_CATCH(..., { | MGB_CATCH(..., { | ||||
mgb_log_warn("caught exception during %s", msg.c_str()); | mgb_log_warn("caught exception during %s", msg.c_str()); | ||||
continue; | continue; | ||||