GitOrigin-RevId: d49015714c
tags/v1.3.1
@@ -105,6 +105,10 @@ public: | |||||
* | * | ||||
*/ | */ | ||||
enum class Attribute : uint32_t { | enum class Attribute : uint32_t { | ||||
/** | |||||
* \brief general algo. | |||||
*/ | |||||
DEFAULT = 0, | |||||
/** | /** | ||||
* \brief whether the execution result is | * \brief whether the execution result is | ||||
@@ -163,6 +167,8 @@ public: | |||||
bool contain_attribute(const Attribute& attr) const; | bool contain_attribute(const Attribute& attr) const; | ||||
static std::string attribute_str(const Attribute& attr); | |||||
Handle::HandleType handle_type() const { return m_handle_type; } | Handle::HandleType handle_type() const { return m_handle_type; } | ||||
Info info() const { | Info info() const { | ||||
return {{handle_type(), type(), param()}, name(), attribute()}; | return {{handle_type(), type(), param()}, name(), attribute()}; | ||||
@@ -311,6 +317,7 @@ class MultiAlgoOpr<Opr, 3> : public MultiAlgoOpr<Opr, -1> { | |||||
public: | public: | ||||
using Algorithm = detail::Algorithm; | using Algorithm = detail::Algorithm; | ||||
using AlgorithmInfo = detail::Algorithm::Info; | using AlgorithmInfo = detail::Algorithm::Info; | ||||
using AlgoAttribute = detail::Algorithm::Attribute; | |||||
//! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | ||||
@@ -335,9 +342,9 @@ public: | |||||
const TensorLayout& p2, | const TensorLayout& p2, | ||||
size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
bool reproducible = false) { | |||||
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) { | |||||
return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes, | return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes, | ||||
reproducible) | |||||
attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
@@ -360,7 +367,7 @@ protected: | |||||
const TensorLayout& p2, | const TensorLayout& p2, | ||||
size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
bool reproducible = false) = 0; | |||||
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0; | |||||
}; | }; | ||||
//! specializae for nargs == 4 | //! specializae for nargs == 4 | ||||
@@ -369,6 +376,7 @@ class MultiAlgoOpr<Opr, 4> : public MultiAlgoOpr<Opr, -1> { | |||||
public: | public: | ||||
using Algorithm = detail::Algorithm; | using Algorithm = detail::Algorithm; | ||||
using AlgorithmInfo = detail::Algorithm::Info; | using AlgorithmInfo = detail::Algorithm::Info; | ||||
using AlgoAttribute = detail::Algorithm::Attribute; | |||||
//! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | ||||
@@ -394,9 +402,9 @@ public: | |||||
const TensorLayout& p2, const TensorLayout& p3, | const TensorLayout& p2, const TensorLayout& p3, | ||||
size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
bool reproducible = false) { | |||||
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) { | |||||
return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes, | return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes, | ||||
reproducible) | |||||
attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
@@ -419,7 +427,7 @@ protected: | |||||
const TensorLayout& p2, const TensorLayout& p3, | const TensorLayout& p2, const TensorLayout& p3, | ||||
size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
bool reproducible = false) = 0; | |||||
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0; | |||||
}; | }; | ||||
//! specializae for nargs == 5 | //! specializae for nargs == 5 | ||||
@@ -428,6 +436,7 @@ class MultiAlgoOpr<Opr, 5> : public MultiAlgoOpr<Opr, -1> { | |||||
public: | public: | ||||
using Algorithm = detail::Algorithm; | using Algorithm = detail::Algorithm; | ||||
using AlgorithmInfo = detail::Algorithm::Info; | using AlgorithmInfo = detail::Algorithm::Info; | ||||
using AlgoAttribute = detail::Algorithm::Attribute; | |||||
//! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | ||||
@@ -455,9 +464,9 @@ public: | |||||
const TensorLayout& p4, | const TensorLayout& p4, | ||||
size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
bool reproducible = false) { | |||||
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) { | |||||
return get_algorithm_heuristic(p0, p1, p2, p3, p4, | return get_algorithm_heuristic(p0, p1, p2, p3, p4, | ||||
workspace_limit_in_bytes, reproducible) | |||||
workspace_limit_in_bytes, attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
@@ -482,7 +491,7 @@ protected: | |||||
const TensorLayout& p4, | const TensorLayout& p4, | ||||
size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
bool reproducible = false) = 0; | |||||
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0; | |||||
}; | }; | ||||
//! specializae for nargs == 8 | //! specializae for nargs == 8 | ||||
@@ -491,6 +500,7 @@ class MultiAlgoOpr<Opr, 8> : public MultiAlgoOpr<Opr, -1> { | |||||
public: | public: | ||||
using Algorithm = detail::Algorithm; | using Algorithm = detail::Algorithm; | ||||
using AlgorithmInfo = detail::Algorithm::Info; | using AlgorithmInfo = detail::Algorithm::Info; | ||||
using AlgoAttribute = detail::Algorithm::Attribute; | |||||
//! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info( | std::vector<AlgorithmInfo> get_all_algorithms_info( | ||||
@@ -518,9 +528,9 @@ public: | |||||
const TensorLayout& p6, const TensorLayout& p7, | const TensorLayout& p6, const TensorLayout& p7, | ||||
size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
bool reproducible = false) { | |||||
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) { | |||||
return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7, | return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7, | ||||
workspace_limit_in_bytes, reproducible) | |||||
workspace_limit_in_bytes, attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
@@ -547,7 +557,7 @@ protected: | |||||
const TensorLayout& p6, const TensorLayout& p7, | const TensorLayout& p6, const TensorLayout& p7, | ||||
size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
bool reproducible = false) = 0; | |||||
const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0; | |||||
}; | }; | ||||
} // namespace detail | } // namespace detail | ||||
@@ -15,8 +15,39 @@ | |||||
using namespace megdnn; | using namespace megdnn; | ||||
#define FOREACH_ALGO_ATTRIBUTE(cb) \ | |||||
cb(DEFAULT) \ | |||||
cb(REPRODUCIBLE) \ | |||||
cb(NAIVE) | |||||
namespace { | |||||
inline const char* attr_str(const AlgoAttribute& attr) { | |||||
#define cb(attr) \ | |||||
case AlgoAttribute::attr: \ | |||||
return #attr; | |||||
switch (attr) { FOREACH_ALGO_ATTRIBUTE(cb) } | |||||
#undef cb | |||||
return "unknown arch"; | |||||
} | |||||
} // namespace | |||||
std::string Algorithm::attribute_str(const Attribute& attr) { | |||||
std::string ret; | |||||
uint32_t attr_val = static_cast<uint32_t>(attr); | |||||
while(attr_val) { | |||||
uint32_t mask = ~(attr_val & (attr_val - 1)); | |||||
Attribute sub_attr = static_cast<Attribute>(mask & attr_val); | |||||
if (!ret.empty()) { | |||||
ret.append(" | "); | |||||
} | |||||
ret.append(attr_str(sub_attr)); | |||||
attr_val = attr_val & (attr_val - 1); | |||||
} | |||||
return ret; | |||||
} | |||||
bool Algorithm::contain_attribute(const Attribute& attr) const { | bool Algorithm::contain_attribute(const Attribute& attr) const { | ||||
return bool(attribute() & attr); | |||||
return attr == static_cast<Attribute>(attribute() & attr); | |||||
} | } | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -32,7 +32,7 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||||
} else { | } else { | ||||
ret = opr->get_algorithm_info_heuristic( | ret = opr->get_algorithm_info_heuristic( | ||||
std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), | std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), | ||||
false).desc; | |||||
AlgoAttribute::DEFAULT).desc; | |||||
} | } | ||||
return static_cast<typename Opr::AlgoBase*>( | return static_cast<typename Opr::AlgoBase*>( | ||||
opr->get_algorithm_from_desc(ret)); | opr->get_algorithm_from_desc(ret)); | ||||
@@ -51,7 +51,7 @@ typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { | |||||
return static_cast<typename Opr::AlgoBase*>( | return static_cast<typename Opr::AlgoBase*>( | ||||
opr->get_algorithm_heuristic(std::forward<Args>(args)..., | opr->get_algorithm_heuristic(std::forward<Args>(args)..., | ||||
std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
false)); | |||||
AlgoAttribute::DEFAULT)); | |||||
} | } | ||||
} | } | ||||
@@ -74,37 +74,34 @@ std::vector<typename Opr::Algorithm*> get_all_algorithms( | |||||
} | } | ||||
/*! | /*! | ||||
* \brief a helper function to get a reproducible algorithm. If require a | |||||
* reproducible algorithm, and the given algorithm is reproducible, return the | |||||
* given algorithm. Otherwise return nullptr | |||||
* \brief a helper function to get an algorithm with attribute. If require a | |||||
* algorithm with specified attribute, and the given algorithm has that | |||||
* attribute, return the given algorithm. Otherwise return nullptr | |||||
*/ | */ | ||||
template <typename Opr> | template <typename Opr> | ||||
typename Opr::Algorithm* get_reproducible_algo(typename Opr::AlgoBase* algo, | |||||
bool reproducible) { | |||||
if (reproducible) { | |||||
if (algo->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
return algo; | |||||
} | |||||
} else { | |||||
typename Opr::Algorithm* get_algo_with_attribute(typename Opr::AlgoBase* algo, | |||||
const AlgoAttribute& attr) { | |||||
if (algo->contain_attribute(attr)) { | |||||
return algo; | return algo; | ||||
} | } | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
template <typename Opr> | template <typename Opr> | ||||
typename Opr::Algorithm* get_reproducible_algo( | |||||
typename Opr::Algorithm* get_algo_with_attribute( | |||||
const std::vector<typename Opr::AlgoBase*>& algos, | const std::vector<typename Opr::AlgoBase*>& algos, | ||||
const typename Opr::AlgoBase::SizeArgs& args, | const typename Opr::AlgoBase::SizeArgs& args, | ||||
size_t workspace_limit_in_bytes, const char* name) { | |||||
size_t workspace_limit_in_bytes, const char* name, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) { | |||||
size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max(); | size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max(); | ||||
bool available_but_limited_by_workspace = false; | bool available_but_limited_by_workspace = false; | ||||
bool available_but_not_reproducible = false; | |||||
bool available_but_without_attribute = false; | |||||
for (auto i : algos) { | for (auto i : algos) { | ||||
if (i->is_available_reproducible(args, true, | |||||
if (i->is_available_attribute(args, attr, | |||||
workspace_limit_in_bytes)) { | workspace_limit_in_bytes)) { | ||||
return i; | return i; | ||||
} | } | ||||
if (i->is_available_reproducible(args)) { | |||||
if (i->is_available_attribute(args)) { | |||||
if (i->get_workspace_in_bytes(args) > workspace_limit_in_bytes) { | if (i->get_workspace_in_bytes(args) > workspace_limit_in_bytes) { | ||||
available_but_limited_by_workspace = true; | available_but_limited_by_workspace = true; | ||||
min_workspace_limit_in_bytes = | min_workspace_limit_in_bytes = | ||||
@@ -113,20 +110,22 @@ typename Opr::Algorithm* get_reproducible_algo( | |||||
} | } | ||||
} | } | ||||
if (i->is_available(args)) { | if (i->is_available(args)) { | ||||
if (!i->contain_attribute(AlgoAttribute::REPRODUCIBLE)) | |||||
available_but_not_reproducible = true; | |||||
if (!i->contain_attribute(attr)) | |||||
available_but_without_attribute = true; | |||||
} | } | ||||
} | } | ||||
MEGDNN_MARK_USED_VAR(name); | MEGDNN_MARK_USED_VAR(name); | ||||
if (available_but_limited_by_workspace) { | if (available_but_limited_by_workspace) { | ||||
megdnn_throw(ssprintf( | megdnn_throw(ssprintf( | ||||
"no reproducible %s algorithm: %s workspace limit %zu is " | |||||
"no %s algorithm with attribute:%s : %s workspace limit %zu is " | |||||
"less than mini workspace limit %zu", | "less than mini workspace limit %zu", | ||||
name, args.to_string().c_str(), workspace_limit_in_bytes, | |||||
name, Algorithm::attribute_str(attr).c_str(), | |||||
args.to_string().c_str(), workspace_limit_in_bytes, | |||||
min_workspace_limit_in_bytes)); | min_workspace_limit_in_bytes)); | ||||
} else if (available_but_not_reproducible) { | |||||
megdnn_throw(ssprintf("no reproducible %s algorithm", name)); | |||||
} else if (available_but_without_attribute) { | |||||
megdnn_throw(ssprintf("no %s algorithm with attribute:%s", name, | |||||
Algorithm::attribute_str(attr).c_str())); | |||||
} else { | } else { | ||||
megdnn_throw(ssprintf("no usable %s algorithm", name)); | megdnn_throw(ssprintf("no usable %s algorithm", name)); | ||||
} | } | ||||
@@ -65,12 +65,11 @@ public: | |||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -22,21 +22,21 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic( | |||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args(this, src, filter, bias, z, dst); | AlgoBase::SizeArgs args(this, src, filter, bias, z, dst); | ||||
if (sm_algo_pack.int8_nchw4_gemm_dotprod.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.int8_nchw4_gemm_dotprod.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.int8_nchw4_gemm_dotprod; | return &sm_algo_pack.int8_nchw4_gemm_dotprod; | ||||
} | } | ||||
if (sm_algo_pack.int8_nchw4_implicit_gemm_dotprod.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.int8_nchw4_implicit_gemm_dotprod.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.int8_nchw4_implicit_gemm_dotprod; | return &sm_algo_pack.int8_nchw4_implicit_gemm_dotprod; | ||||
} | } | ||||
megdnn_throw( | |||||
ssprintf("no %s batch conv bias algorithm with args(%s) and " | |||||
"workspace limit (%zu bytes)", | |||||
reproducible ? "reproducible" : "usable", | |||||
args.to_string().c_str(), workspace_limit_in_bytes)); | |||||
megdnn_throw(ssprintf( | |||||
"no batch conv bias algorithm with attribute%s args(%s) and " | |||||
"workspace limit (%zu bytes)", | |||||
Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(), | |||||
workspace_limit_in_bytes)); | |||||
} | } | ||||
std::vector<BatchConvBiasForwardImpl::Algorithm*> | std::vector<BatchConvBiasForwardImpl::Algorithm*> | ||||
@@ -48,7 +48,7 @@ protected: | |||||
const TensorLayout& z, | const TensorLayout& z, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
@@ -68,12 +68,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) { | bool is_available_wk(const SizeArgs& args, size_t limit) { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
const Workspace& workspace) { | const Workspace& workspace) { | ||||
@@ -55,24 +55,21 @@ std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms( | |||||
Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, bool reproducible) { | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); | MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); | ||||
AlgoBase::SizeArgs args(this, A, B, C); | AlgoBase::SizeArgs args(this, A, B, C); | ||||
if (sm_algo_pack.cublas.is_available_reproducible(args, reproducible)) { | |||||
if (sm_algo_pack.cublas.is_available_attribute(args, attr)) { | |||||
return &sm_algo_pack.cublas; | return &sm_algo_pack.cublas; | ||||
} | } | ||||
#if CUDA_VERSION >= 10010 | #if CUDA_VERSION >= 10010 | ||||
else if (sm_algo_pack.cublasLt.is_available_reproducible(args, | |||||
reproducible)) { | |||||
else if (sm_algo_pack.cublasLt.is_available_attribute(args, attr)) { | |||||
return &sm_algo_pack.cublasLt; | return &sm_algo_pack.cublasLt; | ||||
} | } | ||||
#endif | #endif | ||||
else if (sm_algo_pack.int8x8x32.is_available_reproducible(args, | |||||
reproducible)) { | |||||
else if (sm_algo_pack.int8x8x32.is_available_attribute(args, attr)) { | |||||
return &sm_algo_pack.int8x8x32; | return &sm_algo_pack.int8x8x32; | ||||
} else { | } else { | ||||
if (sm_algo_pack.brute_force.is_available_reproducible(args, | |||||
reproducible)) { | |||||
if (sm_algo_pack.brute_force.is_available_attribute(args, attr)) { | |||||
return &sm_algo_pack.brute_force; | return &sm_algo_pack.brute_force; | ||||
} | } | ||||
} | } | ||||
@@ -49,7 +49,7 @@ protected: | |||||
const TensorLayout& B, | const TensorLayout& B, | ||||
const TensorLayout& C, | const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
@@ -127,12 +127,11 @@ public: | |||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -51,7 +51,7 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; | AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; | ||||
auto dst_layout = *args.dst_layout; | auto dst_layout = *args.dst_layout; | ||||
@@ -74,7 +74,7 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
}; | }; | ||||
auto get_cudnn_algo = | auto get_cudnn_algo = | ||||
[this, &conv_args, &args, workspace_limit_in_bytes, reproducible]( | |||||
[this, &conv_args, &args, workspace_limit_in_bytes, attr]( | |||||
const thin_function<AlgoBase*(cudnnConvolutionFwdAlgo_t)>& | const thin_function<AlgoBase*(cudnnConvolutionFwdAlgo_t)>& | ||||
cb) -> AlgoBase* { | cb) -> AlgoBase* { | ||||
auto cudnn_handle = cuda::cudnn_handle(this->handle()); | auto cudnn_handle = cuda::cudnn_handle(this->handle()); | ||||
@@ -92,8 +92,8 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
&ret_count, algo_perf.data())); | &ret_count, algo_perf.data())); | ||||
for (int i = 0; i < ret_count; ++i) { | for (int i = 0; i < ret_count; ++i) { | ||||
auto conv_bias_algo = cb(algo_perf[i].algo); | auto conv_bias_algo = cb(algo_perf[i].algo); | ||||
if (conv_bias_algo->is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) | |||||
if (conv_bias_algo->is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) | |||||
return conv_bias_algo; | return conv_bias_algo; | ||||
} | } | ||||
#else | #else | ||||
@@ -105,18 +105,18 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
workspace_limit_in_bytes, &algo)); | workspace_limit_in_bytes, &algo)); | ||||
auto conv_bias_algo = cb(algo); | auto conv_bias_algo = cb(algo); | ||||
if (conv_bias_algo->is_available_reproducible(args, reproducible, | |||||
workspace_limit_in_bytes)) | |||||
if (conv_bias_algo->is_available_attribute(args, attr, | |||||
workspace_limit_in_bytes)) | |||||
return conv_bias_algo; | return conv_bias_algo; | ||||
#endif | #endif | ||||
return nullptr; | return nullptr; | ||||
}; | }; | ||||
auto get_1x1_algo = [workspace_limit_in_bytes, | auto get_1x1_algo = [workspace_limit_in_bytes, | ||||
reproducible](const AlgoBase::SizeArgs& size_arg) | |||||
attr](const AlgoBase::SizeArgs& size_arg) | |||||
-> ConvBiasForwardImpl::AlgoBase* { | -> ConvBiasForwardImpl::AlgoBase* { | ||||
if (sm_algo_pack.batched_matmul.is_available_reproducible( | |||||
size_arg, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.batched_matmul.is_available_attribute( | |||||
size_arg, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.batched_matmul; | return &sm_algo_pack.batched_matmul; | ||||
} | } | ||||
return nullptr; | return nullptr; | ||||
@@ -144,11 +144,11 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
//! avoid bad case in cudnn, check dnn chanwise impl first | //! avoid bad case in cudnn, check dnn chanwise impl first | ||||
if (is_chanwise) { | if (is_chanwise) { | ||||
if (prefer_dnn_chanwise) { | if (prefer_dnn_chanwise) { | ||||
if (sm_algo_pack.chanwise.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) | |||||
if (sm_algo_pack.chanwise.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) | |||||
return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
if (sm_algo_pack.chanwise8x8x32.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) | |||||
if (sm_algo_pack.chanwise8x8x32.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) | |||||
return &sm_algo_pack.chanwise8x8x32; | return &sm_algo_pack.chanwise8x8x32; | ||||
} else { | } else { | ||||
conv_args.dst_layout = &dst_layout; | conv_args.dst_layout = &dst_layout; | ||||
@@ -163,8 +163,7 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
//! Prefer CUDNN CONVBIAS. | //! Prefer CUDNN CONVBIAS. | ||||
bool cudnn_conv_bias_act_supported = false; | bool cudnn_conv_bias_act_supported = false; | ||||
for (auto&& algo : sm_algo_pack.cudnn_conv_bias_activations) { | for (auto&& algo : sm_algo_pack.cudnn_conv_bias_activations) { | ||||
if (algo.is_available_reproducible(args, reproducible, | |||||
workspace_limit_in_bytes)) { | |||||
if (algo.is_available_attribute(args, attr, workspace_limit_in_bytes)) { | |||||
cudnn_conv_bias_act_supported = true; | cudnn_conv_bias_act_supported = true; | ||||
break; | break; | ||||
} | } | ||||
@@ -201,26 +200,26 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
return algo; | return algo; | ||||
} | } | ||||
if (sm_algo_pack.fallback_nchw_qs8.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.fallback_nchw_qs8.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.fallback_nchw_qs8; | return &sm_algo_pack.fallback_nchw_qs8; | ||||
} | } | ||||
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute<ConvBiasForwardImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, | sm_algo_pack.non_cudnn_algos, args, | ||||
workspace_limit_in_bytes, "cuda convbias fwd"); | |||||
workspace_limit_in_bytes, "cuda convbias fwd", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<ConvBiasForwardImpl>( | return megdnn::get_usable_algo<ConvBiasForwardImpl>( | ||||
sm_algo_pack.non_cudnn_algos, args, | sm_algo_pack.non_cudnn_algos, args, | ||||
workspace_limit_in_bytes, "cuda convbias fwd"); | workspace_limit_in_bytes, "cuda convbias fwd"); | ||||
} | } | ||||
} else { | } else { | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvBiasForwardImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute<ConvBiasForwardImpl>( | |||||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | ||||
"cuda convbias fwd"); | |||||
"cuda convbias fwd", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<ConvBiasForwardImpl>( | return megdnn::get_usable_algo<ConvBiasForwardImpl>( | ||||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | ||||
@@ -82,7 +82,7 @@ public: | |||||
const TensorLayout& z, | const TensorLayout& z, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
@@ -82,12 +82,11 @@ public: | |||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -78,12 +78,11 @@ public: | |||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -63,13 +63,13 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) const { | bool is_available_wk(const SizeArgs& args, size_t limit) const { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
size_t limit = std::numeric_limits<size_t>::max()) const { | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
const Workspace& workspace) { | const Workspace& workspace) { | ||||
auto req = get_workspace_in_bytes(args); | auto req = get_workspace_in_bytes(args); | ||||
@@ -12,6 +12,7 @@ | |||||
#include "src/cuda/convolution/opr_impl.h" | #include "src/cuda/convolution/opr_impl.h" | ||||
#include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
#include "src/common/algo_chooser.h" | |||||
#include "src/cuda/convolution/helper.h" | #include "src/cuda/convolution/helper.h" | ||||
#include "src/cuda/convolution/forward/algos.h" | #include "src/cuda/convolution/forward/algos.h" | ||||
#include "src/cuda/convolution/backward_data/algo.h" | #include "src/cuda/convolution/backward_data/algo.h" | ||||
@@ -36,10 +37,10 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src, | |||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args{this, src, filter, dst}; | AlgoBase::SizeArgs args{this, src, filter, dst}; | ||||
MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); | MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); | ||||
MEGDNN_MARK_USED_VAR(reproducible); | |||||
MEGDNN_MARK_USED_VAR(attr); | |||||
return &sm_algo_pack.algo_default; | return &sm_algo_pack.algo_default; | ||||
} | } | ||||
@@ -100,32 +101,32 @@ ConvolutionBackwardDataImpl::Algorithm* | |||||
ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto fm = check_layout_fwd(grad, filter, diff); | auto fm = check_layout_fwd(grad, filter, diff); | ||||
return get_algorithm_heuristic(filter, fm, diff, grad, | return get_algorithm_heuristic(filter, fm, diff, grad, | ||||
workspace_limit_in_bytes, reproducible); | |||||
workspace_limit_in_bytes, attr); | |||||
} | } | ||||
ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||||
const TensorLayout& filter, const CanonizedFilterMeta& filter_meta, | |||||
const TensorLayout& diff, const TensorLayout& grad, | |||||
size_t workspace_limit_in_bytes, bool reproducible) { | |||||
ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, | |||||
const CanonizedFilterMeta& filter_meta, const TensorLayout& diff, | |||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args(this, filter, filter_meta, diff, grad); | AlgoBase::SizeArgs args(this, filter, filter_meta, diff, grad); | ||||
if (args.filter_meta.group > 1 && | if (args.filter_meta.group > 1 && | ||||
sm_algo_pack.chanwise.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
sm_algo_pack.chanwise.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
// prefer special chanwise impl | // prefer special chanwise impl | ||||
return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
} | } | ||||
if (args.filter_layout->dtype.enumv() == | if (args.filter_layout->dtype.enumv() == | ||||
DTypeTrait<dtype::QuantizedS8>::enumv) { | DTypeTrait<dtype::QuantizedS8>::enumv) { | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>( | |||||
sm_algo_pack.int8_algos, args, workspace_limit_in_bytes, | sm_algo_pack.int8_algos, args, workspace_limit_in_bytes, | ||||
"cuda conv bwd_data"); | |||||
"cuda conv bwd_data", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | ||||
sm_algo_pack.int8_algos, args, workspace_limit_in_bytes, | sm_algo_pack.int8_algos, args, workspace_limit_in_bytes, | ||||
@@ -133,9 +134,8 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||||
} | } | ||||
} | } | ||||
auto get_cudnn_algo = | |||||
[this, &args, workspace_limit_in_bytes, | |||||
reproducible]() -> ConvolutionBackwardDataImpl::AlgoBase* { | |||||
auto get_cudnn_algo = [this, &args, workspace_limit_in_bytes, | |||||
attr]() -> ConvolutionBackwardDataImpl::AlgoBase* { | |||||
auto cudnn_handle = cuda::cudnn_handle(this->handle()); | auto cudnn_handle = cuda::cudnn_handle(this->handle()); | ||||
CUDNNBwdDataDescs desc; | CUDNNBwdDataDescs desc; | ||||
args.init_desc(desc); | args.init_desc(desc); | ||||
@@ -153,7 +153,7 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||||
for (int i = 0; i < ret_count; ++i) { | for (int i = 0; i < ret_count; ++i) { | ||||
if (algo_perf[i].memory > workspace_limit_in_bytes) | if (algo_perf[i].memory > workspace_limit_in_bytes) | ||||
continue; | continue; | ||||
if (reproducible) { | |||||
if (attr & AlgoAttribute::REPRODUCIBLE) { | |||||
if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { | if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { | ||||
return reinterpret_cast<AlgoBase*>( | return reinterpret_cast<AlgoBase*>( | ||||
sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); | sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); | ||||
@@ -174,8 +174,8 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||||
auto&& cast_algo = | auto&& cast_algo = | ||||
reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo)); | reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo)); | ||||
return reinterpret_cast<AlgoBase*>( | return reinterpret_cast<AlgoBase*>( | ||||
megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( | |||||
cast_algo, reproducible)); | |||||
megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>( | |||||
cast_algo, attr)); | |||||
#endif | #endif | ||||
}; | }; | ||||
@@ -197,20 +197,20 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||||
if (args.filter_layout->dtype.enumv() != | if (args.filter_layout->dtype.enumv() != | ||||
DTypeTrait<dtype::BFloat16>::enumv) { | DTypeTrait<dtype::BFloat16>::enumv) { | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, | sm_algo_pack.non_cudnn_algos, args, | ||||
workspace_limit_in_bytes, "cuda conv bwd_data"); | |||||
workspace_limit_in_bytes, "cuda conv bwd_data", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | ||||
sm_algo_pack.non_cudnn_algos, args, | sm_algo_pack.non_cudnn_algos, args, | ||||
workspace_limit_in_bytes, "cuda conv bwd_data"); | workspace_limit_in_bytes, "cuda conv bwd_data"); | ||||
} | } | ||||
} else { | } else { | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>( | |||||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | ||||
"cuda conv bwd_data"); | |||||
"cuda conv bwd_data", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | ||||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | ||||
@@ -255,29 +255,29 @@ ConvolutionBackwardFilterImpl::Algorithm* | |||||
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto fm = check_layout_fwd(src, grad, diff); | auto fm = check_layout_fwd(src, grad, diff); | ||||
return get_algorithm_heuristic(src, diff, grad, fm, | return get_algorithm_heuristic(src, diff, grad, fm, | ||||
workspace_limit_in_bytes, reproducible); | |||||
workspace_limit_in_bytes, attr); | |||||
} | } | ||||
ConvolutionBackwardFilterImpl::Algorithm* | ConvolutionBackwardFilterImpl::Algorithm* | ||||
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad, const CanonizedFilterMeta& grad_meta, | const TensorLayout& grad, const CanonizedFilterMeta& grad_meta, | ||||
size_t workspace_limit_in_bytes, bool reproducible) { | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args(this, src, diff, grad, grad_meta); | AlgoBase::SizeArgs args(this, src, diff, grad, grad_meta); | ||||
if (args.grad_filter_meta.group > 1 && | if (args.grad_filter_meta.group > 1 && | ||||
sm_algo_pack.chanwise.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
sm_algo_pack.chanwise.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
// prefer special chanwise impl | // prefer special chanwise impl | ||||
return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
} | } | ||||
auto get_cudnn_algo = | auto get_cudnn_algo = | ||||
[this, &args, workspace_limit_in_bytes, | [this, &args, workspace_limit_in_bytes, | ||||
reproducible]() -> ConvolutionBackwardFilterImpl::AlgoBase* { | |||||
attr]() -> ConvolutionBackwardFilterImpl::AlgoBase* { | |||||
auto cudnn_handle = cuda::cudnn_handle(this->handle()); | auto cudnn_handle = cuda::cudnn_handle(this->handle()); | ||||
CUDNNBwdFilterDescs desc; | CUDNNBwdFilterDescs desc; | ||||
args.init_desc(desc); | args.init_desc(desc); | ||||
@@ -305,7 +305,7 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||||
for (int i = 0; i < ret_count; ++i) { | for (int i = 0; i < ret_count; ++i) { | ||||
if (algo_perf[i].memory > workspace_limit_in_bytes) | if (algo_perf[i].memory > workspace_limit_in_bytes) | ||||
continue; | continue; | ||||
if (reproducible) { | |||||
if (attr & AlgoAttribute::REPRODUCIBLE) { | |||||
if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { | if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { | ||||
return reinterpret_cast<AlgoBase*>( | return reinterpret_cast<AlgoBase*>( | ||||
sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); | sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); | ||||
@@ -326,8 +326,8 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||||
auto&& cast_algo = | auto&& cast_algo = | ||||
reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo)); | reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo)); | ||||
return reinterpret_cast<AlgoBase*>( | return reinterpret_cast<AlgoBase*>( | ||||
megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( | |||||
cast_algo, reproducible)); | |||||
megdnn::get_algo_with_attribute<ConvolutionBackwardFilterImpl>( | |||||
cast_algo, attr)); | |||||
#endif | #endif | ||||
}; | }; | ||||
@@ -348,20 +348,22 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||||
} | } | ||||
if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute< | |||||
ConvolutionBackwardFilterImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, | sm_algo_pack.non_cudnn_algos, args, | ||||
workspace_limit_in_bytes, "cuda conv bwd_filter"); | |||||
workspace_limit_in_bytes, "cuda conv bwd_filter", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | ||||
sm_algo_pack.non_cudnn_algos, args, | sm_algo_pack.non_cudnn_algos, args, | ||||
workspace_limit_in_bytes, "cuda conv bwd_filter"); | workspace_limit_in_bytes, "cuda conv bwd_filter"); | ||||
} | } | ||||
} else { | } else { | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute< | |||||
ConvolutionBackwardFilterImpl>( | |||||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | ||||
"cuda conv bwd_filter"); | |||||
"cuda conv bwd_filter", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | ||||
sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | ||||
@@ -63,7 +63,7 @@ protected: | |||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
@@ -77,9 +77,9 @@ public: | |||||
AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
const TensorLayout& filter, const CanonizedFilterMeta& filter_meta, | const TensorLayout& filter, const CanonizedFilterMeta& filter_meta, | ||||
const TensorLayout& diff, const TensorLayout& grad, | const TensorLayout& diff, const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, bool reproducible) { | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
return get_algorithm_heuristic(filter, filter_meta, diff, grad, | return get_algorithm_heuristic(filter, filter_meta, diff, grad, | ||||
workspace_limit_in_bytes, reproducible) | |||||
workspace_limit_in_bytes, attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
@@ -87,9 +87,9 @@ public: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
return get_algorithm_heuristic(filter, diff, grad, | return get_algorithm_heuristic(filter, diff, grad, | ||||
workspace_limit_in_bytes, reproducible) | |||||
workspace_limit_in_bytes, attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
@@ -122,7 +122,7 @@ protected: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | ||||
@@ -130,7 +130,7 @@ private: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible); | |||||
const AlgoAttribute& attr); | |||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
}; | }; | ||||
@@ -146,9 +146,9 @@ public: | |||||
AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad, const CanonizedFilterMeta& grad_meta, | const TensorLayout& grad, const CanonizedFilterMeta& grad_meta, | ||||
size_t workspace_limit_in_bytes, bool reproducible) { | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
return get_algorithm_heuristic(src, diff, grad, grad_meta, | return get_algorithm_heuristic(src, diff, grad, grad_meta, | ||||
workspace_limit_in_bytes, reproducible) | |||||
workspace_limit_in_bytes, attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
@@ -156,9 +156,9 @@ public: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
return get_algorithm_heuristic(filter, diff, grad, | return get_algorithm_heuristic(filter, diff, grad, | ||||
workspace_limit_in_bytes, reproducible) | |||||
workspace_limit_in_bytes, attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
@@ -185,7 +185,7 @@ protected: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | Algorithm* get_algorithm_heuristic(const TensorLayout& src, | ||||
@@ -193,7 +193,7 @@ private: | |||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
const CanonizedFilterMeta& grad_meta, | const CanonizedFilterMeta& grad_meta, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible); | |||||
const AlgoAttribute& attr); | |||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
}; | }; | ||||
@@ -75,12 +75,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) { | bool is_available_wk(const SizeArgs& args, size_t limit) { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
const Workspace& workspace) { | const Workspace& workspace) { | ||||
@@ -69,12 +69,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) { | bool is_available_wk(const SizeArgs& args, size_t limit) { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
const Workspace& workspace) { | const Workspace& workspace) { | ||||
@@ -74,12 +74,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) { | bool is_available_wk(const SizeArgs& args, size_t limit) { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
const Workspace& workspace) { | const Workspace& workspace) { | ||||
@@ -97,8 +97,8 @@ namespace convolution3d { | |||||
const cudnnConvolutionDescriptor_t conv_desc, | const cudnnConvolutionDescriptor_t conv_desc, | ||||
const cudnnTensorDescriptor_t y_desc, | const cudnnTensorDescriptor_t y_desc, | ||||
size_t workspace_limit_in_bytes, cudnnConvolutionFwdAlgo_t* algo, | size_t workspace_limit_in_bytes, cudnnConvolutionFwdAlgo_t* algo, | ||||
bool reproducible) { | |||||
MEGDNN_MARK_USED_VAR(reproducible); | |||||
const AlgoAttribute& attr) { | |||||
MEGDNN_MARK_USED_VAR(attr); | |||||
#if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
int algo_max_count = 0; | int algo_max_count = 0; | ||||
cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount( | cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount( | ||||
@@ -118,7 +118,7 @@ namespace convolution3d { | |||||
cudnn_handle, x_desc, w_desc, conv_desc, y_desc, | cudnn_handle, x_desc, w_desc, conv_desc, y_desc, | ||||
algo_perf[i].algo, &workspace_size)); | algo_perf[i].algo, &workspace_size)); | ||||
if (workspace_size > workspace_limit_in_bytes) continue; | if (workspace_size > workspace_limit_in_bytes) continue; | ||||
if (!reproducible) { | |||||
if (!(attr & AlgoAttribute::REPRODUCIBLE)) { | |||||
*algo = algo_perf[i].algo; | *algo = algo_perf[i].algo; | ||||
return true; | return true; | ||||
} else { | } else { | ||||
@@ -144,8 +144,8 @@ namespace convolution3d { | |||||
const cudnnConvolutionDescriptor_t conv_desc, | const cudnnConvolutionDescriptor_t conv_desc, | ||||
const cudnnTensorDescriptor_t dx_desc, | const cudnnTensorDescriptor_t dx_desc, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
cudnnConvolutionBwdDataAlgo_t* algo, bool reproducible) { | |||||
MEGDNN_MARK_USED_VAR(reproducible); | |||||
cudnnConvolutionBwdDataAlgo_t* algo, const AlgoAttribute& attr) { | |||||
MEGDNN_MARK_USED_VAR(attr); | |||||
#if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
int algo_max_count = 0; | int algo_max_count = 0; | ||||
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( | cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( | ||||
@@ -166,7 +166,7 @@ namespace convolution3d { | |||||
cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc, | cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc, | ||||
algo_perf[i].algo, &workspace_size)); | algo_perf[i].algo, &workspace_size)); | ||||
if (workspace_size > workspace_limit_in_bytes) continue; | if (workspace_size > workspace_limit_in_bytes) continue; | ||||
if (!reproducible) { | |||||
if (!(attr & AlgoAttribute::REPRODUCIBLE)) { | |||||
*algo = algo_perf[i].algo; | *algo = algo_perf[i].algo; | ||||
return true; | return true; | ||||
} else { | } else { | ||||
@@ -193,8 +193,8 @@ namespace convolution3d { | |||||
const cudnnConvolutionDescriptor_t conv_desc, | const cudnnConvolutionDescriptor_t conv_desc, | ||||
const cudnnFilterDescriptor_t dw_desc, | const cudnnFilterDescriptor_t dw_desc, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
cudnnConvolutionBwdFilterAlgo_t* algo, bool reproducible) { | |||||
MEGDNN_MARK_USED_VAR(reproducible); | |||||
cudnnConvolutionBwdFilterAlgo_t* algo, const AlgoAttribute& attr) { | |||||
MEGDNN_MARK_USED_VAR(attr); | |||||
#if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
int algo_max_count = 0; | int algo_max_count = 0; | ||||
cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( | cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( | ||||
@@ -207,14 +207,15 @@ namespace convolution3d { | |||||
algo_max_count, &algo_count, algo_perf.data())); | algo_max_count, &algo_count, algo_perf.data())); | ||||
for (int i = 0; i < algo_count; ++i) { | for (int i = 0; i < algo_count; ++i) { | ||||
if (algo_perf[i].algo == | if (algo_perf[i].algo == | ||||
cudnnConvolutionBwdFilterAlgo_t::CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING) | |||||
cudnnConvolutionBwdFilterAlgo_t:: | |||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING) | |||||
continue; | continue; | ||||
size_t workspace_size = 0; | size_t workspace_size = 0; | ||||
cudnn_check(cudnnGetConvolutionBackwardFilterWorkspaceSize( | cudnn_check(cudnnGetConvolutionBackwardFilterWorkspaceSize( | ||||
cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc, | cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc, | ||||
algo_perf[i].algo, &workspace_size)); | algo_perf[i].algo, &workspace_size)); | ||||
if (workspace_size > workspace_limit_in_bytes) continue; | if (workspace_size > workspace_limit_in_bytes) continue; | ||||
if (!reproducible) { | |||||
if (!(attr & AlgoAttribute::REPRODUCIBLE)) { | |||||
*algo = algo_perf[i].algo; | *algo = algo_perf[i].algo; | ||||
return true; | return true; | ||||
} else { | } else { | ||||
@@ -15,6 +15,7 @@ | |||||
#include "./forward/algo.h" | #include "./forward/algo.h" | ||||
#include "./helper.h" | #include "./helper.h" | ||||
#include "src/common/algo_chooser.h" | |||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -32,16 +33,16 @@ Convolution3DForwardImpl::Algorithm* | |||||
Convolution3DForwardImpl::get_algorithm_heuristic( | Convolution3DForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto fm = check_layout_fwd(src, filter, dst); | auto fm = check_layout_fwd(src, filter, dst); | ||||
return get_algorithm_heuristic(src, fm, dst, workspace_limit_in_bytes, | return get_algorithm_heuristic(src, fm, dst, workspace_limit_in_bytes, | ||||
reproducible); | |||||
attr); | |||||
} | } | ||||
Convolution3DForwardImpl::Algorithm* | Convolution3DForwardImpl::Algorithm* | ||||
Convolution3DForwardImpl::get_algorithm_heuristic( | Convolution3DForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const CanonizedFilterMeta& filter, | const TensorLayout& src, const CanonizedFilterMeta& filter, | ||||
const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args(this, src, filter, dst); | AlgoBase::SizeArgs args(this, src, filter, dst); | ||||
#if CUDNN_MAJOR < 7 || (CUDNN_MAJOR == 7 && CUDNN_MINOR < 5) | #if CUDNN_MAJOR < 7 || (CUDNN_MAJOR == 7 && CUDNN_MINOR < 5) | ||||
@@ -49,26 +50,26 @@ Convolution3DForwardImpl::get_algorithm_heuristic( | |||||
// prefer special chanwise impl since as the group conv of cudnn whose | // prefer special chanwise impl since as the group conv of cudnn whose | ||||
// version is lower than v7.5.0 is still slower than our implementation | // version is lower than v7.5.0 is still slower than our implementation | ||||
// in many channel-wise cases | // in many channel-wise cases | ||||
if (sm_algo_pack.chanwise.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.chanwise.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
} | } | ||||
} | } | ||||
#endif | #endif | ||||
auto prefer_1x1x1 = [&args, reproducible, workspace_limit_in_bytes]() { | |||||
auto prefer_1x1x1 = [&args, attr, workspace_limit_in_bytes]() { | |||||
const size_t MAX_BATCH_SIZE_FOR_1x1x1_MAT_ALGO = 4; | const size_t MAX_BATCH_SIZE_FOR_1x1x1_MAT_ALGO = 4; | ||||
size_t batch_size = args.src_layout->shape[0]; | size_t batch_size = args.src_layout->shape[0]; | ||||
if (batch_size > MAX_BATCH_SIZE_FOR_1x1x1_MAT_ALGO) { | if (batch_size > MAX_BATCH_SIZE_FOR_1x1x1_MAT_ALGO) { | ||||
return false; | return false; | ||||
} | } | ||||
return sm_algo_pack.a1x1x1.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes); | |||||
return sm_algo_pack.a1x1x1.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes); | |||||
}; | }; | ||||
auto get_cudnn_algo = | auto get_cudnn_algo = | ||||
[this, &args, workspace_limit_in_bytes, | [this, &args, workspace_limit_in_bytes, | ||||
reproducible]() -> Convolution3DForwardImpl::AlgoBase* { | |||||
attr]() -> Convolution3DForwardImpl::AlgoBase* { | |||||
auto cudnn_handle = cuda::cudnn_handle(this->handle()); | auto cudnn_handle = cuda::cudnn_handle(this->handle()); | ||||
cudnnConvolutionFwdAlgo_t algo; | cudnnConvolutionFwdAlgo_t algo; | ||||
CUDNNForwardDescs desc; | CUDNNForwardDescs desc; | ||||
@@ -77,11 +78,11 @@ Convolution3DForwardImpl::get_algorithm_heuristic( | |||||
bool got = cudnn_get_convolution_fwd_algo_helper( | bool got = cudnn_get_convolution_fwd_algo_helper( | ||||
cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | ||||
desc.conv_desc.desc, desc.dst_desc.desc, | desc.conv_desc.desc, desc.dst_desc.desc, | ||||
workspace_limit_in_bytes, &algo, reproducible); | |||||
workspace_limit_in_bytes, &algo, attr); | |||||
if (got) { | if (got) { | ||||
return static_cast<AlgoBase*>( | return static_cast<AlgoBase*>( | ||||
megdnn::get_reproducible_algo<Convolution3DForwardImpl>( | |||||
sm_algo_pack.cudnn_from_enum(algo), reproducible)); | |||||
megdnn::get_algo_with_attribute<Convolution3DForwardImpl>( | |||||
sm_algo_pack.cudnn_from_enum(algo), attr)); | |||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
@@ -107,10 +108,10 @@ Convolution3DForwardImpl::get_algorithm_heuristic( | |||||
args = orig_args; | args = orig_args; | ||||
} | } | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<Convolution3DForwardImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute<Convolution3DForwardImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | ||||
"cuda conv3d fwd"); | |||||
"cuda conv3d fwd", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<Convolution3DForwardImpl>( | return megdnn::get_usable_algo<Convolution3DForwardImpl>( | ||||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | ||||
@@ -168,28 +169,28 @@ Convolution3DBackwardDataImpl::Algorithm* | |||||
Convolution3DBackwardDataImpl::get_algorithm_heuristic( | Convolution3DBackwardDataImpl::get_algorithm_heuristic( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto fm = check_layout_fwd(grad, filter, diff); | auto fm = check_layout_fwd(grad, filter, diff); | ||||
return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes, | return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes, | ||||
reproducible); | |||||
attr); | |||||
} | } | ||||
Convolution3DBackwardDataImpl::Algorithm* | Convolution3DBackwardDataImpl::Algorithm* | ||||
Convolution3DBackwardDataImpl::get_algorithm_heuristic( | Convolution3DBackwardDataImpl::get_algorithm_heuristic( | ||||
const CanonizedFilterMeta& filter, const TensorLayout& diff, | const CanonizedFilterMeta& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args(this, filter, diff, grad); | AlgoBase::SizeArgs args(this, filter, diff, grad); | ||||
if (args.filter_meta.group > 1 && | if (args.filter_meta.group > 1 && | ||||
sm_algo_pack.chanwise.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
sm_algo_pack.chanwise.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
} | } | ||||
auto get_cudnn_algo = | auto get_cudnn_algo = | ||||
[this, &args, workspace_limit_in_bytes, | [this, &args, workspace_limit_in_bytes, | ||||
reproducible]() -> Convolution3DBackwardDataImpl::AlgoBase* { | |||||
attr]() -> Convolution3DBackwardDataImpl::AlgoBase* { | |||||
auto cudnn_handle = cuda::cudnn_handle(this->handle()); | auto cudnn_handle = cuda::cudnn_handle(this->handle()); | ||||
cudnnConvolutionBwdDataAlgo_t algo; | cudnnConvolutionBwdDataAlgo_t algo; | ||||
CUDNNBwdDataDescs desc; | CUDNNBwdDataDescs desc; | ||||
@@ -197,11 +198,11 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic( | |||||
bool got = cudnn_get_convolution_bwd_data_algo_helper( | bool got = cudnn_get_convolution_bwd_data_algo_helper( | ||||
cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, | cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, | ||||
desc.conv_desc.desc, desc.grad_desc.desc, | desc.conv_desc.desc, desc.grad_desc.desc, | ||||
workspace_limit_in_bytes, &algo, reproducible); | |||||
workspace_limit_in_bytes, &algo, attr); | |||||
if (got) { | if (got) { | ||||
return static_cast<AlgoBase*>(megdnn::get_reproducible_algo< | |||||
return static_cast<AlgoBase*>(megdnn::get_algo_with_attribute< | |||||
Convolution3DBackwardDataImpl>( | Convolution3DBackwardDataImpl>( | ||||
sm_algo_pack.cudnn_from_enum(algo), reproducible)); | |||||
sm_algo_pack.cudnn_from_enum(algo), attr)); | |||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
@@ -223,10 +224,10 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic( | |||||
args = orig_args; | args = orig_args; | ||||
} | } | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<Convolution3DBackwardDataImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute<Convolution3DBackwardDataImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | ||||
"cuda conv3d bwd data"); | |||||
"cuda conv3d bwd data", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<Convolution3DBackwardDataImpl>( | return megdnn::get_usable_algo<Convolution3DBackwardDataImpl>( | ||||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | ||||
@@ -268,28 +269,28 @@ Convolution3DBackwardFilterImpl::Algorithm* | |||||
Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto fm = check_layout_fwd(src, grad, diff); | auto fm = check_layout_fwd(src, grad, diff); | ||||
return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes, | return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes, | ||||
reproducible); | |||||
attr); | |||||
} | } | ||||
Convolution3DBackwardFilterImpl::Algorithm* | Convolution3DBackwardFilterImpl::Algorithm* | ||||
Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, | const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args(this, src, diff, grad); | AlgoBase::SizeArgs args(this, src, diff, grad); | ||||
if (args.grad_filter_meta.group > 1 && | if (args.grad_filter_meta.group > 1 && | ||||
sm_algo_pack.chanwise.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
sm_algo_pack.chanwise.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
} | } | ||||
auto get_cudnn_algo = | auto get_cudnn_algo = | ||||
[this, &args, workspace_limit_in_bytes, | [this, &args, workspace_limit_in_bytes, | ||||
reproducible]() -> Convolution3DBackwardFilterImpl::AlgoBase* { | |||||
attr]() -> Convolution3DBackwardFilterImpl::AlgoBase* { | |||||
auto cudnn_handle = cuda::cudnn_handle(this->handle()); | auto cudnn_handle = cuda::cudnn_handle(this->handle()); | ||||
cudnnConvolutionBwdFilterAlgo_t algo; | cudnnConvolutionBwdFilterAlgo_t algo; | ||||
CUDNNBwdFilterDescs desc; | CUDNNBwdFilterDescs desc; | ||||
@@ -297,11 +298,11 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | |||||
bool got = cudnn_get_convolution_bwd_filter_algo_helper( | bool got = cudnn_get_convolution_bwd_filter_algo_helper( | ||||
cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, | cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, | ||||
desc.conv_desc.desc, desc.grad_desc.desc, | desc.conv_desc.desc, desc.grad_desc.desc, | ||||
workspace_limit_in_bytes, &algo, reproducible); | |||||
workspace_limit_in_bytes, &algo, attr); | |||||
if (got) { | if (got) { | ||||
return static_cast<AlgoBase*>(megdnn::get_reproducible_algo< | |||||
return static_cast<AlgoBase*>(megdnn::get_algo_with_attribute< | |||||
Convolution3DBackwardFilterImpl>( | Convolution3DBackwardFilterImpl>( | ||||
sm_algo_pack.cudnn_from_enum(algo), reproducible)); | |||||
sm_algo_pack.cudnn_from_enum(algo), attr)); | |||||
} else { | } else { | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
@@ -322,10 +323,10 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | |||||
args = orig_args; | args = orig_args; | ||||
} | } | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<Convolution3DBackwardFilterImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute<Convolution3DBackwardFilterImpl>( | |||||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | ||||
"cuda conv3d bwd filter"); | |||||
"cuda conv3d bwd filter", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<Convolution3DBackwardFilterImpl>( | return megdnn::get_usable_algo<Convolution3DBackwardFilterImpl>( | ||||
sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | ||||
@@ -25,9 +25,9 @@ public: | |||||
const CanonizedFilterMeta& filter, | const CanonizedFilterMeta& filter, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
return get_algorithm_heuristic(src, filter, dst, | return get_algorithm_heuristic(src, filter, dst, | ||||
workspace_limit_in_bytes, reproducible) | |||||
workspace_limit_in_bytes, attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
size_t get_workspace_in_bytes(const TensorLayout& src, | size_t get_workspace_in_bytes(const TensorLayout& src, | ||||
@@ -52,14 +52,14 @@ protected: | |||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | Algorithm* get_algorithm_heuristic(const TensorLayout& src, | ||||
const CanonizedFilterMeta& filter, | const CanonizedFilterMeta& filter, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible); | |||||
const AlgoAttribute& attr); | |||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
@@ -73,9 +73,9 @@ public: | |||||
AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
const CanonizedFilterMeta& filter, const TensorLayout& diff, | const CanonizedFilterMeta& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
return get_algorithm_heuristic(filter, diff, grad, | return get_algorithm_heuristic(filter, diff, grad, | ||||
workspace_limit_in_bytes, reproducible) | |||||
workspace_limit_in_bytes, attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
size_t get_workspace_in_bytes(const TensorLayout& filter, | size_t get_workspace_in_bytes(const TensorLayout& filter, | ||||
@@ -102,14 +102,14 @@ protected: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | ||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible); | |||||
const AlgoAttribute& attr); | |||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
}; | }; | ||||
@@ -126,9 +126,9 @@ public: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const CanonizedFilterMeta& grad, | const CanonizedFilterMeta& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
return get_algorithm_heuristic(src, diff, grad, | return get_algorithm_heuristic(src, diff, grad, | ||||
workspace_limit_in_bytes, reproducible) | |||||
workspace_limit_in_bytes, attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
@@ -153,14 +153,14 @@ protected: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | Algorithm* get_algorithm_heuristic(const TensorLayout& src, | ||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const CanonizedFilterMeta& grad, | const CanonizedFilterMeta& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible); | |||||
const AlgoAttribute& attr); | |||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
}; | }; | ||||
@@ -80,12 +80,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) { | bool is_available_wk(const SizeArgs& args, size_t limit) { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
const Workspace& workspace) { | const Workspace& workspace) { | ||||
@@ -73,12 +73,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) { | bool is_available_wk(const SizeArgs& args, size_t limit) { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
const Workspace& workspace) { | const Workspace& workspace) { | ||||
@@ -68,12 +68,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) { | bool is_available_wk(const SizeArgs& args, size_t limit) { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
const Workspace& workspace) { | const Workspace& workspace) { | ||||
@@ -59,10 +59,10 @@ AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, | |||||
const TensorLayout& mask, | const TensorLayout& mask, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto fm = make_canonized_filter_meta(im.ndim, filter, offset); | auto fm = make_canonized_filter_meta(im.ndim, filter, offset); | ||||
return get_algorithm_heuristic(im, fm, offset, mask, dst, | return get_algorithm_heuristic(im, fm, offset, mask, dst, | ||||
workspace_limit_in_bytes, reproducible); | |||||
workspace_limit_in_bytes, attr); | |||||
} | } | ||||
AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, | AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, | ||||
@@ -71,17 +71,17 @@ AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, | |||||
const TensorLayout& mask, | const TensorLayout& mask, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args(this, im, filter, offset, mask, dst); | AlgoBase::SizeArgs args(this, im, filter, offset, mask, dst); | ||||
if (sm_algo_pack.algo_matmul.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.algo_matmul.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.algo_matmul; | return &sm_algo_pack.algo_matmul; | ||||
} | } | ||||
megdnn_throw( | |||||
ssprintf("no %s deformable conv fwd algorithm with args(%s) and " | |||||
"workspace limit (%zu bytes)", | |||||
reproducible ? "reproducible" : "usable", | |||||
args.to_string().c_str(), workspace_limit_in_bytes)); | |||||
megdnn_throw(ssprintf( | |||||
"no deformable conv fwd algorithm with attribute%s , args(%s) and " | |||||
"workspace limit (%zu bytes)", | |||||
Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(), | |||||
workspace_limit_in_bytes)); | |||||
} | } | ||||
const char* Fwd::get_algorithm_set_name() const { | const char* Fwd::get_algorithm_set_name() const { | ||||
@@ -115,27 +115,28 @@ AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( | |||||
const TensorLayout& im, const TensorLayout& offset, | const TensorLayout& im, const TensorLayout& offset, | ||||
const TensorLayout& mask, const TensorLayout& out_grad, | const TensorLayout& mask, const TensorLayout& out_grad, | ||||
const TensorLayout& filter_grad, | const TensorLayout& filter_grad, | ||||
size_t workspace_limit_in_bytes, bool reproducible) { | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
auto fm = make_canonized_filter_meta(im.ndim, filter_grad, offset); | auto fm = make_canonized_filter_meta(im.ndim, filter_grad, offset); | ||||
return get_algorithm_heuristic(im, offset, mask, out_grad, fm, | return get_algorithm_heuristic(im, offset, mask, out_grad, fm, | ||||
workspace_limit_in_bytes, reproducible); | |||||
workspace_limit_in_bytes, attr); | |||||
} | } | ||||
AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( | AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( | ||||
const TensorLayout& im, const TensorLayout& offset, | const TensorLayout& im, const TensorLayout& offset, | ||||
const TensorLayout& mask, const TensorLayout& out_grad, | const TensorLayout& mask, const TensorLayout& out_grad, | ||||
const CanonizedFilterMeta& filter_grad, | const CanonizedFilterMeta& filter_grad, | ||||
size_t workspace_limit_in_bytes, bool reproducible) { | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args(this, im, offset, mask, out_grad, filter_grad); | AlgoBase::SizeArgs args(this, im, offset, mask, out_grad, filter_grad); | ||||
if (sm_algo_pack.algo_matmul.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.algo_matmul.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.algo_matmul; | return &sm_algo_pack.algo_matmul; | ||||
} | } | ||||
megdnn_throw(ssprintf( | |||||
"no %s deformable conv bwd filter algorithm with args(%s) and " | |||||
"workspace limit (%zu bytes)", | |||||
reproducible ? "reproducible" : "usable", args.to_string().c_str(), | |||||
workspace_limit_in_bytes)); | |||||
megdnn_throw( | |||||
ssprintf("no deformable conv bwd filter algorithm with " | |||||
"attribute%s, args(%s) and " | |||||
"workspace limit (%zu bytes)", | |||||
Algorithm::attribute_str(attr).c_str(), | |||||
args.to_string().c_str(), workspace_limit_in_bytes)); | |||||
} | } | ||||
size_t BwdFlt::get_workspace_in_bytes( | size_t BwdFlt::get_workspace_in_bytes( | ||||
@@ -175,11 +176,11 @@ AlgoBwdData* BwdData::get_algorithm_heuristic( | |||||
const TensorLayout& offset, const TensorLayout& mask, | const TensorLayout& offset, const TensorLayout& mask, | ||||
const TensorLayout& out_grad, const TensorLayout& im_grad, | const TensorLayout& out_grad, const TensorLayout& im_grad, | ||||
const TensorLayout& offset_grad, const TensorLayout& mask_grad, | const TensorLayout& offset_grad, const TensorLayout& mask_grad, | ||||
size_t workspace_limit_in_bytes, bool reproducible) { | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
auto fm = make_canonized_filter_meta(im.ndim, filter, offset); | auto fm = make_canonized_filter_meta(im.ndim, filter, offset); | ||||
return get_algorithm_heuristic(im, fm, offset, mask, out_grad, im_grad, | return get_algorithm_heuristic(im, fm, offset, mask, out_grad, im_grad, | ||||
offset_grad, mask_grad, | offset_grad, mask_grad, | ||||
workspace_limit_in_bytes, reproducible); | |||||
workspace_limit_in_bytes, attr); | |||||
} | } | ||||
AlgoBwdData* BwdData::get_algorithm_heuristic( | AlgoBwdData* BwdData::get_algorithm_heuristic( | ||||
@@ -187,18 +188,19 @@ AlgoBwdData* BwdData::get_algorithm_heuristic( | |||||
const TensorLayout& offset, const TensorLayout& mask, | const TensorLayout& offset, const TensorLayout& mask, | ||||
const TensorLayout& out_grad, const TensorLayout& im_grad, | const TensorLayout& out_grad, const TensorLayout& im_grad, | ||||
const TensorLayout& offset_grad, const TensorLayout& mask_grad, | const TensorLayout& offset_grad, const TensorLayout& mask_grad, | ||||
size_t workspace_limit_in_bytes, bool reproducible) { | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args(this, im, filter, offset, mask, out_grad, im_grad, | AlgoBase::SizeArgs args(this, im, filter, offset, mask, out_grad, im_grad, | ||||
offset_grad, mask_grad); | offset_grad, mask_grad); | ||||
if (sm_algo_pack.algo_matmul.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.algo_matmul.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.algo_matmul; | return &sm_algo_pack.algo_matmul; | ||||
} | } | ||||
megdnn_throw(ssprintf( | |||||
"no %s deformable conv bwd data algorithm with args(%s) and " | |||||
"workspace limit (%zu bytes)", | |||||
reproducible ? "reproducible" : "usable", args.to_string().c_str(), | |||||
workspace_limit_in_bytes)); | |||||
megdnn_throw( | |||||
ssprintf("no deformable conv bwd data algorithm with attribute%s, " | |||||
"args(%s) and " | |||||
"workspace limit (%zu bytes)", | |||||
Algorithm::attribute_str(attr).c_str(), | |||||
args.to_string().c_str(), workspace_limit_in_bytes)); | |||||
} | } | ||||
size_t BwdData::get_workspace_in_bytes( | size_t BwdData::get_workspace_in_bytes( | ||||
@@ -36,7 +36,7 @@ public: | |||||
const TensorLayout& mask, | const TensorLayout& mask, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible); | |||||
const AlgoAttribute& attr); | |||||
const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
@@ -60,7 +60,7 @@ protected: | |||||
const TensorLayout& mask, | const TensorLayout& mask, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
@@ -81,7 +81,7 @@ public: | |||||
const TensorLayout& out_grad, | const TensorLayout& out_grad, | ||||
const CanonizedFilterMeta& filter_grad, | const CanonizedFilterMeta& filter_grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible); | |||||
const AlgoAttribute& attr); | |||||
size_t get_workspace_in_bytes(const TensorLayout& im, | size_t get_workspace_in_bytes(const TensorLayout& im, | ||||
const TensorLayout& offset, | const TensorLayout& offset, | ||||
@@ -111,7 +111,7 @@ protected: | |||||
const TensorLayout& out_grad, | const TensorLayout& out_grad, | ||||
const TensorLayout& filter_grad, | const TensorLayout& filter_grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
@@ -132,7 +132,7 @@ public: | |||||
const TensorLayout& offset, const TensorLayout& mask, | const TensorLayout& offset, const TensorLayout& mask, | ||||
const TensorLayout& out_grad, const TensorLayout& im_grad, | const TensorLayout& out_grad, const TensorLayout& im_grad, | ||||
const TensorLayout& offset_grad, const TensorLayout& mask_grad, | const TensorLayout& offset_grad, const TensorLayout& mask_grad, | ||||
size_t workspace_limit_in_bytes, bool reproducible); | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& attr); | |||||
size_t get_workspace_in_bytes(const TensorLayout& im, | size_t get_workspace_in_bytes(const TensorLayout& im, | ||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
@@ -166,7 +166,8 @@ protected: | |||||
const TensorLayout& offset, const TensorLayout& mask, | const TensorLayout& offset, const TensorLayout& mask, | ||||
const TensorLayout& out_grad, const TensorLayout& im_grad, | const TensorLayout& out_grad, const TensorLayout& im_grad, | ||||
const TensorLayout& offset_grad, const TensorLayout& mask_grad, | const TensorLayout& offset_grad, const TensorLayout& mask_grad, | ||||
size_t workspace_limit_in_bytes, bool reproducible) override; | |||||
size_t workspace_limit_in_bytes, | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
@@ -59,12 +59,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) { | bool is_available_wk(const SizeArgs& args, size_t limit) { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
const Workspace& workspace) { | const Workspace& workspace) { | ||||
@@ -59,12 +59,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) { | bool is_available_wk(const SizeArgs& args, size_t limit) { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
const Workspace& workspace) { | const Workspace& workspace) { | ||||
@@ -60,12 +60,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) { | bool is_available_wk(const SizeArgs& args, size_t limit) { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
const Workspace& workspace) { | const Workspace& workspace) { | ||||
@@ -24,26 +24,26 @@ LocalShareForwardImpl::get_algorithm_heuristic(const TensorLayout& src, | |||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args(this, src, filter, dst); | AlgoBase::SizeArgs args(this, src, filter, dst); | ||||
if (sm_algo_pack.batch_size_aware_chwn_small_image | if (sm_algo_pack.batch_size_aware_chwn_small_image | ||||
.is_available_reproducible(args, reproducible, | |||||
.is_available_attribute(args, attr, | |||||
workspace_limit_in_bytes)) { | workspace_limit_in_bytes)) { | ||||
return &sm_algo_pack.batch_size_aware_chwn_small_image; | return &sm_algo_pack.batch_size_aware_chwn_small_image; | ||||
} | } | ||||
if (sm_algo_pack.batch_size_aware_chwn.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.batch_size_aware_chwn.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.batch_size_aware_chwn; | return &sm_algo_pack.batch_size_aware_chwn; | ||||
} | } | ||||
if (sm_algo_pack.batched_matmul.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.batched_matmul.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.batched_matmul; | return &sm_algo_pack.batched_matmul; | ||||
} | } | ||||
megdnn_throw( | |||||
ssprintf("no %s local share conv algorithm with args(%s) and " | |||||
"workspace limit (%zu bytes)", | |||||
reproducible ? "reproducible" : "usable", | |||||
args.to_string().c_str(), workspace_limit_in_bytes)); | |||||
megdnn_throw(ssprintf( | |||||
"no local share conv algorithm with attribute%s, args(%s) and " | |||||
"workspace limit (%zu bytes)", | |||||
Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(), | |||||
workspace_limit_in_bytes)); | |||||
} | } | ||||
std::vector<LocalShareForwardImpl::Algorithm*> | std::vector<LocalShareForwardImpl::Algorithm*> | ||||
@@ -79,21 +79,21 @@ LocalShareBackwardDataImpl::Algorithm* | |||||
LocalShareBackwardDataImpl::get_algorithm_heuristic( | LocalShareBackwardDataImpl::get_algorithm_heuristic( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args(this, filter, diff, grad); | AlgoBase::SizeArgs args(this, filter, diff, grad); | ||||
if (sm_algo_pack.implicit_gemm.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.implicit_gemm.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.implicit_gemm; | return &sm_algo_pack.implicit_gemm; | ||||
} | } | ||||
if (sm_algo_pack.batched_matmul.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.batched_matmul.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.batched_matmul; | return &sm_algo_pack.batched_matmul; | ||||
} | } | ||||
megdnn_throw( | |||||
ssprintf("no %s local share bwd data algorithm with args(%s) and " | |||||
"workspace limit (%zu bytes)", | |||||
reproducible ? "reproducible" : "usable", | |||||
args.to_string().c_str(), workspace_limit_in_bytes)); | |||||
megdnn_throw(ssprintf( | |||||
"no local share bwd data algorithm with attribute%s args(%s) and " | |||||
"workspace limit (%zu bytes)", | |||||
Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(), | |||||
workspace_limit_in_bytes)); | |||||
} | } | ||||
std::vector<LocalShareBackwardDataImpl::Algorithm*> | std::vector<LocalShareBackwardDataImpl::Algorithm*> | ||||
@@ -129,20 +129,21 @@ LocalShareBackwardFilterImpl::Algorithm* | |||||
LocalShareBackwardFilterImpl::get_algorithm_heuristic( | LocalShareBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args(this, src, diff, grad); | AlgoBase::SizeArgs args(this, src, diff, grad); | ||||
if (sm_algo_pack.implicit_gemm.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.implicit_gemm.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.implicit_gemm; | return &sm_algo_pack.implicit_gemm; | ||||
} | } | ||||
if (sm_algo_pack.batched_matmul.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.batched_matmul.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.batched_matmul; | return &sm_algo_pack.batched_matmul; | ||||
} | } | ||||
megdnn_throw( | megdnn_throw( | ||||
ssprintf("no %s local share bwd filter algorithm with args(%s) and " | |||||
ssprintf("no local share bwd filter algorithm with attribute%s, " | |||||
"args(%s) and " | |||||
"workspace limit (%zu bytes)", | "workspace limit (%zu bytes)", | ||||
reproducible ? "reproducible" : "usable", | |||||
Algorithm::attribute_str(attr).c_str(), | |||||
args.to_string().c_str(), workspace_limit_in_bytes)); | args.to_string().c_str(), workspace_limit_in_bytes)); | ||||
} | } | ||||
@@ -43,7 +43,7 @@ protected: | |||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
}; | }; | ||||
@@ -75,7 +75,7 @@ protected: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
@@ -108,7 +108,7 @@ protected: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
@@ -83,12 +83,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) const { | bool is_available_wk(const SizeArgs& args, size_t limit) const { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) const { | size_t limit = std::numeric_limits<size_t>::max()) const { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
const Workspace& workspace) { | const Workspace& workspace) { | ||||
@@ -30,30 +30,30 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, bool reproducible) { | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args{this, A, B, C}; | AlgoBase::SizeArgs args{this, A, B, C}; | ||||
if (sm_algo_pack.cublas.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.cublas.is_available_attribute(args, attr, | |||||
workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.cublas; | return &sm_algo_pack.cublas; | ||||
} | } | ||||
#if CUDA_VERSION >= 10010 | #if CUDA_VERSION >= 10010 | ||||
if (sm_algo_pack.cublas_lt.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.cublas_lt.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.cublas_lt; | return &sm_algo_pack.cublas_lt; | ||||
} | } | ||||
#endif | #endif | ||||
#if CUDA_VERSION >= 10000 | #if CUDA_VERSION >= 10000 | ||||
if (sm_algo_pack.wmma_uint4x4x32.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.wmma_uint4x4x32.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.wmma_uint4x4x32; | return &sm_algo_pack.wmma_uint4x4x32; | ||||
} | } | ||||
#endif | #endif | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<MatrixMulForwardImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute<MatrixMulForwardImpl>( | |||||
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | ||||
"matrix mul forward"); | |||||
"matrix mul forward", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<MatrixMulForwardImpl>( | return megdnn::get_usable_algo<MatrixMulForwardImpl>( | ||||
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | ||||
@@ -61,7 +61,7 @@ protected: | |||||
const TensorLayout& B, | const TensorLayout& B, | ||||
const TensorLayout& C, | const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
private: | private: | ||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
@@ -63,12 +63,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) const { | bool is_available_wk(const SizeArgs& args, size_t limit) const { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) const { | size_t limit = std::numeric_limits<size_t>::max()) const { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
const Workspace& workspace) { | const Workspace& workspace) { | ||||
@@ -31,16 +31,16 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
BatchedMatrixMulForwardImpl::Algorithm* | BatchedMatrixMulForwardImpl::Algorithm* | ||||
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, bool reproducible) { | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args{this, A, B, C}; | AlgoBase::SizeArgs args{this, A, B, C}; | ||||
if (sm_algo_pack.algo_default.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.algo_default.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.algo_default; | return &sm_algo_pack.algo_default; | ||||
} | } | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<BatchedMatrixMulForwardImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute<BatchedMatrixMulForwardImpl>( | |||||
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | ||||
"batched matrix mul forward"); | |||||
"batched matrix mul forward", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( | return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( | ||||
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | ||||
@@ -40,7 +40,7 @@ private: | |||||
const TensorLayout& /*B*/, | const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/, | const TensorLayout& /*C*/, | ||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /*reproducible*/) override; | |||||
const AlgoAttribute& /*attr*/) override; | |||||
const char* get_algorithm_set_name() const override { | const char* get_algorithm_set_name() const override { | ||||
return "FALLBACK BATCHED MATMUL"; | return "FALLBACK BATCHED MATMUL"; | ||||
@@ -280,32 +280,29 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic( | |||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | ||||
auto result = get_algorithm_heuristic_with_ncb( | auto result = get_algorithm_heuristic_with_ncb( | ||||
fparam, workspace_limit_in_bytes, reproducible); | |||||
fparam, workspace_limit_in_bytes, attr); | |||||
if (result == nullptr) { | if (result == nullptr) { | ||||
result = naive::ConvBiasForwardImpl::get_algorithm_heuristic( | result = naive::ConvBiasForwardImpl::get_algorithm_heuristic( | ||||
src, filter, bias, z, dst, workspace_limit_in_bytes, | |||||
reproducible); | |||||
src, filter, bias, z, dst, workspace_limit_in_bytes, attr); | |||||
} | } | ||||
return result; | return result; | ||||
} | } | ||||
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb( | ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb( | ||||
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto algo_data_type = param.deduce_algo_data_type(); | auto algo_data_type = param.deduce_algo_data_type(); | ||||
auto suggest_category_order = suggest_algo_category_order(param); | auto suggest_category_order = suggest_algo_category_order(param); | ||||
for (auto category : suggest_category_order) { | for (auto category : suggest_category_order) { | ||||
auto&& origin_algos = select_algo_type({algo_data_type, category}); | auto&& origin_algos = select_algo_type({algo_data_type, category}); | ||||
ConvBiasImpl::Algorithm* heuristic_algo = nullptr; | ConvBiasImpl::Algorithm* heuristic_algo = nullptr; | ||||
for (auto i : origin_algos) { | for (auto i : origin_algos) { | ||||
bool usable_reproducible = | |||||
static_cast<AlgoBase*>(i)->usable_reproducible( | |||||
param, AlgoSelectionStrategy::HEURISTIC, | |||||
reproducible); | |||||
if (usable_reproducible && | |||||
bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute( | |||||
param, AlgoSelectionStrategy::HEURISTIC, attr); | |||||
if (usable_attribute && | |||||
static_cast<AlgoBase*>(i)->get_workspace(param) <= | static_cast<AlgoBase*>(i)->get_workspace(param) <= | ||||
workspace_limit_in_bytes) { | workspace_limit_in_bytes) { | ||||
//! store the first usable algo if no prefer algo, choose it as | //! store the first usable algo if no prefer algo, choose it as | ||||
@@ -499,8 +496,8 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( | |||||
} | } | ||||
if (!m_prev_selected_algo || | if (!m_prev_selected_algo || | ||||
memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | ||||
m_prev_selected_algo = | |||||
get_algorithm_heuristic_with_ncb(param, workspace_size); | |||||
m_prev_selected_algo = get_algorithm_heuristic_with_ncb( | |||||
param, workspace_size, AlgoAttribute::DEFAULT); | |||||
m_prev_selected_algo_sizep = param; | m_prev_selected_algo_sizep = param; | ||||
} | } | ||||
return m_prev_selected_algo; | return m_prev_selected_algo; | ||||
@@ -95,9 +95,7 @@ public: | |||||
const TensorLayout& z, | const TensorLayout& z, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
//! size param for kernels with non-contiguous batch | //! size param for kernels with non-contiguous batch | ||||
struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam { | struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam { | ||||
@@ -321,11 +319,11 @@ public: | |||||
return false; | return false; | ||||
} | } | ||||
bool usable_reproducible(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy, | |||||
bool reproducible = true) const { | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
bool usable_attribute( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) const { | |||||
return contain_attribute(attr) && | |||||
usable(param, algo_selection_strategy); | usable(param, algo_selection_strategy); | ||||
} | } | ||||
@@ -363,7 +361,7 @@ protected: | |||||
virtual Algorithm* get_algorithm_heuristic_with_ncb( | virtual Algorithm* get_algorithm_heuristic_with_ncb( | ||||
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
bool reproducible = false); | |||||
const AlgoAttribute& attr); | |||||
const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
@@ -198,13 +198,13 @@ std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms( | |||||
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic( | ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); | auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); | ||||
auto result = get_algorithm_heuristic_with_ncb( | auto result = get_algorithm_heuristic_with_ncb( | ||||
fparam, workspace_limit_in_bytes, reproducible); | |||||
fparam, workspace_limit_in_bytes, attr); | |||||
if (result == nullptr) { | if (result == nullptr) { | ||||
result = naive::ConvolutionForwardImpl::get_algorithm_heuristic( | result = naive::ConvolutionForwardImpl::get_algorithm_heuristic( | ||||
src, filter, dst, workspace_limit_in_bytes, reproducible); | |||||
src, filter, dst, workspace_limit_in_bytes, attr); | |||||
} | } | ||||
return result; | return result; | ||||
} | } | ||||
@@ -312,18 +312,16 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, | |||||
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb( | ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb( | ||||
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto algo_data_type = param.deduce_algo_data_type(); | auto algo_data_type = param.deduce_algo_data_type(); | ||||
auto suggest_category_order = suggest_algo_category_order(param); | auto suggest_category_order = suggest_algo_category_order(param); | ||||
for (auto category : suggest_category_order) { | for (auto category : suggest_category_order) { | ||||
auto&& origin_algos = select_algo_type({algo_data_type, category}); | auto&& origin_algos = select_algo_type({algo_data_type, category}); | ||||
ConvolutionImpl::Algorithm* heuristic_algo = nullptr; | ConvolutionImpl::Algorithm* heuristic_algo = nullptr; | ||||
for (auto i : origin_algos) { | for (auto i : origin_algos) { | ||||
bool usable_reproducible = | |||||
static_cast<AlgoBase*>(i)->usable_reproducible( | |||||
param, AlgoSelectionStrategy::HEURISTIC, | |||||
reproducible); | |||||
if (usable_reproducible && | |||||
bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute( | |||||
param, AlgoSelectionStrategy::HEURISTIC, attr); | |||||
if (usable_attribute && | |||||
static_cast<AlgoBase*>(i)->get_workspace(param) <= | static_cast<AlgoBase*>(i)->get_workspace(param) <= | ||||
workspace_limit_in_bytes) { | workspace_limit_in_bytes) { | ||||
//! store the first usable algo if no prefer algo, choose it as | //! store the first usable algo if no prefer algo, choose it as | ||||
@@ -392,8 +390,8 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm( | |||||
} | } | ||||
if (!m_prev_selected_algo || | if (!m_prev_selected_algo || | ||||
memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | ||||
m_prev_selected_algo = | |||||
get_algorithm_heuristic_with_ncb(param, workspace_size); | |||||
m_prev_selected_algo = get_algorithm_heuristic_with_ncb( | |||||
param, workspace_size, AlgoAttribute::DEFAULT); | |||||
m_prev_selected_algo_sizep = param; | m_prev_selected_algo_sizep = param; | ||||
} | } | ||||
return m_prev_selected_algo; | return m_prev_selected_algo; | ||||
@@ -515,15 +513,15 @@ ConvolutionBackwardDataImpl::Algorithm* | |||||
ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
if (param().format == param::Convolution::Format::NHWCD4 || | if (param().format == param::Convolution::Format::NHWCD4 || | ||||
param().format == param::Convolution::Format::NCHW4) { | param().format == param::Convolution::Format::NCHW4) { | ||||
return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic( | return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
filter, diff, grad, workspace_limit_in_bytes, reproducible); | |||||
filter, diff, grad, workspace_limit_in_bytes, attr); | |||||
} | } | ||||
auto fparam = make_ncb_kern_size_param(filter, diff, grad); | auto fparam = make_ncb_kern_size_param(filter, diff, grad); | ||||
return get_algorithm_heuristic_with_ncb(fparam, workspace_limit_in_bytes, | return get_algorithm_heuristic_with_ncb(fparam, workspace_limit_in_bytes, | ||||
reproducible); | |||||
attr); | |||||
} | } | ||||
ConvolutionBackwardDataImpl::NCBKernSizeParam | ConvolutionBackwardDataImpl::NCBKernSizeParam | ||||
@@ -668,15 +666,15 @@ ConvolutionBackwardDataImpl::get_all_algorithms_with_ncb( | |||||
ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb( | ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb( | ||||
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
if (param.filter_meta.group != 1) { | if (param.filter_meta.group != 1) { | ||||
auto p1g = param; | auto p1g = param; | ||||
p1g.filter_meta.group = 1; | p1g.filter_meta.group = 1; | ||||
return ncb_1g_get_algorithm_heuristic(p1g, workspace_limit_in_bytes, | return ncb_1g_get_algorithm_heuristic(p1g, workspace_limit_in_bytes, | ||||
reproducible); | |||||
attr); | |||||
} | } | ||||
return ncb_1g_get_algorithm_heuristic(param, workspace_limit_in_bytes, | return ncb_1g_get_algorithm_heuristic(param, workspace_limit_in_bytes, | ||||
reproducible); | |||||
attr); | |||||
} | } | ||||
size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( | size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( | ||||
@@ -731,14 +729,10 @@ ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( | |||||
ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic( | ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic( | ||||
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
for (auto i : ncb_1g_get_all_algorithms(param)) { | for (auto i : ncb_1g_get_all_algorithms(param)) { | ||||
if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) { | if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) { | ||||
if (reproducible) { | |||||
if (i->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
return i; | |||||
} | |||||
} else { | |||||
if (i->contain_attribute(attr)) { | |||||
return i; | return i; | ||||
} | } | ||||
} | } | ||||
@@ -788,7 +782,8 @@ ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) { | |||||
if (!m_prev_selected_algo || | if (!m_prev_selected_algo || | ||||
memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | ||||
m_prev_selected_algo = ncb_1g_get_algorithm_heuristic( | m_prev_selected_algo = ncb_1g_get_algorithm_heuristic( | ||||
param, std::numeric_limits<size_t>::max()); | |||||
param, std::numeric_limits<size_t>::max(), | |||||
AlgoAttribute::DEFAULT); | |||||
m_prev_selected_algo_sizep = param; | m_prev_selected_algo_sizep = param; | ||||
} | } | ||||
return m_prev_selected_algo; | return m_prev_selected_algo; | ||||
@@ -90,7 +90,7 @@ public: | |||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
//! size param for kernels with non-contiguous batch | //! size param for kernels with non-contiguous batch | ||||
struct NCBKernSizeParam { | struct NCBKernSizeParam { | ||||
@@ -238,11 +238,11 @@ public: | |||||
return false; | return false; | ||||
} | } | ||||
bool usable_reproducible(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy, | |||||
bool reproducible = true) const { | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
bool usable_attribute( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) const { | |||||
return contain_attribute(attr) && | |||||
usable(param, algo_selection_strategy); | usable(param, algo_selection_strategy); | ||||
} | } | ||||
@@ -272,7 +272,7 @@ protected: | |||||
virtual Algorithm* get_algorithm_heuristic_with_ncb( | virtual Algorithm* get_algorithm_heuristic_with_ncb( | ||||
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
bool reproducible = false); | |||||
const AlgoAttribute& attr); | |||||
const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
@@ -326,7 +326,7 @@ public: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
//! size param for kernels with non-contiguous batch | //! size param for kernels with non-contiguous batch | ||||
@@ -421,12 +421,10 @@ protected: | |||||
virtual ncb_kern_t dispatch_kern( | virtual ncb_kern_t dispatch_kern( | ||||
ConvolutionBackwardDataImpl* opr, | ConvolutionBackwardDataImpl* opr, | ||||
const NCBKernSizeParam& param) const = 0; | const NCBKernSizeParam& param) const = 0; | ||||
bool usable_reproducible(ConvolutionBackwardDataImpl* opr, | |||||
const NCBKernSizeParam& param, | |||||
bool reproducible = true) const { | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
usable(opr, param); | |||||
bool usable_attribute( | |||||
ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) const { | |||||
return contain_attribute(attr) && usable(opr, param); | |||||
} | } | ||||
virtual bool is_preferred(const NCBKernSizeParam&) const { | virtual bool is_preferred(const NCBKernSizeParam&) const { | ||||
return false; | return false; | ||||
@@ -451,7 +449,7 @@ protected: | |||||
//! default impl calls ncb_1g_get_algorithm_heuristic() | //! default impl calls ncb_1g_get_algorithm_heuristic() | ||||
virtual Algorithm* get_algorithm_heuristic_with_ncb( | virtual Algorithm* get_algorithm_heuristic_with_ncb( | ||||
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
bool reproducible = false); | |||||
const AlgoAttribute& attr); | |||||
//! get kernel pointer for float32 non-contiguous batch 1-group kernel | //! get kernel pointer for float32 non-contiguous batch 1-group kernel | ||||
virtual ncb_kern_t ncb_1g_dispatch_kern(Algorithm* algo, | virtual ncb_kern_t ncb_1g_dispatch_kern(Algorithm* algo, | ||||
@@ -469,7 +467,7 @@ protected: | |||||
*/ | */ | ||||
virtual Algorithm* ncb_1g_get_algorithm_heuristic( | virtual Algorithm* ncb_1g_get_algorithm_heuristic( | ||||
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
bool reproducible = false); | |||||
const AlgoAttribute& attr); | |||||
static bool is_matrix_mul_preferred(const NCBKernSizeParam& param); | static bool is_matrix_mul_preferred(const NCBKernSizeParam& param); | ||||
/** | /** | ||||
@@ -131,19 +131,20 @@ MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc( | |||||
MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( | MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( | ||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, bool reproducible) { | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
auto kern_size_param = make_kern_size_param(A, B, C); | auto kern_size_param = make_kern_size_param(A, B, C); | ||||
if (auto algo = static_cast<AlgoBase*>( | if (auto algo = static_cast<AlgoBase*>( | ||||
get_algorithm_from_desc(execution_policy().algo))) { | get_algorithm_from_desc(execution_policy().algo))) { | ||||
megdnn_assert(algo->get_workspace(kern_size_param) < | megdnn_assert(algo->get_workspace(kern_size_param) < | ||||
workspace_limit_in_bytes); | workspace_limit_in_bytes); | ||||
auto cur = megdnn::get_reproducible_algo<MatrixMulImpl>(algo, | |||||
reproducible); | |||||
auto cur = megdnn::get_algo_with_attribute<MatrixMulImpl>(algo, attr); | |||||
if (cur) | if (cur) | ||||
return cur; | return cur; | ||||
megdnn_throw( | |||||
"require reproducible algorithm, but given algorithm is not " | |||||
"reproducible"); | |||||
megdnn_throw(ssprintf( | |||||
"require algorithm with attribute%s, but given algorithm with " | |||||
"attribute%s", | |||||
Algorithm::attribute_str(attr).c_str(), | |||||
Algorithm::attribute_str(algo->attribute()).c_str())); | |||||
} | } | ||||
AlgoTypePack algo_type; | AlgoTypePack algo_type; | ||||
algo_type.data_type = kern_size_param.deduce_algo_data_type(); | algo_type.data_type = kern_size_param.deduce_algo_data_type(); | ||||
@@ -155,8 +156,8 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( | |||||
if (static_cast<AlgoBase*>(algo)->usable(kern_size_param) && | if (static_cast<AlgoBase*>(algo)->usable(kern_size_param) && | ||||
static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <= | static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <= | ||||
workspace_limit_in_bytes) { | workspace_limit_in_bytes) { | ||||
if (static_cast<AlgoBase*>(algo)->preferred_reproducible( | |||||
kern_size_param, reproducible)) { | |||||
if (static_cast<AlgoBase*>(algo)->preferred_attribute( | |||||
kern_size_param, attr)) { | |||||
//! use gemv algo if it's prefered | //! use gemv algo if it's prefered | ||||
if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | ||||
return algo; | return algo; | ||||
@@ -214,8 +215,9 @@ MatrixMulImpl::KernParam MatrixMulImpl::make_kern_param( | |||||
size_t MatrixMulImpl::get_workspace_in_bytes(const TensorLayout& A, | size_t MatrixMulImpl::get_workspace_in_bytes(const TensorLayout& A, | ||||
const TensorLayout& B, | const TensorLayout& B, | ||||
const TensorLayout& C) { | const TensorLayout& C) { | ||||
if (auto algo = get_algorithm_heuristic( | |||||
A, B, C, std::numeric_limits<size_t>::max(), false)) { | |||||
if (auto algo = get_algorithm_heuristic(A, B, C, | |||||
std::numeric_limits<size_t>::max(), | |||||
AlgoAttribute::DEFAULT)) { | |||||
auto kern_size_param = make_kern_size_param(A, B, C); | auto kern_size_param = make_kern_size_param(A, B, C); | ||||
return static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param); | return static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param); | ||||
} | } | ||||
@@ -228,7 +230,7 @@ void MatrixMulImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
if (auto algo = get_algorithm_heuristic(A.layout, B.layout, C.layout, | if (auto algo = get_algorithm_heuristic(A.layout, B.layout, C.layout, | ||||
std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
false)) { | |||||
AlgoAttribute::DEFAULT)) { | |||||
auto kern_param = make_kern_param(A, B, C, workspace); | auto kern_param = make_kern_param(A, B, C, workspace); | ||||
auto kern = static_cast<AlgoBase*>(algo)->get_kern(kern_param); | auto kern = static_cast<AlgoBase*>(algo)->get_kern(kern_param); | ||||
auto run = [kern, kern_param]() { kern(kern_param); }; | auto run = [kern, kern_param]() { kern(kern_param); }; | ||||
@@ -223,11 +223,10 @@ public: | |||||
virtual InnerBlockSize get_inner_block_size() const { | virtual InnerBlockSize get_inner_block_size() const { | ||||
megdnn_assert(0); | megdnn_assert(0); | ||||
}; | }; | ||||
bool preferred_reproducible(const KernSizeParam& param, | |||||
bool reproducible = true) { | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
preferred(param); | |||||
bool preferred_attribute( | |||||
const KernSizeParam& param, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) { | |||||
return contain_attribute(attr) && preferred(param); | |||||
}; | }; | ||||
virtual MatmulDescription matmul_description() const = 0; | virtual MatmulDescription matmul_description() const = 0; | ||||
@@ -272,7 +271,7 @@ protected: | |||||
const TensorLayout& B, | const TensorLayout& B, | ||||
const TensorLayout& C, | const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
}; | }; | ||||
@@ -125,16 +125,14 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic( | |||||
const TensorLayout& /* bias */, const TensorLayout& /* z */, | const TensorLayout& /* bias */, const TensorLayout& /* z */, | ||||
const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */ | const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */ | ||||
, | , | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto algo = static_cast<HandleImpl*>(handle()) | auto algo = static_cast<HandleImpl*>(handle()) | ||||
->default_batch_conv_bias_fwd_algo(); | ->default_batch_conv_bias_fwd_algo(); | ||||
if (reproducible) { | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | |||||
"algorithm(%s) is not " | |||||
"reproducible", | |||||
algo->name()); | |||||
} | |||||
megdnn_assert(algo->contain_attribute(attr), | |||||
"require algorithm with attribute%s, but heuristic " | |||||
"algorithm(%s) with attribute%s ", | |||||
Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
return algo; | return algo; | ||||
} | } | ||||
@@ -37,7 +37,7 @@ public: | |||||
const TensorLayout& z, | const TensorLayout& z, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | ||||
@@ -76,7 +76,7 @@ BatchedMatrixMulForward::Algorithm* | |||||
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | ||||
bool /* reproducible */) { | |||||
const AlgoAttribute& /*attr*/) { | |||||
return static_cast<HandleImpl*>(handle()) | return static_cast<HandleImpl*>(handle()) | ||||
->default_batched_matmul_fwd_algo(); | ->default_batched_matmul_fwd_algo(); | ||||
} | } | ||||
@@ -32,7 +32,7 @@ public: | |||||
const TensorLayout& /*B*/, | const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/, | const TensorLayout& /*C*/, | ||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /* reproducible */) override; | |||||
const AlgoAttribute& /*attr*/) override; | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | ||||
@@ -246,16 +246,14 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
const TensorLayout& /* src */, const TensorLayout& /* filter */, | const TensorLayout& /* src */, const TensorLayout& /* filter */, | ||||
const TensorLayout& /* bias */, const TensorLayout& /* z */, | const TensorLayout& /* bias */, const TensorLayout& /* z */, | ||||
const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto algo = | auto algo = | ||||
static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo(); | static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo(); | ||||
if (reproducible) { | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | |||||
"algorithm(%s) is not " | |||||
"reproducible", | |||||
algo->name()); | |||||
} | |||||
megdnn_assert(algo->contain_attribute(attr), | |||||
"require algorithm with attribute%s, but heuristic " | |||||
"algorithm(%s) with attribute%s ", | |||||
Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
return algo; | return algo; | ||||
} | } | ||||
@@ -37,7 +37,7 @@ public: | |||||
const TensorLayout& z, | const TensorLayout& z, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
size_t get_workspace_in_bytes( | size_t get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
@@ -272,16 +272,14 @@ ConvolutionForwardImpl:: get_all_algorithms(const TensorLayout &, | |||||
ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( | ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto algo = | auto algo = | ||||
static_cast<HandleImpl*>(handle())->default_conv_fwd_algo(); | static_cast<HandleImpl*>(handle())->default_conv_fwd_algo(); | ||||
if (reproducible) { | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | |||||
"algorithm(%s) is not " | |||||
"reproducible", | |||||
algo->name()); | |||||
} | |||||
megdnn_assert(algo->contain_attribute(attr), | |||||
"require algorithm with attribute%s, but heuristic " | |||||
"algorithm(%s) with attribute%s ", | |||||
Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
return algo; | return algo; | ||||
} | } | ||||
@@ -304,16 +302,14 @@ ConvolutionBackwardData::Algorithm* | |||||
ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* filter */, const TensorLayout& /* diff */, | const TensorLayout& /* filter */, const TensorLayout& /* diff */, | ||||
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto algo = | auto algo = | ||||
static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo(); | static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo(); | ||||
if (reproducible) { | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | |||||
"algorithm(%s) is not " | |||||
"reproducible", | |||||
algo->name()); | |||||
} | |||||
megdnn_assert(algo->contain_attribute(attr), | |||||
"require algorithm with attribute%s, but heuristic " | |||||
"algorithm(%s) with attribute%s ", | |||||
Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
return algo; | return algo; | ||||
} | } | ||||
@@ -337,16 +333,14 @@ ConvolutionBackwardFilter::Algorithm* | |||||
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto algo = | auto algo = | ||||
static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo(); | static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo(); | ||||
if (reproducible) { | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | |||||
"algorithm(%s) is not " | |||||
"reproducible", | |||||
algo->name()); | |||||
} | |||||
megdnn_assert(algo->contain_attribute(attr), | |||||
"require algorithm with attribute%s, but heuristic " | |||||
"algorithm(%s) with attribute%s ", | |||||
Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
return algo; | return algo; | ||||
} | } | ||||
@@ -29,7 +29,7 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, | const TensorLayout&, | ||||
const PreprocessedFilter*) override { | const PreprocessedFilter*) override { | ||||
@@ -71,7 +71,7 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&) override; | const TensorLayout&) override; | ||||
@@ -94,7 +94,7 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&) override; | const TensorLayout&) override; | ||||
@@ -120,15 +120,13 @@ Convolution3DForward::Algorithm* | |||||
Convolution3DForwardImpl::get_algorithm_heuristic( | Convolution3DForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* src */, const TensorLayout& /* filter */, | const TensorLayout& /* src */, const TensorLayout& /* filter */, | ||||
const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto algo = static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo(); | auto algo = static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo(); | ||||
if (reproducible) { | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | |||||
"algorithm(%s) is not " | |||||
"reproducible", | |||||
algo->name()); | |||||
} | |||||
megdnn_assert(algo->contain_attribute(attr), | |||||
"require algorithm with attribute%s, but heuristic " | |||||
"algorithm(%s) with attribute%s ", | |||||
Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
return algo; | return algo; | ||||
} | } | ||||
@@ -152,16 +150,14 @@ Convolution3DBackwardData::Algorithm* | |||||
Convolution3DBackwardDataImpl::get_algorithm_heuristic( | Convolution3DBackwardDataImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* filter */, const TensorLayout& /* diff */, | const TensorLayout& /* filter */, const TensorLayout& /* diff */, | ||||
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto algo = | auto algo = | ||||
static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo(); | static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo(); | ||||
if (reproducible) { | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | |||||
"algorithm(%s) is not " | |||||
"reproducible", | |||||
algo->name()); | |||||
} | |||||
megdnn_assert(algo->contain_attribute(attr), | |||||
"require algorithm with attribute%s, but heuristic " | |||||
"algorithm(%s) with attribute%s ", | |||||
Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
return algo; | return algo; | ||||
} | } | ||||
@@ -187,16 +183,14 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | |||||
const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */ | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */ | ||||
, | , | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto algo = static_cast<HandleImpl*>(handle()) | auto algo = static_cast<HandleImpl*>(handle()) | ||||
->default_conv3d_bwd_filter_algo(); | ->default_conv3d_bwd_filter_algo(); | ||||
if (reproducible) { | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | |||||
"algorithm(%s) is not " | |||||
"reproducible", | |||||
algo->name()); | |||||
} | |||||
megdnn_assert(algo->contain_attribute(attr), | |||||
"require algorithm with attribute%s, but heuristic " | |||||
"algorithm(%s) with attribute%s ", | |||||
Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
return algo; | return algo; | ||||
} | } | ||||
@@ -26,7 +26,7 @@ public: | |||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&) override { | const TensorLayout&) override { | ||||
return 0; | return 0; | ||||
@@ -48,7 +48,7 @@ public: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&) override { | const TensorLayout&) override { | ||||
return 0; | return 0; | ||||
@@ -70,7 +70,7 @@ public: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&) override { | const TensorLayout&) override { | ||||
return 0; | return 0; | ||||
@@ -32,7 +32,7 @@ public: | |||||
const TensorLayout& /* mask */, | const TensorLayout& /* mask */, | ||||
const TensorLayout& /* dst */, | const TensorLayout& /* dst */, | ||||
size_t /* workspace_limit_in_bytes */, | size_t /* workspace_limit_in_bytes */, | ||||
bool /* reproducible */) override { | |||||
const AlgoAttribute& /*attr*/) override { | |||||
return nullptr; | return nullptr; | ||||
}; | }; | ||||
@@ -74,7 +74,7 @@ public: | |||||
const TensorLayout& /* out_grad */, | const TensorLayout& /* out_grad */, | ||||
const TensorLayout& /* filter_grad */, | const TensorLayout& /* filter_grad */, | ||||
size_t /* workspace_limit_in_bytes */, | size_t /* workspace_limit_in_bytes */, | ||||
bool /* reproducible */) override { | |||||
const AlgoAttribute& /*attr*/) override { | |||||
return nullptr; | return nullptr; | ||||
}; | }; | ||||
@@ -121,7 +121,7 @@ public: | |||||
const TensorLayout& /* offset_grad */, | const TensorLayout& /* offset_grad */, | ||||
const TensorLayout& /* mask_grad */, | const TensorLayout& /* mask_grad */, | ||||
size_t /* workspace_limit_in_bytes */, | size_t /* workspace_limit_in_bytes */, | ||||
bool /* reproducible */) override { | |||||
const AlgoAttribute& /*attr*/) override { | |||||
return nullptr; | return nullptr; | ||||
}; | }; | ||||
@@ -162,16 +162,14 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout&, | |||||
LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( | LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto algo = | auto algo = | ||||
static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo(); | static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo(); | ||||
if (reproducible) { | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | |||||
"algorithm(%s) is not " | |||||
"reproducible", | |||||
algo->name()); | |||||
} | |||||
megdnn_assert(algo->contain_attribute(attr), | |||||
"require algorithm with attribute%s, but heuristic " | |||||
"algorithm(%s) with attribute%s ", | |||||
Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
return algo; | return algo; | ||||
} | } | ||||
@@ -196,16 +194,14 @@ LocalShareBackwardData::Algorithm* | |||||
LocalShareBackwardDataImpl::get_algorithm_heuristic( | LocalShareBackwardDataImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* filter */, const TensorLayout& /* diff */, | const TensorLayout& /* filter */, const TensorLayout& /* diff */, | ||||
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto algo = static_cast<HandleImpl*>(handle()) | auto algo = static_cast<HandleImpl*>(handle()) | ||||
->default_local_share_bwd_data_algo(); | ->default_local_share_bwd_data_algo(); | ||||
if (reproducible) { | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | |||||
"algorithm(%s) is not " | |||||
"reproducible", | |||||
algo->name()); | |||||
} | |||||
megdnn_assert(algo->contain_attribute(attr), | |||||
"require algorithm with attribute%s, but heuristic " | |||||
"algorithm(%s) with attribute%s ", | |||||
Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
return algo; | return algo; | ||||
} | } | ||||
@@ -230,16 +226,14 @@ LocalShareBackwardFilter::Algorithm* | |||||
LocalShareBackwardFilterImpl::get_algorithm_heuristic( | LocalShareBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto algo = static_cast<HandleImpl*>(handle()) | auto algo = static_cast<HandleImpl*>(handle()) | ||||
->default_local_share_bwd_filter_algo(); | ->default_local_share_bwd_filter_algo(); | ||||
if (reproducible) { | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | |||||
"algorithm(%s) is not " | |||||
"reproducible", | |||||
algo->name()); | |||||
} | |||||
megdnn_assert(algo->contain_attribute(attr), | |||||
"require algorithm with attribute%s, but heuristic " | |||||
"algorithm(%s) with attribute%s ", | |||||
Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
return algo; | return algo; | ||||
} | } | ||||
@@ -34,7 +34,7 @@ public: | |||||
const TensorLayout& /*filter*/, | const TensorLayout& /*filter*/, | ||||
const TensorLayout& /*dst*/, | const TensorLayout& /*dst*/, | ||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /*reproducible*/) override; | |||||
const AlgoAttribute& /*attr*/) override; | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | ||||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
@@ -59,7 +59,7 @@ public: | |||||
const TensorLayout& /*diff*/, | const TensorLayout& /*diff*/, | ||||
const TensorLayout& /*grad*/, | const TensorLayout& /*grad*/, | ||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /*reproducible*/) override; | |||||
const AlgoAttribute& /*attr*/) override; | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | ||||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
@@ -84,7 +84,7 @@ public: | |||||
const TensorLayout& /*diff*/, | const TensorLayout& /*diff*/, | ||||
const TensorLayout& /*grad*/, | const TensorLayout& /*grad*/, | ||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /*reproducible*/) override; | |||||
const AlgoAttribute& /*attr*/) override; | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | ||||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
@@ -91,7 +91,7 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||||
MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | ||||
bool /* reproducible */) { | |||||
const AlgoAttribute& /*attr*/) { | |||||
return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); | return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); | ||||
} | } | ||||
@@ -33,7 +33,7 @@ public: | |||||
const TensorLayout& /*B*/, | const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/, | const TensorLayout& /*C*/, | ||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /* reproducible */) override; | |||||
const AlgoAttribute& /*attr*/) override; | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | ||||
@@ -70,12 +70,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) const { | bool is_available_wk(const SizeArgs& args, size_t limit) const { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) const { | size_t limit = std::numeric_limits<size_t>::max()) const { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
const Workspace& workspace) { | const Workspace& workspace) { | ||||
@@ -32,16 +32,16 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
BatchedMatrixMulForwardImpl::Algorithm* | BatchedMatrixMulForwardImpl::Algorithm* | ||||
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, bool reproducible) { | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args{this, A, B, C}; | AlgoBase::SizeArgs args{this, A, B, C}; | ||||
if (sm_algo_pack.blas.is_available_reproducible(args, reproducible, | |||||
workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.blas.is_available_attribute(args, attr, | |||||
workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.blas; | return &sm_algo_pack.blas; | ||||
} | } | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<BatchedMatrixMulForwardImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute<BatchedMatrixMulForwardImpl>( | |||||
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | ||||
"batched matrix mul forward"); | |||||
"batched matrix mul forward", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( | return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( | ||||
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | ||||
@@ -40,7 +40,7 @@ private: | |||||
const TensorLayout& /*B*/, | const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/, | const TensorLayout& /*C*/, | ||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /*reproducible*/) override; | |||||
const AlgoAttribute& /*attr*/) override; | |||||
const char* get_algorithm_set_name() const override { | const char* get_algorithm_set_name() const override { | ||||
return "ROCM BATCHED MATMUL"; | return "ROCM BATCHED MATMUL"; | ||||
@@ -74,12 +74,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) { | bool is_available_wk(const SizeArgs& args, size_t limit) { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -96,24 +95,20 @@ public: | |||||
}; | }; | ||||
class ConvolutionBackwardDataImpl::AlgoMIOpen final : public AlgoBase { | class ConvolutionBackwardDataImpl::AlgoMIOpen final : public AlgoBase { | ||||
bool m_is_reproducible; | |||||
AlgoAttribute m_algo_attribute; | |||||
const char* m_name; | const char* m_name; | ||||
miopenConvBwdDataAlgorithm_t find_best_algo(const ExecArgs& args); | miopenConvBwdDataAlgorithm_t find_best_algo(const ExecArgs& args); | ||||
public: | public: | ||||
AlgoMIOpen() = delete; | AlgoMIOpen() = delete; | ||||
AlgoMIOpen(bool is_reproducible) : m_is_reproducible(is_reproducible) {} | |||||
AlgoMIOpen(AlgoAttribute attr) : m_algo_attribute(attr) {} | |||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_is_reproducible) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
return m_algo_attribute; | |||||
} | } | ||||
const char* name() const override { | const char* name() const override { | ||||
@@ -124,7 +119,7 @@ public: | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) | MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) | ||||
std::string param() const override { | std::string param() const override { | ||||
std::string ret; | std::string ret; | ||||
serialize_write_pod(m_is_reproducible, ret); | |||||
serialize_write_pod(m_algo_attribute, ret); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -170,7 +165,7 @@ class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | |||||
public: | public: | ||||
AlgoPack(); | AlgoPack(); | ||||
AlgoMIOpen miopen{true}; | |||||
AlgoMIOpen miopen{AlgoAttribute::REPRODUCIBLE}; | |||||
AlgoMatmul matmul; | AlgoMatmul matmul; | ||||
AlgoChanwise chanwise; | AlgoChanwise chanwise; | ||||
@@ -71,12 +71,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) { | bool is_available_wk(const SizeArgs& args, size_t limit) { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -93,25 +92,21 @@ public: | |||||
}; | }; | ||||
class ConvolutionBackwardFilterImpl::AlgoMIOpen final : public AlgoBase { | class ConvolutionBackwardFilterImpl::AlgoMIOpen final : public AlgoBase { | ||||
bool m_is_reproducible; | |||||
AlgoAttribute m_algo_attribute; | |||||
const char* m_name; | const char* m_name; | ||||
miopenConvBwdWeightsAlgorithm_t find_best_algo(const ExecArgs& args); | miopenConvBwdWeightsAlgorithm_t find_best_algo(const ExecArgs& args); | ||||
public: | public: | ||||
AlgoMIOpen() = delete; | AlgoMIOpen() = delete; | ||||
AlgoMIOpen(bool is_reproducible) : m_is_reproducible(is_reproducible) {} | |||||
AlgoMIOpen(AlgoAttribute attr) : m_algo_attribute(attr) {} | |||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_is_reproducible) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
return m_algo_attribute; | |||||
} | } | ||||
const char* name() const override { | const char* name() const override { | ||||
return "MIOpenConvolutionBackwardFilter"; | return "MIOpenConvolutionBackwardFilter"; | ||||
@@ -121,7 +116,7 @@ public: | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) | MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) | ||||
std::string param() const override { | std::string param() const override { | ||||
std::string ret; | std::string ret; | ||||
serialize_write_pod(m_is_reproducible, ret); | |||||
serialize_write_pod(m_algo_attribute, ret); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -166,7 +161,7 @@ class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj { | |||||
public: | public: | ||||
AlgoPack(); | AlgoPack(); | ||||
AlgoMIOpen miopen{true}; | |||||
AlgoMIOpen miopen{AlgoAttribute::REPRODUCIBLE}; | |||||
AlgoMatmul matmul; | AlgoMatmul matmul; | ||||
AlgoChanwise chanwise; | AlgoChanwise chanwise; | ||||
@@ -73,12 +73,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) { | bool is_available_wk(const SizeArgs& args, size_t limit) { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -94,25 +93,21 @@ public: | |||||
}; | }; | ||||
class ConvolutionForwardImpl::AlgoMIOpen final : public AlgoBase { | class ConvolutionForwardImpl::AlgoMIOpen final : public AlgoBase { | ||||
bool m_is_reproducible; | |||||
AlgoAttribute m_algo_attribute; | |||||
const char* m_name; | const char* m_name; | ||||
miopenConvFwdAlgorithm_t find_best_algo(const ExecArgs& args); | miopenConvFwdAlgorithm_t find_best_algo(const ExecArgs& args); | ||||
public: | public: | ||||
AlgoMIOpen() = delete; | AlgoMIOpen() = delete; | ||||
AlgoMIOpen(bool is_reproducible) : m_is_reproducible(is_reproducible) {} | |||||
AlgoMIOpen(AlgoAttribute attr) : m_algo_attribute(attr) {} | |||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_is_reproducible) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
return m_algo_attribute; | |||||
} | } | ||||
const char* name() const override { return "MIOpenConvolutionForward"; } | const char* name() const override { return "MIOpenConvolutionForward"; } | ||||
@@ -121,7 +116,7 @@ public: | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) | MEGDNN_DECL_ALGO_TYPE(ROCM_MIOPEN) | ||||
std::string param() const override { | std::string param() const override { | ||||
std::string ret; | std::string ret; | ||||
serialize_write_pod(m_is_reproducible, ret); | |||||
serialize_write_pod(m_algo_attribute, ret); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -215,7 +210,7 @@ class ConvolutionForwardImpl::AlgoPack : NonCopyableObj { | |||||
public: | public: | ||||
AlgoPack(); | AlgoPack(); | ||||
AlgoMIOpen miopen{true}; | |||||
AlgoMIOpen miopen{AlgoAttribute::REPRODUCIBLE}; | |||||
AlgoMatmul matmul; | AlgoMatmul matmul; | ||||
AlgoInplaceMatmul inplace_matmul; | AlgoInplaceMatmul inplace_matmul; | ||||
Algo1x1 a1x1; | Algo1x1 a1x1; | ||||
@@ -33,70 +33,69 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src, | |||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto fm = check_layout_fwd(src, filter, dst); | auto fm = check_layout_fwd(src, filter, dst); | ||||
return get_algorithm_heuristic(src, fm, dst, workspace_limit_in_bytes, | return get_algorithm_heuristic(src, fm, dst, workspace_limit_in_bytes, | ||||
reproducible); | |||||
attr); | |||||
} | } | ||||
ConvolutionForwardImpl::Algorithm* | ConvolutionForwardImpl::Algorithm* | ||||
ConvolutionForwardImpl::get_algorithm_heuristic( | ConvolutionForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const CanonizedFilterMeta& filter, | const TensorLayout& src, const CanonizedFilterMeta& filter, | ||||
const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args(this, src, filter, dst); | AlgoBase::SizeArgs args(this, src, filter, dst); | ||||
//! MIOpen auto-tuning need to run with actual tensors, so we cannot get | //! MIOpen auto-tuning need to run with actual tensors, so we cannot get | ||||
//! best algorithm here. | //! best algorithm here. | ||||
if (is_miopen_supported(args)) { | if (is_miopen_supported(args)) { | ||||
auto algo = megdnn::get_reproducible_algo<ConvolutionForwardImpl>( | |||||
sm_algo_pack.miopen_algos[0], reproducible); | |||||
auto algo = megdnn::get_algo_with_attribute<ConvolutionForwardImpl>( | |||||
sm_algo_pack.miopen_algos[0], attr); | |||||
if (algo) | if (algo) | ||||
return algo; | return algo; | ||||
} | } | ||||
if (args.filter_meta.group > 1) { | if (args.filter_meta.group > 1) { | ||||
if (sm_algo_pack.chanwise.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.chanwise.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
} | } | ||||
} | } | ||||
auto prefer_1x1 = [&args, reproducible, workspace_limit_in_bytes]() { | |||||
auto prefer_1x1 = [&args, attr, workspace_limit_in_bytes]() { | |||||
const size_t MAX_BATCH_SIZE_FOR_1x1_MAT_ALGO = 4; | const size_t MAX_BATCH_SIZE_FOR_1x1_MAT_ALGO = 4; | ||||
size_t batch_size = args.src_layout->shape[0]; | size_t batch_size = args.src_layout->shape[0]; | ||||
if (batch_size > MAX_BATCH_SIZE_FOR_1x1_MAT_ALGO) { | if (batch_size > MAX_BATCH_SIZE_FOR_1x1_MAT_ALGO) { | ||||
return false; | return false; | ||||
} | } | ||||
return sm_algo_pack.a1x1.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes); | |||||
return sm_algo_pack.a1x1.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes); | |||||
}; | }; | ||||
if (prefer_1x1()) { | if (prefer_1x1()) { | ||||
return &sm_algo_pack.a1x1; | return &sm_algo_pack.a1x1; | ||||
} | } | ||||
auto prefer_1x1_large_batch = [&args, reproducible, | |||||
workspace_limit_in_bytes]() { | |||||
auto prefer_1x1_large_batch = [&args, attr, workspace_limit_in_bytes]() { | |||||
const size_t MIN_BATCH_SIZE_FOR_1x1_LARGE_BATCH_ALGO = 32; | const size_t MIN_BATCH_SIZE_FOR_1x1_LARGE_BATCH_ALGO = 32; | ||||
size_t batch_size = args.src_layout->shape[0]; | size_t batch_size = args.src_layout->shape[0]; | ||||
if (batch_size < MIN_BATCH_SIZE_FOR_1x1_LARGE_BATCH_ALGO) { | if (batch_size < MIN_BATCH_SIZE_FOR_1x1_LARGE_BATCH_ALGO) { | ||||
return false; | return false; | ||||
} | } | ||||
return sm_algo_pack.batched_matrix_mul.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes); | |||||
return sm_algo_pack.batched_matrix_mul.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes); | |||||
}; | }; | ||||
if (prefer_1x1_large_batch()) { | if (prefer_1x1_large_batch()) { | ||||
return &sm_algo_pack.batched_matrix_mul; | return &sm_algo_pack.batched_matrix_mul; | ||||
} | } | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvolutionForwardImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute<ConvolutionForwardImpl>( | |||||
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | ||||
"rocm conv fwd"); | |||||
"rocm conv fwd", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<ConvolutionForwardImpl>( | return megdnn::get_usable_algo<ConvolutionForwardImpl>( | ||||
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | ||||
@@ -157,36 +156,36 @@ ConvolutionBackwardDataImpl::Algorithm* | |||||
ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto fm = check_layout_fwd(grad, filter, diff); | auto fm = check_layout_fwd(grad, filter, diff); | ||||
return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes, | return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes, | ||||
reproducible); | |||||
attr); | |||||
} | } | ||||
ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
const CanonizedFilterMeta& filter, const TensorLayout& diff, | const CanonizedFilterMeta& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args(this, filter, diff, grad); | AlgoBase::SizeArgs args(this, filter, diff, grad); | ||||
if (is_miopen_supported(args.as_fwd_args())) { | if (is_miopen_supported(args.as_fwd_args())) { | ||||
auto algo = megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( | |||||
sm_algo_pack.miopen_algos[0], reproducible); | |||||
auto algo = megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>( | |||||
sm_algo_pack.miopen_algos[0], attr); | |||||
if (algo) | if (algo) | ||||
return algo; | return algo; | ||||
} | } | ||||
if (args.filter_meta.group > 1 && | if (args.filter_meta.group > 1 && | ||||
sm_algo_pack.chanwise.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
sm_algo_pack.chanwise.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
} | } | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>( | |||||
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | ||||
"rocm conv bwd_data"); | |||||
"rocm conv bwd_data", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | ||||
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | ||||
@@ -230,38 +229,38 @@ ConvolutionBackwardFilterImpl::Algorithm* | |||||
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
auto fm = check_layout_fwd(src, grad, diff); | auto fm = check_layout_fwd(src, grad, diff); | ||||
return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes, | return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes, | ||||
reproducible); | |||||
attr); | |||||
} | } | ||||
ConvolutionBackwardFilterImpl::Algorithm* | ConvolutionBackwardFilterImpl::Algorithm* | ||||
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, | const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args(this, src, diff, grad); | AlgoBase::SizeArgs args(this, src, diff, grad); | ||||
if (is_miopen_supported(args.as_fwd_args())) { | if (is_miopen_supported(args.as_fwd_args())) { | ||||
auto algo = | auto algo = | ||||
megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( | |||||
sm_algo_pack.miopen_algos[0], reproducible); | |||||
megdnn::get_algo_with_attribute<ConvolutionBackwardFilterImpl>( | |||||
sm_algo_pack.miopen_algos[0], attr); | |||||
if (algo) | if (algo) | ||||
return algo; | return algo; | ||||
} | } | ||||
if (args.grad_filter_meta.group > 1 && | if (args.grad_filter_meta.group > 1 && | ||||
sm_algo_pack.chanwise.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
sm_algo_pack.chanwise.is_available_attribute( | |||||
args, attr, workspace_limit_in_bytes)) { | |||||
// prefer special chanwise impl | // prefer special chanwise impl | ||||
return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
} | } | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute<ConvolutionBackwardFilterImpl>( | |||||
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | ||||
"rocm conv bwd_filter"); | |||||
"rocm conv bwd_filter", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | ||||
sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | ||||
@@ -26,9 +26,9 @@ public: | |||||
AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
const TensorLayout& src, const CanonizedFilterMeta& filter, | const TensorLayout& src, const CanonizedFilterMeta& filter, | ||||
const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
return get_algorithm_heuristic(src, filter, dst, | return get_algorithm_heuristic(src, filter, dst, | ||||
workspace_limit_in_bytes, reproducible) | |||||
workspace_limit_in_bytes, attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
size_t get_workspace_in_bytes(const TensorLayout& src, | size_t get_workspace_in_bytes(const TensorLayout& src, | ||||
@@ -76,12 +76,12 @@ private: | |||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | Algorithm* get_algorithm_heuristic(const TensorLayout& src, | ||||
const CanonizedFilterMeta& filter, | const CanonizedFilterMeta& filter, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible); | |||||
const AlgoAttribute& attr); | |||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
}; | }; | ||||
@@ -94,9 +94,9 @@ public: | |||||
AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
const CanonizedFilterMeta& filter, const TensorLayout& diff, | const CanonizedFilterMeta& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
return get_algorithm_heuristic(filter, diff, grad, | return get_algorithm_heuristic(filter, diff, grad, | ||||
workspace_limit_in_bytes, reproducible) | |||||
workspace_limit_in_bytes, attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
size_t get_workspace_in_bytes(const TensorLayout& filter, | size_t get_workspace_in_bytes(const TensorLayout& filter, | ||||
@@ -122,12 +122,12 @@ private: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | ||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible); | |||||
const AlgoAttribute& attr); | |||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
}; | }; | ||||
@@ -141,9 +141,9 @@ public: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const CanonizedFilterMeta& grad, | const CanonizedFilterMeta& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | |||||
const AlgoAttribute& attr) { | |||||
return get_algorithm_heuristic(src, diff, grad, | return get_algorithm_heuristic(src, diff, grad, | ||||
workspace_limit_in_bytes, reproducible) | |||||
workspace_limit_in_bytes, attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
size_t get_workspace_in_bytes(const TensorLayout& src, | size_t get_workspace_in_bytes(const TensorLayout& src, | ||||
@@ -169,12 +169,12 @@ private: | |||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad, | const TensorLayout& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | |||||
const AlgoAttribute& attr) override; | |||||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | Algorithm* get_algorithm_heuristic(const TensorLayout& src, | ||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const CanonizedFilterMeta& grad, | const CanonizedFilterMeta& grad, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible); | |||||
const AlgoAttribute& attr); | |||||
static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
}; | }; | ||||
@@ -70,12 +70,11 @@ public: | |||||
bool is_available_wk(const SizeArgs& args, size_t limit) const { | bool is_available_wk(const SizeArgs& args, size_t limit) const { | ||||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | return is_available(args) && get_workspace_in_bytes(args) <= limit; | ||||
} | } | ||||
bool is_available_reproducible( | |||||
const SizeArgs& args, bool reproducible = true, | |||||
bool is_available_attribute( | |||||
const SizeArgs& args, | |||||
const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
size_t limit = std::numeric_limits<size_t>::max()) const { | size_t limit = std::numeric_limits<size_t>::max()) const { | ||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | |||||
return contain_attribute(attr) && is_available_wk(args, limit); | |||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
const Workspace& workspace) { | const Workspace& workspace) { | ||||
@@ -29,16 +29,16 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, bool reproducible) { | |||||
size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
AlgoBase::SizeArgs args{this, A, B, C}; | AlgoBase::SizeArgs args{this, A, B, C}; | ||||
if (sm_algo_pack.blas.is_available_reproducible( | |||||
args, reproducible, workspace_limit_in_bytes)) { | |||||
if (sm_algo_pack.blas.is_available_attribute(args, attr, | |||||
workspace_limit_in_bytes)) { | |||||
return &sm_algo_pack.blas; | return &sm_algo_pack.blas; | ||||
} | } | ||||
if (reproducible) { | |||||
return megdnn::get_reproducible_algo<MatrixMulForwardImpl>( | |||||
if (attr != AlgoAttribute::DEFAULT) { | |||||
return megdnn::get_algo_with_attribute<MatrixMulForwardImpl>( | |||||
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | ||||
"matrix mul forward"); | |||||
"matrix mul forward", attr); | |||||
} else { | } else { | ||||
return megdnn::get_usable_algo<MatrixMulForwardImpl>( | return megdnn::get_usable_algo<MatrixMulForwardImpl>( | ||||
sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | ||||
@@ -40,7 +40,7 @@ private: | |||||
const TensorLayout& /*B*/, | const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/, | const TensorLayout& /*C*/, | ||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /*reproducible*/) override; | |||||
const AlgoAttribute& /*attr*/) override; | |||||
const char* get_algorithm_set_name() const override { | const char* get_algorithm_set_name() const override { | ||||
return "ROCM MATMUL"; | return "ROCM MATMUL"; | ||||
@@ -278,6 +278,15 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( | |||||
return ret; | return ret; | ||||
} | } | ||||
AlgoAttribute extract_algo_attribute_from_execution_strategy( | |||||
const ExecutionStrategy& strategy) { | |||||
AlgoAttribute ret = AlgoAttribute::DEFAULT; | |||||
if (strategy & ExecutionStrategy::REPRODUCIBLE) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
//! Test whether the algo attribute of a algo match the require | //! Test whether the algo attribute of a algo match the require | ||||
//! algo_strategy | //! algo_strategy | ||||
static bool algo_attribute_match_strategy(AlgoAttribute attribute, | static bool algo_attribute_match_strategy(AlgoAttribute attribute, | ||||
@@ -290,7 +299,6 @@ static bool algo_attribute_match_strategy(AlgoAttribute attribute, | |||||
} | } | ||||
return ret; | return ret; | ||||
} | } | ||||
} // namespace | } // namespace | ||||
namespace mgb { | namespace mgb { | ||||
@@ -303,9 +311,9 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||||
return; | return; | ||||
AlgoChooserProfileCache::Result prof_rst; | AlgoChooserProfileCache::Result prof_rst; | ||||
std::string str_on_inp_shape = ssprintf( | |||||
"on input layouts (%s, %s)", ctx.layouts()[0].to_string().c_str(), | |||||
ctx.layouts()[1].to_string().c_str()); | |||||
auto target_attribute = | |||||
extract_algo_attribute_from_execution_strategy(selected_strategy); | |||||
std::string layouts_str = format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out); | |||||
double cur_timeout = 0; | double cur_timeout = 0; | ||||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
@@ -316,20 +324,22 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||||
Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst; | Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst; | ||||
std::string msg = ssprintf("profiling %s algorithm %s %s", | std::string msg = ssprintf("profiling %s algorithm %s %s", | ||||
ctx.mgb_opr()->dyn_typeinfo()->name, | ctx.mgb_opr()->dyn_typeinfo()->name, | ||||
algo.name.c_str(), str_on_inp_shape.c_str()); | |||||
algo.name.c_str(), layouts_str.c_str()); | |||||
ImplExecutionPolicy policy; | ImplExecutionPolicy policy; | ||||
policy.algo = algo.desc; | policy.algo = algo.desc; | ||||
ctx.construct_execution_policy(selected_strategy, policy); | ctx.construct_execution_policy(selected_strategy, policy); | ||||
if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) { | if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) { | ||||
continue; | continue; | ||||
} | } | ||||
auto algo_attribute = ctx.megdnn_opr() | |||||
->get_algorithm_from_desc(policy.algo) | |||||
->attribute(); | |||||
if (!algo_attribute_match_strategy(algo_attribute, selected_strategy)) { | |||||
auto palgo = ctx.megdnn_opr()->get_algorithm_from_desc(policy.algo); | |||||
if (!algo_attribute_match_strategy(palgo->attribute(), | |||||
selected_strategy)) { | |||||
mgb_log_debug( | mgb_log_debug( | ||||
"skip algo %s, which is not match the profile strategy.", | |||||
algo.name.c_str()); | |||||
"skip algo %s with attribute%s, which is not match the " | |||||
"profile strategy required attribute%s.", | |||||
algo.name.c_str(), | |||||
Algorithm::attribute_str(palgo->attribute()).c_str(), | |||||
Algorithm::attribute_str(target_attribute).c_str()); | |||||
continue; | continue; | ||||
} | } | ||||
@@ -360,9 +370,10 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||||
rst.workspace, rst.time); | rst.workspace, rst.time); | ||||
prof_rst.push_back(rst); | prof_rst.push_back(rst); | ||||
} | } | ||||
std::string msg = ssprintf("no usable %s algorithm %s", | |||||
ctx.mgb_opr()->dyn_typeinfo()->name, | |||||
str_on_inp_shape.c_str()); | |||||
std::string msg = | |||||
ssprintf("no usable %s algorithm %s with attribute(%s)", | |||||
ctx.mgb_opr()->dyn_typeinfo()->name, layouts_str.c_str(), | |||||
Algorithm::attribute_str(target_attribute).c_str()); | |||||
mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); | mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); | ||||
FixedTensorLayouts origin_layouts = ctx.layouts(); | FixedTensorLayouts origin_layouts = ctx.layouts(); | ||||
@@ -589,14 +600,15 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||||
"workspace_limit should not be setted if choose algo by " | "workspace_limit should not be setted if choose algo by " | ||||
"heuristic"); | "heuristic"); | ||||
} | } | ||||
bool reproducible = static_cast<bool>(selected_strategy & | |||||
ExecutionStrategy::REPRODUCIBLE); | |||||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
ImplExecutionPolicy policy; | ImplExecutionPolicy policy; | ||||
policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | ||||
args..., workspace_limit, reproducible), | |||||
m_layouts).desc; | |||||
args..., workspace_limit, | |||||
extract_algo_attribute_from_execution_strategy( | |||||
selected_strategy)), | |||||
m_layouts) | |||||
.desc; | |||||
Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); | Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); | ||||
mgb_assert(algo, "Unknown algo description"); | mgb_assert(algo, "Unknown algo description"); | ||||
@@ -647,8 +659,6 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||||
ExecutionStrategy selected_strategy, | ExecutionStrategy selected_strategy, | ||||
typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, | typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, | ||||
bool retrive_from_cache) const { | bool retrive_from_cache) const { | ||||
bool reproducible = static_cast<bool>(selected_strategy & | |||||
ExecutionStrategy::REPRODUCIBLE); | |||||
if (!policy.algo.valid()) { | if (!policy.algo.valid()) { | ||||
if (retrive_from_cache) { | if (retrive_from_cache) { | ||||
policy.algo = | policy.algo = | ||||
@@ -656,11 +666,13 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||||
} else { | } else { | ||||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | |||||
args..., workspace_limit, | |||||
reproducible), | |||||
m_layouts) | |||||
.desc; | |||||
policy.algo = | |||||
APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | |||||
args..., workspace_limit, | |||||
extract_algo_attribute_from_execution_strategy( | |||||
selected_strategy)), | |||||
m_layouts) | |||||
.desc; | |||||
} | } | ||||
mgb_assert(policy.algo.valid(), | mgb_assert(policy.algo.valid(), | ||||
"No algo found from cache or heuristic, maybe some error " | "No algo found from cache or heuristic, maybe some error " | ||||
@@ -2375,7 +2375,7 @@ public: | |||||
AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1, | AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1, | ||||
const TensorLayout& p2, | const TensorLayout& p2, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible)); | |||||
const AlgoAttribute& attr)); | |||||
MOCK_METHOD3(get_all_algorithms, | MOCK_METHOD3(get_all_algorithms, | ||||
std::vector<Algorithm*>(const TensorLayout& p0, | std::vector<Algorithm*>(const TensorLayout& p0, | ||||
@@ -2385,7 +2385,7 @@ public: | |||||
Algorithm*(const TensorLayout& p0, const TensorLayout& p1, | Algorithm*(const TensorLayout& p0, const TensorLayout& p1, | ||||
const TensorLayout& p2, | const TensorLayout& p2, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible)); | |||||
const AlgoAttribute& attr)); | |||||
MOCK_METHOD1(get_algorithm_from_desc, | MOCK_METHOD1(get_algorithm_from_desc, | ||||
Algorithm*(const AlgorithmDesc&)); | Algorithm*(const AlgorithmDesc&)); | ||||