You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise_unary_trait_def.inl 1.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. /**
  2. * \file src/opr/test/basic_arith/elemwise_unary_trait_def.inl
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #ifndef DEF_TRAIT
  12. #error "DEF_TRAIT must be defined"
  13. #endif
  14. /* ======================= unary ======================= */
  15. #define _CUR_ARITY 1
  16. #define _EXPAND_PARAMS \
  17. ctype x = inp[0][idx]
  18. #define _ALLOW_BOOL true
  19. #define _ALLOW_FLOAT false
  20. #define _ALLOW_INT false
  21. DEF_TRAIT(NOT, !x)
  22. #undef _ALLOW_INT
  23. #undef _ALLOW_FLOAT
  24. #undef _ALLOW_BOOL
  25. #define _ALLOW_BOOL false
  26. #define _ALLOW_FLOAT true
  27. #define _ALLOW_INT true
  28. DEF_TRAIT(ABS, std::abs(x))
  29. DEF_TRAIT(NEGATE, -x)
  30. DEF_TRAIT(RELU, std::max<ctype>(x, 0))
  31. #undef _ALLOW_INT
  32. #define _ALLOW_INT false
  33. DEF_TRAIT(ACOS, std::acos(x))
  34. DEF_TRAIT(ASIN, std::asin(x))
  35. DEF_TRAIT(CEIL, std::ceil(x))
  36. DEF_TRAIT(COS, std::cos(x))
  37. DEF_TRAIT(EXP, std::exp(x))
  38. DEF_TRAIT(EXPM1, std::expm1(x))
  39. DEF_TRAIT(FLOOR, std::floor(x))
  40. DEF_TRAIT(LOG, std::log(x))
  41. DEF_TRAIT(LOG1P, std::log1p(x))
  42. DEF_TRAIT(SIGMOID, 1 / (1 + std::exp(-x)))
  43. DEF_TRAIT(SIN, std::sin(x))
  44. DEF_TRAIT(TANH, std::tanh(x))
  45. DEF_TRAIT(FAST_TANH, do_fast_tanh(x))
  46. DEF_TRAIT(ROUND, std::round(x))
  47. DEF_TRAIT(ERF, std::erf(x))
  48. DEF_TRAIT(ERFINV, do_erfinv(x))
  49. DEF_TRAIT(ERFC, std::erfc(x))
  50. DEF_TRAIT(ERFCINV, do_erfcinv(x))
  51. DEF_TRAIT(H_SWISH, do_h_swish(x))
  52. #undef _ALLOW_INT
  53. #undef _ALLOW_FLOAT
  54. #undef _ALLOW_BOOL
  55. #undef _CUR_ARITY
  56. #undef _EXPAND_PARAMS
  57. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台