GitOrigin-RevId: b7a1dc62d8
release-1.10
@@ -51,15 +51,6 @@ PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack; | |||
size_t PoolingImpl::get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) { | |||
TensorLayoutArray layouts{src, dst}; | |||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
layouts.data(), layouts.size(), | |||
&this->param(), sizeof(this->param())}; | |||
auto rst = AlgorithmCache::instance().get(key); | |||
if (rst.policy.algo.valid()) { | |||
return rst.workspace; | |||
} | |||
auto param = make_pooling_kern_szie_param(this, src, dst); | |||
auto algo = get_algorithm(this, src, dst); | |||
if (!is_fallback_algo(algo)) { | |||
@@ -13,14 +13,6 @@ namespace megdnn { | |||
template <class Opr, typename... Args> | |||
size_t get_dnn_workspace(Opr* opr, Args&&... args) { | |||
TensorLayoutArray layouts{{args...}}; | |||
AlgorithmCache::Key key{opr->handle(), opr->get_opr_type(), layouts.data(), | |||
layouts.size(), &opr->param(), sizeof(opr->param())}; | |||
auto rst = AlgorithmCache::instance().get(key); | |||
if (rst.policy.algo.valid()) { | |||
return rst.workspace; | |||
} | |||
typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...); | |||
return get_algorithm(opr, std::forward<Args>(args)...) | |||
->get_workspace_in_bytes(size_args); | |||
@@ -32,6 +24,7 @@ size_t get_dnn_workspace(Opr* opr, Args&&... args) { | |||
template <class Opr, typename... Args> | |||
typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||
typename Opr::AlgorithmDesc ret; | |||
// first check self configured algorithm | |||
auto set = opr->execution_policy().algo; | |||
if (set.valid()) { | |||
ret = set; | |||
@@ -40,10 +33,12 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||
AlgorithmCache::Key key{opr->handle(), opr->get_opr_type(), | |||
layouts.data(), layouts.size(), | |||
&opr->param(), sizeof(opr->param())}; | |||
// then get from global algorithm cache | |||
auto rst = AlgorithmCache::instance().get(key); | |||
if (rst.policy.algo.valid()) { | |||
ret = rst.policy.algo; | |||
} else { | |||
// finally get pre-defined heuristic algorithm | |||
ret = opr->get_algorithm_info_heuristic( | |||
std::forward<Args>(args)..., | |||
std::numeric_limits<size_t>::max(), AlgoAttribute::DEFAULT, | |||
@@ -44,14 +44,6 @@ WorkspaceBundle BatchConvBiasForwardImpl::get_workspace_bundle( | |||
size_t BatchConvBiasForwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias, | |||
const TensorLayout& z, const TensorLayout& dst) { | |||
TensorLayoutArray layouts{src, flt, bias, z, dst}; | |||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
layouts.data(), layouts.size(), | |||
&this->param(), sizeof(this->param())}; | |||
auto rst = AlgorithmCache::instance().get(key); | |||
if (rst.policy.algo.valid()) { | |||
return rst.workspace; | |||
} | |||
return get_workspace_bundle(nullptr, src, flt, bias, z, dst).total_size_in_bytes(); | |||
} | |||
@@ -187,15 +187,6 @@ void forward_bias<dt_quint4, dt_qint4, dt_qint32, dt_qint32>( | |||
size_t ConvBiasForwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias, | |||
const TensorLayout& z, const TensorLayout& dst, const PreprocessedFilter*) { | |||
TensorLayoutArray layouts{src, flt, bias, z, dst}; | |||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
layouts.data(), layouts.size(), | |||
&this->param(), sizeof(this->param())}; | |||
auto rst = AlgorithmCache::instance().get(key); | |||
if (rst.policy.algo.valid()) { | |||
return rst.workspace; | |||
} | |||
size_t float_workspace_size = 0; | |||
if (z.ndim > 0 && z.dtype.category() != DTypeCategory::FLOAT) { | |||
@@ -66,15 +66,6 @@ void ConvolutionForwardImpl::exec( | |||
size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( | |||
const TensorLayout& filter, const TensorLayout& diff, | |||
const TensorLayout& grad) { | |||
TensorLayoutArray layouts{filter, diff, grad}; | |||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
layouts.data(), layouts.size(), | |||
&this->param(), sizeof(this->param())}; | |||
auto rst = AlgorithmCache::instance().get(key); | |||
if (rst.policy.algo.valid()) { | |||
return rst.workspace; | |||
} | |||
size_t workspace_size = 0; | |||
auto flt_dt = filter.dtype.enumv(); | |||
auto grad_dt = grad.dtype.enumv(); | |||
@@ -178,15 +169,6 @@ size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { | |||
size_t workspace_size = 0; | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
TensorLayoutArray layouts{src, diff, grad}; | |||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
layouts.data(), layouts.size(), | |||
&this->param(), sizeof(this->param())}; | |||
auto rst = AlgorithmCache::instance().get(key); | |||
if (rst.policy.algo.valid()) { | |||
return rst.workspace; | |||
} | |||
auto src_dt = src.dtype.enumv(); | |||
auto grad_dt = grad.dtype.enumv(); | |||
auto diff_dt = diff.dtype.enumv(); | |||
@@ -397,14 +397,6 @@ WorkspaceBundle PoolingForwardImpl::get_workspace_bundle( | |||
size_t PoolingForwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) { | |||
TensorLayoutArray layouts{src, dst}; | |||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
layouts.data(), layouts.size(), | |||
&this->param(), sizeof(this->param())}; | |||
auto rst = AlgorithmCache::instance().get(key); | |||
if (rst.policy.algo.valid()) { | |||
return rst.workspace; | |||
} | |||
return get_workspace_bundle(nullptr, src, dst).total_size_in_bytes(); | |||
} | |||
namespace { | |||
@@ -649,14 +641,6 @@ WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( | |||
size_t PoolingBackwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, | |||
const TensorLayout& grad) { | |||
TensorLayoutArray layouts{src, dst, diff, grad}; | |||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
layouts.data(), layouts.size(), | |||
&this->param(), sizeof(this->param())}; | |||
auto rst = AlgorithmCache::instance().get(key); | |||
if (rst.policy.algo.valid()) { | |||
return rst.workspace; | |||
} | |||
return get_workspace_bundle(nullptr, src, dst, diff, grad).total_size_in_bytes(); | |||
} | |||
@@ -104,15 +104,6 @@ std::vector<ConvolutionForwardImpl::Algorithm*> ConvolutionForwardImpl:: | |||
size_t ConvolutionForwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, | |||
const PreprocessedFilter*) { | |||
TensorLayoutArray layouts{src, filter, dst}; | |||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
layouts.data(), layouts.size(), | |||
&this->param(), sizeof(this->param())}; | |||
auto rst = AlgorithmCache::instance().get(key); | |||
if (rst.policy.algo.valid()) { | |||
return rst.workspace; | |||
} | |||
AlgoBase::SizeArgs args(this, src, filter, dst); | |||
return get_algorithm(this, src, filter, dst)->get_workspace_in_bytes(args); | |||
} | |||
@@ -198,15 +189,6 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl:: | |||
size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( | |||
const TensorLayout& filter, const TensorLayout& diff, | |||
const TensorLayout& grad) { | |||
TensorLayoutArray layouts{filter, diff, grad}; | |||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
layouts.data(), layouts.size(), | |||
&this->param(), sizeof(this->param())}; | |||
auto rst = AlgorithmCache::instance().get(key); | |||
if (rst.policy.algo.valid()) { | |||
return rst.workspace; | |||
} | |||
AlgoBase::SizeArgs args(this, filter, diff, grad); | |||
return get_algorithm(this, filter, diff, grad)->get_workspace_in_bytes(args); | |||
} | |||
@@ -282,15 +264,6 @@ ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl:: | |||
size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { | |||
TensorLayoutArray layouts{src, diff, grad}; | |||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
layouts.data(), layouts.size(), | |||
&this->param(), sizeof(this->param())}; | |||
auto rst = AlgorithmCache::instance().get(key); | |||
if (rst.policy.algo.valid()) { | |||
return rst.workspace; | |||
} | |||
AlgoBase::SizeArgs args(this, src, diff, grad); | |||
return get_algorithm(this, src, diff, grad)->get_workspace_in_bytes(args); | |||
} | |||
@@ -35,15 +35,6 @@ WorkspaceBundle megdnn::x86::get_bundle( | |||
size_t PoolingImpl::get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) { | |||
TensorLayoutArray layouts{src, dst}; | |||
AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
layouts.data(), layouts.size(), | |||
&this->param(), sizeof(this->param())}; | |||
auto rst = AlgorithmCache::instance().get(key); | |||
if (rst.policy.algo.valid()) { | |||
return rst.workspace; | |||
} | |||
auto algo = get_algorithm(this, src, dst); | |||
if (!is_fallback_algo(algo)) { | |||
if (is_supported(SIMDType::SSE) && src.dtype == dtype::Float32() && | |||
@@ -351,7 +351,7 @@ class TimedFuncInvokerImpl final : public TimedFuncInvoker { | |||
} else { | |||
CHECK_SYS_ERR(cur_recv); | |||
} | |||
mgb_assert(cur_recv > 0); | |||
mgb_assert(cur_recv >= 0); | |||
dest += cur_recv; | |||
size -= cur_recv; | |||
} | |||
@@ -950,10 +950,12 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile( | |||
algo.desc.name.c_str(), layouts_str.c_str()); | |||
timer.reset(); | |||
MGB_TRY { cur_rst = profile_single_algo(policy, cur_timeout); } | |||
// megbrain catched exception | |||
MGB_CATCH(std::exception & exc, { | |||
mgb_log_warn("caught exception during %s: %s", msg.c_str(), exc.what()); | |||
mgb_log_debug("caught exception during %s: %s", msg.c_str(), exc.what()); | |||
continue; | |||
}) | |||
// megbrain uncatched exception | |||
MGB_CATCH(..., { | |||
mgb_log_warn("caught exception during %s", msg.c_str()); | |||
continue; | |||