You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.h 3.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. /**
  2. * \file dnn/src/arm_common/elemwise/opr_impl.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "src/fallback/elemwise/opr_impl.h"
  14. #include "src/arm_common/elemwise_op.h"
  15. namespace megdnn {
  16. namespace arm_common {
  17. class ElemwiseImpl final : public fallback::ElemwiseImpl {
  18. public:
  19. using fallback::ElemwiseImpl::ElemwiseImpl;
  20. void exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) override;
  21. const char* get_algorithm_set_name() const { return "ARM COMMON ELEMWISE"; }
  22. private:
  23. struct KernParam {
  24. BcastType broad_cast_type;
  25. Mode mode;
  26. const TensorND* m_dst;
  27. Handle* handle;
  28. ElemwiseOpParamN<3> ternary_elparam;
  29. ElemwiseOpParamN<2> binary_elparam;
  30. ElemwiseOpParamN<1> unary_elparam;
  31. };
  32. KernParam make_kern_param(ElemwiseImpl* opr);
  33. class AlgoBase;
  34. class AlgoUnary;
  35. class AlgoBinaryVecVec;
  36. class AlgoBinaryVecScalar;
  37. class AlgoBinaryVecBcast101;
  38. class AlgoBinaryVecBcast101x4;
  39. class AlgoTernaryFma3VecVecVec;
  40. class AlgoTernaryFma3VecVecScalar;
  41. class AlgoTernaryFma3Bcast101VecBcast101;
  42. class AlgoTernaryFma3Bcast101x4VecBcast101x4;
  43. class AlgoTernaryFma3VecBcast101Vec;
  44. class AlgoTernaryFma3VecBcast101x4Vec;
  45. class AlgoTernaryFma3VecScalarVec;
  46. class AlgoTernaryFma3VecScalarScalar;
  47. class AlgoPack;
  48. };
  49. /*!
  50. *
  51. * \brief base class for Elemwise algo
  52. *
  53. */
  54. class ElemwiseImpl::AlgoBase : public detail::Algorithm {
  55. public:
  56. virtual bool is_available(const KernParam&) const = 0;
  57. virtual void exec(const KernParam&) const = 0;
  58. virtual ~AlgoBase() = default;
  59. uint32_t type() const override { return INVALID_ALGO_TYPE; };
  60. };
  61. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  62. #define DISPATCH_TYPE(_case) \
  63. if (src0.layout.dtype == dtype::Float32{}) { \
  64. DISPATCH_MODE_FLOAT(_case, float, 0); \
  65. } else if (MEGDNN_FLOAT16_SELECT(src0.layout.dtype == dtype::Float16{}, \
  66. false)) { \
  67. DISPATCH_MODE_FLOAT(_case, __fp16, 1); \
  68. } else if (src0.layout.dtype == dtype::Int32{}) { \
  69. DISPATCH_MODE_INT(_case, int, 2); \
  70. } else if (src0.layout.dtype == dtype::Int16{}) { \
  71. DISPATCH_MODE_INT(_case, dt_int16, 3); \
  72. } else if (src0.layout.dtype == dtype::Int8{}) { \
  73. DISPATCH_MODE_INT(_case, dt_int8, 4); \
  74. }
  75. #else
  76. #define DISPATCH_TYPE(_case) \
  77. if (src0.layout.dtype == dtype::Float32{}) { \
  78. DISPATCH_MODE_FLOAT(_case, float, 0); \
  79. } else if (src0.layout.dtype == dtype::Int32{}) { \
  80. DISPATCH_MODE_INT(_case, int, 2); \
  81. } else if (src0.layout.dtype == dtype::Int16{}) { \
  82. DISPATCH_MODE_INT(_case, dt_int16, 3); \
  83. } else if (src0.layout.dtype == dtype::Int8{}) { \
  84. DISPATCH_MODE_INT(_case, dt_int8, 4); \
  85. }
  86. #endif
  87. } // namespace arm_common
  88. } // namespace megdnn
  89. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台