|
|
@@ -67,7 +67,8 @@ struct StrategyHashParamEqual { |
|
|
|
return flags; |
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
//! NOTE: must keep consistence with can_make_conv1x1_strategy when you modify |
|
|
|
//! this function |
|
|
|
std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( |
|
|
|
const ConvBiasImpl::NCBKernSizeParam& param, |
|
|
|
MatrixMulImpl::AlgoBase::PackMode pack_mode, |
|
|
@@ -211,14 +212,64 @@ Conv1x1StrategyBase* Conv1x1Factory::make_conv1x1_strategy( |
|
|
|
bool Conv1x1Factory::can_make_conv1x1_strategy( |
|
|
|
const ConvBiasImpl::NCBKernSizeParam& param, |
|
|
|
MatrixMulImpl::AlgoBase::PackMode pack_mode, param::ConvBias::Format) { |
|
|
|
bool ok_default_cb1 = |
|
|
|
param.src_type.enumv() == DTypeTrait<dt_float32>::enumv; |
|
|
|
bool ok_default_cb2 = |
|
|
|
param.filter_type.enumv() == param.src_type.enumv() && |
|
|
|
((param.src_type.enumv() == DTypeTrait<dt_int8>::enumv && |
|
|
|
param.dst_type.enumv() == DTypeTrait<dt_int32>::enumv) || |
|
|
|
(param.src_type.enumv() == DTypeTrait<dt_int8>::enumv && |
|
|
|
param.dst_type.enumv() == DTypeTrait<dt_int16>::enumv) || |
|
|
|
(param.src_type.enumv() == DTypeTrait<dtype::QuantizedS8>::enumv && |
|
|
|
param.dst_type.enumv() == |
|
|
|
DTypeTrait<dtype::QuantizedS32>::enumv) || |
|
|
|
(param.src_type.enumv() == DTypeTrait<dtype::QuantizedS8>::enumv && |
|
|
|
param.dst_type.enumv() == DTypeTrait<dtype::QuantizedS8>::enumv)); |
|
|
|
bool ok_default_cb1_fp16 = false; |
|
|
|
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC || !MEGDNN_DISABLE_FLOAT16 |
|
|
|
if ((pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK || |
|
|
|
pack_mode == MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) && |
|
|
|
param.src_type.enumv() == DTypeTrait<dt_float16>::enumv) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
ok_default_cb1_fp16 = |
|
|
|
param.src_type.enumv() == DTypeTrait<dt_float16>::enumv; |
|
|
|
#endif |
|
|
|
bool ok_default_cb2_arm = false; |
|
|
|
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 |
|
|
|
ok_default_cb2_arm = param.filter_type.enumv() == param.src_type.enumv() && |
|
|
|
((param.src_type.enumv() == |
|
|
|
DTypeTrait<dtype::Quantized8Asymm>::enumv && |
|
|
|
param.dst_type.enumv() == |
|
|
|
DTypeTrait<dtype::QuantizedS32>::enumv) || |
|
|
|
(param.src_type.enumv() == |
|
|
|
DTypeTrait<dtype::Quantized8Asymm>::enumv && |
|
|
|
param.dst_type.enumv() == |
|
|
|
DTypeTrait<dtype::Quantized8Asymm>::enumv)); |
|
|
|
#endif |
|
|
|
return true; |
|
|
|
|
|
|
|
bool ok_only_packa_cb1 = |
|
|
|
param.src_type.enumv() == DTypeTrait<dt_float32>::enumv; |
|
|
|
bool ok_no_pack_cb1 = |
|
|
|
param.src_type.enumv() == DTypeTrait<dt_float32>::enumv; |
|
|
|
bool ok_no_pack_cb2 = |
|
|
|
param.filter_type.enumv() == param.src_type.enumv() && |
|
|
|
((param.src_type.enumv() == DTypeTrait<dt_int8>::enumv && |
|
|
|
param.dst_type.enumv() == DTypeTrait<dt_int16>::enumv) || |
|
|
|
(param.src_type.enumv() == DTypeTrait<dt_int8>::enumv && |
|
|
|
param.dst_type.enumv() == DTypeTrait<dt_int32>::enumv) || |
|
|
|
(param.src_type.enumv() == DTypeTrait<dtype::QuantizedS8>::enumv && |
|
|
|
param.dst_type.enumv() == |
|
|
|
DTypeTrait<dtype::QuantizedS32>::enumv)); |
|
|
|
switch (pack_mode) { |
|
|
|
case MatrixMulImpl::AlgoBase::PackMode::DEFAULT: |
|
|
|
return ok_default_cb1 || ok_default_cb2 || ok_default_cb1_fp16 || |
|
|
|
ok_default_cb2_arm; |
|
|
|
break; |
|
|
|
case MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA: |
|
|
|
return ok_only_packa_cb1; |
|
|
|
break; |
|
|
|
case MatrixMulImpl::AlgoBase::PackMode::NO_PACK: |
|
|
|
return ok_no_pack_cb1 || ok_no_pack_cb2; |
|
|
|
break; |
|
|
|
default: |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace conv1x1 |
|
|
|