Browse Source

fix(mgb/plugin): fix opr footprint for conv with NCHW32_NCHW4 format

GitOrigin-RevId: 9881b7971c
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
5a44255324
1 changed files with 5 additions and 3 deletions
  1. +5
    -3
      src/plugin/impl/opr_footprint.cpp

+ 5
- 3
src/plugin/impl/opr_footprint.cpp View File

@@ -127,7 +127,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
src_shape[1] / group * 2;
return hybird_nchwx ? computation : computation * 4;
}
if (param.format == Param::Format::NCHW32) {
if (param.format == Param::Format::NCHW32 ||
param.format == Param::Format::NCHW32_NCHW4) {
return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * 32 /
group * 2;
}
@@ -157,11 +158,12 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
};
if (param.format == Param::Format::NCHW4 ||
param.format == Param::Format::NCHW4_NCHW ||
param.format == Param::Format::NCHW4_NCHW32 ||
param.format == Param::Format::NCHW4_NCHW32 ||
param.format == Param::Format::NCHW88 ||
param.format == Param::Format::NCHW44 ||
param.format == Param::Format::NCHW44_DOT ||
param.format == Param::Format::NCHW32) {
param.format == Param::Format::NCHW32 ||
param.format == Param::Format::NCHW32_NCHW4) {
return eval_conv_computation_nchwx();
}
if (param.format == Param::Format::CHWN4) {


Loading…
Cancel
Save