|
|
@@ -220,6 +220,28 @@ struct MultiAlgoOprTrait; |
|
|
|
::megdnn::has_available_algo(megdnn_opr, args...), array_layouts); \ |
|
|
|
MIDOUT_E \ |
|
|
|
} \ |
|
|
|
static bool has_no_naive_heuristic_algo( \ |
|
|
|
const VarNodeArray& i, const cg::OperatorNodeBase* opr_) { \ |
|
|
|
MIDOUT_B( \ |
|
|
|
midout_iv(MGB_HASH_STR(#_Opr)), \ |
|
|
|
midout_iv(MGB_HASH_STR("has_no_naive_heuristic_algo"))) \ |
|
|
|
auto&& opr = opr_->cast_final_safe<_Opr>(); \ |
|
|
|
auto&& megdnn_opr = reinterpret_cast<MegDNNOpr*>(opr.megdnn_opr()); \ |
|
|
|
FixedTensorLayouts array_layouts; \ |
|
|
|
size_t in = i.size() - 1; \ |
|
|
|
for (size_t idx = 0; idx < in; idx++) { \ |
|
|
|
const auto& v = i[idx]; \ |
|
|
|
array_layouts[idx] = \ |
|
|
|
TensorLayout{v->shape(), v->dtype(), v->format()}; \ |
|
|
|
} \ |
|
|
|
const auto& v = i[in]; \ |
|
|
|
array_layouts[arity - 1] = \ |
|
|
|
TensorLayout{v->shape(), v->dtype(), v->format()}; \ |
|
|
|
return APPLY( \ |
|
|
|
::megdnn::has_no_naive_heuristic_algo(megdnn_opr, args...), \ |
|
|
|
array_layouts); \ |
|
|
|
MIDOUT_E \ |
|
|
|
} \ |
|
|
|
}; |
|
|
|
INST(Convolution) |
|
|
|
INST(ConvBiasForward) |
|
|
@@ -365,6 +387,23 @@ bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr) |
|
|
|
#undef cb |
|
|
|
} |
|
|
|
|
|
|
|
bool has_no_naive_heuristic_algo( |
|
|
|
const VarNodeArray& i, const cg::OperatorNodeBase* opr) { |
|
|
|
#define cb(_Opr) \ |
|
|
|
if (opr->dyn_typeinfo() == _Opr::typeinfo()) { \ |
|
|
|
MGB_MARK_USED_VAR(MultiAlgoOprTrait<_Opr>::has_algo); \ |
|
|
|
VarNodeArray _ = i; \ |
|
|
|
_.emplace_back(opr->output(0)); \ |
|
|
|
return MultiAlgoOprTrait<_Opr>::has_no_naive_heuristic_algo(_, opr); \ |
|
|
|
} else |
|
|
|
cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) { |
|
|
|
mgb_throw( |
|
|
|
InternalError, "invalid multi-algo operator(got:%s)", |
|
|
|
opr->dyn_typeinfo()->name); |
|
|
|
} |
|
|
|
#undef cb |
|
|
|
} |
|
|
|
|
|
|
|
bool has_opr_format(const cg::OperatorNodeBase* opr) { |
|
|
|
bool ret = false; |
|
|
|
#define cb(_Opr) ret |= opr->dyn_typeinfo() == _Opr::typeinfo(); |
|
|
|