You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

padding.cu 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. /**
  2. * \file dnn/src/cuda/padding/padding.cu
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include <algorithm>
  13. #include <cstring>
  14. #include <iostream>
  15. #include "megdnn/basic_types.h"
  16. #include "padding.cuh"
  17. #include "src/cuda/int_fastdiv.cuh"
  18. #include "src/cuda/query_blocksize.cuh"
  19. namespace megdnn {
  20. namespace cuda {
  21. namespace padding {
  22. struct ShapeParams {
  23. size_t src_shape[MEGDNN_MAX_NDIM];
  24. size_t dst_shape[MEGDNN_MAX_NDIM];
  25. Uint32Fastdiv src_stride[MEGDNN_MAX_NDIM];
  26. Uint32Fastdiv dst_stride[MEGDNN_MAX_NDIM];
  27. size_t offsets[MEGDNN_MAX_NDIM * 2];
  28. };
  29. template <typename T>
  30. __global__ void paddingConst_kernel(
  31. const size_t ndim, const size_t total_out_nr, const T* const src, T* const dst,
  32. ShapeParams params, const float_t padding_val) {
  33. KERN_FOR(out_index, total_out_nr) {
  34. bool in_src_valid_area = true;
  35. size_t in_index = 0;
  36. size_t out_index_tmp = out_index;
  37. for (size_t dim = 0; dim <= ndim - 1; ++dim) {
  38. Uint32Fastdiv dst_stride = params.dst_stride[dim],
  39. src_stride = params.src_stride[dim];
  40. size_t src_shape = params.src_shape[dim];
  41. size_t offset = params.offsets[dim * 2];
  42. size_t dim_index = out_index_tmp / dst_stride;
  43. in_src_valid_area &=
  44. (dim_index >= offset && dim_index < offset + src_shape);
  45. if (!in_src_valid_area)
  46. break;
  47. out_index_tmp -= dim_index * dst_stride.divisor();
  48. in_index += (dim_index - offset) * src_stride.divisor();
  49. /*
  50. size_t dim_index = out_index_tmp / params.dst_stride[dim];
  51. out_index_tmp -= dim_index * params.dst_stride[dim].divisor();
  52. in_src_valid_area &= (dim_index >= params.offsets[dim * 2] &&
  53. dim_index < params.offsets[dim * 2] +
  54. params.src_shape[dim]);
  55. in_index += (dim_index - params.offsets[dim * 2]) *
  56. params.src_stride[dim].divisor();
  57. */
  58. }
  59. dst[out_index] =
  60. in_src_valid_area ? src[in_index] : static_cast<T>(padding_val);
  61. }
  62. }
  63. template <typename T>
  64. __global__ void paddingReplicate_kernel(
  65. const size_t ndim, const size_t total_out_nr, const T* const src, T* const dst,
  66. ShapeParams params, const float_t) {
  67. KERN_FOR(out_index, total_out_nr) {
  68. size_t in_index = 0;
  69. size_t out_index_tmp = out_index;
  70. for (size_t dim = 0; dim <= ndim - 1; ++dim) {
  71. size_t dim_index = out_index_tmp / params.dst_stride[dim];
  72. out_index_tmp -= dim_index * params.dst_stride[dim].divisor();
  73. dim_index = (size_t)llmin(
  74. (long long)params.src_shape[dim] - 1,
  75. llmax((long long)dim_index - (long long)params.offsets[dim * 2],
  76. (long long)0));
  77. in_index += dim_index * params.src_stride[dim].divisor();
  78. }
  79. dst[out_index] = src[in_index];
  80. }
  81. }
  82. template <typename T>
  83. __global__ void paddingReflect_kernel(
  84. const size_t ndim, const size_t total_out_nr, const T* const src, T* const dst,
  85. ShapeParams params, const float_t) {
  86. KERN_FOR(out_index, total_out_nr) {
  87. size_t in_index = 0;
  88. size_t out_index_tmp = out_index;
  89. for (size_t dim = 0; dim <= ndim - 1; ++dim) {
  90. long long dim_index = out_index_tmp / params.dst_stride[dim];
  91. out_index_tmp -= dim_index * params.dst_stride[dim].divisor();
  92. dim_index -= (long long)params.offsets[dim * 2];
  93. dim_index = llmax(dim_index, -dim_index);
  94. dim_index = llmin(
  95. dim_index, 2 * (long long)params.src_shape[dim] - dim_index - 2);
  96. in_index += size_t(dim_index) * (size_t)params.src_stride[dim].divisor();
  97. }
  98. dst[out_index] = src[in_index];
  99. }
  100. }
  101. template <typename T>
  102. __global__ void paddingConstBackward_kernel(
  103. const size_t ndim, const size_t total_in_nr, const T* const src, T* const dst,
  104. ShapeParams params) {
  105. KERN_FOR(in_index, total_in_nr) {
  106. bool in_dst_valid_area = true;
  107. size_t out_index = 0;
  108. size_t in_index_tmp = in_index;
  109. for (size_t dim = 0; dim <= ndim - 1; ++dim) {
  110. size_t dim_index = in_index_tmp / params.src_stride[dim];
  111. in_index_tmp -= dim_index * params.src_stride[dim].divisor();
  112. in_dst_valid_area &=
  113. (dim_index >= params.offsets[dim * 2] &&
  114. dim_index < params.offsets[dim * 2] + params.dst_shape[dim]);
  115. out_index += (dim_index - params.offsets[dim * 2]) *
  116. params.dst_stride[dim].divisor();
  117. }
  118. if (in_dst_valid_area) {
  119. dst[out_index] = src[in_index];
  120. }
  121. }
  122. }
  123. template <typename T>
  124. __global__ void paddingReplicateBackward_kernel(
  125. const size_t ndim, const size_t total_in_nr, const T* const src, T* const dst,
  126. ShapeParams params) {
  127. KERN_FOR(in_index, total_in_nr) {
  128. size_t out_index = 0;
  129. size_t in_index_tmp = in_index;
  130. for (size_t dim = 0; dim <= ndim - 1; ++dim) {
  131. size_t dim_index = in_index_tmp / params.src_stride[dim];
  132. in_index_tmp -= dim_index * params.src_stride[dim].divisor();
  133. dim_index = (size_t)llmin(
  134. (long long)params.dst_shape[dim] - 1,
  135. llmax((long long)dim_index - (long long)params.offsets[dim * 2],
  136. (long long)0));
  137. out_index += dim_index * params.dst_stride[dim].divisor();
  138. }
  139. atomic_add(&dst[out_index], src[in_index]);
  140. }
  141. }
  142. template <typename T>
  143. __global__ void paddingReflectBackward_kernel(
  144. const size_t ndim, const size_t total_in_nr, const T* const src, T* const dst,
  145. ShapeParams params) {
  146. KERN_FOR(in_index, total_in_nr) {
  147. size_t out_index = 0;
  148. size_t in_index_tmp = in_index;
  149. for (size_t dim = 0; dim <= ndim - 1; ++dim) {
  150. long long dim_index = in_index_tmp / params.src_stride[dim];
  151. in_index_tmp -= dim_index * params.src_stride[dim].divisor();
  152. dim_index -= (long long)params.offsets[dim * 2];
  153. dim_index = llmax(dim_index, -dim_index);
  154. dim_index = llmin(
  155. dim_index, 2 * (long long)params.dst_shape[dim] - dim_index - 2);
  156. out_index += size_t(dim_index) * (size_t)params.dst_stride[dim].divisor();
  157. }
  158. atomic_add(&dst[out_index], src[in_index]);
  159. }
  160. }
  161. template <typename T>
  162. void padding_forward_proxy(
  163. const TensorND& src, const TensorND& dst, size_t offsets[MEGDNN_MAX_NDIM * 2],
  164. uint32_t mode, const float_t padding_val, cudaStream_t stream) {
  165. ShapeParams params;
  166. for (size_t i = 0; i < src.layout.ndim; ++i) {
  167. params.src_shape[i] = src.layout.shape[i];
  168. params.dst_shape[i] = dst.layout.shape[i];
  169. params.src_stride[i] = src.layout.stride[i];
  170. params.dst_stride[i] = dst.layout.stride[i];
  171. params.offsets[i * 2] = offsets[i * 2];
  172. params.offsets[i * 2 + 1] = offsets[i * 2 + 1];
  173. }
  174. void (*fwd_kern)(
  175. const size_t, const size_t, const T* const, T* const, ShapeParams,
  176. const float_t);
  177. switch (mode) {
  178. case param_enumv::Padding::PaddingMode::CONSTANT:
  179. fwd_kern = paddingConst_kernel<T>;
  180. break;
  181. case param_enumv::Padding::PaddingMode::REPLICATE:
  182. fwd_kern = paddingReplicate_kernel<T>;
  183. break;
  184. case param_enumv::Padding::PaddingMode::REFLECT:
  185. fwd_kern = paddingReflect_kernel<T>;
  186. break;
  187. default:
  188. megdnn_assert(false, "invalid padding mode");
  189. }
  190. size_t total_nr = dst.layout.total_nr_elems();
  191. uint32_t nr_threads = query_blocksize_for_kernel(fwd_kern);
  192. dim3 threads(nr_threads);
  193. dim3 blocks(DIVUP(total_nr, nr_threads));
  194. fwd_kern<<<blocks, threads, 0, stream>>>(
  195. src.layout.ndim, total_nr, src.ptr<T>(), dst.ptr<T>(), params, padding_val);
  196. after_kernel_launch();
  197. }
  198. template <typename T>
  199. void padding_backward_proxy(
  200. const TensorND& src, const TensorND& dst, size_t offsets[MEGDNN_MAX_NDIM * 2],
  201. uint32_t mode, cudaStream_t stream) {
  202. ShapeParams params;
  203. for (size_t i = 0; i < src.layout.ndim; ++i) {
  204. params.src_shape[i] = src.layout.shape[i];
  205. params.dst_shape[i] = dst.layout.shape[i];
  206. params.src_stride[i] = src.layout.stride[i];
  207. params.dst_stride[i] = dst.layout.stride[i];
  208. params.offsets[i * 2] = offsets[i * 2];
  209. params.offsets[i * 2 + 1] = offsets[i * 2 + 1];
  210. }
  211. cudaMemset(dst.raw_ptr(), 0, dst.layout.access_bytes());
  212. void (*bwd_kern)(const size_t, const size_t, const T* const, T* const, ShapeParams);
  213. switch (mode) {
  214. case param_enumv::Padding::PaddingMode::CONSTANT:
  215. bwd_kern = paddingConstBackward_kernel<T>;
  216. break;
  217. case param_enumv::Padding::PaddingMode::REPLICATE:
  218. bwd_kern = paddingReplicateBackward_kernel<T>;
  219. break;
  220. case param_enumv::Padding::PaddingMode::REFLECT:
  221. bwd_kern = paddingReflectBackward_kernel<T>;
  222. break;
  223. default:
  224. megdnn_assert(false, "invalid padding mode");
  225. }
  226. size_t total_nr = src.layout.total_nr_elems();
  227. uint32_t nr_threads = query_blocksize_for_kernel(bwd_kern);
  228. dim3 threads(nr_threads);
  229. dim3 blocks(DIVUP(total_nr, nr_threads));
  230. bwd_kern<<<blocks, threads, 0, stream>>>(
  231. src.layout.ndim, total_nr, src.ptr<T>(), dst.ptr<T>(), params);
  232. after_kernel_launch();
  233. }
  234. #define INST(T) \
  235. template void padding_forward_proxy<T>( \
  236. const TensorND& src, const TensorND& dst, \
  237. size_t offsets[MEGDNN_MAX_NDIM * 2], uint32_t mode, \
  238. const float_t padding_val, cudaStream_t stream);
  239. #define cb(DType) INST(typename DTypeTrait<DType>::ctype)
  240. MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
  241. MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
  242. #undef cb
  243. #undef INST
  244. #define INST(T) \
  245. template void padding_backward_proxy<T>( \
  246. const TensorND& src, const TensorND& dst, \
  247. size_t offsets[MEGDNN_MAX_NDIM * 2], uint32_t mode, cudaStream_t stream);
  248. #define cb(DType) INST(typename DTypeTrait<DType>::ctype)
  249. MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
  250. #undef cb
  251. #undef INST
  252. } // namespace padding
  253. } // namespace cuda
  254. } // namespace megdnn