You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gemm_operation.h 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. /***************************************************************************************************
  2. * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Redistribution and use in source and binary forms, with or without
  5. *modification, are permitted provided that the following conditions are met:
  6. * * Redistributions of source code must retain the above copyright notice,
  7. *this list of conditions and the following disclaimer.
  8. * * Redistributions in binary form must reproduce the above copyright
  9. *notice, this list of conditions and the following disclaimer in the
  10. *documentation and/or other materials provided with the distribution.
  11. * * Neither the name of the NVIDIA CORPORATION nor the names of its
  12. *contributors may be used to endorse or promote products derived from this
  13. *software without specific prior written permission.
  14. *
  15. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  16. *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
  19. *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
  20. * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  21. *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  22. *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
  23. *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
  24. *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. *
  26. **************************************************************************************************/
  27. /**
  28. * \file dnn/src/cuda/cutlass/gemm_operation.h
  29. *
  30. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  31. *
  32. * Unless required by applicable law or agreed to in writing,
  33. * software distributed under the License is distributed on an
  34. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  35. * implied.
  36. */
  37. #pragma once
  38. #include "cutlass/gemm/device/gemm.h"
  39. #include "src/cuda/cutlass/library_internal.h"
  40. ///////////////////////////////////////////////////////////////////////////////////////////////////
  41. namespace cutlass {
  42. namespace library {
  43. ///////////////////////////////////////////////////////////////////////////////////////////////////
  44. /// Check whether Operator has member ReductionKernel using SFINAE (Substitution
  45. /// Failure Is Not An Error)
  46. template <typename Operator>
  47. struct split_k_mode {
  48. template <typename T>
  49. static char check(typename T::ReductionKernel*);
  50. template <typename T>
  51. static int check(...);
  52. SplitKMode operator()() {
  53. if (sizeof(check<Operator>(0)) == sizeof(char)) {
  54. // cutlass::gemm::device::GemmSplitKParallel
  55. return SplitKMode::kParallel;
  56. } else {
  57. // cutlass::gemm::device::Gemm
  58. return SplitKMode::kNone;
  59. }
  60. }
  61. };
  62. ///////////////////////////////////////////////////////////////////////////////////////////////////
  63. template <typename Operator_>
  64. class GemmOperationBase : public Operation {
  65. public:
  66. using Operator = Operator_;
  67. using ElementA = typename Operator::ElementA;
  68. using LayoutA = typename Operator::LayoutA;
  69. using ElementB = typename Operator::ElementB;
  70. using LayoutB = typename Operator::LayoutB;
  71. using ElementC = typename Operator::ElementC;
  72. using LayoutC = typename Operator::LayoutC;
  73. using ElementAccumulator = typename Operator::ElementAccumulator;
  74. GemmOperationBase(char const* name = "unknown_gemm") {
  75. m_description.name = name;
  76. m_description.provider = Provider::kCUTLASS;
  77. m_description.kind = OperationKind::kGemm;
  78. m_description.gemm_kind = GemmKind::kGemm;
  79. m_description.tile_description.threadblock_shape = make_Coord(
  80. Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN,
  81. Operator::ThreadblockShape::kK);
  82. m_description.tile_description.threadblock_stages = Operator::kStages;
  83. m_description.tile_description.warp_count = make_Coord(
  84. Operator::GemmKernel::WarpCount::kM,
  85. Operator::GemmKernel::WarpCount::kN,
  86. Operator::GemmKernel::WarpCount::kK);
  87. m_description.tile_description.math_instruction.instruction_shape = make_Coord(
  88. Operator::InstructionShape::kM, Operator::InstructionShape::kN,
  89. Operator::InstructionShape::kK);
  90. m_description.tile_description.math_instruction.element_accumulator =
  91. NumericTypeMap<ElementAccumulator>::kId;
  92. m_description.tile_description.math_instruction.opcode_class =
  93. OpcodeClassMap<typename Operator::OperatorClass>::kId;
  94. m_description.tile_description.math_instruction.math_operation =
  95. MathOperationMap<typename Operator::Operator>::kId;
  96. m_description.tile_description.minimum_compute_capability =
  97. ArchMap<typename Operator::ArchTag,
  98. typename Operator::OperatorClass>::kMin;
  99. m_description.tile_description.maximum_compute_capability =
  100. ArchMap<typename Operator::ArchTag,
  101. typename Operator::OperatorClass>::kMax;
  102. m_description.A =
  103. make_TensorDescription<ElementA, LayoutA>(Operator::kAlignmentA);
  104. m_description.B =
  105. make_TensorDescription<ElementB, LayoutB>(Operator::kAlignmentB);
  106. m_description.C =
  107. make_TensorDescription<ElementC, LayoutC>(Operator::kAlignmentC);
  108. m_description.stages = Operator::kStages;
  109. split_k_mode<Operator> mode;
  110. m_description.split_k_mode = mode();
  111. }
  112. virtual OperationDescription const& description() const { return m_description; }
  113. protected:
  114. GemmDescription m_description;
  115. };
  116. ///////////////////////////////////////////////////////////////////////////////////////////////////
  117. template <typename Operator_>
  118. class GemmOperation : public GemmOperationBase<Operator_> {
  119. public:
  120. using Operator = Operator_;
  121. using ElementA = typename Operator::ElementA;
  122. using LayoutA = typename Operator::LayoutA;
  123. using ElementB = typename Operator::ElementB;
  124. using LayoutB = typename Operator::LayoutB;
  125. using ElementC = typename Operator::ElementC;
  126. using LayoutC = typename Operator::LayoutC;
  127. using ElementAccumulator = typename Operator::ElementAccumulator;
  128. using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
  129. using OperatorArguments = typename Operator::Arguments;
  130. GemmOperation(char const* name = "unknown_gemm")
  131. : GemmOperationBase<Operator_>(name) {}
  132. virtual Status run(
  133. void const* arguments_ptr, void* device_workspace = nullptr,
  134. cudaStream_t stream = nullptr) const {
  135. GemmArguments const* gemm_args =
  136. reinterpret_cast<GemmArguments const*>(arguments_ptr);
  137. OperatorArguments args;
  138. args.problem_size = gemm_args->problem_size;
  139. args.ref_A = {static_cast<ElementA const*>(gemm_args->A), int(gemm_args->lda)};
  140. args.ref_B = {static_cast<ElementB const*>(gemm_args->B), int(gemm_args->ldb)};
  141. args.ref_C = {static_cast<ElementC const*>(gemm_args->C), int(gemm_args->ldc)};
  142. args.ref_D = {static_cast<ElementC*>(gemm_args->D), int(gemm_args->ldd)};
  143. args.split_k_slices = gemm_args->split_k_slices;
  144. args.epilogue = {
  145. *static_cast<ElementCompute const*>(gemm_args->alpha),
  146. *static_cast<ElementCompute const*>(gemm_args->beta)};
  147. Operator op;
  148. Status status = op.initialize(args, device_workspace);
  149. if (status != Status::kSuccess) {
  150. return status;
  151. }
  152. return op.run(stream);
  153. }
  154. };
  155. ///////////////////////////////////////////////////////////////////////////////////////////////////
  156. } // namespace library
  157. } // namespace cutlass
  158. ///////////////////////////////////////////////////////////////////////////////////////////////////