GitOrigin-RevId: 39cad612cc
release-1.7
@@ -32,13 +32,8 @@ namespace megdnn { | |||||
*/ | */ | ||||
template <class Opr, typename... Args> | template <class Opr, typename... Args> | ||||
bool has_available_algo(Opr* opr, Args&&... args) { | bool has_available_algo(Opr* opr, Args&&... args) { | ||||
const typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...); | |||||
for (auto i : Opr::algo_pack().all_algos) { | |||||
if (i->is_available(size_args)) { | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
auto&& all_algos = opr->get_all_algorithms_info(std::forward<Args>(args)...); | |||||
return !all_algos.empty(); | |||||
} | } | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -157,7 +157,6 @@ struct ConvMaker<opr::BatchConvBiasForward> | |||||
MakeConvCaller4<megdnn::BatchConvBiasForward>, | MakeConvCaller4<megdnn::BatchConvBiasForward>, | ||||
megdnn::param::BatchConvBias> {}; | megdnn::param::BatchConvBias> {}; | ||||
#if 0 | |||||
#include "../../opr/impl/internal/invoke.h" | #include "../../opr/impl/internal/invoke.h" | ||||
template <typename Opr> | template <typename Opr> | ||||
struct MultiAlgoOprTrait; | struct MultiAlgoOprTrait; | ||||
@@ -202,7 +201,6 @@ INST(ConvolutionBackwardData) | |||||
INST(PoolingForward) | INST(PoolingForward) | ||||
#undef APPLY | #undef APPLY | ||||
#undef INST | #undef INST | ||||
#endif | |||||
} // namespace | } // namespace | ||||
namespace mgb { | namespace mgb { | ||||
@@ -291,9 +289,7 @@ VarNode* modify_opr_format( | |||||
#undef cb | #undef cb | ||||
} | } | ||||
#if 0 | |||||
bool has_available_algo(const VarNodeArray& i, | |||||
const cg::OperatorNodeBase* opr) { | |||||
bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr) { | |||||
#define cb(_Opr) \ | #define cb(_Opr) \ | ||||
if (opr->dyn_typeinfo() == _Opr::typeinfo()) { \ | if (opr->dyn_typeinfo() == _Opr::typeinfo()) { \ | ||||
MGB_MARK_USED_VAR(MultiAlgoOprTrait<_Opr>::has_algo); \ | MGB_MARK_USED_VAR(MultiAlgoOprTrait<_Opr>::has_algo); \ | ||||
@@ -301,13 +297,12 @@ bool has_available_algo(const VarNodeArray& i, | |||||
_.emplace_back(opr->output(0)); \ | _.emplace_back(opr->output(0)); \ | ||||
return MultiAlgoOprTrait<_Opr>::has_available_algo(_, opr); \ | return MultiAlgoOprTrait<_Opr>::has_available_algo(_, opr); \ | ||||
} else | } else | ||||
cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) | |||||
cb(PoolingForward) { | |||||
mgb_throw(InternalError, "invalid multi-algo operator(got:%s)", | |||||
opr->dyn_typeinfo()->name); | |||||
cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) { | |||||
mgb_throw( | |||||
InternalError, "invalid multi-algo operator(got:%s)", | |||||
opr->dyn_typeinfo()->name); | |||||
} | } | ||||
} | } | ||||
#endif | |||||
} // namespace intl | } // namespace intl | ||||
} // namespace gopt | } // namespace gopt | ||||
@@ -21,9 +21,7 @@ namespace intl { | |||||
#define FOREACH_FORMAT_AWARE_OPR(cb) \ | #define FOREACH_FORMAT_AWARE_OPR(cb) \ | ||||
cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) \ | cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) \ | ||||
cb(WarpPerspective) cb(Resize) | cb(WarpPerspective) cb(Resize) | ||||
#if 0 | |||||
bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr); | bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr); | ||||
#endif | |||||
VarNode* modify_opr_format( | VarNode* modify_opr_format( | ||||
opr::ConvBias::Param::Format opr_format, const VarNodeArray& i, | opr::ConvBias::Param::Format opr_format, const VarNodeArray& i, | ||||
@@ -43,7 +43,8 @@ static inline size_t extra_alignment( | |||||
size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; | size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; | ||||
size_t extra_alignment = | size_t extra_alignment = | ||||
alignment_in_bits >= dtype_bits ? alignment_in_bits / dtype_bits : 1; | alignment_in_bits >= dtype_bits ? alignment_in_bits / dtype_bits : 1; | ||||
if (target_formats == TensorFormats::NHWC) | |||||
if (target_formats == TensorFormats::NHWC || | |||||
target_formats == TensorFormats::KRSC) | |||||
channel_alignment = extra_alignment * channel_alignment / | channel_alignment = extra_alignment * channel_alignment / | ||||
gcd(channel_alignment, extra_alignment); | gcd(channel_alignment, extra_alignment); | ||||
return channel_alignment; | return channel_alignment; | ||||
@@ -60,10 +61,12 @@ static inline std::tuple<size_t, size_t> extra_alignment( | |||||
size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; | size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; | ||||
size_t extra_alignment = | size_t extra_alignment = | ||||
alignment_in_bits >= dtype_bits ? alignment_in_bits / dtype_bits : 1; | alignment_in_bits >= dtype_bits ? alignment_in_bits / dtype_bits : 1; | ||||
if (key.input_format == TensorFormats::NHWC) | |||||
if (key.input_format == TensorFormats::NHWC || | |||||
key.input_format == TensorFormats::KRSC) | |||||
input_channel_alignment = input_channel_alignment * extra_alignment / | input_channel_alignment = input_channel_alignment * extra_alignment / | ||||
gcd(input_channel_alignment, extra_alignment); | gcd(input_channel_alignment, extra_alignment); | ||||
if (key.output_format == TensorFormats::NHWC) | |||||
if (key.output_format == TensorFormats::NHWC || | |||||
key.output_format == TensorFormats::KRSC) | |||||
output_channel_alignment = output_channel_alignment * extra_alignment / | output_channel_alignment = output_channel_alignment * extra_alignment / | ||||
gcd(output_channel_alignment, extra_alignment); | gcd(output_channel_alignment, extra_alignment); | ||||
return std::make_tuple(input_channel_alignment, output_channel_alignment); | return std::make_tuple(input_channel_alignment, output_channel_alignment); | ||||
@@ -62,6 +62,16 @@ enum class TensorFormats : uint32_t { | |||||
KCRS = 24, ///< [K, C, R, S] | KCRS = 24, ///< [K, C, R, S] | ||||
GKCRS = 25, ///< [G, K, C, R, S] | GKCRS = 25, ///< [G, K, C, R, S] | ||||
C11RS = 26, ///< [C, 1, 1, R, S] | C11RS = 26, ///< [C, 1, 1, R, S] | ||||
// NHWC | |||||
KRSC = 27, /// < [K, R, S, C] | |||||
// NCHW32 | |||||
KCRSc32 = 28, ///<[K, C/32, R, S, C%32] | |||||
// NCHW64 | |||||
KCRSc64 = 29, ///<[K, C/64, R, S, C%64] | |||||
// CHWN4 | |||||
CRSKc4 = 30, ///< [C/4, R, S, K, C%4] | |||||
}; | }; | ||||
class ReformatManager : public NonCopyableObj { | class ReformatManager : public NonCopyableObj { | ||||