GitOrigin-RevId: a43077550c
tags/v0.3.2
@@ -247,33 +247,31 @@ void StrategyHelper< | |||||
Getter<ctype, input_filter_compute_type> getter(dtype); | Getter<ctype, input_filter_compute_type> getter(dtype); | ||||
InputVisitor<layout, format> intput_visitor(IC); | InputVisitor<layout, format> intput_visitor(IC); | ||||
rep(ic, IC) { | |||||
memset(mid_buf1, 0, alpha * alpha * sizeof(input_filter_compute_type)); | |||||
rep(i, alpha) rep(j, alpha) { | |||||
int ih = ih_start + i; | |||||
int iw = iw_start + j; | |||||
if (ih >= 0 && ih < (int)IH && iw >= 0 && iw < (int)IW) { | |||||
mid_buf1[i * alpha + j] = getter( | |||||
input[intput_visitor.get(alpha, ic, IH, IW, ih, iw)]); | |||||
} | |||||
memset(mid_buf1, 0, alpha * alpha * sizeof(input_filter_compute_type)); | |||||
rep(i, alpha) rep(j, alpha) { | |||||
int ih = ih_start + i; | |||||
int iw = iw_start + j; | |||||
if (ih >= 0 && ih < (int)IH && iw >= 0 && iw < (int)IW) { | |||||
mid_buf1[i * alpha + j] = getter( | |||||
input[intput_visitor.get(alpha, ic, IH, IW, ih, iw)]); | |||||
} | } | ||||
} | |||||
megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type, | |||||
input_filter_compute_type, true, | |||||
false>( | |||||
winograd_coeff.B(rescale).data(), mid_buf1, mid_buf2, alpha, | |||||
alpha, alpha, alpha, alpha, alpha, dtype, dtype); | |||||
megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type, | |||||
input_filter_compute_type, false, | |||||
false>( | |||||
mid_buf2, winograd_coeff.B(rescale).data(), mid_buf1, alpha, | |||||
alpha, alpha, alpha, alpha, alpha, dtype, dtype); | |||||
rep(i, alpha) rep(j, alpha) { | |||||
input_transform_buf[intput_visitor.put(alpha, ic, nr_units_in_tile, | |||||
unit_idx, i, j)] = | |||||
mid_buf1[i * alpha + j]; | |||||
} | |||||
megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type, | |||||
input_filter_compute_type, true, | |||||
false>( | |||||
winograd_coeff.B(rescale).data(), mid_buf1, mid_buf2, alpha, | |||||
alpha, alpha, alpha, alpha, alpha, dtype, dtype); | |||||
megdnn::naive::run_matrix_mul_tpl<input_filter_compute_type, | |||||
input_filter_compute_type, false, | |||||
false>( | |||||
mid_buf2, winograd_coeff.B(rescale).data(), mid_buf1, alpha, | |||||
alpha, alpha, alpha, alpha, alpha, dtype, dtype); | |||||
rep(i, alpha) rep(j, alpha) { | |||||
input_transform_buf[intput_visitor.put(alpha, ic, nr_units_in_tile, | |||||
unit_idx, i, j)] = | |||||
mid_buf1[i * alpha + j]; | |||||
} | } | ||||
} | } | ||||
@@ -287,7 +285,7 @@ void StrategyHelper< | |||||
output_compute_type* transform_mid_buf, BiasMode bmode, | output_compute_type* transform_mid_buf, BiasMode bmode, | ||||
NonlineMode nonline_mode, size_t oh_start, | NonlineMode nonline_mode, size_t oh_start, | ||||
size_t ow_start, size_t OH, size_t OW, size_t oc_start, | size_t ow_start, size_t OH, size_t OW, size_t oc_start, | ||||
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile, | |||||
size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, | |||||
size_t m, size_t r, | size_t m, size_t r, | ||||
const std::vector<float>& interp_points, DType dtype, | const std::vector<float>& interp_points, DType dtype, | ||||
float input_filter_scale, float input_filter_rescale, | float input_filter_scale, float input_filter_rescale, | ||||
@@ -300,49 +298,49 @@ void StrategyHelper< | |||||
OutputGetter<output_compute_type, dst_type> getter(dtype); | OutputGetter<output_compute_type, dst_type> getter(dtype); | ||||
OutputVisitor<layout, format> output_visitor(oc_end - oc_start); | OutputVisitor<layout, format> output_visitor(oc_end - oc_start); | ||||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||||
/* gather */ | |||||
rep(i, alpha) rep(j, alpha) { | |||||
mid_buf1[i * alpha + j] = output_transform_buf[output_visitor.get( | |||||
alpha, oc - oc_start, oc, nr_units_in_tile, unit_idx, i, | |||||
j)]; | |||||
} | |||||
/* A[alpha*m] M[alpha*alpha] */ | |||||
megdnn::naive::run_matrix_mul_tpl<output_compute_type, | |||||
output_compute_type, true, false>( | |||||
winograd_coeff.A(rescale).data(), mid_buf1, mid_buf2, m, alpha, | |||||
alpha, m, alpha, alpha, dtype, dtype); | |||||
megdnn::naive::run_matrix_mul_tpl<output_compute_type, | |||||
output_compute_type, false, false>( | |||||
mid_buf2, winograd_coeff.A(rescale).data(), mid_buf1, m, m, | |||||
alpha, alpha, m, m, dtype, dtype); | |||||
rep(i, m) rep(j, m) { | |||||
auto oh = oh_start + i; | |||||
auto ow = ow_start + j; | |||||
if (oh < OH && ow < OW) { | |||||
float val = mid_buf1[i * m + j]; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
val += bias[oc] * input_filter_rescale * | |||||
input_filter_rescale; | |||||
} else if (bmode == BiasMode::BIAS) { | |||||
val += bias[output_visitor.put(oc, OH, OW, oh, ow)] * | |||||
input_filter_rescale * input_filter_rescale; | |||||
} | |||||
val = val * input_filter_scale / | |||||
(input_filter_rescale * input_filter_rescale * rescale * | |||||
rescale); | |||||
if (nonline_mode == NonlineMode::RELU) { | |||||
val = val > 0 ? val : 0; | |||||
} else if (nonline_mode == NonlineMode::SIGMOID) { | |||||
val = 1.f / (expf(-val) + 1.f); | |||||
} else if (nonline_mode == NonlineMode::H_SWISH) { | |||||
val = val * std::min(std::max(val + 3, 0.f), 6.f) / 6.f; | |||||
} else { | |||||
megdnn_assert(nonline_mode == NonlineMode::IDENTITY); | |||||
} | |||||
output[output_visitor.put(oc, OH, OW, oh, ow)] = getter(val); | |||||
size_t oc = oc_start + oc_index; | |||||
/* gather */ | |||||
rep(i, alpha) rep(j, alpha) { | |||||
mid_buf1[i * alpha + j] = output_transform_buf[output_visitor.get( | |||||
alpha, oc_index, oc, nr_units_in_tile, unit_idx, i, | |||||
j)]; | |||||
} | |||||
/* A[alpha*m] M[alpha*alpha] */ | |||||
megdnn::naive::run_matrix_mul_tpl<output_compute_type, | |||||
output_compute_type, true, false>( | |||||
winograd_coeff.A(rescale).data(), mid_buf1, mid_buf2, m, alpha, | |||||
alpha, m, alpha, alpha, dtype, dtype); | |||||
megdnn::naive::run_matrix_mul_tpl<output_compute_type, | |||||
output_compute_type, false, false>( | |||||
mid_buf2, winograd_coeff.A(rescale).data(), mid_buf1, m, m, | |||||
alpha, alpha, m, m, dtype, dtype); | |||||
rep(i, m) rep(j, m) { | |||||
auto oh = oh_start + i; | |||||
auto ow = ow_start + j; | |||||
if (oh < OH && ow < OW) { | |||||
float val = mid_buf1[i * m + j]; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
val += bias[oc] * input_filter_rescale * | |||||
input_filter_rescale; | |||||
} else if (bmode == BiasMode::BIAS) { | |||||
val += bias[output_visitor.put(oc, OH, OW, oh, ow)] * | |||||
input_filter_rescale * input_filter_rescale; | |||||
} | |||||
val = val * input_filter_scale / | |||||
(input_filter_rescale * input_filter_rescale * rescale * | |||||
rescale); | |||||
if (nonline_mode == NonlineMode::RELU) { | |||||
val = val > 0 ? val : 0; | |||||
} else if (nonline_mode == NonlineMode::SIGMOID) { | |||||
val = 1.f / (expf(-val) + 1.f); | |||||
} else if (nonline_mode == NonlineMode::H_SWISH) { | |||||
val = val * std::min(std::max(val + 3, 0.f), 6.f) / 6.f; | |||||
} else { | |||||
megdnn_assert(nonline_mode == NonlineMode::IDENTITY); | |||||
} | } | ||||
output[output_visitor.put(oc, OH, OW, oh, ow)] = getter(val); | |||||
} | } | ||||
} | } | ||||
}; | }; | ||||
@@ -44,7 +44,7 @@ public: | |||||
input_filter_compute_type* input_transform_buf, | input_filter_compute_type* input_transform_buf, | ||||
input_filter_compute_type* transform_mid_buf, | input_filter_compute_type* transform_mid_buf, | ||||
int ih_start, int iw_start, size_t IH, size_t IW, | int ih_start, int iw_start, size_t IH, size_t IW, | ||||
size_t IC, size_t unit_idx, size_t nr_units_in_tile, | |||||
size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, | |||||
size_t m, size_t r, | size_t m, size_t r, | ||||
const std::vector<float>& interp_points, DType dtype, | const std::vector<float>& interp_points, DType dtype, | ||||
float rescale = 1.0f); | float rescale = 1.0f); | ||||
@@ -54,7 +54,7 @@ public: | |||||
const output_compute_type* bias, dst_type* output, | const output_compute_type* bias, dst_type* output, | ||||
output_compute_type* transform_mid_buf, BiasMode bmode, | output_compute_type* transform_mid_buf, BiasMode bmode, | ||||
NonlineMode nonline_mode, size_t oh_start, size_t ow_start, | NonlineMode nonline_mode, size_t oh_start, size_t ow_start, | ||||
size_t OH, size_t OW, size_t oc_start, size_t oc_end, | |||||
size_t OH, size_t OW, size_t oc_start, size_t oc_index, | |||||
size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, | size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, | ||||
const std::vector<float>& interp_points, DType dtype, | const std::vector<float>& interp_points, DType dtype, | ||||
float input_filter_scale = 1.0f, // input_scale * filter_scale | float input_filter_scale = 1.0f, // input_scale * filter_scale | ||||
@@ -55,7 +55,7 @@ public: | |||||
ohw_tile_size)); | ohw_tile_size)); | ||||
all_algos.emplace_back(refhold.back().get()); | all_algos.emplace_back(refhold.back().get()); | ||||
} | } | ||||
#if 0 | |||||
#if 1 | |||||
//! As these algos maybe very slow, it will make fastrun search slow, so | //! As these algos maybe very slow, it will make fastrun search slow, so | ||||
//! we disable it, but for the test of strategyhelper, we just keep it. | //! we disable it, but for the test of strategyhelper, we just keep it. | ||||
//! FIXME: I do not know a better way to do it. | //! FIXME: I do not know a better way to do it. | ||||
@@ -6,8 +6,7 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | */ | ||||
#include "src/fallback/conv_bias/winograd/strategy.h" | #include "src/fallback/conv_bias/winograd/strategy.h" | ||||
@@ -31,27 +30,54 @@ void winograd_2x3_1x1_f::filter(const float* filter, | |||||
} | } | ||||
void winograd_2x3_1x1_f::input(const float* input, float* input_transform_buf, | void winograd_2x3_1x1_f::input(const float* input, float* input_transform_buf, | ||||
float* transform_mid_buf, int ih_start, | |||||
int iw_start, size_t IH, size_t IW, size_t IC, | |||||
size_t unit_idx, size_t nr_units_in_tile) { | |||||
::megdnn::winograd::StrategyHelper<float, float, float, float>::input( | |||||
input, input_transform_buf, transform_mid_buf, ih_start, iw_start, | |||||
IH, IW, IC, unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, | |||||
KERNEL_SIZE, {0, 1, -1}, src_dtype); | |||||
float* transform_mid_buf, size_t IH, size_t IW, | |||||
size_t IC, size_t PH, size_t PW, | |||||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = | |||||
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
rep(ic, IC) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
::megdnn::winograd::StrategyHelper<float, float, float, float>:: | |||||
input(input, input_transform_buf, transform_mid_buf, | |||||
ih_start, iw_start, IH, IW, IC, ic, unit_idx, | |||||
nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||||
{0, 1, -1}, src_dtype); | |||||
} | |||||
} | |||||
} | } | ||||
void winograd_2x3_1x1_f::output(const float* output_transform_buf, | void winograd_2x3_1x1_f::output(const float* output_transform_buf, | ||||
const float* bias, float* output, | const float* bias, float* output, | ||||
float* transform_mid_buf, BiasMode bmode, | float* transform_mid_buf, BiasMode bmode, | ||||
NonlineMode nonline_mode, size_t oh_start, | |||||
size_t ow_start, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, size_t unit_idx, | |||||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, | |||||
size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | size_t nr_units_in_tile) { | ||||
::megdnn::winograd::StrategyHelper<float, float, float, float>::output( | |||||
output_transform_buf, bias, output, transform_mid_buf, bmode, | |||||
nonline_mode, oh_start, ow_start, OH, OW, oc_start, oc_end, | |||||
unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||||
{0, 1, -1}, dst_dtype); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
size_t OC = oc_end - oc_start; | |||||
for (size_t oc = oc_start; oc < oc_end; ++oc) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
::megdnn::winograd::StrategyHelper<float, float, float, float>:: | |||||
output(output_transform_buf, bias, output, | |||||
transform_mid_buf, bmode, nonline_mode, oh_start, | |||||
ow_start, OH, OW, OC, oc_start, oc_index, unit_idx, | |||||
nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||||
{0, 1, -1}, dst_dtype); | |||||
} | |||||
} | |||||
} | } | ||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f) | MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f) | ||||
@@ -71,38 +97,70 @@ void winograd_2x3_4x4_f::filter(const float* filter, | |||||
} | } | ||||
void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf, | void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf, | ||||
float* transform_mid_buf, int ih_start, | |||||
int iw_start, size_t IH, size_t IW, size_t IC, | |||||
size_t unit_idx, size_t nr_units_in_tile) { | |||||
::megdnn::winograd::StrategyHelper< | |||||
float, float, float, float, param::ConvBias::Format::NCHW, | |||||
param::MatrixMul::Format::MK4>::input(input, input_transform_buf, | |||||
transform_mid_buf, ih_start, | |||||
iw_start, IH, IW, IC, | |||||
unit_idx, nr_units_in_tile, | |||||
OUTPUT_BLOCK_SIZE, | |||||
KERNEL_SIZE, {0, 1, -1}, | |||||
src_dtype); | |||||
float* transform_mid_buf, size_t IH, size_t IW, | |||||
size_t IC, size_t PH, size_t PW, | |||||
size_t unit_start_idx, size_t nr_units_in_tile) { | |||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = | |||||
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
rep(ic, IC) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
::megdnn::winograd::StrategyHelper< | |||||
float, float, float, float, param::ConvBias::Format::NCHW, | |||||
param::MatrixMul::Format::MK4>::input(input, | |||||
input_transform_buf, | |||||
transform_mid_buf, | |||||
ih_start, iw_start, | |||||
IH, IW, IC, ic, | |||||
unit_idx, | |||||
nr_units_in_tile, | |||||
OUTPUT_BLOCK_SIZE, | |||||
KERNEL_SIZE, | |||||
{0, 1, -1}, | |||||
src_dtype); | |||||
} | |||||
} | |||||
} | } | ||||
void winograd_2x3_4x4_f::output(const float* output_transform_buf, | void winograd_2x3_4x4_f::output(const float* output_transform_buf, | ||||
const float* bias, float* output, | const float* bias, float* output, | ||||
float* transform_mid_buf, BiasMode bmode, | float* transform_mid_buf, BiasMode bmode, | ||||
NonlineMode nonline_mode, size_t oh_start, | |||||
size_t ow_start, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, size_t unit_idx, | |||||
NonlineMode nonline_mode, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, | |||||
size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | size_t nr_units_in_tile) { | ||||
::megdnn::winograd::StrategyHelper< | |||||
float, float, float, float, param::ConvBias::Format::NCHW, | |||||
param::MatrixMul::Format::MK4>::output(output_transform_buf, bias, | |||||
output, transform_mid_buf, | |||||
bmode, nonline_mode, | |||||
oh_start, ow_start, OH, OW, | |||||
oc_start, oc_end, unit_idx, | |||||
nr_units_in_tile, | |||||
OUTPUT_BLOCK_SIZE, | |||||
KERNEL_SIZE, {0, 1, -1}, | |||||
dst_dtype); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
size_t OC = oc_end - oc_start; | |||||
for (size_t oc = oc_start; oc < oc_end; ++oc) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
::megdnn::winograd::StrategyHelper< | |||||
float, float, float, float, param::ConvBias::Format::NCHW, | |||||
param::MatrixMul::Format::MK4>::output(output_transform_buf, | |||||
bias, output, | |||||
transform_mid_buf, | |||||
bmode, nonline_mode, | |||||
oh_start, ow_start, | |||||
OH, OW, OC, oc_start, | |||||
oc_index, unit_idx, | |||||
nr_units_in_tile, | |||||
OUTPUT_BLOCK_SIZE, | |||||
KERNEL_SIZE, | |||||
{0, 1, -1}, | |||||
dst_dtype); | |||||
} | |||||
} | |||||
} | } | ||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_1x1_qs8) | MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_1x1_qs8) | ||||
@@ -119,29 +177,59 @@ void winograd_2x3_1x1_qs8::filter(const int8_t* filter, | |||||
void winograd_2x3_1x1_qs8::input(const int8_t* input, | void winograd_2x3_1x1_qs8::input(const int8_t* input, | ||||
int16_t* input_transform_buf, | int16_t* input_transform_buf, | ||||
int16_t* transform_mid_buf, int ih_start, | |||||
int iw_start, size_t IH, size_t IW, size_t IC, | |||||
size_t unit_idx, size_t nr_units_in_tile) { | |||||
::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>::input( | |||||
input, input_transform_buf, transform_mid_buf, ih_start, iw_start, | |||||
IH, IW, IC, unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, | |||||
KERNEL_SIZE, {0, 1, -1}, src_dtype, 1.0f); | |||||
int16_t* transform_mid_buf, size_t IH, | |||||
size_t IW, size_t IC, size_t PH, size_t PW, | |||||
size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = | |||||
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
rep(ic, IC) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>:: | |||||
input(input, input_transform_buf, transform_mid_buf, | |||||
ih_start, iw_start, IH, IW, IC, ic, unit_idx, | |||||
nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||||
{0, 1, -1}, src_dtype, 1.0f); | |||||
} | |||||
} | |||||
} | } | ||||
void winograd_2x3_1x1_qs8::output(const int* output_transform_buf, | void winograd_2x3_1x1_qs8::output(const int* output_transform_buf, | ||||
const int* bias, int8_t* output, | const int* bias, int8_t* output, | ||||
int* transform_mid_buf, BiasMode bmode, | int* transform_mid_buf, BiasMode bmode, | ||||
NonlineMode nonline_mode, size_t oh_start, | |||||
size_t ow_start, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, | |||||
size_t unit_idx, size_t nr_units_in_tile) { | |||||
NonlineMode nonline_mode, size_t OH, | |||||
size_t OW, size_t oc_start, size_t oc_end, | |||||
size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
float scale_input = src_dtype.param<dtype::QuantizedS8>().scale; | float scale_input = src_dtype.param<dtype::QuantizedS8>().scale; | ||||
float scale_filter = filter_dtype.param<dtype::QuantizedS8>().scale; | float scale_filter = filter_dtype.param<dtype::QuantizedS8>().scale; | ||||
::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>::output( | |||||
output_transform_buf, bias, output, transform_mid_buf, bmode, | |||||
nonline_mode, oh_start, ow_start, OH, OW, oc_start, oc_end, | |||||
unit_idx, nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||||
{0, 1, -1}, dst_dtype, scale_input * scale_filter, 2.0f, 1.0f); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
size_t OC = oc_end - oc_start; | |||||
for (size_t oc = oc_start; oc < oc_end; ++oc) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
::megdnn::winograd::StrategyHelper<int8_t, int8_t, int16_t, int>:: | |||||
output(output_transform_buf, bias, output, | |||||
transform_mid_buf, bmode, nonline_mode, oh_start, | |||||
ow_start, OH, OW, OC, oc_start, oc_index, unit_idx, | |||||
nr_units_in_tile, OUTPUT_BLOCK_SIZE, KERNEL_SIZE, | |||||
{0, 1, -1}, dst_dtype, scale_input * scale_filter, | |||||
2.0f, 1.0f); | |||||
} | |||||
} | |||||
} | } | ||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_qs8) | MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_qs8) | ||||
@@ -162,27 +250,44 @@ void winograd_2x3_8x8_qs8::filter(const int8_t* filter, | |||||
void winograd_2x3_8x8_qs8::input(const int8_t* input, | void winograd_2x3_8x8_qs8::input(const int8_t* input, | ||||
int16_t* input_transform_buf, | int16_t* input_transform_buf, | ||||
int16_t* transform_mid_buf, int ih_start, | |||||
int iw_start, size_t IH, size_t IW, size_t IC, | |||||
size_t unit_idx, size_t nr_units_in_tile) { | |||||
::megdnn::winograd::StrategyHelper< | |||||
int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, | |||||
param::MatrixMul::Format::MK8>::input(input, input_transform_buf, | |||||
transform_mid_buf, ih_start, | |||||
iw_start, IH, IW, IC, | |||||
unit_idx, nr_units_in_tile, | |||||
OUTPUT_BLOCK_SIZE, | |||||
KERNEL_SIZE, {0, 1, -1}, | |||||
src_dtype, 1.0f); | |||||
int16_t* transform_mid_buf, size_t IH, | |||||
size_t IW, size_t IC, size_t PH, size_t PW, | |||||
size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = | |||||
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
rep(ic, IC) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
::megdnn::winograd::StrategyHelper< | |||||
int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, | |||||
param::MatrixMul::Format::MK8>::input(input, | |||||
input_transform_buf, | |||||
transform_mid_buf, | |||||
ih_start, iw_start, | |||||
IH, IW, IC, ic, | |||||
unit_idx, | |||||
nr_units_in_tile, | |||||
OUTPUT_BLOCK_SIZE, | |||||
KERNEL_SIZE, | |||||
{0, 1, -1}, src_dtype, | |||||
1.0f); | |||||
} | |||||
} | |||||
} | } | ||||
void winograd_2x3_8x8_qs8::output(const int* output_transform_buf, | void winograd_2x3_8x8_qs8::output(const int* output_transform_buf, | ||||
const int* bias, int8_t* output, | const int* bias, int8_t* output, | ||||
int* transform_mid_buf, BiasMode bmode, | int* transform_mid_buf, BiasMode bmode, | ||||
NonlineMode nonline_mode, size_t oh_start, | |||||
size_t ow_start, size_t OH, size_t OW, | |||||
size_t oc_start, size_t oc_end, | |||||
size_t unit_idx, size_t nr_units_in_tile) { | |||||
NonlineMode nonline_mode, size_t OH, | |||||
size_t OW, size_t oc_start, size_t oc_end, | |||||
size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
float scale_input = src_dtype.param<dtype::QuantizedS8>().scale; | float scale_input = src_dtype.param<dtype::QuantizedS8>().scale; | ||||
float scale_filter = 0.f; | float scale_filter = 0.f; | ||||
if (filter_dtype.enumv() == DTypeEnum::QuantizedS8) { | if (filter_dtype.enumv() == DTypeEnum::QuantizedS8) { | ||||
@@ -191,19 +296,37 @@ void winograd_2x3_8x8_qs8::output(const int* output_transform_buf, | |||||
megdnn_assert(filter_dtype.enumv() == DTypeEnum::QuantizedS16); | megdnn_assert(filter_dtype.enumv() == DTypeEnum::QuantizedS16); | ||||
scale_filter = filter_dtype.param<dtype::QuantizedS16>().scale; | scale_filter = filter_dtype.param<dtype::QuantizedS16>().scale; | ||||
} | } | ||||
::megdnn::winograd::StrategyHelper< | |||||
int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, | |||||
param::MatrixMul::Format::MK8>::output(output_transform_buf, bias, | |||||
output, transform_mid_buf, | |||||
bmode, nonline_mode, | |||||
oh_start, ow_start, OH, OW, | |||||
oc_start, oc_end, unit_idx, | |||||
nr_units_in_tile, | |||||
OUTPUT_BLOCK_SIZE, | |||||
KERNEL_SIZE, {0, 1, -1}, | |||||
dst_dtype, | |||||
scale_input * scale_filter, | |||||
2.0f, 1.0f); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
size_t OC = oc_end - oc_start; | |||||
for (size_t oc = oc_start; oc < oc_end; ++oc) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
::megdnn::winograd::StrategyHelper< | |||||
int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW, | |||||
param::MatrixMul::Format::MK8>::output(output_transform_buf, | |||||
bias, output, | |||||
transform_mid_buf, | |||||
bmode, nonline_mode, | |||||
oh_start, ow_start, | |||||
OH, OW, OC, oc_start, | |||||
oc_index, unit_idx, | |||||
nr_units_in_tile, | |||||
OUTPUT_BLOCK_SIZE, | |||||
KERNEL_SIZE, | |||||
{0, 1, -1}, | |||||
dst_dtype, | |||||
scale_input * | |||||
scale_filter, | |||||
2.0f, 1.0f); | |||||
} | |||||
} | |||||
} | } | ||||
} // namespace winograd | } // namespace winograd | ||||
@@ -321,17 +321,10 @@ public: | |||||
"nr_tiles_in_unit: %zu TILE_SIZE:%zu", | "nr_tiles_in_unit: %zu TILE_SIZE:%zu", | ||||
nr_tiles_in_unit, unit_tile_size); | nr_tiles_in_unit, unit_tile_size); | ||||
} | } | ||||
rep(unit_idx, nr_tiles_in_unit) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * Strategy::OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * Strategy::OUTPUT_BLOCK_SIZE - PW; | |||||
strategy.input(src_ptr, input_transform_buf, transform_mid_buf, | |||||
ih_start, iw_start, IH, IW, IC, unit_idx, | |||||
nr_tiles_in_unit); | |||||
} | |||||
//! BTdB | |||||
strategy.input(src_ptr, input_transform_buf, transform_mid_buf, | |||||
IH, IW, IC, PH, PW, unit_start_idx, nr_tiles_in_unit); | |||||
rep(i, Strategy::ALPHA) rep(j, Strategy::ALPHA) { | rep(i, Strategy::ALPHA) rep(j, Strategy::ALPHA) { | ||||
if (format == param::MatrixMul::Format::DEFAULT) { | if (format == param::MatrixMul::Format::DEFAULT) { | ||||
matmul_param.A_ptr = | matmul_param.A_ptr = | ||||
@@ -368,22 +361,14 @@ public: | |||||
} | } | ||||
matmul_kern(matmul_param); | matmul_kern(matmul_param); | ||||
} | } | ||||
/* Y = ATmA */ | |||||
rep(unit_idx, nr_tiles_in_unit) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * Strategy::OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * Strategy::OUTPUT_BLOCK_SIZE; | |||||
size_t oc_end_idx = oc_start_idx + nr_oc_in_unit; | |||||
strategy.output( | |||||
output_transform_buf, bias_ptr, dst_ptr, | |||||
reinterpret_cast<output_compute_type*>(transform_mid_buf), | |||||
ncb_param.bias_mode, ncb_param.nonlineMode, oh_start, | |||||
ow_start, OH, OW, oc_start_idx, oc_end_idx, unit_idx, | |||||
nr_tiles_in_unit); | |||||
} | |||||
//! Y = ATmA | |||||
size_t oc_end_idx = oc_start_idx + nr_oc_in_unit; | |||||
strategy.output( | |||||
output_transform_buf, bias_ptr, dst_ptr, | |||||
reinterpret_cast<output_compute_type*>(transform_mid_buf), | |||||
ncb_param.bias_mode, ncb_param.nonlineMode, OH, OW, | |||||
oc_start_idx, oc_end_idx, unit_start_idx, nr_tiles_in_unit); | |||||
}; | }; | ||||
SmallVector<NCBKern> get_kerns( | SmallVector<NCBKern> get_kerns( | ||||
@@ -542,15 +527,16 @@ public: | |||||
size_t IC, size_t oc_start, size_t oc_end); \ | size_t IC, size_t oc_start, size_t oc_end); \ | ||||
void input(const stype* input, \ | void input(const stype* input, \ | ||||
input_filter_compute_type* input_transform_buf, \ | input_filter_compute_type* input_transform_buf, \ | ||||
input_filter_compute_type* transform_mid_buf, int ih_start, \ | |||||
int iw_start, size_t IH, size_t IW, size_t IC, \ | |||||
size_t unit_idx, size_t nr_tiles_in_unit); \ | |||||
input_filter_compute_type* transform_mid_buf, \ | |||||
size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, \ | |||||
size_t unit_start_idx, size_t nr_tiles_in_unit); \ | |||||
void output(const output_compute_type* output_transform_buf, \ | void output(const output_compute_type* output_transform_buf, \ | ||||
const output_compute_type* bias, dst_type* output, \ | const output_compute_type* bias, dst_type* output, \ | ||||
output_compute_type* transform_mid_buf, BiasMode bmode, \ | output_compute_type* transform_mid_buf, BiasMode bmode, \ | ||||
NonlineMode nonline_mode, size_t oh_start, \ | |||||
size_t ow_start, size_t OH, size_t OW, size_t oc_start, \ | |||||
size_t oc_end, size_t unit_idx, size_t nr_tiles_in_unit); \ | |||||
NonlineMode nonline_mode, size_t OH, size_t OW, \ | |||||
size_t oc_start, size_t oc_end, size_t unit_start_idx, \ | |||||
size_t nr_tiles_in_unit); \ | |||||
}; | }; | ||||
#define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ | #define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ | ||||
@@ -274,31 +274,43 @@ void winograd_nchw88_2x3_8x8_f::filter(const float* filter, | |||||
transform_mid_buf, OC, IC, oc_start, | transform_mid_buf, OC, IC, oc_start, | ||||
oc_end); | oc_end); | ||||
} | } | ||||
void winograd_nchw88_2x3_8x8_f::input(const float* input, | void winograd_nchw88_2x3_8x8_f::input(const float* input, | ||||
float* input_transform_buf, | float* input_transform_buf, | ||||
float* transform_mid_buf, int ih_start, | |||||
int iw_start, size_t IH, size_t IW, | |||||
size_t IC, size_t unit_idx, | |||||
float* transform_mid_buf, size_t IH, | |||||
size_t IW, size_t IC, size_t PH, | |||||
size_t PW, size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | size_t nr_units_in_tile) { | ||||
megdnn_assert(IC % 8 == 0); | megdnn_assert(IC % 8 == 0); | ||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
float* patch = transform_mid_buf; | float* patch = transform_mid_buf; | ||||
float* patchT = transform_mid_buf + 8 * alpha * alpha; | float* patchT = transform_mid_buf + 8 * alpha * alpha; | ||||
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) && | |||||
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) { | |||||
for (size_t ic = 0; ic < IC; ic += 8) { | |||||
InputTransform2X3_NCHW88::prepare<true>( | |||||
input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); | |||||
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, | |||||
unit_idx, nr_units_in_tile, ic, | |||||
IC); | |||||
} | |||||
} else { | |||||
for (size_t ic = 0; ic < IC; ic += 8) { | |||||
InputTransform2X3_NCHW88::prepare<false>(input, patch, patchT, ih_start, | |||||
iw_start, IH, IW, ic, IC); | |||||
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, | |||||
unit_idx, nr_units_in_tile, ic, | |||||
IC); | |||||
for (size_t ic = 0; ic < IC; ic += 8) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) && | |||||
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) { | |||||
InputTransform2X3_NCHW88::prepare<true>(input, patch, patchT, | |||||
ih_start, iw_start, IH, | |||||
IW, ic, IC); | |||||
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, | |||||
unit_idx, nr_units_in_tile, | |||||
ic, IC); | |||||
} else { | |||||
InputTransform2X3_NCHW88::prepare<false>(input, patch, patchT, | |||||
ih_start, iw_start, IH, | |||||
IW, ic, IC); | |||||
InputTransform2X3_NCHW88::transform(patchT, input_transform_buf, | |||||
unit_idx, nr_units_in_tile, | |||||
ic, IC); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -338,32 +338,43 @@ void winograd_nchw88_6x3_8x8_f::filter(const float* filter, | |||||
transform_mid_buf, OC, IC, oc_start, | transform_mid_buf, OC, IC, oc_start, | ||||
oc_end); | oc_end); | ||||
} | } | ||||
void winograd_nchw88_6x3_8x8_f::input(const float* input, | void winograd_nchw88_6x3_8x8_f::input(const float* input, | ||||
float* input_transform_buf, | float* input_transform_buf, | ||||
float* transform_mid_buf, int ih_start, | |||||
int iw_start, size_t IH, size_t IW, | |||||
size_t IC, size_t unit_idx, | |||||
float* transform_mid_buf, size_t IH, | |||||
size_t IW, size_t IC, size_t PH, | |||||
size_t PW, size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | size_t nr_units_in_tile) { | ||||
megdnn_assert(IC % 8 == 0); | megdnn_assert(IC % 8 == 0); | ||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
float* patch = transform_mid_buf; | float* patch = transform_mid_buf; | ||||
float* patchT = transform_mid_buf + 8 * alpha * alpha; | float* patchT = transform_mid_buf + 8 * alpha * alpha; | ||||
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) && | |||||
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) { | |||||
for (size_t ic = 0; ic < IC; ic += 8) { | |||||
InputTransform6X3_NCHW88::prepare<true>( | |||||
input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); | |||||
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, | |||||
unit_idx, nr_units_in_tile, ic, | |||||
IC); | |||||
} | |||||
} else { | |||||
for (size_t ic = 0; ic < IC; ic += 8) { | |||||
InputTransform6X3_NCHW88::prepare<false>(input, patch, patchT, ih_start, | |||||
iw_start, IH, IW, ic, IC); | |||||
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, | |||||
unit_idx, nr_units_in_tile, ic, | |||||
IC); | |||||
for (size_t ic = 0; ic < IC; ic += 8) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) && | |||||
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) { | |||||
InputTransform6X3_NCHW88::prepare<true>(input, patch, patchT, | |||||
ih_start, iw_start, IH, | |||||
IW, ic, IC); | |||||
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, | |||||
unit_idx, nr_units_in_tile, | |||||
ic, IC); | |||||
} else { | |||||
InputTransform6X3_NCHW88::prepare<false>(input, patch, patchT, | |||||
ih_start, iw_start, IH, | |||||
IW, ic, IC); | |||||
InputTransform6X3_NCHW88::transform(patchT, input_transform_buf, | |||||
unit_idx, nr_units_in_tile, | |||||
ic, IC); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||