|
|
@@ -41,48 +41,56 @@ enum class StrategyType : uint32_t { |
|
|
|
}; |
|
|
|
|
|
|
|
struct StrategyHashParam { |
|
|
|
fallback::ConvBiasImpl::NCBKernSizeParam param; |
|
|
|
param::ConvBias::Format format; |
|
|
|
fallback::MatrixMulImpl::AlgoBase::PackMode packmode; |
|
|
|
bool is_xcorr; |
|
|
|
bool is_square; //! kernel_h == kernel_w, stride_h = stride_w |
|
|
|
size_t block_m; |
|
|
|
size_t block_n; |
|
|
|
size_t block_k; |
|
|
|
size_t kernel; |
|
|
|
size_t stride; |
|
|
|
|
|
|
|
fallback::ConvBiasImpl::NCBKernSizeParam param; |
|
|
|
param::ConvBias::Format format; |
|
|
|
fallback::MatrixMulImpl::AlgoBase::PackMode packmode; |
|
|
|
}; |
|
|
|
|
|
|
|
struct StrategyHashParamHash { |
|
|
|
std::size_t operator()(const StrategyHashParam& sparam) const { |
|
|
|
constexpr size_t base = 1; //! avoid hashkey is zero |
|
|
|
std::size_t result = |
|
|
|
static_cast<std::size_t>(sparam.param.src_type.enumv()) + base; |
|
|
|
uint64_t operator()(const StrategyHashParam& sparam) const { |
|
|
|
constexpr uint64_t base = 1; //! avoid hashkey is zero |
|
|
|
uint64_t result = |
|
|
|
static_cast<uint64_t>(sparam.param.src_type.enumv()) + base; |
|
|
|
result = result ^ |
|
|
|
((static_cast<std::size_t>(sparam.param.dst_type.enumv()) + |
|
|
|
base) |
|
|
|
((static_cast<uint64_t>(sparam.param.dst_type.enumv()) + base) |
|
|
|
<< 3); |
|
|
|
result = result ^ |
|
|
|
((static_cast<std::size_t>(sparam.param.filter_type.enumv()) + |
|
|
|
((static_cast<uint64_t>(sparam.param.filter_type.enumv()) + |
|
|
|
base) |
|
|
|
<< 6); |
|
|
|
result = result ^ |
|
|
|
((static_cast<std::size_t>(sparam.param.bias_type.enumv()) + |
|
|
|
base) |
|
|
|
((static_cast<uint64_t>(sparam.param.bias_type.enumv()) + base) |
|
|
|
<< 9); |
|
|
|
result = result ^ ((static_cast<uint64_t>(sparam.format) + base) << 12); |
|
|
|
result = result ^ |
|
|
|
((static_cast<std::size_t>(sparam.format) + base) << 12); |
|
|
|
result = result ^ |
|
|
|
((static_cast<std::size_t>(sparam.packmode) + base) << 15); |
|
|
|
result = result ^ |
|
|
|
((static_cast<std::size_t>(sparam.block_m) + base) << 18); |
|
|
|
((static_cast<uint64_t>(sparam.packmode) + base) << 15); |
|
|
|
result = |
|
|
|
result ^ ((static_cast<uint64_t>(sparam.block_m) + base) << 18); |
|
|
|
result = |
|
|
|
result ^ ((static_cast<uint64_t>(sparam.block_n) + base) << 22); |
|
|
|
result = |
|
|
|
result ^ ((static_cast<uint64_t>(sparam.block_k) + base) << 26); |
|
|
|
result = result ^ ((static_cast<uint64_t>(sparam.kernel) + base) << 30); |
|
|
|
result = result ^ ((static_cast<uint64_t>(sparam.stride) + base) << 34); |
|
|
|
result = result ^ |
|
|
|
((static_cast<std::size_t>(sparam.block_n) + base) << 22); |
|
|
|
((static_cast<uint64_t>(sparam.is_square) + base) << 35); |
|
|
|
result = result ^ |
|
|
|
((static_cast<std::size_t>(sparam.block_k) + base) << 26); |
|
|
|
((static_cast<uint64_t>(sparam.is_xcorr) + base) << 36); |
|
|
|
return result; |
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
struct StrategyHashParamEqual { |
|
|
|
std::size_t operator()(const StrategyHashParam& param1, |
|
|
|
const StrategyHashParam& param2) const { |
|
|
|
bool operator()(const StrategyHashParam& param1, |
|
|
|
const StrategyHashParam& param2) const { |
|
|
|
bool flags = true; |
|
|
|
flags = param1.param.src_type == param2.param.src_type && flags; |
|
|
|
flags = param1.param.filter_type == param2.param.filter_type && flags; |
|
|
@@ -93,6 +101,10 @@ struct StrategyHashParamEqual { |
|
|
|
flags = param1.block_m == param2.block_m && flags; |
|
|
|
flags = param1.block_n == param2.block_n && flags; |
|
|
|
flags = param1.block_k == param2.block_k && flags; |
|
|
|
flags = param1.kernel == param2.kernel && flags; |
|
|
|
flags = param1.stride == param2.stride && flags; |
|
|
|
flags = param1.is_square == param2.is_square && flags; |
|
|
|
flags = param1.is_xcorr == param2.is_xcorr && flags; |
|
|
|
return flags; |
|
|
|
}; |
|
|
|
}; |
|
|
@@ -484,10 +496,15 @@ Strategy* StrategyDelegationStorage::get( |
|
|
|
sparam.block_m = block_m; |
|
|
|
sparam.block_n = block_n; |
|
|
|
sparam.block_k = block_k; |
|
|
|
sparam.kernel = param.filter_meta.spatial[0]; |
|
|
|
sparam.stride = param.filter_meta.stride[0]; |
|
|
|
sparam.is_square = |
|
|
|
param.filter_meta.spatial[0] == param.filter_meta.spatial[0]; |
|
|
|
sparam.is_xcorr = param.filter_meta.should_flip; |
|
|
|
MEGDNN_LOCK_GUARD(m_mtx); |
|
|
|
if (map_strategys.find(sparam) == map_strategys.end()) { |
|
|
|
MEGDNN_LOCK_GUARD(m_mtx); |
|
|
|
auto strategy = Factory::make_strategy(matmul_algo, packmode, |
|
|
|
param, stype); |
|
|
|
auto strategy = |
|
|
|
Factory::make_strategy(matmul_algo, packmode, param, stype); |
|
|
|
map_strategys[sparam] = std::move(strategy); |
|
|
|
} |
|
|
|
return static_cast<Strategy*>(map_strategys[sparam].get()); |
|
|
|