/** * \file dnn/src/cuda/matrix_mul/cublasLt_wrapper.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include #include "./algos.h" #include "megdnn/basic_types.h" #include "megdnn/oprs/nn.h" #include "src/common/utils.h" #include "src/cuda/utils.h" #if CUDA_VERSION >= 10010 #include namespace megdnn { namespace cuda { struct CUBLASLTMatmulDesc { struct SizeArgs { using MMSizeArgs = MatrixMulForwardImpl::AlgoBase::SizeArgs; HandleImpl* handle; bool transposeA, transposeB; TensorLayout layout_a, layout_b, layout_c; std::string to_string() const; SizeArgs(HandleImpl* handle, bool transposeA, bool transposeB, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) : handle(handle), transposeA(transposeA), transposeB(transposeB), layout_a(A), layout_b(B), layout_c(C){}; explicit SizeArgs(const MMSizeArgs& args) : layout_a(args.layout_a), layout_b(args.layout_b), layout_c(args.layout_c) { handle = concrete_handle(args.opr->handle()); auto&& param = args.opr->param(); transposeA = param.transposeA; transposeB = param.transposeB; }; }; bool is_batched; cublasLtMatmulDesc_t matmul_desc; cudaDataType_t dt_a, dt_b, dt_c; cublasComputeType_t dt_compute; cublasLtMatrixLayout_t layout_a, layout_b, layout_c; cublasLtMatrixLayout_t layout_trans_a, layout_trans_b, layout_trans_c; size_t workspace_a, workspace_b, workspace_c; CUBLASLTMatmulDesc(const SizeArgs& args, bool batched = false) : matmul_desc(nullptr), layout_a(nullptr), layout_b(nullptr), layout_c(nullptr), layout_trans_a(nullptr), layout_trans_b(nullptr), layout_trans_c(nullptr), workspace_a(0), workspace_b(0), workspace_c(0) { is_batched = batched; set(args, batched); } ~CUBLASLTMatmulDesc(); void set(const SizeArgs& args, bool batched = false); void reset(); bool get_algorithm_heuristic(const SizeArgs& args, size_t ws_limit, cublasLtMatmulAlgo_t& algo); WorkspaceBundle get_workspace_bundle(const SizeArgs& args, const cublasLtMatmulAlgo_t& algo); bool is_available(const SizeArgs& args, size_t ws_limit); }; } // namespace cuda } // namespace megdnn #endif // vim: syntax=cpp.doxygen