You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

numerical.cpp 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. /**
  2. * \file src/jit/impl/mlir/ir/numerical.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain_build_config.h"
  13. #if MGB_JIT && MGB_JIT_MLIR
  14. #include "numerical.h"
  15. namespace mgb {
  16. namespace jit {
  17. mlir::Value polynomial(ValueBuilderHelper& helper, mlir::Value x,
  18. std::vector<mlir::Value>& coeff) {
  19. size_t n = coeff.size();
  20. if (n == 0) {
  21. return helper.const_val(0);
  22. }
  23. mlir::Value r = coeff[0];
  24. for (size_t i = 1; i < n; i++) {
  25. r = helper.add(helper.mul(r, x), coeff[i]);
  26. }
  27. return r;
  28. }
  29. // polynomial approximation of arctangent
  30. // atan(t) = t + c3 * t^3 + c5 * t^5 + ... + c17 * t^17
  31. // original paper:
  32. // https://arxiv.org/pdf/1508.03211.pdf
  33. mlir::Value atan2_approx(ValueBuilderHelper& helper, mlir::Value y,
  34. mlir::Value x) {
  35. auto atan_poly = [&](mlir::Value t) {
  36. std::vector<mlir::Value> coeff = {
  37. helper.const_val(2.90188402868807315826416015625E-3),
  38. helper.const_val(-1.62907354533672332763671875E-2),
  39. helper.const_val(4.3082617223262786865234375E-2),
  40. helper.const_val(-7.5408883392810821533203125E-2),
  41. helper.const_val(0.1066047251224517822265625),
  42. helper.const_val(-0.14209578931331634521484375),
  43. helper.const_val(0.19993579387664794921875),
  44. helper.const_val(-0.3333314359188079833984375)};
  45. auto t2 = helper.mul(t, t);
  46. auto p = polynomial(helper, t2, coeff);
  47. return helper.add(helper.mul(helper.mul(p, t2), t), t);
  48. };
  49. // constants
  50. auto zero = helper.const_val(0);
  51. auto pi = helper.const_val(3.141592653589793);
  52. auto pi_over_2 = helper.const_val(1.570796326794897);
  53. // transform the angle into interval [0, pi/4]
  54. auto ax = helper.abs(x);
  55. auto ay = helper.abs(y);
  56. auto q = helper.div(helper.min(ax, ay), helper.max(ax, ay));
  57. // get approximation for interval [0, pi/4]
  58. auto r = atan_poly(q);
  59. // [0, pi/4] => [0, pi/2]
  60. r = helper.select(helper.le(ax, ay), helper.sub(pi_over_2, r), r);
  61. // [0, pi/2] => [0, pi]
  62. r = helper.select(helper.le(x, zero), helper.sub(pi, r), r);
  63. // [0, pi] => [-pi, pi]
  64. r = helper.select(helper.le(y, zero), helper.sub(zero, r), r);
  65. return r;
  66. }
  67. // numerical approximation of gauss error function
  68. // https://en.wikipedia.org/wiki/Error_function#Polynomial
  69. // original book:
  70. // Numerical Recipes in Fortran 77: The Art of Scientific Computing
  71. mlir::Value erf_approx(ValueBuilderHelper& helper, mlir::Value x) {
  72. auto zero = helper.const_val(0);
  73. auto one = helper.const_val(1);
  74. auto half = helper.const_val(0.5);
  75. auto t = helper.div(one, helper.add(one, helper.mul(half, helper.abs(x))));
  76. std::vector<mlir::Value> coeff = {
  77. helper.const_val(0.17087277),
  78. helper.const_val(-0.82215223),
  79. helper.const_val(1.48851587),
  80. helper.const_val(-1.13520398),
  81. helper.const_val(0.27886807),
  82. helper.const_val(-0.18628806),
  83. helper.const_val(0.09678418),
  84. helper.const_val(0.37409196),
  85. helper.const_val(1.00002368),
  86. helper.const_val(-1.26551223)};
  87. auto p = polynomial(helper, t, coeff);
  88. auto r = helper.mul(t, helper.exp(helper.sub(p, helper.mul(x, x))));
  89. return helper.select(helper.ge(x, zero),
  90. helper.sub(one, r),
  91. helper.sub(r, one));
  92. }
  93. // numerical approximation of the inverse of normal distribution function
  94. // original algorithm:
  95. // https://github.com/scipy/scipy/blob/master/scipy/special/cephes/ndtri.c
  96. // case 1: 0 < x < exp(-2)
  97. // z = sqrt(-2 * log(x))
  98. // t = 1 / z
  99. // res = log(z) / z - z + t * P(t) / Q(t)
  100. // where coefficients of P and Q are different
  101. // for z < 8 and for z >= 8
  102. //
  103. // case2: exp(-2) <= x <= 1 - exp(-2)
  104. // w = x - 0.5
  105. // res = sqrt(2pi) * (w + w^3 * R(w^2) / S(w^2))
  106. //
  107. // case3: 1 - exp(-2) < x < 1
  108. // 0 < 1 - x < exp(-2)
  109. // ndtri(x) = -ndtri(1 - x)
  110. // fallback to case 1
  111. mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) {
  112. // polynomial P
  113. auto P = [&](mlir::Value i, mlir::Value cond) {
  114. std::vector<mlir::Value> coeff0 = {
  115. helper.const_val(4.05544892305962419923E0),
  116. helper.const_val(3.15251094599893866154E1),
  117. helper.const_val(5.71628192246421288162E1),
  118. helper.const_val(4.40805073893200834700E1),
  119. helper.const_val(1.46849561928858024014E1),
  120. helper.const_val(2.18663306850790267539E0),
  121. helper.const_val(-1.40256079171354495875E-1),
  122. helper.const_val(-3.50424626827848203418E-2),
  123. helper.const_val(-8.57456785154685413611E-4)};
  124. std::vector<mlir::Value> coeff1 = {
  125. helper.const_val(3.23774891776946035970E0),
  126. helper.const_val(6.91522889068984211695E0),
  127. helper.const_val(3.93881025292474443415E0),
  128. helper.const_val(1.33303460815807542389E0),
  129. helper.const_val(2.01485389549179081538E-1),
  130. helper.const_val(1.23716634817820021358E-2),
  131. helper.const_val(3.01581553508235416007E-4),
  132. helper.const_val(2.65806974686737550832E-6),
  133. helper.const_val(6.23974539184983293730E-9)};
  134. return helper.select(cond,
  135. polynomial(helper, i, coeff0),
  136. polynomial(helper, i, coeff1));
  137. };
  138. // polynomial Q
  139. auto Q = [&](mlir::Value i, mlir::Value cond) {
  140. std::vector<mlir::Value> coeff0 = {
  141. helper.const_val(1.f),
  142. helper.const_val(1.57799883256466749731E1),
  143. helper.const_val(4.53907635128879210584E1),
  144. helper.const_val(4.13172038254672030440E1),
  145. helper.const_val(1.50425385692907503408E1),
  146. helper.const_val(2.50464946208309415979E0),
  147. helper.const_val(-1.42182922854787788574E-1),
  148. helper.const_val(-3.80806407691578277194E-2),
  149. helper.const_val(-9.33259480895457427372E-4)};
  150. std::vector<mlir::Value> coeff1 = {
  151. helper.const_val(1.f),
  152. helper.const_val(6.02427039364742014255E0),
  153. helper.const_val(3.67983563856160859403E0),
  154. helper.const_val(1.37702099489081330271E0),
  155. helper.const_val(2.16236993594496635890E-1),
  156. helper.const_val(1.34204006088543189037E-2),
  157. helper.const_val(3.28014464682127739104E-4),
  158. helper.const_val(2.89247864745380683936E-6),
  159. helper.const_val(6.79019408009981274425E-9)};
  160. return helper.select(cond,
  161. polynomial(helper, i, coeff0),
  162. polynomial(helper, i, coeff1));
  163. };
  164. // polynomial R
  165. auto R = [&](mlir::Value i) {
  166. std::vector<mlir::Value> coeff = {
  167. helper.const_val(-5.99633501014107895267E1),
  168. helper.const_val(9.80010754185999661536E1),
  169. helper.const_val(-5.66762857469070293439E1),
  170. helper.const_val(1.39312609387279679503E1),
  171. helper.const_val(-1.23916583867381258016E0)};
  172. return polynomial(helper, i, coeff);
  173. };
  174. // polynomial S
  175. auto S = [&](mlir::Value i) {
  176. std::vector<mlir::Value> coeff = {
  177. helper.const_val(1.f),
  178. helper.const_val(1.95448858338141759834E0),
  179. helper.const_val(4.67627912898881538453E0),
  180. helper.const_val(8.63602421390890590575E1),
  181. helper.const_val(-2.25462687854119370527E2),
  182. helper.const_val(2.00260212380060660359E2),
  183. helper.const_val(-8.20372256168333339912E1),
  184. helper.const_val(1.59056225126211695515E1),
  185. helper.const_val(-1.18331621121330003142E0)};
  186. return polynomial(helper, i, coeff);
  187. };
  188. // constants
  189. auto zero = helper.const_val(0);
  190. auto one = helper.const_val(1);
  191. auto half = helper.const_val(0.5);
  192. auto eight = helper.const_val(8);
  193. auto minus_2 = helper.const_val(-2);
  194. auto exp_minus_2 = helper.const_val(0.135335283236); // exp(-2)
  195. auto sqrt_2pi = helper.const_val(2.506628274631); // sqrt(2pi)
  196. // conditions
  197. auto case1 = helper.lt(x, exp_minus_2); // x < exp(-2)
  198. auto case3 = helper.gt(x, helper.sub(one, exp_minus_2)); // x > 1 - exp(-2)
  199. auto case13 = helper.bit_or(case1, case3);
  200. // case1 or case3
  201. auto x13 = helper.select(case1, x, helper.sub(one, x)); // x or (1 - x)
  202. auto z = helper.sqrt(helper.mul(minus_2, helper.log(x13)));
  203. auto z_lt_8 = helper.lt(z, eight);
  204. auto t = helper.div(one, z);
  205. auto res1 = helper.add(helper.sub(helper.div(helper.log(z), z), z),
  206. helper.div(helper.mul(t, P(t, z_lt_8)), Q(t, z_lt_8)));
  207. auto res13 = helper.select(case1, res1, helper.sub(zero, res1));
  208. // case2
  209. auto w = helper.sub(x, half);
  210. auto w2 = helper.mul(w, w);
  211. auto w3 = helper.mul(w, w2);
  212. auto res2 = helper.mul(
  213. sqrt_2pi, helper.add(w, helper.div(helper.mul(w3, R(w2)), S(w2))));
  214. return helper.select(case13, res13, res2);
  215. }
  216. } // namespace jit
  217. } // namespace mgb
  218. #endif // MGB_JIT && MGB_JIT_MLIR
  219. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台