GitOrigin-RevId: a43077550c
tags/v0.3.2
@@ -247,33 +247,31 @@ void StrategyHelper< | |||
Getter<ctype, input_filter_compute_type> getter(dtype); | |||
InputVisitor<layout, format> intput_visitor(IC); | |||
rep(ic, IC) { | |||
memset(mid_buf1, 0, alpha * alpha * sizeof(input_filter_compute_type)); | |||
rep(i, alpha) rep(j, alpha) { | |||
int ih = ih_start + i; | |||
int iw = iw_start + j; | |||
if (ih >= 0 && ih < (int)IH && iw >= 0 && iw < (int)IW) { | |||
mid_buf1[i * alpha + j] = getter( | |||
input[intput_visitor.get(alpha, ic, IH, IW, ih, iw)]); | |||
} | |||
memset(mid_buf1, 0, alpha * alpha * sizeof(input_filter_compute_type)); | |||
rep(i, alpha) rep(j, alpha) { | |||
int ih = ih_start + i; | |||
int iw = iw_start + j; | |||
if (ih >= 0 && ih < (int)IH && iw >= 0 && iw < (int)IW) { | |||
mid_buf1[i * alpha + j] = getter( | |||
input[intput_visitor.get(alpha, ic, IH, IW, ih, iw)]); | |||
} | |||
} | |||
megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type, | |||
input_filter_compute_type, true, | |||
false>( | |||
winograd_coeff.B(rescale).data(), mid_buf1, mid_buf2, alpha, | |||
alpha, alpha, alpha, alpha, alpha, dtype, dtype); | |||
megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type, | |||
input_filter_compute_type, false, | |||
false>( | |||
mid_buf2, winograd_coeff.B(rescale).data(), mid_buf1, alpha, | |||
alpha, alpha, alpha, alpha, alpha, dtype, dtype); | |||
rep(i, alpha) rep(j, alpha) { | |||
input_transform_buf[intput_visitor.put(alpha, ic, nr_units_in_tile, | |||
unit_idx, i, j)] = | |||
mid_buf1[i * alpha + j]; | |||
} | |||
megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type, | |||
input_filter_compute_type, true, | |||
false>( | |||
winograd_coeff.B(rescale).data(), mid_buf1, mid_buf2, alpha, | |||
alpha, alpha, alpha, alpha, alpha, dtype, dtype); | |||
megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type, | |||
input_filter_compute_type, false, | |||
false>( | |||
mid_buf2, winograd_coeff.B(rescale).data(), mid_buf1, alpha, | |||
alpha, alpha, alpha, alpha, alpha, dtype, dtype); | |||
rep(i, alpha) rep(j, alpha) { | |||
input_transform_buf[intput_visitor.put(alpha, ic, nr_units_in_tile, | |||
unit_idx, i, j)] = | |||
mid_buf1[i * alpha + j]; | |||
} | |||
} | |||
@@ -287,7 +285,7 @@ void StrategyHelper< | |||
output_compute_type* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t oh_start, | |||
size_t ow_start, size_t OH, size_t OW, size_t oc_start, | |||
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile, | |||
size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, | |||
size_t m, size_t r, | |||
const std::vector<float>& interp_points, DType dtype, | |||
float input_filter_scale, float input_filter_rescale, | |||
@@ -300,49 +298,49 @@ void StrategyHelper< | |||
OutputGetter<output_compute_type, dst_type> getter(dtype); | |||
OutputVisitor<layout, format> output_visitor(oc_end - oc_start); | |||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||
/* gather */ | |||
rep(i, alpha) rep(j, alpha) { | |||
mid_buf1[i * alpha + j] = output_transform_buf[output_visitor.get( | |||
alpha, oc - oc_start, oc, nr_units_in_tile, unit_idx, i, | |||
j)]; | |||
} | |||
/* A[alpha*m] M[alpha*alpha] */ | |||
megdnn::naive::run_matrix_mul_tpl<output_compute_type, | |||
output_compute_type, true, false>( | |||
winograd_coeff.A(rescale).data(), mid_buf1, mid_buf2, m, alpha, | |||
alpha, m, alpha, alpha, dtype, dtype); | |||
megdnn::naive::run_matrix_mul_tpl<output_compute_type, | |||
output_compute_type, false, false>( | |||
mid_buf2, winograd_coeff.A(rescale).data(), mid_buf1, m, m, | |||
alpha, alpha, m, m, dtype, dtype); | |||
rep(i, m) rep(j, m) { | |||
auto oh = oh_start + i; | |||
auto ow = ow_start + j; | |||
if (oh < OH && ow < OW) { | |||
float val = mid_buf1[i * m + j]; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
val += bias[oc] * input_filter_rescale * | |||
input_filter_rescale; | |||
} else if (bmode == BiasMode::BIAS) { | |||
val += bias[output_visitor.put(oc, OH, OW, oh, ow)] * | |||
input_filter_rescale * input_filter_rescale; | |||
} | |||
val = val * input_filter_scale / | |||
(input_filter_rescale * input_filter_rescale * rescale * | |||
rescale); | |||
if (nonline_mode == NonlineMode::RELU) { | |||
val = val > 0 ? val : 0; | |||
} else if (nonline_mode == NonlineMode::SIGMOID) { | |||
val = 1.f / (expf(-val) + 1.f); | |||
} else if (nonline_mode == NonlineMode::H_SWISH) { | |||
val = val * std::min(std::max(val + 3, 0.f), 6.f) / 6.f; | |||
} else { | |||
megdnn_assert(nonline_mode == NonlineMode::IDENTITY); | |||
} | |||
output[output_visitor.put(oc, OH, OW, oh, ow)] = getter(val); | |||
size_t oc = oc_start + oc_index; | |||
/* gather */ | |||
rep(i, alpha) rep(j, alpha) { | |||
mid_buf1[i * alpha + j] = output_transform_buf[output_visitor.get( | |||
alpha, oc_index, oc, nr_units_in_tile, unit_idx, i, | |||
j)]; | |||
} | |||
/* A[alpha*m] M[alpha*alpha] */ | |||
megdnn::naive::run_matrix_mul_tpl<output_compute_type, | |||
output_compute_type, true, false>( | |||
winograd_coeff.A(rescale).data(), mid_buf1, mid_buf2, m, alpha, | |||
alpha, m, alpha, alpha, dtype, dtype); | |||
megdnn::naive::run_matrix_mul_tpl<output_compute_type, | |||
output_compute_type, false, false>( | |||
mid_buf2, winograd_coeff.A(rescale).data(), mid_buf1, m, m, | |||
alpha, alpha, m, m, dtype, dtype); | |||
rep(i, m) rep(j, m) { | |||
auto oh = oh_start + i; | |||
auto ow = ow_start + j; | |||
if (oh < OH && ow < OW) { | |||
float val = mid_buf1[i * m + j]; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
val += bias[oc] * input_filter_rescale * | |||
input_filter_rescale; | |||
} else if (bmode == BiasMode::BIAS) { | |||
val += bias[output_visitor.put(oc, OH, OW, oh, ow)] * | |||
input_filter_rescale * input_filter_rescale; | |||
} | |||
val = val * input_filter_scale / | |||
(input_filter_rescale * input_filter_rescale * rescale * | |||
rescale); | |||
if (nonline_mode == NonlineMode::RELU) { | |||
val = val > 0 ? val : 0; | |||
} else if (nonline_mode == NonlineMode::SIGMOID) { | |||
val = 1.f / (expf(-val) + 1.f); | |||
} else if (nonline_mode == NonlineMode::H_SWISH) { | |||
val = val * std::min(std::max(val + 3, 0.f), 6.f) / 6.f; | |||
} else { | |||
megdnn_assert(nonline_mode == NonlineMode::IDENTITY); | |||
} | |||
output[output_visitor.put(oc, OH, OW, oh, ow)] = getter(val); | |||
} | |||
} | |||
}; | |||
@@ -44,7 +44,7 @@ public: | |||
input_filter_compute_type* input_transform_buf, | |||
input_filter_compute_type* transform_mid_buf, | |||
int ih_start, int iw_start, size_t IH, size_t IW, | |||
size_t IC, size_t unit_idx, size_t nr_units_in_tile, | |||
size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, | |||
size_t m, size_t r, | |||
const std::vector<float>& interp_points, DType dtype, | |||
float rescale = 1.0f); | |||
@@ -54,7 +54,7 @@ public: | |||
const output_compute_type* bias, dst_type* output, | |||
output_compute_type* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t oh_start, size_t ow_start, | |||
size_t OH, size_t OW, size_t oc_start, size_t oc_end, | |||
size_t OH, size_t OW, size_t oc_start, size_t oc_index, | |||
size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, | |||
const std::vector<float>& interp_points, DType dtype, | |||
float input_filter_scale = 1.0f, // input_scale * filter_scale | |||
@@ -55,7 +55,7 @@ public: | |||
ohw_tile_size)); | |||
all_algos.emplace_back(refhold.back().get()); | |||
} | |||
#if 0 | |||
#if 1 | |||
//! As these algos maybe very slow, it will make fastrun search slow, so | |||
//! we disable it, but for the test of strategyhelper, we just keep it. | |||
//! FIXME: I do not know a better way to do it. | |||
@@ -6,8 +6,7 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/fallback/conv_bias/winograd/strategy.h" | |||
@@ -31,27 +30,54 @@ void winograd_2x3_1x1_f::filter(const float* filter, | |||
} | |||
void winograd_2x3_1x1_f::input(const float* input, float* input_transform_buf, | |||
float* transform_mid_buf, int ih_start, | |||
int iw_start, size_t IH, size_t IW, size_t IC, | |||
size_t unit_idx, size_t nr_units_in_tile) { | |||
::megdnn::winograd::StrategyHelper<float, float, float, float>::input( | |||
input, input_transform_buf, transform_mid_buf, ih_start, iw_start, | |||
IH, IW, IC, unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, | |||
KERNEL_SIZE, {0, 1, -1}, src_dtype); | |||
float* transform_mid_buf, size_t IH, size_t IW, | |||
size_t IC, size_t PH, size_t PW, | |||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = | |||
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
rep(ic, IC) { | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||
::megdnn::winograd::StrategyHelper<float, float, float, float>:: | |||
input(input, input_transform_buf, transform_mid_buf, | |||
ih_start, iw_start, IH, IW, IC, ic, unit_idx, | |||
nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||
{0, 1, -1}, src_dtype); | |||
} | |||
} | |||
} | |||
void winograd_2x3_1x1_f::output(const float* output_transform_buf, | |||
const float* bias, float* output, | |||
float* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t oh_start, | |||
size_t ow_start, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, size_t unit_idx, | |||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, | |||
size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
::megdnn::winograd::StrategyHelper<float, float, float, float>::output( | |||
output_transform_buf, bias, output, transform_mid_buf, bmode, | |||
nonline_mode, oh_start, ow_start, OH, OW, oc_start, oc_end, | |||
unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||
{0, 1, -1}, dst_dtype); | |||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
size_t OC = oc_end - oc_start; | |||
for (size_t oc = oc_start; oc < oc_end; ++oc) { | |||
size_t oc_index = oc - oc_start; | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
::megdnn::winograd::StrategyHelper<float, float, float, float>:: | |||
output(output_transform_buf, bias, output, | |||
transform_mid_buf, bmode, nonline_mode, oh_start, | |||
ow_start, OH, OW, OC, oc_start, oc_index, unit_idx, | |||
nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||
{0, 1, -1}, dst_dtype); | |||
} | |||
} | |||
} | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f) | |||
@@ -71,38 +97,70 @@ void winograd_2x3_4x4_f::filter(const float* filter, | |||
} | |||
void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf, | |||
float* transform_mid_buf, int ih_start, | |||
int iw_start, size_t IH, size_t IW, size_t IC, | |||
size_t unit_idx, size_t nr_units_in_tile) { | |||
::megdnn::winograd::StrategyHelper< | |||
float, float, float, float, param::ConvBias::Format::NCHW, | |||
param::MatrixMul::Format::MK4>::input(input, input_transform_buf, | |||
transform_mid_buf, ih_start, | |||
iw_start, IH, IW, IC, | |||
unit_idx, nr_units_in_tile, | |||
OUTPUT_BLOCK_SIZE, | |||
KERNEL_SIZE, {0, 1, -1}, | |||
src_dtype); | |||
float* transform_mid_buf, size_t IH, size_t IW, | |||
size_t IC, size_t PH, size_t PW, | |||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = | |||
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
rep(ic, IC) { | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||
::megdnn::winograd::StrategyHelper< | |||
float, float, float, float, param::ConvBias::Format::NCHW, | |||
param::MatrixMul::Format::MK4>::input(input, | |||
input_transform_buf, | |||
transform_mid_buf, | |||
ih_start, iw_start, | |||
IH, IW, IC, ic, | |||
unit_idx, | |||
nr_units_in_tile, | |||
OUTPUT_BLOCK_SIZE, | |||
KERNEL_SIZE, | |||
{0, 1, -1}, | |||
src_dtype); | |||
} | |||
} | |||
} | |||
void winograd_2x3_4x4_f::output(const float* output_transform_buf, | |||
const float* bias, float* output, | |||
float* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t oh_start, | |||
size_t ow_start, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, size_t unit_idx, | |||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, | |||
size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
::megdnn::winograd::StrategyHelper< | |||
float, float, float, float, param::ConvBias::Format::NCHW, | |||
param::MatrixMul::Format::MK4>::output(output_transform_buf, bias, | |||
output, transform_mid_buf, | |||
bmode, nonline_mode, | |||
oh_start, ow_start, OH, OW, | |||
oc_start, oc_end, unit_idx, | |||
nr_units_in_tile, | |||
OUTPUT_BLOCK_SIZE, | |||
KERNEL_SIZE, {0, 1, -1}, | |||
dst_dtype); | |||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
size_t OC = oc_end - oc_start; | |||
for (size_t oc = oc_start; oc < oc_end; ++oc) { | |||
size_t oc_index = oc - oc_start; | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
::megdnn::winograd::StrategyHelper< | |||
float, float, float, float, param::ConvBias::Format::NCHW, | |||
param::MatrixMul::Format::MK4>::output(output_transform_buf, | |||
bias, output, | |||
transform_mid_buf, | |||
bmode, nonline_mode, | |||
oh_start, ow_start, | |||
OH, OW, OC, oc_start, | |||
oc_index, unit_idx, | |||
nr_units_in_tile, | |||
OUTPUT_BLOCK_SIZE, | |||
KERNEL_SIZE, | |||
{0, 1, -1}, | |||
dst_dtype); | |||
} | |||
} | |||
} | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_1x1_qs8) | |||
@@ -119,29 +177,59 @@ void winograd_2x3_1x1_qs8::filter(const int8_t* filter, | |||
void winograd_2x3_1x1_qs8::input(const int8_t* input, | |||
int16_t* input_transform_buf, | |||
int16_t* transform_mid_buf, int ih_start, | |||
int iw_start, size_t IH, size_t IW, size_t IC, | |||
size_t unit_idx, size_t nr_units_in_tile) { | |||
::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>::input( | |||
input, input_transform_buf, transform_mid_buf, ih_start, iw_start, | |||
IH, IW, IC, unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, | |||
KERNEL_SIZE, {0, 1, -1}, src_dtype, 1.0f); | |||
int16_t* transform_mid_buf, size_t IH, | |||
size_t IW, size_t IC, size_t PH, size_t PW, | |||
size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = | |||
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
rep(ic, IC) { | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||
::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>:: | |||
input(input, input_transform_buf, transform_mid_buf, | |||
ih_start, iw_start, IH, IW, IC, ic, unit_idx, | |||
nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||
{0, 1, -1}, src_dtype, 1.0f); | |||
} | |||
} | |||
} | |||
void winograd_2x3_1x1_qs8::output(const int* output_transform_buf, | |||
const int* bias, int8_t* output, | |||
int* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t oh_start, | |||
size_t ow_start, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, | |||
size_t unit_idx, size_t nr_units_in_tile) { | |||
NonlineMode nonline_mode, size_t OH, | |||
size_t OW, size_t oc_start, size_t oc_end, | |||
size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
float scale_input = src_dtype.param<dtype::QuantizedS8>().scale; | |||
float scale_filter = filter_dtype.param<dtype::QuantizedS8>().scale; | |||
::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>::output( | |||
output_transform_buf, bias, output, transform_mid_buf, bmode, | |||
nonline_mode, oh_start, ow_start, OH, OW, oc_start, oc_end, | |||
unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||
{0, 1, -1}, dst_dtype, scale_input * scale_filter, 2.0f, 1.0f); | |||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
size_t OC = oc_end - oc_start; | |||
for (size_t oc = oc_start; oc < oc_end; ++oc) { | |||
size_t oc_index = oc - oc_start; | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>:: | |||
output(output_transform_buf, bias, output, | |||
transform_mid_buf, bmode, nonline_mode, oh_start, | |||
ow_start, OH, OW, OC, oc_start, oc_index, unit_idx, | |||
nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||
{0, 1, -1}, dst_dtype, scale_input * scale_filter, | |||
2.0f, 1.0f); | |||
} | |||
} | |||
} | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_qs8) | |||
@@ -162,27 +250,44 @@ void winograd_2x3_8x8_qs8::filter(const int8_t* filter, | |||
void winograd_2x3_8x8_qs8::input(const int8_t* input, | |||
int16_t* input_transform_buf, | |||
int16_t* transform_mid_buf, int ih_start, | |||
int iw_start, size_t IH, size_t IW, size_t IC, | |||
size_t unit_idx, size_t nr_units_in_tile) { | |||
::megdnn::winograd::StrategyHelper< | |||
int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, | |||
param::MatrixMul::Format::MK8>::input(input, input_transform_buf, | |||
transform_mid_buf, ih_start, | |||
iw_start, IH, IW, IC, | |||
unit_idx, nr_units_in_tile, | |||
OUTPUT_BLOCK_SIZE, | |||
KERNEL_SIZE, {0, 1, -1}, | |||
src_dtype, 1.0f); | |||
int16_t* transform_mid_buf, size_t IH, | |||
size_t IW, size_t IC, size_t PH, size_t PW, | |||
size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = | |||
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
rep(ic, IC) { | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||
::megdnn::winograd::StrategyHelper< | |||
int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, | |||
param::MatrixMul::Format::MK8>::input(input, | |||
input_transform_buf, | |||
transform_mid_buf, | |||
ih_start, iw_start, | |||
IH, IW, IC, ic, | |||
unit_idx, | |||
nr_units_in_tile, | |||
OUTPUT_BLOCK_SIZE, | |||
KERNEL_SIZE, | |||
{0, 1, -1}, src_dtype, | |||
1.0f); | |||
} | |||
} | |||
} | |||
void winograd_2x3_8x8_qs8::output(const int* output_transform_buf, | |||
const int* bias, int8_t* output, | |||
int* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t oh_start, | |||
size_t ow_start, size_t OH, size_t OW, | |||
size_t oc_start, size_t oc_end, | |||
size_t unit_idx, size_t nr_units_in_tile) { | |||
NonlineMode nonline_mode, size_t OH, | |||
size_t OW, size_t oc_start, size_t oc_end, | |||
size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
float scale_input = src_dtype.param<dtype::QuantizedS8>().scale; | |||
float scale_filter = 0.f; | |||
if (filter_dtype.enumv() == DTypeEnum::QuantizedS8) { | |||
@@ -191,19 +296,37 @@ void winograd_2x3_8x8_qs8::output(const int* output_transform_buf, | |||
megdnn_assert(filter_dtype.enumv() == DTypeEnum::QuantizedS16); | |||
scale_filter = filter_dtype.param<dtype::QuantizedS16>().scale; | |||
} | |||
::megdnn::winograd::StrategyHelper< | |||
int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, | |||
param::MatrixMul::Format::MK8>::output(output_transform_buf, bias, | |||
output, transform_mid_buf, | |||
bmode, nonline_mode, | |||
oh_start, ow_start, OH, OW, | |||
oc_start, oc_end, unit_idx, | |||
nr_units_in_tile, | |||
OUTPUT_BLOCK_SIZE, | |||
KERNEL_SIZE, {0, 1, -1}, | |||
dst_dtype, | |||
scale_input * scale_filter, | |||
2.0f, 1.0f); | |||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
size_t OC = oc_end - oc_start; | |||
for (size_t oc = oc_start; oc < oc_end; ++oc) { | |||
size_t oc_index = oc - oc_start; | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
::megdnn::winograd::StrategyHelper< | |||
int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, | |||
param::MatrixMul::Format::MK8>::output(output_transform_buf, | |||
bias, output, | |||
transform_mid_buf, | |||
bmode, nonline_mode, | |||
oh_start, ow_start, | |||
OH, OW, OC, oc_start, | |||
oc_index, unit_idx, | |||
nr_units_in_tile, | |||
OUTPUT_BLOCK_SIZE, | |||
KERNEL_SIZE, | |||
{0, 1, -1}, | |||
dst_dtype, | |||
scale_input * | |||
scale_filter, | |||
2.0f, 1.0f); | |||
} | |||
} | |||
} | |||
} // namespace winograd | |||
@@ -321,17 +321,10 @@ public: | |||
"nr_tiles_in_unit: %zu TILE_SIZE:%zu", | |||
nr_tiles_in_unit, unit_tile_size); | |||
} | |||
rep(unit_idx, nr_tiles_in_unit) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * Strategy::OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * Strategy::OUTPUT_BLOCK_SIZE - PW; | |||
strategy.input(src_ptr, input_transform_buf, transform_mid_buf, | |||
ih_start, iw_start, IH, IW, IC, unit_idx, | |||
nr_tiles_in_unit); | |||
} | |||
//! BTdB | |||
strategy.input(src_ptr, input_transform_buf, transform_mid_buf, | |||
IH, IW, IC, PH, PW, unit_start_idx, nr_tiles_in_unit); | |||
rep(i, Strategy::ALPHA) rep(j, Strategy::ALPHA) { | |||
if (format == param::MatrixMul::Format::DEFAULT) { | |||
matmul_param.A_ptr = | |||
@@ -368,22 +361,14 @@ public: | |||
} | |||
matmul_kern(matmul_param); | |||
} | |||
/* Y = ATmA */ | |||
rep(unit_idx, nr_tiles_in_unit) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * Strategy::OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * Strategy::OUTPUT_BLOCK_SIZE; | |||
size_t oc_end_idx = oc_start_idx + nr_oc_in_unit; | |||
strategy.output( | |||
output_transform_buf, bias_ptr, dst_ptr, | |||
reinterpret_cast<output_compute_type*>(transform_mid_buf), | |||
ncb_param.bias_mode, ncb_param.nonlineMode, oh_start, | |||
ow_start, OH, OW, oc_start_idx, oc_end_idx, unit_idx, | |||
nr_tiles_in_unit); | |||
} | |||
//! Y = ATmA | |||
size_t oc_end_idx = oc_start_idx + nr_oc_in_unit; | |||
strategy.output( | |||
output_transform_buf, bias_ptr, dst_ptr, | |||
reinterpret_cast<output_compute_type*>(transform_mid_buf), | |||
ncb_param.bias_mode, ncb_param.nonlineMode, OH, OW, | |||
oc_start_idx, oc_end_idx, unit_start_idx, nr_tiles_in_unit); | |||
}; | |||
SmallVector<NCBKern> get_kerns( | |||
@@ -542,15 +527,16 @@ public: | |||
size_t IC, size_t oc_start, size_t oc_end); \ | |||
void input(const stype* input, \ | |||
input_filter_compute_type* input_transform_buf, \ | |||
input_filter_compute_type* transform_mid_buf, int ih_start, \ | |||
int iw_start, size_t IH, size_t IW, size_t IC, \ | |||
size_t unit_idx, size_t nr_tiles_in_unit); \ | |||
input_filter_compute_type* transform_mid_buf, \ | |||
size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, \ | |||
size_t unit_start_idx, size_t nr_tiles_in_unit); \ | |||
void output(const output_compute_type* output_transform_buf, \ | |||
const output_compute_type* bias, dst_type* output, \ | |||
output_compute_type* transform_mid_buf, BiasMode bmode, \ | |||
NonlineMode nonline_mode, size_t oh_start, \ | |||
size_t ow_start, size_t OH, size_t OW, size_t oc_start, \ | |||
size_t oc_end, size_t unit_idx, size_t nr_tiles_in_unit); \ | |||
NonlineMode nonline_mode, size_t OH, size_t OW, \ | |||
size_t oc_start, size_t oc_end, size_t unit_start_idx, \ | |||
size_t nr_tiles_in_unit); \ | |||
}; | |||
#define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ | |||
@@ -274,31 +274,43 @@ void winograd_nchw88_2x3_8x8_f::filter(const float* filter, | |||
transform_mid_buf, OC, IC, oc_start, | |||
oc_end); | |||
} | |||
void winograd_nchw88_2x3_8x8_f::input(const float* input, | |||
float* input_transform_buf, | |||
float* transform_mid_buf, int ih_start, | |||
int iw_start, size_t IH, size_t IW, | |||
size_t IC, size_t unit_idx, | |||
float* transform_mid_buf, size_t IH, | |||
size_t IW, size_t IC, size_t PH, | |||
size_t PW, size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
megdnn_assert(IC % 8 == 0); | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
float* patch = transform_mid_buf; | |||
float* patchT = transform_mid_buf + 8 * alpha * alpha; | |||
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) && | |||
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) { | |||
for (size_t ic = 0; ic < IC; ic += 8) { | |||
InputTransform2X3_NCHW88::prepare<true>( | |||
input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); | |||
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, | |||
unit_idx, nr_units_in_tile, ic, | |||
IC); | |||
} | |||
} else { | |||
for (size_t ic = 0; ic < IC; ic += 8) { | |||
InputTransform2X3_NCHW88::prepare<false>(input, patch, patchT, ih_start, | |||
iw_start, IH, IW, ic, IC); | |||
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, | |||
unit_idx, nr_units_in_tile, ic, | |||
IC); | |||
for (size_t ic = 0; ic < IC; ic += 8) { | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) && | |||
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) { | |||
InputTransform2X3_NCHW88::prepare<true>(input, patch, patchT, | |||
ih_start, iw_start, IH, | |||
IW, ic, IC); | |||
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, | |||
unit_idx, nr_units_in_tile, | |||
ic, IC); | |||
} else { | |||
InputTransform2X3_NCHW88::prepare<false>(input, patch, patchT, | |||
ih_start, iw_start, IH, | |||
IW, ic, IC); | |||
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, | |||
unit_idx, nr_units_in_tile, | |||
ic, IC); | |||
} | |||
} | |||
} | |||
} | |||
@@ -338,32 +338,43 @@ void winograd_nchw88_6x3_8x8_f::filter(const float* filter, | |||
transform_mid_buf, OC, IC, oc_start, | |||
oc_end); | |||
} | |||
void winograd_nchw88_6x3_8x8_f::input(const float* input, | |||
float* input_transform_buf, | |||
float* transform_mid_buf, int ih_start, | |||
int iw_start, size_t IH, size_t IW, | |||
size_t IC, size_t unit_idx, | |||
float* transform_mid_buf, size_t IH, | |||
size_t IW, size_t IC, size_t PH, | |||
size_t PW, size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
megdnn_assert(IC % 8 == 0); | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
float* patch = transform_mid_buf; | |||
float* patchT = transform_mid_buf + 8 * alpha * alpha; | |||
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) && | |||
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) { | |||
for (size_t ic = 0; ic < IC; ic += 8) { | |||
InputTransform6X3_NCHW88::prepare<true>( | |||
input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); | |||
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, | |||
unit_idx, nr_units_in_tile, ic, | |||
IC); | |||
} | |||
} else { | |||
for (size_t ic = 0; ic < IC; ic += 8) { | |||
InputTransform6X3_NCHW88::prepare<false>(input, patch, patchT, ih_start, | |||
iw_start, IH, IW, ic, IC); | |||
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, | |||
unit_idx, nr_units_in_tile, ic, | |||
IC); | |||
for (size_t ic = 0; ic < IC; ic += 8) { | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
size_t nh = index / units_w; | |||
size_t nw = index % units_w; | |||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) && | |||
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) { | |||
InputTransform6X3_NCHW88::prepare<true>(input, patch, patchT, | |||
ih_start, iw_start, IH, | |||
IW, ic, IC); | |||
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, | |||
unit_idx, nr_units_in_tile, | |||
ic, IC); | |||
} else { | |||
InputTransform6X3_NCHW88::prepare<false>(input, patch, patchT, | |||
ih_start, iw_start, IH, | |||
IW, ic, IC); | |||
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, | |||
unit_idx, nr_units_in_tile, | |||
ic, IC); | |||
} | |||
} | |||
} | |||
} | |||