You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #include "src/fallback/softmax/opr_impl.h"
  2. #include <cstring>
  3. #include <numeric>
  4. #include "src/fallback/elemwise/gi_impl/gi_mathfun.h"
  5. #include "src/naive/handle.h"
  6. namespace megdnn {
  7. namespace fallback {
  8. static void do_softmax(
  9. const float* sptr, float* dptr, size_t A, size_t B, size_t C,
  10. _megdnn_workspace workspace) {
  11. constexpr auto float_min = std::numeric_limits<float>::min();
  12. constexpr auto step = GI_SIMD_LEN_BYTE / sizeof(float);
  13. // TODO: When C=2,3,4..., src_ptr span is relatively large, the performance may
  14. // be poor
  15. if (C != 1) {
  16. WorkspaceBundle workspace_bundle{
  17. workspace.raw_ptr, {A * C * sizeof(float), A * C * sizeof(float)}};
  18. float* max = workspace_bundle.get_workspace(0).raw_ptr->as<float>();
  19. GI_FLOAT32_t v_max = GiBroadcastFloat32(float_min);
  20. size_t i = 0;
  21. for (; i + step <= A * C; i += step)
  22. GiStoreFloat32(max + i, v_max);
  23. for (; i < A * C; i++)
  24. max[i] = float_min;
  25. for (size_t a = 0; a < A; a++) {
  26. for (size_t b = 0; b < B; b++) {
  27. auto max_ptr = max + a * C;
  28. auto limit = max_ptr + C;
  29. auto src_ptr = sptr + a * B * C + b * C;
  30. for (; max_ptr + step <= limit; max_ptr += step, src_ptr += step) {
  31. GI_FLOAT32_t v_p = GiLoadFloat32(src_ptr);
  32. GI_FLOAT32_t v_max = GiLoadFloat32(max_ptr);
  33. v_max = GiMaximumFloat32(v_max, v_p);
  34. GiStoreFloat32(max_ptr, v_max);
  35. }
  36. for (; max_ptr < limit; ++max_ptr, ++src_ptr) {
  37. *max_ptr = std::max(*src_ptr, *max_ptr);
  38. }
  39. }
  40. }
  41. float* sum = workspace_bundle.get_workspace(1).raw_ptr->as<float>();
  42. memset(sum, 0, A * C * sizeof(float));
  43. for (size_t a = 0; a < A; a++) {
  44. for (size_t b = 0; b < B; b++) {
  45. auto max_ptr = max + a * C;
  46. auto limit = max_ptr + C;
  47. auto sum_ptr = sum + a * C;
  48. auto src_ptr = sptr + a * B * C + C * b;
  49. auto dst_ptr = dptr + a * B * C + C * b;
  50. for (; max_ptr + step <= limit; max_ptr += step, sum_ptr += step,
  51. src_ptr += step, dst_ptr += step) {
  52. GI_FLOAT32_t v_p = GiLoadFloat32(src_ptr);
  53. GI_FLOAT32_t v_max = GiLoadFloat32(max_ptr);
  54. GI_FLOAT32_t v_sum = GiLoadFloat32(sum_ptr);
  55. v_p = GiExpPsFloat32(GiSubtractFloat32(v_p, v_max));
  56. v_sum = GiAddFloat32(v_p, v_sum);
  57. GiStoreFloat32(dst_ptr, v_p);
  58. GiStoreFloat32(sum_ptr, v_sum);
  59. }
  60. for (; max_ptr < limit; ++max_ptr, ++sum_ptr, ++src_ptr, ++dst_ptr) {
  61. *dst_ptr = exp(*src_ptr - *max_ptr);
  62. *sum_ptr += *dst_ptr;
  63. }
  64. }
  65. }
  66. for (size_t a = 0; a < A; a++) {
  67. for (size_t b = 0; b < B; b++) {
  68. auto sum_ptr = sum + a * C;
  69. auto limit = sum_ptr + C;
  70. auto dst_ptr = dptr + a * B * C + C * b;
  71. for (; sum_ptr + step <= limit; sum_ptr += step, dst_ptr += step) {
  72. GI_FLOAT32_t v_p = GiLoadFloat32(dst_ptr);
  73. GI_FLOAT32_t v_sum = GiLoadFloat32(sum_ptr);
  74. v_p = GiDivideFloat32(v_p, v_sum);
  75. GiStoreFloat32(dst_ptr, v_p);
  76. }
  77. for (; sum_ptr < limit; ++sum_ptr, ++dst_ptr)
  78. *dst_ptr = *dst_ptr / *sum_ptr;
  79. }
  80. }
  81. } else {
  82. for (size_t a = 0; a < A; a++) {
  83. auto max = float_min;
  84. {
  85. auto src_ptr = sptr + a * B;
  86. auto limit = src_ptr + B;
  87. GI_FLOAT32_t v_max = GiBroadcastFloat32(max);
  88. for (; src_ptr + step <= limit; src_ptr += step) {
  89. GI_FLOAT32_t v_p = GiLoadFloat32(src_ptr);
  90. v_max = GiMaximumFloat32(v_max, v_p);
  91. }
  92. max = std::max(max, GiReduceMaxNanFloat32(v_max));
  93. for (; src_ptr < limit; ++src_ptr) {
  94. max = std::max(*src_ptr, max);
  95. }
  96. }
  97. auto sum = 0.f;
  98. {
  99. auto src_ptr = sptr + a * B;
  100. auto limit = src_ptr + B;
  101. auto dst_ptr = dptr + a * B;
  102. GI_FLOAT32_t v_sum = GiZeroFloat32();
  103. GI_FLOAT32_t v_max = GiBroadcastFloat32(max);
  104. for (; src_ptr + step <= limit; src_ptr += step, dst_ptr += step) {
  105. GI_FLOAT32_t v_p = GiLoadFloat32(src_ptr);
  106. v_p = GiExpPsFloat32(GiSubtractFloat32(v_p, v_max));
  107. GiStoreFloat32(dst_ptr, v_p);
  108. v_sum = GiAddFloat32(v_sum, v_p);
  109. }
  110. sum += GiReduceAddFloat32(v_sum);
  111. for (; src_ptr < limit; ++src_ptr, ++dst_ptr) {
  112. *dst_ptr = exp(*src_ptr - max);
  113. sum += *dst_ptr;
  114. }
  115. }
  116. {
  117. auto dst_ptr = dptr + a * B;
  118. auto limit = dst_ptr + B;
  119. sum = 1 / sum;
  120. GI_FLOAT32_t v_sum = GiBroadcastFloat32(sum);
  121. for (; dst_ptr + step <= limit; dst_ptr += step) {
  122. GI_FLOAT32_t v_p = GiLoadFloat32(dst_ptr);
  123. v_p = GiMultiplyFloat32(v_p, v_sum);
  124. GiStoreFloat32(dst_ptr, v_p);
  125. }
  126. for (; dst_ptr < limit; ++dst_ptr) {
  127. *dst_ptr *= sum;
  128. }
  129. }
  130. }
  131. }
  132. }
  133. void SoftmaxForwardImpl::exec(
  134. _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
  135. auto axis = param().axis;
  136. if (axis < 0)
  137. axis += src.layout.ndim;
  138. megdnn_assert(axis >= 0);
  139. check_exec(src.layout, dst.layout, workspace.size);
  140. if (!usable(src.layout)) {
  141. naive::SoftmaxForwardImpl::exec(src, dst, workspace);
  142. return;
  143. }
  144. typedef DTypeTrait<dtype::Float32>::ctype Float32;
  145. auto sptr = src.ptr<Float32>();
  146. auto dptr = dst.ptr<Float32>();
  147. size_t A, B, C;
  148. reduce::get_ABC(src.layout, A, B, C, axis);
  149. MEGDNN_DISPATCH_CPU_KERN_OPR(do_softmax(sptr, dptr, A, B, C, workspace));
  150. }
  151. } // namespace fallback
  152. } // namespace megdnn
  153. // vim: syntax=cpp.doxygen