You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kern.cuh 2.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. /**
  2. * \file dnn/src/cuda/convolution/chanwise/kern.cuh
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "src/cuda/utils.cuh"
  13. #include <cuda_runtime.h>
  14. #include <stdint.h>
  15. #if MEGDNN_CC_HOST
  16. #include "src/cuda/convolution/helper.h"
  17. #endif
  18. namespace megdnn {
  19. namespace cuda {
  20. namespace convolution {
  21. namespace chanwise {
  22. struct Param {
  23. uint32_t batch, src_chl, src_h, src_w, chl_mul, flt_h, flt_w, out_h, out_w, pad_h,
  24. pad_w, stride_h, stride_w, dilation_h, dilation_w;
  25. #if MEGDNN_CC_HOST
  26. static Param from_fwd_args(const ForwardSizeArgs& args) {
  27. #define U(v) static_cast<uint32_t>(v)
  28. auto&& src = args.src_layout->shape;
  29. auto&& dst = args.dst_layout->shape;
  30. auto&& fm = args.filter_meta;
  31. size_t c_pos, hw_pos;
  32. if (fm.format == param::Convolution::Format::NCHW) {
  33. c_pos = 1;
  34. hw_pos = 2;
  35. } else {
  36. c_pos = 3;
  37. hw_pos = 1;
  38. }
  39. return {
  40. U(src[0]), U(src[c_pos]), U(src[hw_pos]),
  41. U(src[hw_pos + 1]), U(fm.ocpg), U(fm.spatial[0]),
  42. U(fm.spatial[1]), U(dst[hw_pos]), U(dst[hw_pos + 1]),
  43. U(fm.padding[0]), U(fm.padding[1]), U(fm.stride[0]),
  44. U(fm.stride[1]), U(fm.dilation[0]), U(fm.dilation[1]),
  45. };
  46. #undef U
  47. }
  48. #endif
  49. };
  50. template <typename T>
  51. void run_bwd_data_small(
  52. T* src_grad, const T* dst_grad, const T* flt, const Param& param,
  53. cudaStream_t stream);
  54. template <typename T>
  55. void run_bwd_data(
  56. T* src_grad, const T* dst_grad, const T* flt, const Param& param,
  57. cudaStream_t stream);
  58. template <typename T>
  59. void run_bwd_depthwise_large_filter(
  60. T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream);
  61. template <typename T>
  62. void run_bwd_filter(
  63. T* filter_grad, const T* src, const T* dst_grad, const Param& param,
  64. cudaStream_t stream);
  65. } // namespace chanwise
  66. } // namespace convolution
  67. } // namespace cuda
  68. } // namespace megdnn
  69. // vim: ft=cpp syntax=cpp.doxygen