|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546 |
- /**
- * \file dnn/src/fallback/convolution/img2col_helper.h
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
- #include "src/common/utils.h"
- namespace {
-
- template <bool is_xcorr, typename dtype>
- void img2col_stride(const dtype* __restrict src, dtype* __restrict dst,
- const int OC, const int OH, const int OW, const int IC,
- const int IH, const int IW, const int FH, const int FW,
- const int SH, const int SW) {
- megdnn_ignore(OC);
- size_t i = 0;
- rep(ic, IC) {
- rep(fh, FH) {
- rep(fw, FW) {
- rep(oh, OH) {
- rep(ow, OW) {
- int fh2, fw2;
- if (is_xcorr) {
- fh2 = fh;
- fw2 = fw;
- } else {
- fh2 = FH - fh - 1;
- fw2 = FW - fw - 1;
- }
- dst[i++] = src[ic * IH * IW + (oh * SH + fh2) * IW +
- (ow * SW + fw2)];
- }
- }
- }
- }
- }
- }
-
-
- //!add for im2col matmul multithread
- //
- template <bool is_xcorr, typename dtype>
- void img2col_stride_nchw4(const dtype* __restrict src, dtype* __restrict dst,
- const int OC, const int OH, const int OW, const int IC,
- const int IH, const int IW, const int FH, const int FW,
- const int SH, const int SW, const int cur_index,
- const int block_size) {
- MEGDNN_MARK_USED_VAR(OC);
- MEGDNN_MARK_USED_VAR(OH);
- int start_h = cur_index / OW;
- int cur_remain_w = cur_index % OW;
- int end_h = (cur_index + block_size) / OW;
- int end_remain_w = (cur_index + block_size) % OW;
- bool same_line = false;
- if (start_h == end_h) {
- same_line = true;
- }
-
- size_t newIC = IC / 4;
- size_t i = 0;
- if (sizeof(dtype) != 1) {
- if (same_line) {
- rep(ic, newIC) {
- rep(fh, FH) {
- rep(fw, FW) {
- int fh2, fw2;
- if (is_xcorr) {
- fh2 = fh;
- fw2 = fw;
- } else {
- fh2 = FH - fh - 1;
- fw2 = FW - fw - 1;
- }
-
- for (int w = cur_remain_w; w < end_remain_w; w++) {
- size_t index = 4 * (ic * IH * IW +
- (start_h * SH + fh2) * IW +
- (w * SW + fw2));
- dst[i++] = src[index];
- dst[i++] = src[index + 1];
- dst[i++] = src[index + 2];
- dst[i++] = src[index + 3];
- }
- }
- }
- }
- } else {
- rep(ic, newIC) {
- rep(fh, FH) {
- rep(fw, FW) {
- int fh2, fw2;
- if (is_xcorr) {
- fh2 = fh;
- fw2 = fw;
- } else {
- fh2 = FH - fh - 1;
- fw2 = FW - fw - 1;
- }
-
- for (int w = cur_remain_w; w < OW; w++) {
- size_t index =4 * (ic * IH * IW +
- (start_h * SH + fh2) * IW +
- (w * SW + fw2));
- dst[i++] = src[index + 0];
- dst[i++] = src[index + 1];
- dst[i++] = src[index + 2];
- dst[i++] = src[index + 3];
- }
-
- for (int h = start_h + 1; h < end_h; h++) {
- rep(ow, OW) {
- size_t index = 4 * (ic * IH * IW +
- (h * SH + fh2) * IW +
- (ow * SW + fw2));
- dst[i++] = src[index + 0];
- dst[i++] = src[index + 1];
- dst[i++] = src[index + 2];
- dst[i++] = src[index + 3];
- }
- }
-
- for (int w = 0; w < end_remain_w; w++) {
- size_t index = 4 * (ic * IH * IW +
- (end_h * SH + fh2) * IW +
- (w * SW + fw2));
- dst[i++] = src[index + 0];
- dst[i++] = src[index + 1];
- dst[i++] = src[index + 2];
- dst[i++] = src[index + 3];
- }
- }
- }
- }
- }
- } else {
- uint32_t* output = nullptr;
- const uint32_t* uint32_src =
- static_cast<const uint32_t*>(static_cast<const void*>(src));
- output = static_cast<uint32_t*>(static_cast<void*>(dst));
- if (same_line) {
- rep(ic, newIC) {
- rep(fh, FH) {
- rep(fw, FW) {
- int fh2, fw2;
- if (is_xcorr) {
- fh2 = fh;
- fw2 = fw;
- } else {
- fh2 = FH - fh - 1;
- fw2 = FW - fw - 1;
- }
-
- size_t index =
- (ic * IH * IW + (start_h * SH + fh2) * IW +
- (cur_remain_w * SW + fw2));
- for (int w = cur_remain_w; w < end_remain_w; w++) {
- output[i++] = uint32_src[index];
- index += SW;
- }
- }
- }
- }
- } else {
- rep(ic, newIC) {
- rep(fh, FH) {
- rep(fw, FW) {
- int fh2, fw2;
- if (is_xcorr) {
- fh2 = fh;
- fw2 = fw;
- } else {
- fh2 = FH - fh - 1;
- fw2 = FW - fw - 1;
- }
-
- size_t index = ic * IH * IW +
- (start_h * SH + fh2) * IW +
- cur_remain_w * SW + fw2;
- for (int w = cur_remain_w; w < OW; w++) {
- output[i++] = uint32_src[index];
- index += SW;
- }
-
- for (int h = start_h + 1; h < end_h; h++) {
- index = ic * IH * IW + (h * SH + fh2) * IW + fw2;
- rep(ow, OW) {
- output[i++] = uint32_src[index];
- index += SW;
- }
- }
-
- index = ic * IH * IW + (end_h * SH + fh2) * IW + fw2;
- for (int w = 0; w < end_remain_w; w++) {
- output[i++] = uint32_src[index];
- index += SW;
- }
- }
- }
- }
- }
- }
- }
-
- template <bool is_xcorr, typename dtype>
- void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst,
- const int OC, const int OH, const int OW, const int IC,
- const int IH, const int IW, const int FH, const int FW,
- const int SH, const int SW, const int cur_index,
- const int block_size) {
- MEGDNN_MARK_USED_VAR(OC);
- MEGDNN_MARK_USED_VAR(OH);
- MEGDNN_MARK_USED_VAR(SH);
- MEGDNN_MARK_USED_VAR(SW);
- int start_h = cur_index / OW;
- int cur_remain_w = cur_index % OW;
- int end_h = (cur_index + block_size) / OW;
- int end_remain_w = (cur_index + block_size) % OW;
- bool same_line = false;
- if (start_h == end_h) {
- same_line = true;
- }
- size_t newIC = IC / 4;
- size_t i = 0;
- if (sizeof(dtype) != 1) {
- if (same_line) {
- rep(ic, newIC) {
- rep(fh, FH) {
- rep(fw, FW) {
- int fh2, fw2;
- if (is_xcorr) {
- fh2 = fh;
- fw2 = fw;
- } else {
- fh2 = FH - fh - 1;
- fw2 = FW - fw - 1;
- }
-
- for (int w = cur_remain_w; w < end_remain_w; w++) {
- size_t index =
- 4 * (ic * IH * IW + (start_h + fh2) * IW +
- (w + fw2));
- dst[i++] = src[index];
- dst[i++] = src[index + 1];
- dst[i++] = src[index + 2];
- dst[i++] = src[index + 3];
- }
- }
- }
- }
- } else {
- rep(ic, newIC) {
- rep(fh, FH) {
- rep(fw, FW) {
- int fh2, fw2;
- if (is_xcorr) {
- fh2 = fh;
- fw2 = fw;
- } else {
- fh2 = FH - fh - 1;
- fw2 = FW - fw - 1;
- }
-
- for (int w = cur_remain_w; w < OW; w++) {
- size_t index = ic * IH * IW + (start_h + fh2) * IW +
- (w + fw2);
- dst[i++] = src[4 * index];
- dst[i++] = src[4 * index + 1];
- dst[i++] = src[4 * index + 2];
- dst[i++] = src[4 * index + 3];
- }
-
- for (int h = start_h + 1; h < end_h; h++) {
- rep(ow, OW) {
- size_t index =
- 4 * (ic * IH * IW + (h + fh2) * IW +
- (ow + fw2));
- dst[i++] = src[index + 0];
- dst[i++] = src[index + 1];
- dst[i++] = src[index + 2];
- dst[i++] = src[index + 3];
- }
- }
-
- for (int w = 0; w < end_remain_w; w++) {
- size_t index = 4 * (ic * IH * IW +
- (end_h + fh2) * IW + (w + fw2));
- dst[i++] = src[index + 0];
- dst[i++] = src[index + 1];
- dst[i++] = src[index + 2];
- dst[i++] = src[index + 3];
- }
- }
- }
- }
- }
- } else {
- uint32_t* output = nullptr;
- const uint32_t* uint32_src =
- static_cast<const uint32_t*>(static_cast<const void*>(src));
- output = static_cast<uint32_t*>(static_cast<void*>(dst));
- if (same_line) {
- rep(ic, newIC) {
- rep(fh, FH) {
- rep(fw, FW) {
- int fh2, fw2;
- if (is_xcorr) {
- fh2 = fh;
- fw2 = fw;
- } else {
- fh2 = FH - fh - 1;
- fw2 = FW - fw - 1;
- }
- for (int w = cur_remain_w; w < end_remain_w; w++) {
- size_t index = (ic * IH * IW +
- (start_h + fh2) * IW + (w + fw2));
- output[i++] = uint32_src[index];
- }
- }
- }
- }
- } else {
- rep(ic, newIC) {
- rep(fh, FH) {
- rep(fw, FW) {
- int fh2, fw2;
- if (is_xcorr) {
- fh2 = fh;
- fw2 = fw;
- } else {
- fh2 = FH - fh - 1;
- fw2 = FW - fw - 1;
- }
-
- for (int w = cur_remain_w; w < OW; w++) {
- size_t index = ic * IH * IW + (start_h + fh2) * IW +
- (w + fw2);
- output[i++] = uint32_src[index];
- }
-
- for (int h = start_h + 1; h < end_h; h++) {
- rep(ow, OW) {
- size_t index = (ic * IH * IW + (h + fh2) * IW +
- (ow + fw2));
- output[i++] = uint32_src[index];
- }
- }
-
- for (int w = 0; w < end_remain_w; w++) {
- size_t index = (ic * IH * IW + (end_h + fh2) * IW +
- (w + fw2));
- output[i++] = uint32_src[index];
- }
- }
- }
- }
- }
- }
- }
-
- template <bool is_xcorr, typename dtype>
- void img2col_stride(const dtype* __restrict src, dtype* __restrict dst,
- const int OC, const int OH, const int OW, const int IC,
- const int IH, const int IW, const int FH, const int FW,
- const int SH, const int SW, const int cur_index,
- const int block_size) {
- MEGDNN_MARK_USED_VAR(OC);
- MEGDNN_MARK_USED_VAR(OH);
- int start_h = cur_index / OW;
- int cur_remain_w = cur_index % OW;
- int end_h = (cur_index + block_size) / OW;
- int end_remain_w = (cur_index + block_size) % OW;
-
- bool same_line = false;
- if (start_h == end_h) {
- same_line = true;
- }
-
- size_t i = 0;
- if (same_line) {
- rep(ic, IC) {
- rep(fh, FH) {
- rep(fw, FW) {
- int fh2, fw2;
- if (is_xcorr) {
- fh2 = fh;
- fw2 = fw;
- } else {
- fh2 = FH - fh - 1;
- fw2 = FW - fw - 1;
- }
-
- for (int w = cur_remain_w; w < end_remain_w; w++) {
- dst[i++] =
- src[ic * IH * IW + (start_h * SH + fh2) * IW +
- (w * SW + fw2)];
- }
- }
- }
- }
- } else {
- rep(ic, IC) {
- rep(fh, FH) {
- rep(fw, FW) {
- int fh2, fw2;
- if (is_xcorr) {
- fh2 = fh;
- fw2 = fw;
- } else {
- fh2 = FH - fh - 1;
- fw2 = FW - fw - 1;
- }
-
- for (int w = cur_remain_w; w < OW; w++) {
- dst[i++] =
- src[ic * IH * IW + (start_h * SH + fh2) * IW +
- (w * SW + fw2)];
- }
-
- for (int h = start_h + 1; h < end_h; h++) {
- rep(ow, OW) {
- dst[i++] = src[ic * IH * IW + (h * SH + fh2) * IW +
- (ow * SW + fw2)];
- }
- }
-
- for (int w = 0; w < end_remain_w; w++) {
- dst[i++] = src[ic * IH * IW + (end_h * SH + fh2) * IW +
- (w * SW + fw2)];
- }
- }
- }
- }
- }
- }
-
- template <bool is_xcorr, typename dtype>
- void img2col(const dtype* __restrict src, dtype* __restrict dst, const int OC,
- const int OH, const int OW, const int IC, const int IH,
- const int IW, const int FH, const int FW, const int cur_index,
- const int block_size) {
- MEGDNN_MARK_USED_VAR(OC);
- MEGDNN_MARK_USED_VAR(OH);
- int start_h = cur_index / OW;
- int cur_remain_w = cur_index % OW;
- int end_h = (cur_index + block_size) / OW;
- int end_remain_w = (cur_index + block_size) % OW;
-
- bool same_line = false;
- if (start_h == end_h) {
- same_line = true;
- }
- int i = 0;
- if (same_line) {
- rep(ic, IC) {
- rep(fh, FH) {
- rep(fw, FW) {
- int fh2, fw2;
- if (is_xcorr) {
- fh2 = fh;
- fw2 = fw;
- } else {
- fh2 = FH - fh - 1;
- fw2 = FW - fw - 1;
- }
- for (int w = cur_remain_w; w < end_remain_w; w++) {
- dst[i++] = src[ic * IH * IW + (start_h + fh2) * IW +
- (w + fw2)];
- }
- }
- }
- }
- } else {
- rep(ic, IC) {
- rep(fh, FH) {
- rep(fw, FW) {
- int fh2, fw2;
- if (is_xcorr) {
- fh2 = fh;
- fw2 = fw;
- } else {
- fh2 = FH - fh - 1;
- fw2 = FW - fw - 1;
- }
- for (int w = cur_remain_w; w < OW; w++) {
- dst[i++] = src[ic * IH * IW + (start_h + fh2) * IW +
- (w + fw2)];
- }
-
- for (int h = start_h + 1; h < end_h; h++) {
- rep(ow, OW) {
- dst[i++] = src[ic * IH * IW + (h + fh2) * IW +
- (ow + fw2)];
- }
- }
-
- for (int w = 0; w < end_remain_w; w++) {
- dst[i++] = src[ic * IH * IW + (end_h + fh2) * IW +
- (w + fw2)];
- }
- }
- }
- }
- }
- }
-
- template <bool is_xcorr, typename dtype>
- void img2col(const dtype* src, dtype* dst, size_t /* OC */, size_t OH,
- size_t OW, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW) {
- size_t offset = (4 - OW % 4) % 4;
- size_t i = 0;
- rep(ic, IC) {
- rep(fh, FH) {
- rep(fw, FW) {
- rep(oh, OH) {
- size_t ow = 0;
- for (; ow < OW; ow += 4) {
- size_t fh2, fw2;
- if (is_xcorr) {
- fh2 = fh;
- fw2 = fw;
- } else {
- fh2 = FH - fh - 1;
- fw2 = FW - fw - 1;
- }
- dst[i++] = src[ic * IH * IW + (oh + fh2) * IW +
- (ow + fw2) + 0];
- dst[i++] = src[ic * IH * IW + (oh + fh2) * IW +
- (ow + fw2) + 1];
- dst[i++] = src[ic * IH * IW + (oh + fh2) * IW +
- (ow + fw2) + 2];
- dst[i++] = src[ic * IH * IW + (oh + fh2) * IW +
- (ow + fw2) + 3];
- }
- i -= offset;
- }
- }
- }
- }
- }
- } // anonymous namespace
-
- // vim: syntax=cpp.doxygen
|