|
|
@@ -103,7 +103,9 @@ struct OprArityTrait; |
|
|
|
#define cb_ref(x) (&(x)) |
|
|
|
#define cb_dnn(x) ((x).as_megdnn()) |
|
|
|
|
|
|
|
#define INST_ARITY(_Opr, _in, _out) \ |
|
|
|
#define WS_ARG_true ,nullptr |
|
|
|
#define WS_ARG_false |
|
|
|
#define INST_ARITY(_Opr, _in, _out, _has_preprocessed_filter) \ |
|
|
|
template <> \ |
|
|
|
struct OprArityTrait<_Opr> { \ |
|
|
|
static constexpr int arity_in = _in; \ |
|
|
@@ -114,7 +116,8 @@ struct OprArityTrait; |
|
|
|
_Opr* opr, typename _Opr::Algorithm* algo, \ |
|
|
|
const TensorLayoutArray& layouts) { \ |
|
|
|
opr->execution_policy() = {algo}; \ |
|
|
|
return opr->get_workspace_in_bytes(LAYOUTS(cb)); \ |
|
|
|
return opr->get_workspace_in_bytes( \ |
|
|
|
LAYOUTS(cb) WS_ARG_##_has_preprocessed_filter); \ |
|
|
|
} \ |
|
|
|
\ |
|
|
|
static std::vector<typename _Opr::Algorithm*> get_all_algorithms( \ |
|
|
@@ -138,8 +141,7 @@ struct OprArityTrait; |
|
|
|
|
|
|
|
#define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0]) |
|
|
|
#define LAYOUTS(cb) cb(layouts[0]), cb(layouts[1]), cb(layouts[2]) |
|
|
|
#define INST_ARITY_2_1(Opr) INST_ARITY(Opr, 2, 1) |
|
|
|
INST_ARITY_2_1(megdnn::Convolution); |
|
|
|
#define INST_ARITY_2_1(Opr) INST_ARITY(Opr, 2, 1, false) |
|
|
|
INST_ARITY_2_1(megdnn::ConvolutionBackwardData); |
|
|
|
INST_ARITY_2_1(megdnn::ConvolutionBackwardFilter); |
|
|
|
INST_ARITY_2_1(megdnn::Convolution3DForward); |
|
|
@@ -149,6 +151,9 @@ INST_ARITY_2_1(megdnn::LocalShareForward); |
|
|
|
INST_ARITY_2_1(megdnn::LocalShareBackwardData); |
|
|
|
INST_ARITY_2_1(megdnn::LocalShareBackwardFilter); |
|
|
|
#undef TENSORS |
|
|
|
#define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0]), nullptr |
|
|
|
INST_ARITY(megdnn::Convolution, 2, 1, true); |
|
|
|
#undef TENSORS |
|
|
|
#undef LAYOUTS |
|
|
|
#undef INST_ARITY_2_1 |
|
|
|
|
|
|
@@ -158,12 +163,16 @@ INST_ARITY_2_1(megdnn::LocalShareBackwardFilter); |
|
|
|
#define LAYOUTS(cb) \ |
|
|
|
cb(layouts[0]), cb(layouts[1]), cb(layouts[2]), cb(layouts[3]), \ |
|
|
|
cb(layouts[4]) |
|
|
|
#define INST_ARITY_4_1(Opr) INST_ARITY(Opr, 4, 1) |
|
|
|
INST_ARITY_4_1(megdnn::ConvBias); |
|
|
|
#define INST_ARITY_4_1(Opr) INST_ARITY(Opr, 4, 1, false) |
|
|
|
INST_ARITY_4_1(megdnn::DeformableConvForward); |
|
|
|
INST_ARITY_4_1(megdnn::DeformableConvBackwardFilter); |
|
|
|
INST_ARITY_4_1(megdnn::BatchConvBiasForward); |
|
|
|
#undef TENSORS |
|
|
|
#define TENSORS(cb) \ |
|
|
|
cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), cb(inp_val[3]), \ |
|
|
|
cb(out_val[0]), nullptr |
|
|
|
INST_ARITY(megdnn::ConvBias, 4, 1, true); |
|
|
|
#undef TENSORS |
|
|
|
#undef LAYOUTS |
|
|
|
#undef INST_ARITY_4_1 |
|
|
|
|
|
|
@@ -174,7 +183,7 @@ INST_ARITY_4_1(megdnn::BatchConvBiasForward); |
|
|
|
cb(layouts[3]), cb(layouts[4]), cb(layouts[5]), \ |
|
|
|
cb(layouts[6]), cb(layouts[7]) |
|
|
|
|
|
|
|
#define INST_ARITY_5_3(Opr) INST_ARITY(Opr, 5, 3) |
|
|
|
#define INST_ARITY_5_3(Opr) INST_ARITY(Opr, 5, 3, false) |
|
|
|
INST_ARITY_5_3(megdnn::DeformableConvBackwardData); |
|
|
|
#undef TENSORS |
|
|
|
#undef LAYOUTS |
|
|
@@ -183,6 +192,8 @@ INST_ARITY_5_3(megdnn::DeformableConvBackwardData); |
|
|
|
#undef cb_ref |
|
|
|
#undef cb_dnn |
|
|
|
#undef INST_ARITY |
|
|
|
#undef WS_ARG_true |
|
|
|
#undef WS_ARG_false |
|
|
|
|
|
|
|
// timeout delta to be added with fastest known algorithm for new algos |
|
|
|
constexpr double TIMEOUT_TOLERANCE = 2; |
|
|
@@ -924,6 +935,41 @@ void ConvolutionForward::init_output_format() { |
|
|
|
output(0)->format(input(0)->format()); |
|
|
|
} |
|
|
|
|
|
|
|
void ConvolutionForward::scn_do_execute() { |
|
|
|
megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), |
|
|
|
input(1)->dev_tensor().as_megdnn(), |
|
|
|
output(0)->dev_tensor().as_megdnn(), nullptr, |
|
|
|
intl::get_megdnn_workspace_from_var(output().back())); |
|
|
|
} |
|
|
|
|
|
|
|
void ConvolutionForward::add_input_layout_constraint() { |
|
|
|
mixin::megdnn_utils::add_input_layout_constraint_contig(*this); |
|
|
|
} |
|
|
|
|
|
|
|
void ConvolutionForward::init_output_static_infer_desc() { |
|
|
|
Super::set_nr_managed_outputs(this->output().size() - 1); |
|
|
|
Super::init_output_static_infer_desc(); |
|
|
|
init_output_static_infer_desc_workspace( |
|
|
|
intl::AutoAddWorkspaceNeedLimitGetter< |
|
|
|
megdnn::ConvolutionForward>::val); |
|
|
|
} |
|
|
|
|
|
|
|
void ConvolutionForward::get_output_var_shape( |
|
|
|
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { |
|
|
|
TensorLayout input_layout{inp_shape[0], input(0)->dtype(), |
|
|
|
input(0)->format()}; |
|
|
|
TensorLayout filter_layout{inp_shape[1], input(1)->dtype(), |
|
|
|
input(1)->format()}; |
|
|
|
TensorLayout dst_layout{output(0)->dtype(), output(0)->format()}; |
|
|
|
megdnn_opr()->deduce_layout(input_layout, filter_layout, dst_layout); |
|
|
|
out_shape[0] = dst_layout; |
|
|
|
} |
|
|
|
|
|
|
|
void ConvolutionForward::record_execute_deps( |
|
|
|
cg::GraphExecutable::ExecDependencyArray& deps) { |
|
|
|
record_megdnn_opr(deps); |
|
|
|
} |
|
|
|
|
|
|
|
/* ==================== ConvolutionBackwardData ==================== */ |
|
|
|
IMPL_CONV(ConvolutionBackwardData, "conv_bwd_data"); |
|
|
|
|
|
|
@@ -1429,6 +1475,7 @@ void ConvBiasForward::scn_do_execute() { |
|
|
|
mo->exec(inp[0]->dev_tensor().as_megdnn(), |
|
|
|
inp[1]->dev_tensor().as_megdnn(), bias_tensor, z_tensor, |
|
|
|
output(0)->dev_tensor().as_megdnn(), |
|
|
|
nullptr, |
|
|
|
intl::get_megdnn_workspace_from_var(output().back())); |
|
|
|
|
|
|
|
} else if (inp.size() == 3) { |
|
|
@@ -1441,6 +1488,7 @@ void ConvBiasForward::scn_do_execute() { |
|
|
|
inp[1]->dev_tensor().as_megdnn(), |
|
|
|
inp[2]->dev_tensor().as_megdnn(), z_tensor, |
|
|
|
output(0)->dev_tensor().as_megdnn(), |
|
|
|
nullptr, |
|
|
|
intl::get_megdnn_workspace_from_var(output().back())); |
|
|
|
} else { |
|
|
|
mgb_assert(inp.size() == 4); |
|
|
@@ -1449,6 +1497,7 @@ void ConvBiasForward::scn_do_execute() { |
|
|
|
inp[2]->dev_tensor().as_megdnn(), |
|
|
|
inp[3]->dev_tensor().as_megdnn(), |
|
|
|
output(0)->dev_tensor().as_megdnn(), |
|
|
|
nullptr, |
|
|
|
intl::get_megdnn_workspace_from_var(output().back())); |
|
|
|
} |
|
|
|
} |
|
|
|