Browse Source

refactor(dnn): optimize winograd input transpose

GitOrigin-RevId: a43077550c
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
c6eb2e8d71
7 changed files with 360 additions and 230 deletions
  1. +66
    -68
      dnn/src/common/winograd/winograd_helper.cpp
  2. +2
    -2
      dnn/src/common/winograd/winograd_helper.h
  3. +1
    -1
      dnn/src/fallback/conv_bias/opr_impl.cpp
  4. +211
    -88
      dnn/src/fallback/conv_bias/winograd/strategy.cpp
  5. +19
    -33
      dnn/src/fallback/conv_bias/winograd/winograd.h
  6. +31
    -19
      dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp
  7. +30
    -19
      dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp

+ 66
- 68
dnn/src/common/winograd/winograd_helper.cpp View File

@@ -247,33 +247,31 @@ void StrategyHelper<
Getter<ctype, input_filter_compute_type> getter(dtype); Getter<ctype, input_filter_compute_type> getter(dtype);
InputVisitor<layout, format> intput_visitor(IC); InputVisitor<layout, format> intput_visitor(IC);


rep(ic, IC) {
memset(mid_buf1, 0, alpha * alpha * sizeof(input_filter_compute_type));
rep(i, alpha) rep(j, alpha) {
int ih = ih_start + i;
int iw = iw_start + j;
if (ih >= 0 && ih < (int)IH && iw >= 0 && iw < (int)IW) {
mid_buf1[i * alpha + j] = getter(
input[intput_visitor.get(alpha, ic, IH, IW, ih, iw)]);
}
memset(mid_buf1, 0, alpha * alpha * sizeof(input_filter_compute_type));
rep(i, alpha) rep(j, alpha) {
int ih = ih_start + i;
int iw = iw_start + j;
if (ih >= 0 && ih < (int)IH && iw >= 0 && iw < (int)IW) {
mid_buf1[i * alpha + j] = getter(
input[intput_visitor.get(alpha, ic, IH, IW, ih, iw)]);
} }
}


megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type,
input_filter_compute_type, true,
false>(
winograd_coeff.B(rescale).data(), mid_buf1, mid_buf2, alpha,
alpha, alpha, alpha, alpha, alpha, dtype, dtype);
megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type,
input_filter_compute_type, false,
false>(
mid_buf2, winograd_coeff.B(rescale).data(), mid_buf1, alpha,
alpha, alpha, alpha, alpha, alpha, dtype, dtype);

rep(i, alpha) rep(j, alpha) {
input_transform_buf[intput_visitor.put(alpha, ic, nr_units_in_tile,
unit_idx, i, j)] =
mid_buf1[i * alpha + j];
}
megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type,
input_filter_compute_type, true,
false>(
winograd_coeff.B(rescale).data(), mid_buf1, mid_buf2, alpha,
alpha, alpha, alpha, alpha, alpha, dtype, dtype);
megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type,
input_filter_compute_type, false,
false>(
mid_buf2, winograd_coeff.B(rescale).data(), mid_buf1, alpha,
alpha, alpha, alpha, alpha, alpha, dtype, dtype);

rep(i, alpha) rep(j, alpha) {
input_transform_buf[intput_visitor.put(alpha, ic, nr_units_in_tile,
unit_idx, i, j)] =
mid_buf1[i * alpha + j];
} }
} }


@@ -287,7 +285,7 @@ void StrategyHelper<
output_compute_type* transform_mid_buf, BiasMode bmode, output_compute_type* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start, NonlineMode nonline_mode, size_t oh_start,
size_t ow_start, size_t OH, size_t OW, size_t oc_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start,
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile,
size_t oc_index, size_t unit_idx, size_t nr_units_in_tile,
size_t m, size_t r, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype, const std::vector<float>& interp_points, DType dtype,
float input_filter_scale, float input_filter_rescale, float input_filter_scale, float input_filter_rescale,
@@ -300,49 +298,49 @@ void StrategyHelper<
OutputGetter<output_compute_type, dst_type> getter(dtype); OutputGetter<output_compute_type, dst_type> getter(dtype);
OutputVisitor<layout, format> output_visitor(oc_end - oc_start); OutputVisitor<layout, format> output_visitor(oc_end - oc_start);


for (size_t oc = oc_start; oc < oc_end; oc++) {
/* gather */
rep(i, alpha) rep(j, alpha) {
mid_buf1[i * alpha + j] = output_transform_buf[output_visitor.get(
alpha, oc - oc_start, oc, nr_units_in_tile, unit_idx, i,
j)];
}
/* A[alpha*m] M[alpha*alpha] */
megdnn::naive::run_matrix_mul_tpl<output_compute_type,
output_compute_type, true, false>(
winograd_coeff.A(rescale).data(), mid_buf1, mid_buf2, m, alpha,
alpha, m, alpha, alpha, dtype, dtype);
megdnn::naive::run_matrix_mul_tpl<output_compute_type,
output_compute_type, false, false>(
mid_buf2, winograd_coeff.A(rescale).data(), mid_buf1, m, m,
alpha, alpha, m, m, dtype, dtype);

rep(i, m) rep(j, m) {
auto oh = oh_start + i;
auto ow = ow_start + j;
if (oh < OH && ow < OW) {
float val = mid_buf1[i * m + j];
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
val += bias[oc] * input_filter_rescale *
input_filter_rescale;
} else if (bmode == BiasMode::BIAS) {
val += bias[output_visitor.put(oc, OH, OW, oh, ow)] *
input_filter_rescale * input_filter_rescale;
}
val = val * input_filter_scale /
(input_filter_rescale * input_filter_rescale * rescale *
rescale);
if (nonline_mode == NonlineMode::RELU) {
val = val > 0 ? val : 0;
} else if (nonline_mode == NonlineMode::SIGMOID) {
val = 1.f / (expf(-val) + 1.f);
} else if (nonline_mode == NonlineMode::H_SWISH) {
val = val * std::min(std::max(val + 3, 0.f), 6.f) / 6.f;
} else {
megdnn_assert(nonline_mode == NonlineMode::IDENTITY);
}
output[output_visitor.put(oc, OH, OW, oh, ow)] = getter(val);
size_t oc = oc_start + oc_index;

/* gather */
rep(i, alpha) rep(j, alpha) {
mid_buf1[i * alpha + j] = output_transform_buf[output_visitor.get(
alpha, oc_index, oc, nr_units_in_tile, unit_idx, i,
j)];
}
/* A[alpha*m] M[alpha*alpha] */
megdnn::naive::run_matrix_mul_tpl<output_compute_type,
output_compute_type, true, false>(
winograd_coeff.A(rescale).data(), mid_buf1, mid_buf2, m, alpha,
alpha, m, alpha, alpha, dtype, dtype);
megdnn::naive::run_matrix_mul_tpl<output_compute_type,
output_compute_type, false, false>(
mid_buf2, winograd_coeff.A(rescale).data(), mid_buf1, m, m,
alpha, alpha, m, m, dtype, dtype);

rep(i, m) rep(j, m) {
auto oh = oh_start + i;
auto ow = ow_start + j;
if (oh < OH && ow < OW) {
float val = mid_buf1[i * m + j];
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
val += bias[oc] * input_filter_rescale *
input_filter_rescale;
} else if (bmode == BiasMode::BIAS) {
val += bias[output_visitor.put(oc, OH, OW, oh, ow)] *
input_filter_rescale * input_filter_rescale;
}
val = val * input_filter_scale /
(input_filter_rescale * input_filter_rescale * rescale *
rescale);
if (nonline_mode == NonlineMode::RELU) {
val = val > 0 ? val : 0;
} else if (nonline_mode == NonlineMode::SIGMOID) {
val = 1.f / (expf(-val) + 1.f);
} else if (nonline_mode == NonlineMode::H_SWISH) {
val = val * std::min(std::max(val + 3, 0.f), 6.f) / 6.f;
} else {
megdnn_assert(nonline_mode == NonlineMode::IDENTITY);
} }
output[output_visitor.put(oc, OH, OW, oh, ow)] = getter(val);
} }
} }
}; };


+ 2
- 2
dnn/src/common/winograd/winograd_helper.h View File

@@ -44,7 +44,7 @@ public:
input_filter_compute_type* input_transform_buf, input_filter_compute_type* input_transform_buf,
input_filter_compute_type* transform_mid_buf, input_filter_compute_type* transform_mid_buf,
int ih_start, int iw_start, size_t IH, size_t IW, int ih_start, int iw_start, size_t IH, size_t IW,
size_t IC, size_t unit_idx, size_t nr_units_in_tile,
size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile,
size_t m, size_t r, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype, const std::vector<float>& interp_points, DType dtype,
float rescale = 1.0f); float rescale = 1.0f);
@@ -54,7 +54,7 @@ public:
const output_compute_type* bias, dst_type* output, const output_compute_type* bias, dst_type* output,
output_compute_type* transform_mid_buf, BiasMode bmode, output_compute_type* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start, size_t ow_start, NonlineMode nonline_mode, size_t oh_start, size_t ow_start,
size_t OH, size_t OW, size_t oc_start, size_t oc_end,
size_t OH, size_t OW, size_t oc_start, size_t oc_index,
size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r,
const std::vector<float>& interp_points, DType dtype, const std::vector<float>& interp_points, DType dtype,
float input_filter_scale = 1.0f, // input_scale * filter_scale float input_filter_scale = 1.0f, // input_scale * filter_scale


+ 1
- 1
dnn/src/fallback/conv_bias/opr_impl.cpp View File

@@ -55,7 +55,7 @@ public:
ohw_tile_size)); ohw_tile_size));
all_algos.emplace_back(refhold.back().get()); all_algos.emplace_back(refhold.back().get());
} }
#if 0
#if 1
//! As these algos maybe very slow, it will make fastrun search slow, so //! As these algos maybe very slow, it will make fastrun search slow, so
//! we disable it, but for the test of strategyhelper, we just keep it. //! we disable it, but for the test of strategyhelper, we just keep it.
//! FIXME: I do not know a better way to do it. //! FIXME: I do not know a better way to do it.


+ 211
- 88
dnn/src/fallback/conv_bias/winograd/strategy.cpp View File

@@ -6,8 +6,7 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#include "src/fallback/conv_bias/winograd/strategy.h" #include "src/fallback/conv_bias/winograd/strategy.h"
@@ -31,27 +30,54 @@ void winograd_2x3_1x1_f::filter(const float* filter,
} }


void winograd_2x3_1x1_f::input(const float* input, float* input_transform_buf, void winograd_2x3_1x1_f::input(const float* input, float* input_transform_buf,
float* transform_mid_buf, int ih_start,
int iw_start, size_t IH, size_t IW, size_t IC,
size_t unit_idx, size_t nr_units_in_tile) {
::megdnn::winograd::StrategyHelper<float, float, float, float>::input(
input, input_transform_buf, transform_mid_buf, ih_start, iw_start,
IH, IW, IC, unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE,
KERNEL_SIZE, {0, 1, -1}, src_dtype);
float* transform_mid_buf, size_t IH, size_t IW,
size_t IC, size_t PH, size_t PW,
size_t unit_start_idx, size_t nr_units_in_tile) {
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w =
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
rep(ic, IC) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
::megdnn::winograd::StrategyHelper<float, float, float, float>::
input(input, input_transform_buf, transform_mid_buf,
ih_start, iw_start, IH, IW, IC, ic, unit_idx,
nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE,
{0, 1, -1}, src_dtype);
}
}
} }


void winograd_2x3_1x1_f::output(const float* output_transform_buf, void winograd_2x3_1x1_f::output(const float* output_transform_buf,
const float* bias, float* output, const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode, float* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start,
size_t ow_start, size_t OH, size_t OW,
size_t oc_start, size_t oc_end, size_t unit_idx,
NonlineMode nonline_mode, size_t OH, size_t OW,
size_t oc_start, size_t oc_end,
size_t unit_start_idx,
size_t nr_units_in_tile) { size_t nr_units_in_tile) {
::megdnn::winograd::StrategyHelper<float, float, float, float>::output(
output_transform_buf, bias, output, transform_mid_buf, bmode,
nonline_mode, oh_start, ow_start, OH, OW, oc_start, oc_end,
unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE,
{0, 1, -1}, dst_dtype);
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
size_t OC = oc_end - oc_start;

for (size_t oc = oc_start; oc < oc_end; ++oc) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
::megdnn::winograd::StrategyHelper<float, float, float, float>::
output(output_transform_buf, bias, output,
transform_mid_buf, bmode, nonline_mode, oh_start,
ow_start, OH, OW, OC, oc_start, oc_index, unit_idx,
nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE,
{0, 1, -1}, dst_dtype);
}
}
} }


MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f) MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f)
@@ -71,38 +97,70 @@ void winograd_2x3_4x4_f::filter(const float* filter,
} }


void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf, void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf,
float* transform_mid_buf, int ih_start,
int iw_start, size_t IH, size_t IW, size_t IC,
size_t unit_idx, size_t nr_units_in_tile) {
::megdnn::winograd::StrategyHelper<
float, float, float, float, param::ConvBias::Format::NCHW,
param::MatrixMul::Format::MK4>::input(input, input_transform_buf,
transform_mid_buf, ih_start,
iw_start, IH, IW, IC,
unit_idx, nr_units_in_tile,
OUTPUT_BLOCK_SIZE,
KERNEL_SIZE, {0, 1, -1},
src_dtype);
float* transform_mid_buf, size_t IH, size_t IW,
size_t IC, size_t PH, size_t PW,
size_t unit_start_idx, size_t nr_units_in_tile) {
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w =
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
rep(ic, IC) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
::megdnn::winograd::StrategyHelper<
float, float, float, float, param::ConvBias::Format::NCHW,
param::MatrixMul::Format::MK4>::input(input,
input_transform_buf,
transform_mid_buf,
ih_start, iw_start,
IH, IW, IC, ic,
unit_idx,
nr_units_in_tile,
OUTPUT_BLOCK_SIZE,
KERNEL_SIZE,
{0, 1, -1},
src_dtype);
}
}
} }


void winograd_2x3_4x4_f::output(const float* output_transform_buf, void winograd_2x3_4x4_f::output(const float* output_transform_buf,
const float* bias, float* output, const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode, float* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start,
size_t ow_start, size_t OH, size_t OW,
size_t oc_start, size_t oc_end, size_t unit_idx,
NonlineMode nonline_mode, size_t OH, size_t OW,
size_t oc_start, size_t oc_end,
size_t unit_start_idx,
size_t nr_units_in_tile) { size_t nr_units_in_tile) {
::megdnn::winograd::StrategyHelper<
float, float, float, float, param::ConvBias::Format::NCHW,
param::MatrixMul::Format::MK4>::output(output_transform_buf, bias,
output, transform_mid_buf,
bmode, nonline_mode,
oh_start, ow_start, OH, OW,
oc_start, oc_end, unit_idx,
nr_units_in_tile,
OUTPUT_BLOCK_SIZE,
KERNEL_SIZE, {0, 1, -1},
dst_dtype);
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
size_t OC = oc_end - oc_start;

for (size_t oc = oc_start; oc < oc_end; ++oc) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
::megdnn::winograd::StrategyHelper<
float, float, float, float, param::ConvBias::Format::NCHW,
param::MatrixMul::Format::MK4>::output(output_transform_buf,
bias, output,
transform_mid_buf,
bmode, nonline_mode,
oh_start, ow_start,
OH, OW, OC, oc_start,
oc_index, unit_idx,
nr_units_in_tile,
OUTPUT_BLOCK_SIZE,
KERNEL_SIZE,
{0, 1, -1},
dst_dtype);
}
}
} }


MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_1x1_qs8) MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_1x1_qs8)
@@ -119,29 +177,59 @@ void winograd_2x3_1x1_qs8::filter(const int8_t* filter,


void winograd_2x3_1x1_qs8::input(const int8_t* input, void winograd_2x3_1x1_qs8::input(const int8_t* input,
int16_t* input_transform_buf, int16_t* input_transform_buf,
int16_t* transform_mid_buf, int ih_start,
int iw_start, size_t IH, size_t IW, size_t IC,
size_t unit_idx, size_t nr_units_in_tile) {
::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>::input(
input, input_transform_buf, transform_mid_buf, ih_start, iw_start,
IH, IW, IC, unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE,
KERNEL_SIZE, {0, 1, -1}, src_dtype, 1.0f);
int16_t* transform_mid_buf, size_t IH,
size_t IW, size_t IC, size_t PH, size_t PW,
size_t unit_start_idx,
size_t nr_units_in_tile) {
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w =
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
rep(ic, IC) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>::
input(input, input_transform_buf, transform_mid_buf,
ih_start, iw_start, IH, IW, IC, ic, unit_idx,
nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE,
{0, 1, -1}, src_dtype, 1.0f);
}
}
} }


void winograd_2x3_1x1_qs8::output(const int* output_transform_buf, void winograd_2x3_1x1_qs8::output(const int* output_transform_buf,
const int* bias, int8_t* output, const int* bias, int8_t* output,
int* transform_mid_buf, BiasMode bmode, int* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start,
size_t ow_start, size_t OH, size_t OW,
size_t oc_start, size_t oc_end,
size_t unit_idx, size_t nr_units_in_tile) {
NonlineMode nonline_mode, size_t OH,
size_t OW, size_t oc_start, size_t oc_end,
size_t unit_start_idx,
size_t nr_units_in_tile) {
float scale_input = src_dtype.param<dtype::QuantizedS8>().scale; float scale_input = src_dtype.param<dtype::QuantizedS8>().scale;
float scale_filter = filter_dtype.param<dtype::QuantizedS8>().scale; float scale_filter = filter_dtype.param<dtype::QuantizedS8>().scale;
::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>::output(
output_transform_buf, bias, output, transform_mid_buf, bmode,
nonline_mode, oh_start, ow_start, OH, OW, oc_start, oc_end,
unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE,
{0, 1, -1}, dst_dtype, scale_input * scale_filter, 2.0f, 1.0f);

auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
size_t OC = oc_end - oc_start;

for (size_t oc = oc_start; oc < oc_end; ++oc) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>::
output(output_transform_buf, bias, output,
transform_mid_buf, bmode, nonline_mode, oh_start,
ow_start, OH, OW, OC, oc_start, oc_index, unit_idx,
nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE,
{0, 1, -1}, dst_dtype, scale_input * scale_filter,
2.0f, 1.0f);
}
}
} }


MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_qs8) MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_qs8)
@@ -162,27 +250,44 @@ void winograd_2x3_8x8_qs8::filter(const int8_t* filter,


void winograd_2x3_8x8_qs8::input(const int8_t* input, void winograd_2x3_8x8_qs8::input(const int8_t* input,
int16_t* input_transform_buf, int16_t* input_transform_buf,
int16_t* transform_mid_buf, int ih_start,
int iw_start, size_t IH, size_t IW, size_t IC,
size_t unit_idx, size_t nr_units_in_tile) {
::megdnn::winograd::StrategyHelper<
int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW,
param::MatrixMul::Format::MK8>::input(input, input_transform_buf,
transform_mid_buf, ih_start,
iw_start, IH, IW, IC,
unit_idx, nr_units_in_tile,
OUTPUT_BLOCK_SIZE,
KERNEL_SIZE, {0, 1, -1},
src_dtype, 1.0f);
int16_t* transform_mid_buf, size_t IH,
size_t IW, size_t IC, size_t PH, size_t PW,
size_t unit_start_idx,
size_t nr_units_in_tile) {
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w =
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
rep(ic, IC) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
::megdnn::winograd::StrategyHelper<
int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW,
param::MatrixMul::Format::MK8>::input(input,
input_transform_buf,
transform_mid_buf,
ih_start, iw_start,
IH, IW, IC, ic,
unit_idx,
nr_units_in_tile,
OUTPUT_BLOCK_SIZE,
KERNEL_SIZE,
{0, 1, -1}, src_dtype,
1.0f);
}
}
} }


void winograd_2x3_8x8_qs8::output(const int* output_transform_buf, void winograd_2x3_8x8_qs8::output(const int* output_transform_buf,
const int* bias, int8_t* output, const int* bias, int8_t* output,
int* transform_mid_buf, BiasMode bmode, int* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t oh_start,
size_t ow_start, size_t OH, size_t OW,
size_t oc_start, size_t oc_end,
size_t unit_idx, size_t nr_units_in_tile) {
NonlineMode nonline_mode, size_t OH,
size_t OW, size_t oc_start, size_t oc_end,
size_t unit_start_idx,
size_t nr_units_in_tile) {
float scale_input = src_dtype.param<dtype::QuantizedS8>().scale; float scale_input = src_dtype.param<dtype::QuantizedS8>().scale;
float scale_filter = 0.f; float scale_filter = 0.f;
if (filter_dtype.enumv() == DTypeEnum::QuantizedS8) { if (filter_dtype.enumv() == DTypeEnum::QuantizedS8) {
@@ -191,19 +296,37 @@ void winograd_2x3_8x8_qs8::output(const int* output_transform_buf,
megdnn_assert(filter_dtype.enumv() == DTypeEnum::QuantizedS16); megdnn_assert(filter_dtype.enumv() == DTypeEnum::QuantizedS16);
scale_filter = filter_dtype.param<dtype::QuantizedS16>().scale; scale_filter = filter_dtype.param<dtype::QuantizedS16>().scale;
} }
::megdnn::winograd::StrategyHelper<
int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW,
param::MatrixMul::Format::MK8>::output(output_transform_buf, bias,
output, transform_mid_buf,
bmode, nonline_mode,
oh_start, ow_start, OH, OW,
oc_start, oc_end, unit_idx,
nr_units_in_tile,
OUTPUT_BLOCK_SIZE,
KERNEL_SIZE, {0, 1, -1},
dst_dtype,
scale_input * scale_filter,
2.0f, 1.0f);

auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
size_t OC = oc_end - oc_start;

for (size_t oc = oc_start; oc < oc_end; ++oc) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
::megdnn::winograd::StrategyHelper<
int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW,
param::MatrixMul::Format::MK8>::output(output_transform_buf,
bias, output,
transform_mid_buf,
bmode, nonline_mode,
oh_start, ow_start,
OH, OW, OC, oc_start,
oc_index, unit_idx,
nr_units_in_tile,
OUTPUT_BLOCK_SIZE,
KERNEL_SIZE,
{0, 1, -1},
dst_dtype,
scale_input *
scale_filter,
2.0f, 1.0f);
}
}
} }


} // namespace winograd } // namespace winograd


+ 19
- 33
dnn/src/fallback/conv_bias/winograd/winograd.h View File

@@ -321,17 +321,10 @@ public:
"nr_tiles_in_unit: %zu TILE_SIZE:%zu", "nr_tiles_in_unit: %zu TILE_SIZE:%zu",
nr_tiles_in_unit, unit_tile_size); nr_tiles_in_unit, unit_tile_size);
} }
rep(unit_idx, nr_tiles_in_unit) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * Strategy::OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * Strategy::OUTPUT_BLOCK_SIZE - PW;

strategy.input(src_ptr, input_transform_buf, transform_mid_buf,
ih_start, iw_start, IH, IW, IC, unit_idx,
nr_tiles_in_unit);
}
//! BTdB
strategy.input(src_ptr, input_transform_buf, transform_mid_buf,
IH, IW, IC, PH, PW, unit_start_idx, nr_tiles_in_unit);

rep(i, Strategy::ALPHA) rep(j, Strategy::ALPHA) { rep(i, Strategy::ALPHA) rep(j, Strategy::ALPHA) {
if (format == param::MatrixMul::Format::DEFAULT) { if (format == param::MatrixMul::Format::DEFAULT) {
matmul_param.A_ptr = matmul_param.A_ptr =
@@ -368,22 +361,14 @@ public:
} }
matmul_kern(matmul_param); matmul_kern(matmul_param);
} }
/* Y = ATmA */
rep(unit_idx, nr_tiles_in_unit) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * Strategy::OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * Strategy::OUTPUT_BLOCK_SIZE;
size_t oc_end_idx = oc_start_idx + nr_oc_in_unit;

strategy.output(
output_transform_buf, bias_ptr, dst_ptr,
reinterpret_cast<output_compute_type*>(transform_mid_buf),
ncb_param.bias_mode, ncb_param.nonlineMode, oh_start,
ow_start, OH, OW, oc_start_idx, oc_end_idx, unit_idx,
nr_tiles_in_unit);
}

//! Y = ATmA
size_t oc_end_idx = oc_start_idx + nr_oc_in_unit;
strategy.output(
output_transform_buf, bias_ptr, dst_ptr,
reinterpret_cast<output_compute_type*>(transform_mid_buf),
ncb_param.bias_mode, ncb_param.nonlineMode, OH, OW,
oc_start_idx, oc_end_idx, unit_start_idx, nr_tiles_in_unit);
}; };


SmallVector<NCBKern> get_kerns( SmallVector<NCBKern> get_kerns(
@@ -542,15 +527,16 @@ public:
size_t IC, size_t oc_start, size_t oc_end); \ size_t IC, size_t oc_start, size_t oc_end); \
void input(const stype* input, \ void input(const stype* input, \
input_filter_compute_type* input_transform_buf, \ input_filter_compute_type* input_transform_buf, \
input_filter_compute_type* transform_mid_buf, int ih_start, \
int iw_start, size_t IH, size_t IW, size_t IC, \
size_t unit_idx, size_t nr_tiles_in_unit); \
input_filter_compute_type* transform_mid_buf, \
size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, \
size_t unit_start_idx, size_t nr_tiles_in_unit); \
void output(const output_compute_type* output_transform_buf, \ void output(const output_compute_type* output_transform_buf, \
const output_compute_type* bias, dst_type* output, \ const output_compute_type* bias, dst_type* output, \
output_compute_type* transform_mid_buf, BiasMode bmode, \ output_compute_type* transform_mid_buf, BiasMode bmode, \
NonlineMode nonline_mode, size_t oh_start, \
size_t ow_start, size_t OH, size_t OW, size_t oc_start, \
size_t oc_end, size_t unit_idx, size_t nr_tiles_in_unit); \
NonlineMode nonline_mode, size_t OH, size_t OW, \
size_t oc_start, size_t oc_end, size_t unit_start_idx, \
size_t nr_tiles_in_unit); \

}; };


#define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ #define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \


+ 31
- 19
dnn/src/x86/conv_bias/f32/strategy_2x3_8x8.cpp View File

@@ -274,31 +274,43 @@ void winograd_nchw88_2x3_8x8_f::filter(const float* filter,
transform_mid_buf, OC, IC, oc_start, transform_mid_buf, OC, IC, oc_start,
oc_end); oc_end);
} }

void winograd_nchw88_2x3_8x8_f::input(const float* input, void winograd_nchw88_2x3_8x8_f::input(const float* input,
float* input_transform_buf, float* input_transform_buf,
float* transform_mid_buf, int ih_start,
int iw_start, size_t IH, size_t IW,
size_t IC, size_t unit_idx,
float* transform_mid_buf, size_t IH,
size_t IW, size_t IC, size_t PH,
size_t PW, size_t unit_start_idx,
size_t nr_units_in_tile) { size_t nr_units_in_tile) {
megdnn_assert(IC % 8 == 0); megdnn_assert(IC % 8 == 0);

// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
float* patch = transform_mid_buf; float* patch = transform_mid_buf;
float* patchT = transform_mid_buf + 8 * alpha * alpha; float* patchT = transform_mid_buf + 8 * alpha * alpha;
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) {
for (size_t ic = 0; ic < IC; ic += 8) {
InputTransform2X3_NCHW88::prepare<true>(
input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC);
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile, ic,
IC);
}
} else {
for (size_t ic = 0; ic < IC; ic += 8) {
InputTransform2X3_NCHW88::prepare<false>(input, patch, patchT, ih_start,
iw_start, IH, IW, ic, IC);
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile, ic,
IC);

for (size_t ic = 0; ic < IC; ic += 8) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) {
InputTransform2X3_NCHW88::prepare<true>(input, patch, patchT,
ih_start, iw_start, IH,
IW, ic, IC);
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile,
ic, IC);
} else {
InputTransform2X3_NCHW88::prepare<false>(input, patch, patchT,
ih_start, iw_start, IH,
IW, ic, IC);
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile,
ic, IC);
}
} }
} }
} }


+ 30
- 19
dnn/src/x86/conv_bias/f32/strategy_6x3_8x8.cpp View File

@@ -338,32 +338,43 @@ void winograd_nchw88_6x3_8x8_f::filter(const float* filter,
transform_mid_buf, OC, IC, oc_start, transform_mid_buf, OC, IC, oc_start,
oc_end); oc_end);
} }

void winograd_nchw88_6x3_8x8_f::input(const float* input, void winograd_nchw88_6x3_8x8_f::input(const float* input,
float* input_transform_buf, float* input_transform_buf,
float* transform_mid_buf, int ih_start,
int iw_start, size_t IH, size_t IW,
size_t IC, size_t unit_idx,
float* transform_mid_buf, size_t IH,
size_t IW, size_t IC, size_t PH,
size_t PW, size_t unit_start_idx,
size_t nr_units_in_tile) { size_t nr_units_in_tile) {
megdnn_assert(IC % 8 == 0); megdnn_assert(IC % 8 == 0);


// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
float* patch = transform_mid_buf; float* patch = transform_mid_buf;
float* patchT = transform_mid_buf + 8 * alpha * alpha; float* patchT = transform_mid_buf + 8 * alpha * alpha;
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) {
for (size_t ic = 0; ic < IC; ic += 8) {
InputTransform6X3_NCHW88::prepare<true>(
input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC);
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile, ic,
IC);
}
} else {
for (size_t ic = 0; ic < IC; ic += 8) {
InputTransform6X3_NCHW88::prepare<false>(input, patch, patchT, ih_start,
iw_start, IH, IW, ic, IC);
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile, ic,
IC);

for (size_t ic = 0; ic < IC; ic += 8) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) {
InputTransform6X3_NCHW88::prepare<true>(input, patch, patchT,
ih_start, iw_start, IH,
IW, ic, IC);
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile,
ic, IC);
} else {
InputTransform6X3_NCHW88::prepare<false>(input, patch, patchT,
ih_start, iw_start, IH,
IW, ic, IC);
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile,
ic, IC);
}
} }
} }
} }


Loading…
Cancel
Save