Browse Source

fix(mgb/plugin): fix opr footprint for conv with NCHW32_NCHW4 format

GitOrigin-RevId: 9881b7971c
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
5a44255324
1 changed files with 5 additions and 3 deletions
  1. +5
    -3
      src/plugin/impl/opr_footprint.cpp

+ 5
- 3
src/plugin/impl/opr_footprint.cpp View File

@@ -127,7 +127,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
src_shape[1] / group * 2; src_shape[1] / group * 2;
return hybird_nchwx ? computation : computation * 4; return hybird_nchwx ? computation : computation * 4;
} }
if (param.format == Param::Format::NCHW32) {
if (param.format == Param::Format::NCHW32 ||
param.format == Param::Format::NCHW32_NCHW4) {
return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * 32 / return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * 32 /
group * 2; group * 2;
} }
@@ -157,11 +158,12 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
}; };
if (param.format == Param::Format::NCHW4 || if (param.format == Param::Format::NCHW4 ||
param.format == Param::Format::NCHW4_NCHW || param.format == Param::Format::NCHW4_NCHW ||
param.format == Param::Format::NCHW4_NCHW32 ||
param.format == Param::Format::NCHW4_NCHW32 ||
param.format == Param::Format::NCHW88 || param.format == Param::Format::NCHW88 ||
param.format == Param::Format::NCHW44 || param.format == Param::Format::NCHW44 ||
param.format == Param::Format::NCHW44_DOT || param.format == Param::Format::NCHW44_DOT ||
param.format == Param::Format::NCHW32) {
param.format == Param::Format::NCHW32 ||
param.format == Param::Format::NCHW32_NCHW4) {
return eval_conv_computation_nchwx(); return eval_conv_computation_nchwx();
} }
if (param.format == Param::Format::CHWN4) { if (param.format == Param::Format::CHWN4) {


Loading…
Cancel
Save