You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo.cpp 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. /**
  2. * \file dnn/src/cuda/convolution/backward_data/algo.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "./algo.h"
  13. #include "src/cuda/utils.h"
  14. using namespace megdnn;
  15. using namespace cuda;
  16. ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() {
  17. non_cudnn_algos.push_back(&chanwise);
  18. non_cudnn_algos.push_back(&chanwise_small);
  19. non_cudnn_algos.push_back(&matmul);
  20. all_algos.push_back(&chanwise); // prefer chanwise
  21. all_algos.push_back(&chanwise_small); // prefer small chanwise
  22. fill_cudnn_algos();
  23. for (auto&& i : cudnn) {
  24. all_algos.push_back(&i);
  25. }
  26. all_algos.push_back(&matmul);
  27. fill_int8_dp4a_algos();
  28. for (auto&& algo : int8_nchw4_dotprod) {
  29. all_algos.push_back(&algo);
  30. int8_algos.push_back(&algo);
  31. }
  32. fill_int8_imma_algos();
  33. for (auto&& algo : int8_nhwc_imma) {
  34. all_algos.push_back(&algo);
  35. int8_algos.push_back(&algo);
  36. }
  37. fill_dwconv_algos();
  38. int8_algos.push_back(&int8_nchw_dotprod);
  39. all_algos.push_back(&int8_nchw_dotprod);
  40. all_algos.push_back(&bfloat16);
  41. bfloat16_algos.push_back(&bfloat16);
  42. all_algos.push_back(&group);
  43. for (auto&& algo : all_algos) {
  44. m_all_algos_map.emplace(algo->info().desc, algo);
  45. }
  46. }
  47. void ConvolutionBackwardDataImpl::AlgoPack::fill_dwconv_algos() {
  48. {
  49. using AlgoParam = AlgoFloat32NCHWFMAImplicitBatchedGemm::AlgoParam;
  50. /// preferred algo
  51. implbmm_nchw_fma.emplace_back(AlgoParam{64, 128, 8, 32, 64, 8, 2});
  52. implbmm_nchw_fma.emplace_back(AlgoParam{128, 128, 8, 32, 64, 8, 2});
  53. implbmm_nchw_fma.emplace_back(AlgoParam{128, 64, 8, 64, 32, 8, 2});
  54. implbmm_nchw_fma.emplace_back(AlgoParam{128, 32, 8, 64, 32, 8, 2});
  55. implbmm_nchw_fma.emplace_back(AlgoParam{32, 128, 8, 32, 64, 8, 2});
  56. implbmm_nchw_fma.emplace_back(AlgoParam{64, 64, 8, 32, 64, 8, 2});
  57. implbmm_nchw_fma.emplace_back(AlgoParam{32, 64, 8, 32, 64, 8, 2});
  58. implbmm_nchw_fma.emplace_back(AlgoParam{32, 32, 8, 32, 32, 8, 2});
  59. implbmm_nchw_fma.emplace_back(AlgoParam{64, 32, 8, 64, 32, 8, 2});
  60. for (auto&& algo : implbmm_nchw_fma) {
  61. all_algos.push_back(&algo);
  62. }
  63. }
  64. #if CUDA_VERSION >= 10010
  65. {
  66. using AlgoParam = AlgoFloat16NCHWHMMAImplicitBatchedGemm::AlgoParam;
  67. /// preferred algo
  68. implbmm_nchw_hmma.emplace_back(AlgoParam{64, 128, 32, 32, 32, 32, 8, 8, 4, 2});
  69. implbmm_nchw_hmma.emplace_back(AlgoParam{128, 128, 32, 32, 32, 32, 8, 8, 4, 2});
  70. implbmm_nchw_hmma.emplace_back(AlgoParam{128, 256, 32, 64, 64, 32, 8, 8, 4, 2});
  71. implbmm_nchw_hmma.emplace_back(AlgoParam{128, 64, 32, 32, 32, 32, 8, 8, 4, 2});
  72. implbmm_nchw_hmma.emplace_back(AlgoParam{64, 64, 32, 32, 32, 32, 8, 8, 4, 2});
  73. for (auto&& algo : implbmm_nchw_hmma) {
  74. all_algos.push_back(&algo);
  75. }
  76. }
  77. #endif
  78. }
  79. MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl)
  80. ConvolutionBackwardDataImpl::AlgoCUDNN* ConvolutionBackwardDataImpl::AlgoPack::
  81. cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo) {
  82. for (auto&& i : cudnn) {
  83. if (i.cudnn_enum() == algo)
  84. return &i;
  85. }
  86. megdnn_throw(ssprintf(
  87. "can not find cudnn bwd_data algorithm %d", static_cast<int>(algo)));
  88. }
  89. ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack;
  90. ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs(
  91. const ConvolutionBackwardDataImpl* o, const TensorLayout& filter,
  92. const TensorLayout& diff, const TensorLayout& grad)
  93. : SizeArgs(
  94. o, filter, o->make_canonized_filter_meta(grad.ndim, filter), diff,
  95. grad) {}
  96. ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs(
  97. const ConvolutionBackwardDataImpl* o, const TensorLayout& filter,
  98. const CanonizedFilterMeta& filter_meta, const TensorLayout& diff,
  99. const TensorLayout& grad)
  100. : handle{concrete_handle(o->handle())},
  101. filter_meta{filter_meta},
  102. diff_layout{&diff},
  103. grad_layout{&grad},
  104. filter_layout{&filter},
  105. opr{o} {}
  106. ConvolutionBackwardDataImpl::AlgoBase::ExecArgs::ExecArgs(
  107. const ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter,
  108. _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace)
  109. : SizeArgs(opr, filter.layout, diff.layout, grad.layout),
  110. filter_tensor{&filter},
  111. diff_tensor{&diff},
  112. grad_tensor{&grad},
  113. workspace{workspace} {}
  114. std::string ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::to_string() const {
  115. auto&& fm = filter_meta;
  116. MEGDNN_MARK_USED_VAR(fm);
  117. return ssprintf(
  118. "filter=%u{%u,%u,%u,%u}, diff=%s, grad=%s, "
  119. "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s",
  120. fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1],
  121. diff_layout->to_string().c_str(), grad_layout->to_string().c_str(),
  122. fm.padding[0], fm.padding[1], fm.stride[0], fm.stride[1], fm.dilation[0],
  123. fm.dilation[1], !fm.should_flip, diff_layout->dtype.name(),
  124. grad_layout->dtype.name());
  125. }
  126. // vim: syntax=cpp.doxygen