|
- #
- # \file generator.py
- #
- # \brief Generates the CUTLASS Library's instances
- #
-
- import re
-
- ###################################################################################################
-
- import enum
-
- # The following block implements enum.auto() for Python 3.5 variants that don't include it such
- # as the default 3.5.2 on Ubuntu 16.04.
- #
- # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
-
- try:
- from enum import auto as enum_auto
- except ImportError:
- __cutlass_library_auto_enum = 0
-
- def enum_auto() -> int:
- global __cutlass_library_auto_enum
- i = __cutlass_library_auto_enum
- __cutlass_library_auto_enum += 1
- return i
-
-
- ###################################################################################################
-
- #
- class GeneratorTarget(enum.Enum):
- Library = enum_auto()
-
-
- #
- GeneratorTargetNames = {GeneratorTarget.Library: "library"}
- #
-
- ###################################################################################################
-
- #
- class DataType(enum.Enum):
- b1 = enum_auto()
- u4 = enum_auto()
- u8 = enum_auto()
- u16 = enum_auto()
- u32 = enum_auto()
- u64 = enum_auto()
- s4 = enum_auto()
- s8 = enum_auto()
- s16 = enum_auto()
- s32 = enum_auto()
- s64 = enum_auto()
- f16 = enum_auto()
- bf16 = enum_auto()
- f32 = enum_auto()
- tf32 = enum_auto()
- f64 = enum_auto()
- cf16 = enum_auto()
- cbf16 = enum_auto()
- cf32 = enum_auto()
- ctf32 = enum_auto()
- cf64 = enum_auto()
- cs4 = enum_auto()
- cs8 = enum_auto()
- cs16 = enum_auto()
- cs32 = enum_auto()
- cs64 = enum_auto()
- cu4 = enum_auto()
- cu8 = enum_auto()
- cu16 = enum_auto()
- cu32 = enum_auto()
- cu64 = enum_auto()
- invalid = enum_auto()
-
-
- #
- ShortDataTypeNames = {
- DataType.s32: "i",
- DataType.f16: "h",
- DataType.f32: "s",
- DataType.f64: "d",
- DataType.cf32: "c",
- DataType.cf64: "z",
- }
-
- #
- DataTypeNames = {
- DataType.b1: "b1",
- DataType.u4: "u4",
- DataType.u8: "u8",
- DataType.u16: "u16",
- DataType.u32: "u32",
- DataType.u64: "u64",
- DataType.s4: "s4",
- DataType.s8: "s8",
- DataType.s16: "s16",
- DataType.s32: "s32",
- DataType.s64: "s64",
- DataType.f16: "f16",
- DataType.bf16: "bf16",
- DataType.f32: "f32",
- DataType.tf32: "tf32",
- DataType.f64: "f64",
- DataType.cf16: "cf16",
- DataType.cbf16: "cbf16",
- DataType.cf32: "cf32",
- DataType.ctf32: "ctf32",
- DataType.cf64: "cf64",
- DataType.cu4: "cu4",
- DataType.cu8: "cu8",
- DataType.cu16: "cu16",
- DataType.cu32: "cu32",
- DataType.cu64: "cu64",
- DataType.cs4: "cs4",
- DataType.cs8: "cs8",
- DataType.cs16: "cs16",
- DataType.cs32: "cs32",
- DataType.cs64: "cs64",
- }
-
- DataTypeTag = {
- DataType.b1: "cutlass::uint1b_t",
- DataType.u4: "cutlass::uint4b_t",
- DataType.u8: "uint8_t",
- DataType.u16: "uint16_t",
- DataType.u32: "uint32_t",
- DataType.u64: "uint64_t",
- DataType.s4: "cutlass::int4b_t",
- DataType.s8: "int8_t",
- DataType.s16: "int16_t",
- DataType.s32: "int32_t",
- DataType.s64: "int64_t",
- DataType.f16: "cutlass::half_t",
- DataType.bf16: "cutlass::bfloat16_t",
- DataType.f32: "float",
- DataType.tf32: "cutlass::tfloat32_t",
- DataType.f64: "double",
- DataType.cf16: "cutlass::complex<cutlass::half_t>",
- DataType.cbf16: "cutlass::complex<cutlass::bfloat16_t>",
- DataType.cf32: "cutlass::complex<float>",
- DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
- DataType.cf64: "cutlass::complex<double>",
- DataType.cu4: "cutlass::complex<cutlass::uint4b_t>",
- DataType.cu8: "cutlass::complex<cutlass::uint8_t>",
- DataType.cu16: "cutlass::complex<cutlass::uint16_t>",
- DataType.cu32: "cutlass::complex<cutlass::uint32_t>",
- DataType.cu64: "cutlass::complex<cutlass::uint64_t>",
- DataType.cs4: "cutlass::complex<cutlass::int4b_t>",
- DataType.cs8: "cutlass::complex<cutlass::int8_t>",
- DataType.cs16: "cutlass::complex<cutlass::int16_t>",
- DataType.cs32: "cutlass::complex<cutlass::int32_t>",
- DataType.cs64: "cutlass::complex<cutlass::int64_t>",
- }
-
- DataTypeSize = {
- DataType.b1: 1,
- DataType.u4: 4,
- DataType.u8: 4,
- DataType.u16: 16,
- DataType.u32: 32,
- DataType.u64: 64,
- DataType.s4: 4,
- DataType.s8: 8,
- DataType.s16: 16,
- DataType.s32: 32,
- DataType.s64: 64,
- DataType.f16: 16,
- DataType.bf16: 16,
- DataType.f32: 32,
- DataType.tf32: 32,
- DataType.f64: 64,
- DataType.cf16: 32,
- DataType.cbf16: 32,
- DataType.cf32: 64,
- DataType.ctf32: 32,
- DataType.cf64: 128,
- DataType.cu4: 8,
- DataType.cu8: 16,
- DataType.cu16: 32,
- DataType.cu32: 64,
- DataType.cu64: 128,
- DataType.cs4: 8,
- DataType.cs8: 16,
- DataType.cs16: 32,
- DataType.cs32: 64,
- DataType.cs64: 128,
- }
-
- ###################################################################################################
-
- #
- class ComplexTransform(enum.Enum):
- none = enum_auto()
- conj = enum_auto()
-
-
- #
- ComplexTransformTag = {
- ComplexTransform.none: "cutlass::ComplexTransform::kNone",
- ComplexTransform.conj: "cutlass::ComplexTransform::kConjugate",
- }
-
- #
- RealComplexBijection = [
- (DataType.f16, DataType.cf16),
- (DataType.f32, DataType.cf32),
- (DataType.f64, DataType.cf64),
- ]
-
- #
- def is_complex(data_type):
- for r, c in RealComplexBijection:
- if data_type == c:
- return True
- return False
-
-
- #
- def get_complex_from_real(real_type):
- for r, c in RealComplexBijection:
- if real_type == r:
- return c
- return DataType.invalid
-
-
- #
- def get_real_from_complex(complex_type):
- for r, c in RealComplexBijection:
- if complex_type == c:
- return r
- return DataType.invalid
-
-
- #
- class ComplexMultiplyOp(enum.Enum):
- multiply_add = enum_auto()
- gaussian = enum_auto()
-
-
- ###################################################################################################
-
- #
- class MathOperation(enum.Enum):
- multiply_add = enum_auto()
- multiply_add_saturate = enum_auto()
- xor_popc = enum_auto()
- multiply_add_fast_bf16 = enum_auto()
- multiply_add_fast_f16 = enum_auto()
- multiply_add_complex = enum_auto()
- multiply_add_complex_gaussian = enum_auto()
-
-
- #
- MathOperationTag = {
- MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd",
- MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate",
- MathOperation.xor_popc: "cutlass::arch::OpXorPopc",
- MathOperation.multiply_add_fast_bf16: "cutlass::arch::OpMultiplyAddFastBF16",
- MathOperation.multiply_add_fast_f16: "cutlass::arch::OpMultiplyAddFastF16",
- MathOperation.multiply_add_complex: "cutlass::arch::OpMultiplyAddComplex",
- MathOperation.multiply_add_complex_gaussian: "cutlass::arch::OpMultiplyAddGaussianComplex",
- }
-
- ###################################################################################################
-
- #
- class LayoutType(enum.Enum):
- ColumnMajor = enum_auto()
- RowMajor = enum_auto()
- ColumnMajorInterleaved2 = enum_auto()
- RowMajorInterleaved2 = enum_auto()
- ColumnMajorInterleaved32 = enum_auto()
- RowMajorInterleaved32 = enum_auto()
- ColumnMajorInterleaved64 = enum_auto()
- RowMajorInterleaved64 = enum_auto()
- TensorNHWC = enum_auto()
- TensorNDHWC = enum_auto()
- TensorNCHW = enum_auto()
- TensorNGHWC = enum_auto()
- TensorNC4HW4 = enum_auto()
- TensorC4RSK4 = enum_auto()
- TensorNC8HW8 = enum_auto()
- TensorNC16HW16 = enum_auto()
- TensorNC32HW32 = enum_auto()
- TensorNC64HW64 = enum_auto()
- TensorC32RSK32 = enum_auto()
- TensorC64RSK64 = enum_auto()
- TensorK4RSC4 = enum_auto()
- TensorCK4RS4 = enum_auto()
- TensorCK8RS8 = enum_auto()
- TensorCK16RS16 = enum_auto()
-
-
- #
- LayoutTag = {
- LayoutType.ColumnMajor: "cutlass::layout::ColumnMajor",
- LayoutType.RowMajor: "cutlass::layout::RowMajor",
- LayoutType.ColumnMajorInterleaved2: "cutlass::layout::ColumnMajorInterleaved<2>",
- LayoutType.RowMajorInterleaved2: "cutlass::layout::RowMajorInterleaved<2>",
- LayoutType.ColumnMajorInterleaved32: "cutlass::layout::ColumnMajorInterleaved<32>",
- LayoutType.RowMajorInterleaved32: "cutlass::layout::RowMajorInterleaved<32>",
- LayoutType.ColumnMajorInterleaved64: "cutlass::layout::ColumnMajorInterleaved<64>",
- LayoutType.RowMajorInterleaved64: "cutlass::layout::RowMajorInterleaved<64>",
- LayoutType.TensorNHWC: "cutlass::layout::TensorNHWC",
- LayoutType.TensorNDHWC: "cutlass::layout::TensorNDHWC",
- LayoutType.TensorNCHW: "cutlass::layout::TensorNCHW",
- LayoutType.TensorNGHWC: "cutlass::layout::TensorNGHWC",
- LayoutType.TensorNC4HW4: "cutlass::layout::TensorNCxHWx<4>",
- LayoutType.TensorC4RSK4: "cutlass::layout::TensorCxRSKx<4>",
- LayoutType.TensorNC8HW8: "cutlass::layout::TensorNCxHWx<8>",
- LayoutType.TensorNC16HW16: "cutlass::layout::TensorNCxHWx<16>",
- LayoutType.TensorNC32HW32: "cutlass::layout::TensorNCxHWx<32>",
- LayoutType.TensorC32RSK32: "cutlass::layout::TensorCxRSKx<32>",
- LayoutType.TensorNC64HW64: "cutlass::layout::TensorNCxHWx<64>",
- LayoutType.TensorC64RSK64: "cutlass::layout::TensorCxRSKx<64>",
- LayoutType.TensorK4RSC4: "cutlass::layout::TensorKxRSCx<4>",
- LayoutType.TensorCK4RS4: "cutlass::layout::TensorCKxRSx<4>",
- LayoutType.TensorCK8RS8: "cutlass::layout::TensorCKxRSx<8>",
- LayoutType.TensorCK16RS16: "cutlass::layout::TensorCKxRSx<16>",
- }
-
- #
- TransposedLayout = {
- LayoutType.ColumnMajor: LayoutType.RowMajor,
- LayoutType.RowMajor: LayoutType.ColumnMajor,
- LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2,
- LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2,
- LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32,
- LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32,
- LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64,
- LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64,
- LayoutType.TensorNHWC: LayoutType.TensorNHWC,
- }
-
- #
- ShortLayoutTypeNames = {
- LayoutType.ColumnMajor: "n",
- LayoutType.ColumnMajorInterleaved32: "n2",
- LayoutType.ColumnMajorInterleaved32: "n32",
- LayoutType.ColumnMajorInterleaved64: "n64",
- LayoutType.RowMajor: "t",
- LayoutType.RowMajorInterleaved2: "t2",
- LayoutType.RowMajorInterleaved32: "t32",
- LayoutType.RowMajorInterleaved64: "t64",
- LayoutType.TensorNHWC: "nhwc",
- LayoutType.TensorNDHWC: "ndhwc",
- LayoutType.TensorNCHW: "nchw",
- LayoutType.TensorNGHWC: "nghwc",
- LayoutType.TensorNC4HW4: "nc4hw4",
- LayoutType.TensorC4RSK4: "c4rsk4",
- LayoutType.TensorNC8HW8: "nc8hw8",
- LayoutType.TensorNC16HW16: "nc16hw16",
- LayoutType.TensorNC32HW32: "nc32hw32",
- LayoutType.TensorNC64HW64: "nc64hw64",
- LayoutType.TensorC32RSK32: "c32rsk32",
- LayoutType.TensorC64RSK64: "c64rsk64",
- LayoutType.TensorK4RSC4: "k4rsc4",
- LayoutType.TensorCK4RS4: "ck4rs4",
- LayoutType.TensorCK8RS8: "ck8rs8",
- LayoutType.TensorCK16RS16: "ck16rs16",
- }
-
- #
- ShortComplexLayoutNames = {
- (LayoutType.ColumnMajor, ComplexTransform.none): "n",
- (LayoutType.ColumnMajor, ComplexTransform.conj): "c",
- (LayoutType.RowMajor, ComplexTransform.none): "t",
- (LayoutType.RowMajor, ComplexTransform.conj): "h",
- }
-
- ###################################################################################################
- #
- class OpcodeClass(enum.Enum):
- Simt = enum_auto()
- TensorOp = enum_auto()
- WmmaTensorOp = enum_auto()
-
-
- OpcodeClassNames = {
- OpcodeClass.Simt: "simt",
- OpcodeClass.TensorOp: "tensorop",
- OpcodeClass.WmmaTensorOp: "wmma_tensorop",
- }
-
- OpcodeClassTag = {
- OpcodeClass.Simt: "cutlass::arch::OpClassSimt",
- OpcodeClass.TensorOp: "cutlass::arch::OpClassTensorOp",
- OpcodeClass.WmmaTensorOp: "cutlass::arch::OpClassWmmaTensorOp",
- }
-
- ###################################################################################################
-
- #
- class OperationKind(enum.Enum):
- Gemm = enum_auto()
- Conv2d = enum_auto()
-
-
- #
- OperationKindNames = {OperationKind.Gemm: "gemm", OperationKind.Conv2d: "conv2d"}
-
- #
- class Target(enum.Enum):
- library = enum_auto()
-
-
- ArchitectureNames = {
- 50: "maxwell",
- 60: "pascal",
- 61: "pascal",
- 70: "volta",
- 75: "turing",
- 80: "ampere",
- }
-
- ###################################################################################################
-
- #
- def SubstituteTemplate(template, values):
- text = template
- changed = True
- while changed:
- changed = False
- for key, value in values.items():
- regex = "\\$\\{%s\\}" % key
- newtext = re.sub(regex, value, text)
- if newtext != text:
- changed = True
- text = newtext
- return text
-
-
- ###################################################################################################
-
- #
- class GemmKind(enum.Enum):
- Gemm = enum_auto()
- Sparse = enum_auto()
- Universal = enum_auto()
- PlanarComplex = enum_auto()
- PlanarComplexArray = enum_auto()
- SplitKParallel = enum_auto()
- GemvBatchedStrided = enum_auto()
-
-
- #
- GemmKindNames = {
- GemmKind.Gemm: "gemm",
- GemmKind.Sparse: "spgemm",
- GemmKind.Universal: "gemm",
- GemmKind.PlanarComplex: "gemm_planar_complex",
- GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
- GemmKind.SplitKParallel: "gemm_split_k_parallel",
- GemmKind.GemvBatchedStrided: "gemv_batched_strided",
- }
-
- #
- class EpilogueFunctor(enum.Enum):
- LinearCombination = enum_auto()
- LinearCombinationClamp = enum_auto()
- BiasAddLinearCombination = enum_auto()
- BiasAddLinearCombinationRelu = enum_auto()
- BiasAddLinearCombinationHSwish = enum_auto()
- BiasAddLinearCombinationClamp = enum_auto()
- BiasAddLinearCombinationReluClamp = enum_auto()
- BiasAddLinearCombinationHSwishClamp = enum_auto()
-
-
- #
- EpilogueFunctorTag = {
- EpilogueFunctor.LinearCombination: "cutlass::epilogue::thread::LinearCombination",
- EpilogueFunctor.LinearCombinationClamp: "cutlass::epilogue::thread::LinearCombinationClamp",
- EpilogueFunctor.BiasAddLinearCombination: "cutlass::epilogue::thread::BiasAddLinearCombination",
- EpilogueFunctor.BiasAddLinearCombinationRelu: "cutlass::epilogue::thread::BiasAddLinearCombinationRelu",
- EpilogueFunctor.BiasAddLinearCombinationHSwish: "cutlass::epilogue::thread::BiasAddLinearCombinationHSwish",
- EpilogueFunctor.BiasAddLinearCombinationClamp: "cutlass::epilogue::thread::BiasAddLinearCombinationClamp",
- EpilogueFunctor.BiasAddLinearCombinationReluClamp: "cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp",
- EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: "cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp",
- }
-
- #
- ShortEpilogueNames = {
- EpilogueFunctor.LinearCombination: "id",
- EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: "hswish",
- EpilogueFunctor.BiasAddLinearCombinationReluClamp: "relu",
- EpilogueFunctor.BiasAddLinearCombinationClamp: "id",
- EpilogueFunctor.BiasAddLinearCombinationHSwish: "hswish",
- EpilogueFunctor.BiasAddLinearCombinationRelu: "relu",
- EpilogueFunctor.BiasAddLinearCombination: "id",
- }
-
-
- #
- class SwizzlingFunctor(enum.Enum):
- Identity1 = enum_auto()
- Identity2 = enum_auto()
- Identity4 = enum_auto()
- Identity8 = enum_auto()
- ConvFpropNCxHWx = enum_auto()
- ConvFpropTrans = enum_auto()
- ConvDgradNCxHWx = enum_auto()
- ConvDgradTrans = enum_auto()
- DepthwiseConvolutionFprop = enum_auto()
- DepthwiseConvolutionDgrad = enum_auto()
- DepthwiseConvolutionWgrad = enum_auto()
-
-
- #
- SwizzlingFunctorTag = {
- SwizzlingFunctor.Identity1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>",
- SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>",
- SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
- SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>",
- SwizzlingFunctor.ConvFpropNCxHWx: "cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle",
- SwizzlingFunctor.ConvFpropTrans: "cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle",
- SwizzlingFunctor.ConvDgradNCxHWx: "cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle",
- SwizzlingFunctor.ConvDgradTrans: "cutlass::conv::threadblock::ConvolutionDgradTransThreadblockSwizzle",
- SwizzlingFunctor.DepthwiseConvolutionFprop: "cutlass::conv::threadblock::DepthwiseConvolutionFpropThreadblockSwizzle",
- SwizzlingFunctor.DepthwiseConvolutionDgrad: "cutlass::conv::threadblock::DepthwiseConvolutionDgradThreadblockSwizzle",
- SwizzlingFunctor.DepthwiseConvolutionWgrad: "cutlass::conv::threadblock::DepthwiseConvolutionWgradThreadblockSwizzle",
- }
-
- ###################################################################################################
-
-
- class ConvType(enum.Enum):
- Convolution = enum_auto()
- BatchConvolution = enum_auto()
- Local = enum_auto()
- LocalShare = enum_auto()
- DepthwiseConvolution = enum_auto()
-
-
- ConvTypeTag = {
- ConvType.Convolution: "cutlass::conv::ConvType::kConvolution",
- ConvType.BatchConvolution: "cutlass::conv::ConvType::kBatchConvolution",
- ConvType.Local: "cutlass::conv::ConvType::kLocal",
- ConvType.LocalShare: "cutlass::conv::ConvType::kLocalShare",
- ConvType.DepthwiseConvolution: "cutlass::conv::ConvType::kDepthwiseConvolution",
- }
-
- #
- class ConvKind(enum.Enum):
- Fprop = enum_auto()
- Dgrad = enum_auto()
- Wgrad = enum_auto()
-
-
- #
- ConvKindTag = {
- ConvKind.Fprop: "cutlass::conv::Operator::kFprop",
- ConvKind.Dgrad: "cutlass::conv::Operator::kDgrad",
- ConvKind.Wgrad: "cutlass::conv::Operator::kWgrad",
- }
-
- ConvKindNames = {
- ConvKind.Fprop: "fprop",
- ConvKind.Dgrad: "dgrad",
- ConvKind.Wgrad: "wgrad",
- }
-
- #
- class IteratorAlgorithm(enum.Enum):
- Analytic = enum_auto()
- Optimized = enum_auto()
-
-
- #
- IteratorAlgorithmTag = {
- IteratorAlgorithm.Analytic: "cutlass::conv::IteratorAlgorithm::kAnalytic",
- IteratorAlgorithm.Optimized: "cutlass::conv::IteratorAlgorithm::kOptimized",
- }
-
- IteratorAlgorithmNames = {
- IteratorAlgorithm.Analytic: "analytic",
- IteratorAlgorithm.Optimized: "optimized",
- }
-
- #
- class StrideSupport(enum.Enum):
- Strided = enum_auto()
- Unity = enum_auto()
-
-
- #
- StrideSupportTag = {
- StrideSupport.Strided: "cutlass::conv::StrideSupport::kStrided",
- StrideSupport.Unity: "cutlass::conv::StrideSupport::kUnity",
- }
-
- StrideSupportNames = {StrideSupport.Strided: "", StrideSupport.Unity: "unity_stride"}
-
-
- class SpecialOptimizeDesc(enum.Enum):
- NoneSpecialOpt = enum_auto()
- ConvFilterUnity = enum_auto()
- DeconvDoubleUpsampling = enum_auto()
-
-
- SpecialOptimizeDescNames = {
- SpecialOptimizeDesc.NoneSpecialOpt: "none",
- SpecialOptimizeDesc.ConvFilterUnity: "conv_filter_unity",
- SpecialOptimizeDesc.DeconvDoubleUpsampling: "deconv_double_upsampling",
- }
-
- SpecialOptimizeDescTag = {
- SpecialOptimizeDesc.NoneSpecialOpt: "cutlass::conv::SpecialOptimizeDesc::NONE",
- SpecialOptimizeDesc.ConvFilterUnity: "cutlass::conv::SpecialOptimizeDesc::CONV_FILTER_UNITY",
- SpecialOptimizeDesc.DeconvDoubleUpsampling: "cutlass::conv::SpecialOptimizeDesc::DECONV_DOUBLE_UPSAMPLING",
- }
-
-
- class ImplicitGemmMode(enum.Enum):
- GemmNT = enum_auto()
- GemmTN = enum_auto()
-
-
- ImplicitGemmModeNames = {
- ImplicitGemmMode.GemmNT: "gemm_nt",
- ImplicitGemmMode.GemmTN: "gemm_tn",
- }
-
- ImplicitGemmModeTag = {
- ImplicitGemmMode.GemmNT: "cutlass::conv::ImplicitGemmMode::GEMM_NT",
- ImplicitGemmMode.GemmTN: "cutlass::conv::ImplicitGemmMode::GEMM_TN",
- }
-
- ###################################################################################################
-
- #
- class MathInstruction:
- def __init__(
- self,
- instruction_shape,
- element_a,
- element_b,
- element_accumulator,
- opcode_class,
- math_operation=MathOperation.multiply_add,
- ):
- self.instruction_shape = instruction_shape
- self.element_a = element_a
- self.element_b = element_b
- self.element_accumulator = element_accumulator
- self.opcode_class = opcode_class
- self.math_operation = math_operation
-
-
- #
- class TileDescription:
- def __init__(
- self,
- threadblock_shape,
- stages,
- warp_count,
- math_instruction,
- min_compute,
- max_compute,
- ):
- self.threadblock_shape = threadblock_shape
- self.stages = stages
- self.warp_count = warp_count
- self.math_instruction = math_instruction
- self.minimum_compute_capability = min_compute
- self.maximum_compute_capability = max_compute
-
- def procedural_name(self):
- return "%dx%d_%dx%d" % (
- self.threadblock_shape[0],
- self.threadblock_shape[1],
- self.threadblock_shape[2],
- self.stages,
- )
-
-
- #
- class TensorDescription:
- def __init__(
- self, element, layout, alignment=1, complex_transform=ComplexTransform.none
- ):
- self.element = element
- self.layout = layout
- self.alignment = alignment
- self.complex_transform = complex_transform
-
-
- ###################################################################################################
-
-
- class GlobalCnt:
- cnt = 0
|