/*************************************************************************************************** * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without *modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, *this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright *notice, this list of conditions and the following disclaimer in the *documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the names of its *contributors may be used to endorse or promote products derived from this *software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** * \file dnn/src/cuda/cutlass/gemm_operation.h * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #pragma once #include "cutlass/gemm/device/gemm.h" #include "src/cuda/cutlass/library_internal.h" /////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace library { /////////////////////////////////////////////////////////////////////////////////////////////////// /// Check whether Operator has member ReductionKernel using SFINAE (Substitution /// Failure Is Not An Error) template struct split_k_mode { template static char check(typename T::ReductionKernel*); template static int check(...); SplitKMode operator()() { if (sizeof(check(0)) == sizeof(char)) { // cutlass::gemm::device::GemmSplitKParallel return SplitKMode::kParallel; } else { // cutlass::gemm::device::Gemm return SplitKMode::kNone; } } }; /////////////////////////////////////////////////////////////////////////////////////////////////// template class GemmOperationBase : public Operation { public: using Operator = Operator_; using ElementA = typename Operator::ElementA; using LayoutA = typename Operator::LayoutA; using ElementB = typename Operator::ElementB; using LayoutB = typename Operator::LayoutB; using ElementC = typename Operator::ElementC; using LayoutC = typename Operator::LayoutC; using ElementAccumulator = typename Operator::ElementAccumulator; GemmOperationBase(char const* name = "unknown_gemm") { m_description.name = name; m_description.provider = Provider::kCUTLASS; m_description.kind = OperationKind::kGemm; m_description.gemm_kind = GemmKind::kGemm; m_description.tile_description.threadblock_shape = make_Coord( Operator::ThreadblockShape::kM, Operator::ThreadblockShape::kN, Operator::ThreadblockShape::kK); m_description.tile_description.threadblock_stages = Operator::kStages; m_description.tile_description.warp_count = make_Coord( Operator::GemmKernel::WarpCount::kM, Operator::GemmKernel::WarpCount::kN, Operator::GemmKernel::WarpCount::kK); m_description.tile_description.math_instruction.instruction_shape = make_Coord( Operator::InstructionShape::kM, Operator::InstructionShape::kN, Operator::InstructionShape::kK); m_description.tile_description.math_instruction.element_accumulator = NumericTypeMap::kId; m_description.tile_description.math_instruction.opcode_class = OpcodeClassMap::kId; m_description.tile_description.math_instruction.math_operation = MathOperationMap::kId; m_description.tile_description.minimum_compute_capability = ArchMap::kMin; m_description.tile_description.maximum_compute_capability = ArchMap::kMax; m_description.A = make_TensorDescription(Operator::kAlignmentA); m_description.B = make_TensorDescription(Operator::kAlignmentB); m_description.C = make_TensorDescription(Operator::kAlignmentC); m_description.stages = Operator::kStages; split_k_mode mode; m_description.split_k_mode = mode(); } virtual OperationDescription const& description() const { return m_description; } protected: GemmDescription m_description; }; /////////////////////////////////////////////////////////////////////////////////////////////////// template class GemmOperation : public GemmOperationBase { public: using Operator = Operator_; using ElementA = typename Operator::ElementA; using LayoutA = typename Operator::LayoutA; using ElementB = typename Operator::ElementB; using LayoutB = typename Operator::LayoutB; using ElementC = typename Operator::ElementC; using LayoutC = typename Operator::LayoutC; using ElementAccumulator = typename Operator::ElementAccumulator; using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; using OperatorArguments = typename Operator::Arguments; GemmOperation(char const* name = "unknown_gemm") : GemmOperationBase(name) {} virtual Status run( void const* arguments_ptr, void* device_workspace = nullptr, cudaStream_t stream = nullptr) const { GemmArguments const* gemm_args = reinterpret_cast(arguments_ptr); OperatorArguments args; args.problem_size = gemm_args->problem_size; args.ref_A = {static_cast(gemm_args->A), int(gemm_args->lda)}; args.ref_B = {static_cast(gemm_args->B), int(gemm_args->ldb)}; args.ref_C = {static_cast(gemm_args->C), int(gemm_args->ldc)}; args.ref_D = {static_cast(gemm_args->D), int(gemm_args->ldd)}; args.split_k_slices = gemm_args->split_k_slices; args.epilogue = { *static_cast(gemm_args->alpha), *static_cast(gemm_args->beta)}; Operator op; Status status = op.initialize(args, device_workspace); if (status != Status::kSuccess) { return status; } return op.run(stream); } }; /////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////////