You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.cpp 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. /**
  2. * \file dnn/src/aarch64/conv_bias/fp16/algos.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/aarch64/conv_bias/fp16/algos.h"
  13. #include "src/aarch64/conv_bias/fp16/stride2_kern.h"
  14. #include "src/arm_common/conv_bias/direct/multi_thread_common.h"
  15. #include "src/arm_common/conv_bias/postprocess_helper.h"
  16. using namespace megdnn;
  17. using namespace aarch64;
  18. #include "midout.h"
  19. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  20. /* ===================== stride-2 algo ===================== */
  21. MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16)
  22. bool ConvBiasImpl::AlgoF16DirectStride2::usable(const NCBKernSizeParam& param,
  23. AlgoSelectionStrategy) const {
  24. MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) {
  25. auto&& fm = param.filter_meta;
  26. auto FH = fm.spatial[0];
  27. return param.filter_meta.format == param::Convolution::Format::NCHW &&
  28. param.src_type.enumv() == DTypeEnum::Float16 &&
  29. param.filter_type.enumv() == DTypeEnum::Float16 &&
  30. param.dst_type.enumv() == DTypeEnum::Float16 &&
  31. !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
  32. fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 &&
  33. FH == fm.spatial[1] &&
  34. (FH == 2 || FH == 3 || FH == 5 || FH == 7);
  35. }
  36. MIDOUT_END();
  37. return false;
  38. }
  39. size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace(
  40. const NCBKernSizeParam& param) const {
  41. MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 1) {
  42. bool large_group = param.filter_meta.group >= param.nr_threads;
  43. auto wbundle = arm_common::MultithreadDirectConvCommon<
  44. dt_float16, __fp16>::get_bundle_stride(param, large_group);
  45. return wbundle.total_size_in_bytes();
  46. }
  47. MIDOUT_END();
  48. return 0;
  49. }
  50. SmallVector<ConvBiasImpl::NCBKern>
  51. ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns(
  52. const NCBKernSizeParam& param) const {
  53. MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) {
  54. return get_kimpls(param);
  55. }
  56. MIDOUT_END();
  57. return {};
  58. }
  59. SmallVector<ConvBiasImpl::NCBKern>
  60. ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
  61. const NCBKernSizeParam& param) const {
  62. auto fm = param.filter_meta;
  63. auto FH = fm.spatial[0];
  64. size_t N = param.n;
  65. size_t IC = param.filter_meta.icpg;
  66. size_t OC = param.filter_meta.ocpg;
  67. size_t group = fm.group;
  68. bool large_group = group >= param.nr_threads;
  69. using Func = std::function<void(const __fp16*, const __fp16*, __fp16*,
  70. size_t, size_t, size_t, size_t, size_t)>;
  71. Func conv = nullptr;
  72. if (FH == 2) {
  73. conv = fp16::conv_stride2::do_conv_2x2_stride2;
  74. } else if (FH == 3) {
  75. conv = fp16::conv_stride2::do_conv_3x3_stride2;
  76. } else if (FH == 5) {
  77. conv = fp16::conv_stride2::do_conv_5x5_stride2;
  78. } else if (FH == 7) {
  79. conv = fp16::conv_stride2::do_conv_7x7_stride2;
  80. }
  81. WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon<
  82. dt_float16, __fp16>::get_bundle_stride(param, large_group);
  83. SmallVector<NCBKern> ret_kerns;
  84. //! Dense conv and small group
  85. if (large_group) {
  86. //! Channel wise conv and big groups
  87. auto exec_one_group = [bundle, conv](
  88. const NCBKernParam& kern_param,
  89. const NCBKernIndex& ncb_index) mutable {
  90. auto fm = kern_param.filter_meta;
  91. size_t IC = fm.icpg;
  92. size_t OC = fm.ocpg;
  93. bundle.set(kern_param.workspace_ptr);
  94. for (size_t ic = 0; ic < IC; ic++) {
  95. arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
  96. copy_padding_kern_stride(bundle, kern_param, ncb_index,
  97. {ncb_index.thread_id, 0, ic});
  98. }
  99. for (size_t oc = 0; oc < OC; oc++) {
  100. arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
  101. do_conv_kern_stride(bundle, kern_param, ncb_index, conv,
  102. {ncb_index.thread_id, 0, oc});
  103. }
  104. };
  105. ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
  106. } else {
  107. auto copy_padding = [bundle](const NCBKernParam& kern_param,
  108. const NCBKernIndex& ncb_index) mutable {
  109. bundle.set(kern_param.workspace_ptr);
  110. arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
  111. copy_padding_kern_stride(bundle, kern_param, ncb_index,
  112. ncb_index.ndrange_id);
  113. };
  114. ret_kerns.push_back({copy_padding, {group, N, IC}});
  115. auto do_conv = [bundle, conv](const NCBKernParam& kern_param,
  116. const NCBKernIndex& ncb_index) mutable {
  117. bundle.set(kern_param.workspace_ptr);
  118. arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
  119. do_conv_kern_stride(bundle, kern_param, ncb_index, conv,
  120. ncb_index.ndrange_id);
  121. };
  122. ret_kerns.push_back({do_conv, {group, N, OC}});
  123. }
  124. return ret_kerns;
  125. }
  126. #endif
  127. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台