You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.h 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. /**
  2. * \file dnn/src/fallback/elemwise/opr_impl.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "src/fallback/elemwise_helper/elemwise_op.h"
  13. #include "src/naive/elemwise/opr_impl.h"
  14. namespace megdnn {
  15. namespace fallback {
  16. class ElemwiseImpl : public naive::ElemwiseForwardImpl {
  17. template <typename dtype, uint32_t mode>
  18. void unary_kern(const ElemwiseOpParamN<1>& param);
  19. template <uint32_t mode>
  20. void exec_UNARY_INT();
  21. template <uint32_t mode>
  22. void exec_UNARY_FLOAT();
  23. template <typename dtype, uint32_t mode>
  24. void binary_kern(const ElemwiseOpParamN<2>& param);
  25. template <uint32_t mode>
  26. void exec_BINARY_INT();
  27. template <uint32_t mode>
  28. void exec_BINARY_FLOAT();
  29. void exec_fallback(const TensorNDArray& srcs, _megdnn_tensor_out dst);
  30. bool exec_gi_intrinsic(const TensorNDArray& srcs, _megdnn_tensor_out dst);
  31. private:
  32. class AlgoUnary;
  33. class AlgoBinaryVecVec;
  34. class AlgoBinaryVecScalar;
  35. class AlgoBinaryVecBcast101;
  36. class AlgoBinaryVecBcastX0X;
  37. class AlgoBinaryVecBcast111C;
  38. class AlgoBinaryVecBcast101xX;
  39. class AlgoTernaryFma3VecVecVec;
  40. class AlgoTernaryFma3VecVecScalar;
  41. class AlgoTernaryFma3Bcast101VecBcast101;
  42. class AlgoTernaryFma3Bcast111CVecBcast111C;
  43. class AlgoTernaryFma3Bcast101xXVecBcast101xX;
  44. class AlgoTernaryFma3VecBcast101Vec;
  45. class AlgoTernaryFma3VecBcast111CVec;
  46. class AlgoTernaryFma3VecBcast101xXVec;
  47. class AlgoTernaryFma3VecScalarVec;
  48. class AlgoTernaryFma3VecScalarScalar;
  49. class AlgoPack;
  50. public:
  51. class AlgoBase;
  52. struct KernParam {
  53. elemwise::BcastType broad_cast_type;
  54. Mode mode;
  55. const TensorND* m_dst;
  56. Handle* handle;
  57. ElemwiseOpParamN<3> ternary_elparam;
  58. ElemwiseOpParamN<2> binary_elparam;
  59. ElemwiseOpParamN<1> unary_elparam;
  60. };
  61. KernParam make_kern_param(ElemwiseImpl* opr);
  62. using naive::ElemwiseForwardImpl::ElemwiseForwardImpl;
  63. void exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) override;
  64. const char* get_algorithm_set_name() const { return "FALLBACK ELEMWISE"; }
  65. bool is_thread_safe() const override { return true; }
  66. };
  67. /*!
  68. * \brief base class for Elemwise algo
  69. */
  70. class ElemwiseImpl::AlgoBase : public detail::Algorithm {
  71. public:
  72. virtual bool is_available(const KernParam&) const = 0;
  73. virtual void exec(const KernParam&) const = 0;
  74. virtual ~AlgoBase() = default;
  75. uint32_t type() const override { return INVALID_ALGO_TYPE; };
  76. };
  77. //! fallback only support float, int32, int8
  78. #define DISPATCH_TYPE_FALLBACK(_case) \
  79. if (src0.layout.dtype == dtype::Float32{}) { \
  80. DISPATCH_MODE_FLOAT(_case, float, 0); \
  81. } else if (src0.layout.dtype == dtype::Int32{}) { \
  82. DISPATCH_MODE_INT(_case, int, 2); \
  83. } else if (src0.layout.dtype == dtype::Int8{}) { \
  84. DISPATCH_MODE_INT(_case, dt_int8, 4); \
  85. }
  86. } // namespace fallback
  87. } // namespace megdnn
  88. // vim: syntax=cpp.doxygen