|
|
@@ -160,6 +160,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, |
|
|
|
spatial_start = 2; |
|
|
|
break; |
|
|
|
case Param::Format::NCHW_WINOGRAD: |
|
|
|
case Param::Format::NCHW44_WINOGRAD: |
|
|
|
case Param::Format::NCHW88_WINOGRAD: |
|
|
|
cpos = 1; |
|
|
|
spatial_start = 0; |
|
|
@@ -191,9 +192,10 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, |
|
|
|
uint64_t fh = static_cast<uint64_t>(filter_shape[spatial_start]); |
|
|
|
uint64_t fw = static_cast<uint64_t>(filter_shape[spatial_start + 1]); |
|
|
|
if (param.format == Param::Format::NCHW_WINOGRAD || |
|
|
|
param.format == Param::Format::NCHW44_WINOGRAD || |
|
|
|
param.format == Param::Format::NCHW88_WINOGRAD) { |
|
|
|
mgb_assert(opr->same_type<opr::ConvBias>(), |
|
|
|
"Only conv bias support NCHW_WINOGRAD"); |
|
|
|
"Only conv bias support WINOGRAD"); |
|
|
|
auto&& conv_bias_opr = opr->cast_final_safe<opr::ConvBias>(); |
|
|
|
uint32_t output_block_size = conv_bias_opr.param().output_block_size; |
|
|
|
mgb_assert(fh == fw, |
|
|
@@ -208,6 +210,10 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, |
|
|
|
return dst_shape.total_nr_elems() * fh * fw * |
|
|
|
static_cast<uint64_t>(src_shape[cpos] * 8) / group * 2; |
|
|
|
} |
|
|
|
if (param.format == Param::Format::NCHW44_WINOGRAD) { |
|
|
|
return dst_shape.total_nr_elems() * fh * fw * |
|
|
|
static_cast<uint64_t>(src_shape[cpos] * 4) / group * 2; |
|
|
|
} |
|
|
|
return dst_shape.total_nr_elems() * fh * fw * |
|
|
|
static_cast<uint64_t>(src_shape[cpos]) / group * 2; |
|
|
|
} |
|
|
|