You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. /**
  2. * \file dnn/src/fallback/powc/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./opr_impl.h"
  12. #include "src/naive/handle.h"
  13. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  14. #include "src/arm_common/simd_macro/marm_neon.h"
  15. #endif
  16. #include <limits>
  17. using namespace megdnn;
  18. using namespace fallback;
  19. namespace {
  20. template <int exp>
  21. struct powci;
  22. template <>
  23. struct powci<0> {
  24. template <typename T>
  25. static T apply(T) {
  26. return static_cast<T>(1);
  27. }
  28. };
  29. template <>
  30. struct powci<1> {
  31. template <typename T>
  32. static T apply(T x) {
  33. return x;
  34. }
  35. };
  36. template <>
  37. struct powci<2> {
  38. template <typename T>
  39. static T apply(T x) {
  40. return x * x;
  41. }
  42. };
  43. template <>
  44. struct powci<3> {
  45. template <typename T>
  46. static T apply(T x) {
  47. return x * x * x;
  48. }
  49. };
  50. template <>
  51. struct powci<4> {
  52. template <typename T>
  53. static T apply(T x) {
  54. x = x * x;
  55. return x * x;
  56. }
  57. };
  58. template <int exp>
  59. struct powci {
  60. static_assert(exp < 0, "bad arg");
  61. template <typename T>
  62. static T apply(T x) {
  63. return powci<-exp>::apply(static_cast<T>(1) / x);
  64. }
  65. };
  66. struct powci_general_even {
  67. int exp;
  68. powci_general_even(int e) : exp{e} {}
  69. template <typename T>
  70. T apply(T x) {
  71. return static_cast<T>(std::pow(std::abs(x), static_cast<T>(exp)));
  72. }
  73. };
  74. template <size_t size>
  75. struct float_itype;
  76. #ifndef MEGDNN_DISABLE_FLOAT16
  77. template <>
  78. struct float_itype<2> {
  79. using type = uint16_t;
  80. static constexpr uint16_t mask = 1u << 15;
  81. };
  82. #endif
  83. template <>
  84. struct float_itype<4> {
  85. using type = uint32_t;
  86. static constexpr uint32_t mask = 1u << 31;
  87. };
  88. struct powci_general_odd {
  89. template <typename T>
  90. union fiu {
  91. T f;
  92. typename float_itype<sizeof(T)>::type i;
  93. fiu() {}
  94. };
  95. int exp;
  96. powci_general_odd(int e) : exp{e} {}
  97. template <typename T>
  98. T apply(T x) {
  99. fiu<T> iret, ix;
  100. iret.f = std::pow(std::abs(x), static_cast<T>(exp));
  101. ix.f = x;
  102. iret.i |= ix.i & float_itype<sizeof(T)>::mask;
  103. return iret.f;
  104. }
  105. };
  106. struct powcf_sqrt {
  107. template <typename T>
  108. static T apply(T x) {
  109. return static_cast<T>(std::sqrt(x));
  110. }
  111. };
  112. struct powcf_cbrt {
  113. template <typename T>
  114. static T apply(T x) {
  115. return static_cast<T>(std::cbrt(x));
  116. }
  117. };
  118. struct powcf_rep_sqrt {
  119. template <typename T>
  120. static T apply(T x) {
  121. return static_cast<T>(std::sqrt(static_cast<T>(1) / x));
  122. }
  123. };
  124. struct powcf_rep_cbrt {
  125. template <typename T>
  126. static T apply(T x) {
  127. return static_cast<T>(std::cbrt(static_cast<T>(1) / x));
  128. }
  129. };
  130. template <typename T>
  131. struct powcf_general {
  132. float exp;
  133. powcf_general(float e) : exp{e} {}
  134. T apply(T x) { return static_cast<T>(std::pow(std::abs(x), exp)); }
  135. };
  136. template <typename T, class ExpFunc>
  137. void pow_invoke(const T* src, T* dst, size_t size, ExpFunc expfunc) {
  138. size_t i;
  139. for (i = 0; i + 4 <= size; i += 4) {
  140. T a0 = src[i], a1 = src[i + 1], a2 = src[i + 2], a3 = src[i + 3];
  141. T b0 = expfunc.apply(a0), b1 = expfunc.apply(a1),
  142. b2 = expfunc.apply(a2), b3 = expfunc.apply(a3);
  143. dst[i] = b0;
  144. dst[i + 1] = b1;
  145. dst[i + 2] = b2;
  146. dst[i + 3] = b3;
  147. }
  148. #if MEGDNN_FIX_AARCH32_BUG
  149. // FIXME: as llvm may cause cannot select error if enable vectorize
  150. #pragma clang loop vectorize(disable)
  151. #endif
  152. for (; i < size; ++i) {
  153. dst[i] = expfunc.apply(src[i]);
  154. }
  155. }
  156. bool float_eq(float x, float y) {
  157. return std::abs(x - y) < std::numeric_limits<float>::epsilon();
  158. }
  159. } // anonymous namespace
  160. template <typename T>
  161. void PowCImpl::do_exec_ct(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  162. const float* exp_f, const int* exp_i) {
  163. auto handle = static_cast<naive::HandleImpl*>(this->handle());
  164. auto sptr = reinterpret_cast<T*>(src.raw_ptr);
  165. auto dptr = reinterpret_cast<T*>(dst.raw_ptr);
  166. auto size = src.layout.total_nr_elems();
  167. #define CALL(_expfunc) \
  168. do { \
  169. auto kern = [ sptr, dptr, size, expfunc = _expfunc ]() { \
  170. pow_invoke(sptr, dptr, size, expfunc); \
  171. }; \
  172. handle->dispatch_kern(kern); \
  173. return; \
  174. } while (0)
  175. if (exp_f) {
  176. float fv = *exp_f;
  177. #define CALL_IF(_v, _expfunc) \
  178. if (float_eq(fv, _v)) { \
  179. CALL(_expfunc); \
  180. return; \
  181. }
  182. constexpr float croot = 1.f / 3.f;
  183. CALL_IF(.5f, powcf_sqrt{});
  184. CALL_IF(croot, powcf_cbrt{});
  185. CALL_IF(-.5f, powcf_rep_sqrt{});
  186. CALL_IF(-croot, powcf_rep_cbrt{});
  187. CALL(powcf_general<T>{fv});
  188. #undef CALL_IF
  189. }
  190. int iv = *exp_i;
  191. switch (iv) {
  192. #define CASE(n) \
  193. case n: \
  194. CALL(powci<n>{}); \
  195. return
  196. CASE(0);
  197. CASE(1);
  198. CASE(2);
  199. CASE(3);
  200. CASE(4);
  201. CASE(-1);
  202. CASE(-2);
  203. CASE(-3);
  204. CASE(-4);
  205. #undef CASE
  206. }
  207. if (iv & 1) {
  208. CALL(powci_general_odd{iv});
  209. } else {
  210. CALL(powci_general_even{iv});
  211. }
  212. #undef CALL
  213. }
  214. void PowCImpl::do_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  215. const float* exp_f, const int* exp_i) {
  216. if (!src.layout.is_contiguous()) {
  217. naive::PowCImpl::do_exec(src, dst, exp_f, exp_i);
  218. return;
  219. }
  220. switch (src.layout.dtype.enumv()) {
  221. #define cb(dt) \
  222. case DTypeTrait<dt>::enumv: \
  223. return do_exec_ct<DTypeTrait<dt>::ctype>(src, dst, exp_f, exp_i);
  224. cb(dtype::Float32);
  225. #undef cb
  226. #if !MEGDNN_DISABLE_FLOAT16
  227. case DTypeTrait<dtype::Float16>::enumv:
  228. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  229. return DNN_INC_FLOAT16(
  230. do_exec_ct<__fp16>(src, dst, exp_f, exp_i));
  231. #else
  232. return DNN_INC_FLOAT16(
  233. do_exec_ct<dt_float16>(src, dst, exp_f, exp_i));
  234. #endif
  235. #endif
  236. default:
  237. megdnn_throw("unsupported dtype for PowC");
  238. }
  239. }
  240. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台