Browse Source

fix(mge/opr): add opr_footprint support for PoolingBackward

GitOrigin-RevId: 5f1c64ef9a
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
2224a25205
1 changed files with 15 additions and 4 deletions
  1. +15
    -4
      src/plugin/impl/opr_footprint.cpp

+ 15
- 4
src/plugin/impl/opr_footprint.cpp View File

@@ -176,7 +176,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
};
if (param.format == Param::Format::NCHW4 ||
param.format == Param::Format::NCHW4_NCHW ||
param.format == Param::Format::NCHW4_NHWC ||
param.format == Param::Format::NCHW4_NHWC ||
param.format == Param::Format::NCHW4_NCHW32 ||
param.format == Param::Format::NCHW88 ||
param.format == Param::Format::NCHW44 ||
@@ -223,9 +223,9 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,

uint64_t fh = static_cast<uint64_t>(filter_shape[spatial_start]);
uint64_t fw = static_cast<uint64_t>(filter_shape[spatial_start + 1]);
// mul and add are counted as 2 operations
return dst_shape.total_nr_elems() * fh * fw *
static_cast<uint64_t>(src_shape[cpos]) / group * 2;
}
@@ -464,6 +464,14 @@ uint64_t opr_footprint_func<opr::PoolingForward>(cg::OperatorNodeBase* opr) {
return opr->output(0)->shape().total_nr_elems() * area;
}

// PoolingBackWard
template <>
uint64_t opr_footprint_func<opr::PoolingBackward>(cg::OperatorNodeBase* opr) {
auto&& param = opr->cast_final_safe<opr::PoolingBackward>().param();
auto area = param.window_h * param.window_w;
return opr->input()[0]->shape().total_nr_elems() * area;
}

// Concat
template <>
uint64_t opr_footprint_func<opr::Concat>(cg::OperatorNodeBase* opr) {
@@ -516,6 +524,7 @@ REGISTE_PARAM_JSON_FUNC(BatchedMatrixMul)
REGISTE_PARAM_JSON_FUNC(Dot)
REGISTE_PARAM_JSON_FUNC(MatrixInverse)
REGISTE_PARAM_JSON_FUNC(PoolingForward)
REGISTE_PARAM_JSON_FUNC(PoolingBackward)
REGISTE_PARAM_JSON_FUNC(SVD)
REGISTE_PARAM_JSON_FUNC(MaskConvolution)
REGISTE_PARAM_JSON_FUNC(Images2Neibs)
@@ -666,7 +675,7 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::standalone::NMSKeep>(
{"max_output", json::Number::make(nms_param.max_output)},
});
}

#endif // MGB_ENABLE_JSON

@@ -700,6 +709,7 @@ void OprFootprint::init_all_footprints() {
add_single_comp_footprint<opr::ConvolutionBackwardFilter>();
add_single_comp_footprint<opr::MatrixMul>();
add_single_comp_footprint<opr::PoolingForward>();
add_single_comp_footprint<opr::PoolingBackward>();
add_single_comp_footprint<opr::Concat>();
add_single_comp_footprint<opr::Dimshuffle>();
add_single_comp_footprint<opr::Reduce>();
@@ -725,6 +735,7 @@ void OprFootprint::init_all_footprints() {
add_single_param_json<opr::Dot>();
add_single_param_json<opr::MatrixInverse>();
add_single_param_json<opr::PoolingForward>();
add_single_param_json<opr::PoolingBackward>();
add_single_param_json<opr::SVD>();
add_single_param_json<opr::MaskConvolution>();
add_single_param_json<opr::Images2Neibs>();


Loading…
Cancel
Save