You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

each_mode.cpp 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. /**
  2. * \file src/jit/impl/mlir/ir/each_mode.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain_build_config.h"
  13. #if MGB_JIT && MGB_JIT_MLIR
  14. #include "./common.h"
  15. #include "./each_mode.h"
  16. #include "./numerical.h"
  17. #include "./types.h"
  18. #include "megbrain/common.h"
  19. #include "megbrain/exception.h"
  20. #include "megbrain/jit/mlir/ir/dialect.h"
  21. #include <mlir/Dialect/StandardOps/IR/Ops.h>
  22. namespace mgb {
  23. namespace jit {
  24. using Mode = megdnn::param::Elemwise::Mode;
  25. template <Mode mode>
  26. mlir::Value lower_mode(mlir::OpBuilder& builder, mlir::Location loc,
  27. ValueRange operands);
  28. /* ===================== trivial implementations ===================== */
  29. #define cb(mode, fun) \
  30. template <> \
  31. mlir::Value lower_mode<Mode::mode>(mlir::OpBuilder & builder, \
  32. mlir::Location loc, \
  33. ValueRange operands) { \
  34. ValueBuilderHelper helper(builder, loc); \
  35. return helper.fun(operands); \
  36. }
  37. //! unary
  38. cb(ABS, abs);
  39. cb(CEIL, ceil);
  40. cb(COS, cos);
  41. cb(EXP, exp);
  42. cb(FLOOR, floor);
  43. cb(LOG, log);
  44. cb(NEGATE, neg);
  45. cb(SIN, sin);
  46. cb(TANH, tanh);
  47. //! binary
  48. cb(ADD, add);
  49. cb(MAX, max);
  50. cb(MIN, min);
  51. cb(MOD, mod);
  52. cb(MUL, mul);
  53. cb(SUB, sub);
  54. cb(TRUE_DIV, div);
  55. #undef cb
  56. /* ===================== unary op ===================== */
  57. //! ACOS: pi / 2 - arctan2(x, sqrt(1 - x * x))
  58. template <>
  59. mlir::Value lower_mode<Mode::ACOS>(mlir::OpBuilder& builder, mlir::Location loc,
  60. ValueRange operands) {
  61. ValueBuilderHelper helper(builder, loc);
  62. auto x = operands[0];
  63. auto one_minus_x_2 = helper.sub(helper.const_f32(1.f), helper.mul(x, x));
  64. auto asin = atan2_approx(helper, x, helper.sqrt(one_minus_x_2));
  65. auto pi_over_2 = helper.const_f32(1.57079637f);
  66. return helper.sub(pi_over_2, asin);
  67. }
  68. //! ASIN: arctan2(x, sqrt(1 - x * x))
  69. template <>
  70. mlir::Value lower_mode<Mode::ASIN>(mlir::OpBuilder& builder, mlir::Location loc,
  71. ValueRange operands) {
  72. ValueBuilderHelper helper(builder, loc);
  73. auto x = operands[0];
  74. auto one_minus_x_2 = helper.sub(helper.const_f32(1.f), helper.mul(x, x));
  75. return atan2_approx(helper, x, helper.sqrt(one_minus_x_2));
  76. }
  77. //! ERFCINV: inverse of complementary gauss error function
  78. //! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c
  79. template <>
  80. mlir::Value lower_mode<Mode::ERFCINV>(mlir::OpBuilder& builder,
  81. mlir::Location loc, ValueRange operands) {
  82. ValueBuilderHelper helper(builder, loc);
  83. auto minus_sqrt2 = helper.const_f32(-1.4142135623f);
  84. auto x = helper.mul(helper.const_f32(0.5f), operands[0]);
  85. return helper.div(ndtri_approx(helper, x), minus_sqrt2);
  86. }
  87. //! ERFC: complementary error function
  88. template <>
  89. mlir::Value lower_mode<Mode::ERFC>(mlir::OpBuilder& builder, mlir::Location loc,
  90. ValueRange operands) {
  91. ValueBuilderHelper helper(builder, loc);
  92. return helper.sub(helper.const_f32(1.f), erf_approx(helper, operands[0]));
  93. }
  94. //! ERFINV: inverse of gauss error function
  95. //! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c
  96. template <>
  97. mlir::Value lower_mode<Mode::ERFINV>(mlir::OpBuilder& builder,
  98. mlir::Location loc, ValueRange operands) {
  99. ValueBuilderHelper helper(builder, loc);
  100. auto sqrt2 = helper.const_f32(1.4142135623f);
  101. auto x = helper.mul(helper.const_f32(0.5f),
  102. helper.add(operands[0], helper.const_f32(1.f)));
  103. return helper.div(ndtri_approx(helper, x), sqrt2);
  104. }
  105. //! ERF: gauss error function
  106. template <>
  107. mlir::Value lower_mode<Mode::ERF>(mlir::OpBuilder& builder, mlir::Location loc,
  108. ValueRange operands) {
  109. ValueBuilderHelper helper(builder, loc);
  110. return erf_approx(helper, operands[0]);
  111. }
  112. //! EXPM1: exp(x) - 1
  113. template <>
  114. mlir::Value lower_mode<Mode::EXPM1>(mlir::OpBuilder& builder,
  115. mlir::Location loc, ValueRange operands) {
  116. ValueBuilderHelper helper(builder, loc);
  117. return helper.sub(helper.exp(operands[0]), helper.const_f32(1.f));
  118. }
  119. //! FAST_TANH: x * (27.f + x * x) / (27.f + 9.f * x * x);
  120. template <>
  121. mlir::Value lower_mode<Mode::FAST_TANH>(mlir::OpBuilder& builder,
  122. mlir::Location loc,
  123. ValueRange operands) {
  124. ValueBuilderHelper helper(builder, loc);
  125. auto square = helper.mul(operands[0], operands[0]);
  126. return helper.div(
  127. helper.mul(operands[0], helper.add(helper.const_f32(27.f), square)),
  128. helper.add(helper.const_f32(27.f),
  129. helper.mul(helper.const_f32(9.f), square)));
  130. }
  131. //! H_SWISH: x * clip(x + 3, 0, 6) / 6
  132. template <>
  133. mlir::Value lower_mode<Mode::H_SWISH>(mlir::OpBuilder& builder,
  134. mlir::Location loc, ValueRange operands) {
  135. ValueBuilderHelper helper(builder, loc);
  136. auto const_3 = helper.const_f32(3.f);
  137. auto const_0 = helper.const_f32(0.f);
  138. auto const_6 = helper.const_f32(6.f);
  139. auto tmp = helper.add(operands[0], const_3);
  140. return helper.div(helper.mul(operands[0],
  141. helper.min(helper.max(tmp, const_0), const_6)),
  142. const_6);
  143. }
  144. //! LOG1P: log(1 + p)
  145. template <>
  146. mlir::Value lower_mode<Mode::LOG1P>(mlir::OpBuilder& builder,
  147. mlir::Location loc, ValueRange operands) {
  148. ValueBuilderHelper helper(builder, loc);
  149. return helper.log(helper.add(operands[0], helper.const_f32(1.f)));
  150. }
  151. //! RELU: max(x, 0)
  152. template <>
  153. mlir::Value lower_mode<Mode::RELU>(mlir::OpBuilder& builder, mlir::Location loc,
  154. ValueRange operands) {
  155. ValueBuilderHelper helper(builder, loc);
  156. return helper.max(operands[0], helper.const_f32(0.f));
  157. }
  158. //! ROUND
  159. template <>
  160. mlir::Value lower_mode<Mode::ROUND>(mlir::OpBuilder& builder,
  161. mlir::Location loc, ValueRange operands) {
  162. ValueBuilderHelper helper(builder, loc);
  163. return helper.select(
  164. helper.gt(operands[0], helper.const_f32(0.f)),
  165. helper.floor(helper.add(operands[0], helper.const_f32(0.5f))),
  166. helper.ceil(helper.sub(operands[0], helper.const_f32(0.5f))));
  167. }
  168. //! SIGMOID: 1.f / (expf(-y) + 1.f))
  169. template <>
  170. mlir::Value lower_mode<Mode::SIGMOID>(mlir::OpBuilder& builder,
  171. mlir::Location loc, ValueRange operands) {
  172. ValueBuilderHelper helper(builder, loc);
  173. return helper.div(helper.const_f32(1.f),
  174. helper.add(helper.exp(helper.neg(operands[0])),
  175. helper.const_f32(1.f)));
  176. }
  177. /* ===================== binary op ===================== */
  178. //! ABS_GRAD: x > 0 ? y : -y
  179. template <>
  180. mlir::Value lower_mode<Mode::ABS_GRAD>(mlir::OpBuilder& builder,
  181. mlir::Location loc,
  182. ValueRange operands) {
  183. ValueBuilderHelper helper(builder, loc);
  184. return helper.select(helper.gt(operands[0], helper.const_f32(0.f)),
  185. operands[1], helper.neg(operands[1]));
  186. }
  187. //! ATAN2
  188. template <>
  189. mlir::Value lower_mode<Mode::ATAN2>(mlir::OpBuilder& builder,
  190. mlir::Location loc, ValueRange operands) {
  191. ValueBuilderHelper helper(builder, loc);
  192. return atan2_approx(helper, operands[0], operands[1]);
  193. }
  194. //! EQ: x == y ? 1 : 0
  195. template <>
  196. mlir::Value lower_mode<Mode::EQ>(mlir::OpBuilder& builder, mlir::Location loc,
  197. ValueRange operands) {
  198. ValueBuilderHelper helper(builder, loc);
  199. return helper.select(helper.eq(operands[0], operands[1]),
  200. helper.const_f32(1.f), helper.const_f32(0.f));
  201. }
  202. //! FAST_TANH_GRAD: ((-48.f * x * x) / (3.f + x * x) + 27.f + x * x) / (3.f + x
  203. //! * x) * y
  204. template <>
  205. mlir::Value lower_mode<Mode::FAST_TANH_GRAD>(mlir::OpBuilder& builder,
  206. mlir::Location loc,
  207. ValueRange operands) {
  208. ValueBuilderHelper helper(builder, loc);
  209. auto x_pow2 = helper.mul(operands[0], operands[0]);
  210. auto deno = helper.add(helper.const_f32(3.f), x_pow2);
  211. return helper.mul(
  212. helper.div(
  213. helper.add(
  214. helper.add(helper.div(helper.mul(helper.const_f32(
  215. -48.f),
  216. x_pow2),
  217. deno),
  218. helper.const_f32(27.f)),
  219. x_pow2),
  220. helper.mul(deno, helper.const_f32(9.f))),
  221. operands[1]);
  222. }
  223. //! FLOOR_DIV: floor(x/y)
  224. template <>
  225. mlir::Value lower_mode<Mode::FLOOR_DIV>(mlir::OpBuilder& builder,
  226. mlir::Location loc,
  227. ValueRange operands) {
  228. ValueBuilderHelper helper(builder, loc);
  229. return helper.floor(helper.div(operands[0], operands[1]));
  230. }
  231. //! FUSE_ADD_H_SWISH: (x+y) * min(max(x + y + 3, 0), 6) * (1/6)
  232. template <>
  233. mlir::Value lower_mode<Mode::FUSE_ADD_H_SWISH>(mlir::OpBuilder& builder,
  234. mlir::Location loc,
  235. ValueRange operands) {
  236. ValueBuilderHelper helper(builder, loc);
  237. auto sum = helper.add(operands[0], operands[1]);
  238. auto const_3 = helper.const_f32(3.f);
  239. auto const_0 = helper.const_f32(0.f);
  240. auto const_6 = helper.const_f32(6.f);
  241. auto tmp = helper.add(sum, const_3);
  242. return helper.div(
  243. helper.mul(sum, helper.min(helper.max(tmp, const_0), const_6)),
  244. const_6);
  245. }
  246. //! FUSE_ADD_RELU: (x + y) <= ctype(0) ? ctype(0) : (x + y)
  247. template <>
  248. mlir::Value lower_mode<Mode::FUSE_ADD_RELU>(mlir::OpBuilder& builder,
  249. mlir::Location loc,
  250. ValueRange operands) {
  251. ValueBuilderHelper helper(builder, loc);
  252. auto sum = helper.add(operands[0], operands[1]);
  253. return helper.max(sum, helper.const_f32(0.f));
  254. }
  255. //! FUSE_ADD_SIGMOID: 1.f / (expf(-(x+y)) + 1.f))
  256. template <>
  257. mlir::Value lower_mode<Mode::FUSE_ADD_SIGMOID>(mlir::OpBuilder& builder,
  258. mlir::Location loc,
  259. ValueRange operands) {
  260. ValueBuilderHelper helper(builder, loc);
  261. return helper.div(helper.const_f32(1.f),
  262. helper.add(helper.exp(helper.neg(
  263. helper.add(operands[0], operands[1]))),
  264. helper.const_f32(1.f)));
  265. }
  266. //! FUSE_ADD_TANH: tanh(x + y)
  267. template <>
  268. mlir::Value lower_mode<Mode::FUSE_ADD_TANH>(mlir::OpBuilder& builder,
  269. mlir::Location loc,
  270. ValueRange operands) {
  271. ValueBuilderHelper helper(builder, loc);
  272. return helper.tanh(helper.add(operands[0], operands[1]));
  273. }
  274. //! H_SWISH_GRAD: x < -3.f ? 0.f : (x > 3.f ? y : (2.f * x + 3.f) / 6.f * y)
  275. template <>
  276. mlir::Value lower_mode<Mode::H_SWISH_GRAD>(mlir::OpBuilder& builder,
  277. mlir::Location loc,
  278. ValueRange operands) {
  279. ValueBuilderHelper helper(builder, loc);
  280. return helper.select(
  281. helper.lt(operands[0], helper.const_f32(-3.f)),
  282. helper.const_f32(0.f),
  283. helper.select(
  284. helper.gt(operands[0], helper.const_f32(3.f)), operands[1],
  285. helper.mul(
  286. helper.div(
  287. helper.add(helper.mul(helper.const_f32(2.f),
  288. operands[0]),
  289. helper.const_f32(3.f)),
  290. helper.const_f32(6.f)),
  291. operands[1])));
  292. }
  293. //! LEQ: x <= y ? 1 : 0
  294. template <>
  295. mlir::Value lower_mode<Mode::LEQ>(mlir::OpBuilder& builder, mlir::Location loc,
  296. ValueRange operands) {
  297. ValueBuilderHelper helper(builder, loc);
  298. return helper.select(helper.le(operands[0], operands[1]),
  299. helper.const_f32(1.f), helper.const_f32(0.f));
  300. }
  301. //! LOG_SUM_EXP: log(exp(x) + exp(y))
  302. template <>
  303. mlir::Value lower_mode<Mode::LOG_SUM_EXP>(mlir::OpBuilder& builder,
  304. mlir::Location loc,
  305. ValueRange operands) {
  306. ValueBuilderHelper helper(builder, loc);
  307. return helper.log(
  308. helper.add(helper.exp(operands[0]), helper.exp(operands[1])));
  309. }
  310. //! LT: x < y ? 1 : 0
  311. template <>
  312. mlir::Value lower_mode<Mode::LT>(mlir::OpBuilder& builder, mlir::Location loc,
  313. ValueRange operands) {
  314. ValueBuilderHelper helper(builder, loc);
  315. return helper.select(helper.lt(operands[0], operands[1]),
  316. helper.const_f32(1.f), helper.const_f32(0.f));
  317. }
  318. //! POW: x^y = exp(y * log(x))
  319. template <>
  320. mlir::Value lower_mode<Mode::POW>(mlir::OpBuilder& builder, mlir::Location loc,
  321. ValueRange operands) {
  322. ValueBuilderHelper helper(builder, loc);
  323. return helper.exp(helper.mul(operands[1], helper.log(operands[0])));
  324. }
  325. //! SIGMOID_GRAD: x * (1 - x) * y
  326. template <>
  327. mlir::Value lower_mode<Mode::SIGMOID_GRAD>(mlir::OpBuilder& builder,
  328. mlir::Location loc,
  329. ValueRange operands) {
  330. ValueBuilderHelper helper(builder, loc);
  331. return helper.mul(helper.mul(operands[0], helper.sub(helper.const_f32(1.f),
  332. operands[0])),
  333. operands[1]);
  334. }
  335. //! SWITCH_GT0: (x > 0) * y
  336. template <>
  337. mlir::Value lower_mode<Mode::SWITCH_GT0>(mlir::OpBuilder& builder,
  338. mlir::Location loc,
  339. ValueRange operands) {
  340. ValueBuilderHelper helper(builder, loc);
  341. return helper.select(helper.gt(operands[0], helper.const_f32(0.f)),
  342. operands[1], helper.const_f32(0.f));
  343. }
  344. //! TANH_GRAD: (1 - x * x) * y
  345. template <>
  346. mlir::Value lower_mode<Mode::TANH_GRAD>(mlir::OpBuilder& builder,
  347. mlir::Location loc,
  348. ValueRange operands) {
  349. ValueBuilderHelper helper(builder, loc);
  350. return helper.mul(helper.sub(helper.const_f32(1.0f),
  351. helper.mul(operands[0], operands[0])),
  352. operands[1]);
  353. }
  354. /* ===================== ternary op ===================== */
  355. //! COND_LEQ_MOV: x <= y ? z : ctype(0)
  356. template <>
  357. mlir::Value lower_mode<Mode::COND_LEQ_MOV>(mlir::OpBuilder& builder,
  358. mlir::Location loc,
  359. ValueRange operands) {
  360. ValueBuilderHelper helper(builder, loc);
  361. return helper.select(helper.le(operands[0], operands[1]), operands[2],
  362. helper.const_f32(0.f));
  363. }
  364. //! FUSE_MUL_ADD3: x * y + z
  365. template <>
  366. mlir::Value lower_mode<Mode::FUSE_MUL_ADD3>(mlir::OpBuilder& builder,
  367. mlir::Location loc,
  368. ValueRange operands) {
  369. ValueBuilderHelper helper(builder, loc);
  370. return helper.add(helper.mul(operands[0], operands[1]), operands[2]);
  371. }
  372. /* ===================== elemwise ===================== */
  373. mlir::Value lower_elemwise_to_std(mlir::Operation* op, mlir::OpBuilder& builder,
  374. mlir::Location loc, ValueRange operands) {
  375. auto mode = llvm::dyn_cast<dialect::Elemwise>(op).mode();
  376. switch (mode) {
  377. #define cb(_, _mode) \
  378. case Mode::_mode: \
  379. return lower_mode<Mode::_mode>(builder, loc, operands);
  380. MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb);
  381. MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb);
  382. MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb);
  383. default:
  384. return nullptr;
  385. }
  386. #undef cb
  387. }
  388. /* ===================== typecvt ===================== */
  389. mlir::Value lower_typecvt_to_std(mlir::Operation* op, mlir::OpBuilder& builder,
  390. mlir::Location loc, mlir::Value input) {
  391. auto&& typecvt = llvm::dyn_cast<dialect::TypeCvt>(op);
  392. megdnn::DType idtype = typecvt.idtype();
  393. megdnn::DType odtype = typecvt.odtype();
  394. mlir::Type itype = input.getType();
  395. mlir::Type otype = megdnn_dtype_to_mlir_type(odtype, builder.getContext());
  396. if (mlir::FPExtOp::areCastCompatible(itype, otype)) {
  397. return builder.create<mlir::FPExtOp>(loc, otype, input);
  398. } else if (mlir::FPTruncOp::areCastCompatible(itype, otype)) {
  399. return builder.create<mlir::FPTruncOp>(loc, otype, input);
  400. } else if (mlir::FPToSIOp::areCastCompatible(itype, otype) and
  401. is_signed_int_dtype(odtype)) {
  402. return builder.create<mlir::FPToSIOp>(loc, otype, input);
  403. } else if (mlir::FPToUIOp::areCastCompatible(itype, otype) and
  404. is_unsigned_int_dtype(odtype)) {
  405. return builder.create<mlir::FPToUIOp>(loc, otype, input);
  406. } else if (mlir::SIToFPOp::areCastCompatible(itype, otype) and
  407. is_signed_int_dtype(idtype)) {
  408. return builder.create<mlir::SIToFPOp>(loc, otype, input);
  409. } else if (mlir::UIToFPOp::areCastCompatible(itype, otype) and
  410. is_unsigned_int_dtype(idtype)) {
  411. return builder.create<mlir::UIToFPOp>(loc, otype, input);
  412. } else {
  413. mgb_throw(InternalError, "cannot convert from %s to %s", idtype.name(),
  414. odtype.name());
  415. }
  416. return nullptr;
  417. }
  418. } // namespace jit
  419. } // namespace mgb
  420. #endif // MGB_JIT && MGB_JIT_MLIR
  421. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台