|
- #
- # \file generator.py
- #
- # \brief Generates the CUTLASS Library's instances
- #
-
- import re
-
- ###################################################################################################
-
- import enum
-
- # The following block implements enum.auto() for Python 3.5 variants that don't include it such
- # as the default 3.5.2 on Ubuntu 16.04.
- #
- # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
-
- try:
- from enum import auto as enum_auto
- except ImportError:
- __cutlass_library_auto_enum = 0
- def enum_auto() -> int:
- global __cutlass_library_auto_enum
- i = __cutlass_library_auto_enum
- __cutlass_library_auto_enum += 1
- return i
-
- ###################################################################################################
-
- #
- class GeneratorTarget(enum.Enum):
- Library = enum_auto()
- #
- GeneratorTargetNames = {
- GeneratorTarget.Library: 'library'
- }
- #
-
- ###################################################################################################
-
- #
- class DataType(enum.Enum):
- b1 = enum_auto()
- u4 = enum_auto()
- u8 = enum_auto()
- u16 = enum_auto()
- u32 = enum_auto()
- u64 = enum_auto()
- s4 = enum_auto()
- s8 = enum_auto()
- s16 = enum_auto()
- s32 = enum_auto()
- s64 = enum_auto()
- f16 = enum_auto()
- bf16 = enum_auto()
- f32 = enum_auto()
- tf32 = enum_auto()
- f64 = enum_auto()
- cf16 = enum_auto()
- cbf16 = enum_auto()
- cf32 = enum_auto()
- ctf32 = enum_auto()
- cf64 = enum_auto()
- cs4 = enum_auto()
- cs8 = enum_auto()
- cs16 = enum_auto()
- cs32 = enum_auto()
- cs64 = enum_auto()
- cu4 = enum_auto()
- cu8 = enum_auto()
- cu16 = enum_auto()
- cu32 = enum_auto()
- cu64 = enum_auto()
- invalid = enum_auto()
-
- #
- ShortDataTypeNames = {
- DataType.s32: 'i',
- DataType.f16: 'h',
- DataType.f32: 's',
- DataType.f64: 'd',
- DataType.cf32: 'c',
- DataType.cf64: 'z',
- }
-
- #
- DataTypeNames = {
- DataType.b1: "b1",
- DataType.u4: "u4",
- DataType.u8: "u8",
- DataType.u16: "u16",
- DataType.u32: "u32",
- DataType.u64: "u64",
- DataType.s4: "s4",
- DataType.s8: "s8",
- DataType.s16: "s16",
- DataType.s32: "s32",
- DataType.s64: "s64",
- DataType.f16: "f16",
- DataType.bf16: "bf16",
- DataType.f32: "f32",
- DataType.tf32: "tf32",
- DataType.f64: "f64",
- DataType.cf16: "cf16",
- DataType.cbf16: "cbf16",
- DataType.cf32: "cf32",
- DataType.ctf32: "ctf32",
- DataType.cf64: "cf64",
- DataType.cu4: "cu4",
- DataType.cu8: "cu8",
- DataType.cu16: "cu16",
- DataType.cu32: "cu32",
- DataType.cu64: "cu64",
- DataType.cs4: "cs4",
- DataType.cs8: "cs8",
- DataType.cs16: "cs16",
- DataType.cs32: "cs32",
- DataType.cs64: "cs64",
- }
-
- DataTypeTag = {
- DataType.b1: "cutlass::uint1b_t",
- DataType.u4: "cutlass::uint4b_t",
- DataType.u8: "uint8_t",
- DataType.u16: "uint16_t",
- DataType.u32: "uint32_t",
- DataType.u64: "uint64_t",
- DataType.s4: "cutlass::int4b_t",
- DataType.s8: "int8_t",
- DataType.s16: "int16_t",
- DataType.s32: "int32_t",
- DataType.s64: "int64_t",
- DataType.f16: "cutlass::half_t",
- DataType.bf16: "cutlass::bfloat16_t",
- DataType.f32: "float",
- DataType.tf32: "cutlass::tfloat32_t",
- DataType.f64: "double",
- DataType.cf16: "cutlass::complex<cutlass::half_t>",
- DataType.cbf16: "cutlass::complex<cutlass::bfloat16_t>",
- DataType.cf32: "cutlass::complex<float>",
- DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
- DataType.cf64: "cutlass::complex<double>",
- DataType.cu4: "cutlass::complex<cutlass::uint4b_t>",
- DataType.cu8: "cutlass::complex<cutlass::uint8_t>",
- DataType.cu16: "cutlass::complex<cutlass::uint16_t>",
- DataType.cu32: "cutlass::complex<cutlass::uint32_t>",
- DataType.cu64: "cutlass::complex<cutlass::uint64_t>",
- DataType.cs4: "cutlass::complex<cutlass::int4b_t>",
- DataType.cs8: "cutlass::complex<cutlass::int8_t>",
- DataType.cs16: "cutlass::complex<cutlass::int16_t>",
- DataType.cs32: "cutlass::complex<cutlass::int32_t>",
- DataType.cs64: "cutlass::complex<cutlass::int64_t>",
- }
-
- DataTypeSize = {
- DataType.b1: 1,
- DataType.u4: 4,
- DataType.u8: 4,
- DataType.u16: 16,
- DataType.u32: 32,
- DataType.u64: 64,
- DataType.s4: 4,
- DataType.s8: 8,
- DataType.s16: 16,
- DataType.s32: 32,
- DataType.s64: 64,
- DataType.f16: 16,
- DataType.bf16: 16,
- DataType.f32: 32,
- DataType.tf32: 32,
- DataType.f64: 64,
- DataType.cf16: 32,
- DataType.cbf16: 32,
- DataType.cf32: 64,
- DataType.ctf32: 32,
- DataType.cf64: 128,
- DataType.cu4: 8,
- DataType.cu8: 16,
- DataType.cu16: 32,
- DataType.cu32: 64,
- DataType.cu64: 128,
- DataType.cs4: 8,
- DataType.cs8: 16,
- DataType.cs16: 32,
- DataType.cs32: 64,
- DataType.cs64: 128,
- }
-
- ###################################################################################################
-
- #
- class ComplexTransform(enum.Enum):
- none = enum_auto()
- conj = enum_auto()
-
- #
- ComplexTransformTag = {
- ComplexTransform.none: 'cutlass::ComplexTransform::kNone',
- ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate',
- }
-
- #
- RealComplexBijection = [
- (DataType.f16, DataType.cf16),
- (DataType.f32, DataType.cf32),
- (DataType.f64, DataType.cf64),
- ]
-
- #
- def is_complex(data_type):
- for r, c in RealComplexBijection:
- if data_type == c:
- return True
- return False
-
- #
- def get_complex_from_real(real_type):
- for r, c in RealComplexBijection:
- if real_type == r:
- return c
- return DataType.invalid
-
- #
- def get_real_from_complex(complex_type):
- for r, c in RealComplexBijection:
- if complex_type == c:
- return r
- return DataType.invalid
-
- #
- class ComplexMultiplyOp(enum.Enum):
- multiply_add = enum_auto()
- gaussian = enum_auto()
-
- ###################################################################################################
-
- #
- class MathOperation(enum.Enum):
- multiply_add = enum_auto()
- multiply_add_saturate = enum_auto()
- xor_popc = enum_auto()
- multiply_add_fast_bf16 = enum_auto()
- multiply_add_fast_f16 = enum_auto()
- multiply_add_complex = enum_auto()
- multiply_add_complex_gaussian = enum_auto()
-
- #
- MathOperationTag = {
- MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',
- MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate',
- MathOperation.xor_popc: 'cutlass::arch::OpXorPopc',
- MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16',
- MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16',
- MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex',
- MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex',
- }
-
- ###################################################################################################
-
- #
- class LayoutType(enum.Enum):
- ColumnMajor = enum_auto()
- RowMajor = enum_auto()
- ColumnMajorInterleaved2 = enum_auto()
- RowMajorInterleaved2 = enum_auto()
- ColumnMajorInterleaved32 = enum_auto()
- RowMajorInterleaved32 = enum_auto()
- ColumnMajorInterleaved64 = enum_auto()
- RowMajorInterleaved64 = enum_auto()
- TensorNHWC = enum_auto()
- TensorNDHWC = enum_auto()
- TensorNCHW = enum_auto()
- TensorNGHWC = enum_auto()
- TensorNC4HW4 = enum_auto()
- TensorC4RSK4 = enum_auto()
- TensorNC8HW8 = enum_auto()
- TensorNC16HW16 = enum_auto()
- TensorNC32HW32 = enum_auto()
- TensorNC64HW64 = enum_auto()
- TensorC32RSK32 = enum_auto()
- TensorC64RSK64 = enum_auto()
- TensorK4RSC4 = enum_auto()
- TensorCK4RS4 = enum_auto()
- TensorCK8RS8 = enum_auto()
- TensorCK16RS16 = enum_auto()
-
- #
- LayoutTag = {
- LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor',
- LayoutType.RowMajor: 'cutlass::layout::RowMajor',
- LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>',
- LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>',
- LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>',
- LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>',
- LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>',
- LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>',
- LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC',
- LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC',
- LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW',
- LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC',
- LayoutType.TensorNC4HW4: 'cutlass::layout::TensorNCxHWx<4>',
- LayoutType.TensorC4RSK4: 'cutlass::layout::TensorCxRSKx<4>',
- LayoutType.TensorNC8HW8: 'cutlass::layout::TensorNCxHWx<8>',
- LayoutType.TensorNC16HW16: 'cutlass::layout::TensorNCxHWx<16>',
- LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>',
- LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>',
- LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>',
- LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>',
- LayoutType.TensorK4RSC4: 'cutlass::layout::TensorKxRSCx<4>',
- LayoutType.TensorCK4RS4: 'cutlass::layout::TensorCKxRSx<4>',
- LayoutType.TensorCK8RS8: 'cutlass::layout::TensorCKxRSx<8>',
- LayoutType.TensorCK16RS16: 'cutlass::layout::TensorCKxRSx<16>',
- }
-
- #
- TransposedLayout = {
- LayoutType.ColumnMajor: LayoutType.RowMajor,
- LayoutType.RowMajor: LayoutType.ColumnMajor,
- LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2,
- LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2,
- LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32,
- LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32,
- LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64,
- LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64,
- LayoutType.TensorNHWC: LayoutType.TensorNHWC
- }
-
- #
- ShortLayoutTypeNames = {
- LayoutType.ColumnMajor: 'n',
- LayoutType.ColumnMajorInterleaved32: 'n2',
- LayoutType.ColumnMajorInterleaved32: 'n32',
- LayoutType.ColumnMajorInterleaved64: 'n64',
- LayoutType.RowMajor: 't',
- LayoutType.RowMajorInterleaved2: 't2',
- LayoutType.RowMajorInterleaved32: 't32',
- LayoutType.RowMajorInterleaved64: 't64',
- LayoutType.TensorNHWC: 'nhwc',
- LayoutType.TensorNDHWC: 'ndhwc',
- LayoutType.TensorNCHW: 'nchw',
- LayoutType.TensorNGHWC: 'nghwc',
- LayoutType.TensorNC4HW4: 'nc4hw4',
- LayoutType.TensorC4RSK4: 'c4rsk4',
- LayoutType.TensorNC8HW8: 'nc8hw8',
- LayoutType.TensorNC16HW16: 'nc16hw16',
- LayoutType.TensorNC32HW32: 'nc32hw32',
- LayoutType.TensorNC64HW64: 'nc64hw64',
- LayoutType.TensorC32RSK32: 'c32rsk32',
- LayoutType.TensorC64RSK64: 'c64rsk64',
- LayoutType.TensorK4RSC4: 'k4rsc4',
- LayoutType.TensorCK4RS4: 'ck4rs4',
- LayoutType.TensorCK8RS8: 'ck8rs8',
- LayoutType.TensorCK16RS16: 'ck16rs16',
- }
-
- #
- ShortComplexLayoutNames = {
- (LayoutType.ColumnMajor, ComplexTransform.none): 'n',
- (LayoutType.ColumnMajor, ComplexTransform.conj): 'c',
- (LayoutType.RowMajor, ComplexTransform.none): 't',
- (LayoutType.RowMajor, ComplexTransform.conj): 'h'
- }
-
- ###################################################################################################
- #
- class OpcodeClass(enum.Enum):
- Simt = enum_auto()
- TensorOp = enum_auto()
- WmmaTensorOp = enum_auto()
-
- OpcodeClassNames = {
- OpcodeClass.Simt: 'simt',
- OpcodeClass.TensorOp: 'tensorop',
- OpcodeClass.WmmaTensorOp: 'wmma_tensorop',
- }
-
- OpcodeClassTag = {
- OpcodeClass.Simt: 'cutlass::arch::OpClassSimt',
- OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp',
- OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp',
- }
-
- ###################################################################################################
-
- #
- class OperationKind(enum.Enum):
- Gemm = enum_auto()
- Conv2d = enum_auto()
-
- #
- OperationKindNames = {
- OperationKind.Gemm: 'gemm'
- , OperationKind.Conv2d: 'conv2d'
- }
-
- #
- class Target(enum.Enum):
- library = enum_auto()
-
- ArchitectureNames = {
- 50: 'maxwell',
- 60: 'pascal',
- 61: 'pascal',
- 70: 'volta',
- 75: 'turing',
- 80: 'ampere',
- }
-
- ###################################################################################################
-
- #
- def SubstituteTemplate(template, values):
- text = template
- changed = True
- while changed:
- changed = False
- for key, value in values.items():
- regex = "\\$\\{%s\\}" % key
- newtext = re.sub(regex, value, text)
- if newtext != text:
- changed = True
- text = newtext
- return text
-
- ###################################################################################################
-
- #
- class GemmKind(enum.Enum):
- Gemm = enum_auto()
- Sparse = enum_auto()
- Universal = enum_auto()
- PlanarComplex = enum_auto()
- PlanarComplexArray = enum_auto()
- SplitKParallel = enum_auto()
- GemvBatchedStrided = enum_auto()
-
- #
- GemmKindNames = {
- GemmKind.Gemm: "gemm",
- GemmKind.Sparse: "spgemm",
- GemmKind.Universal: "gemm",
- GemmKind.PlanarComplex: "gemm_planar_complex",
- GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
- GemmKind.SplitKParallel: "gemm_split_k_parallel",
- GemmKind.GemvBatchedStrided: "gemv_batched_strided",
- }
-
- #
- class EpilogueFunctor(enum.Enum):
- LinearCombination = enum_auto()
- LinearCombinationClamp = enum_auto()
- BiasAddLinearCombination = enum_auto()
- BiasAddLinearCombinationRelu = enum_auto()
- BiasAddLinearCombinationHSwish = enum_auto()
- BiasAddLinearCombinationClamp = enum_auto()
- BiasAddLinearCombinationReluClamp = enum_auto()
- BiasAddLinearCombinationHSwishClamp = enum_auto()
-
-
- #
- EpilogueFunctorTag = {
- EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination',
- EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp',
- EpilogueFunctor.BiasAddLinearCombination: 'cutlass::epilogue::thread::BiasAddLinearCombination',
- EpilogueFunctor.BiasAddLinearCombinationRelu: 'cutlass::epilogue::thread::BiasAddLinearCombinationRelu',
- EpilogueFunctor.BiasAddLinearCombinationHSwish: 'cutlass::epilogue::thread::BiasAddLinearCombinationHSwish',
- EpilogueFunctor.BiasAddLinearCombinationClamp: 'cutlass::epilogue::thread::BiasAddLinearCombinationClamp',
- EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp',
- EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp',
- }
-
- #
- ShortEpilogueNames = {
- EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'hswish',
- EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'relu',
- EpilogueFunctor.BiasAddLinearCombinationClamp: 'id',
- EpilogueFunctor.BiasAddLinearCombinationHSwish: 'hswish',
- EpilogueFunctor.BiasAddLinearCombinationRelu: 'relu',
- EpilogueFunctor.BiasAddLinearCombination: 'id',
- }
-
-
-
-
-
-
- #
- class SwizzlingFunctor(enum.Enum):
- Identity1 = enum_auto()
- Identity2 = enum_auto()
- Identity4 = enum_auto()
- Identity8 = enum_auto()
- ConvFpropNCxHWx = enum_auto()
- ConvFpropTrans = enum_auto()
- ConvDgradNCxHWx = enum_auto()
- ConvDgradTrans = enum_auto()
-
- #
- SwizzlingFunctorTag = {
- SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>',
- SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>',
- SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>',
- SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>',
- SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle',
- SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle',
- SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle',
- SwizzlingFunctor.ConvDgradTrans: 'cutlass::conv::threadblock::ConvolutionDgradTransThreadblockSwizzle',
- }
-
- ###################################################################################################
-
- class ConvType(enum.Enum):
- Convolution = enum_auto()
- BatchConvolution = enum_auto()
- Local = enum_auto()
- LocalShare = enum_auto()
-
- ConvTypeTag = {
- ConvType.Convolution: 'cutlass::conv::ConvType::kConvolution',
- ConvType.BatchConvolution: 'cutlass::conv::ConvType::kBatchConvolution',
- ConvType.Local: 'cutlass::conv::ConvType::kLocal',
- ConvType.LocalShare : 'cutlass::conv::ConvType::kLocalShare',
- }
-
- #
- class ConvKind(enum.Enum):
- Fprop = enum_auto()
- Dgrad = enum_auto()
- Wgrad = enum_auto()
-
- #
- ConvKindTag = {
- ConvKind.Fprop: 'cutlass::conv::Operator::kFprop',
- ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad',
- ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad'
- }
-
- ConvKindNames = {
- ConvKind.Fprop: 'fprop',
- ConvKind.Dgrad: 'dgrad',
- ConvKind.Wgrad: 'wgrad',
- }
-
- #
- class IteratorAlgorithm(enum.Enum):
- Analytic = enum_auto()
- Optimized = enum_auto()
-
- #
- IteratorAlgorithmTag = {
- IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic',
- IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized',
- }
-
- IteratorAlgorithmNames = {
- IteratorAlgorithm.Analytic: 'analytic',
- IteratorAlgorithm.Optimized: 'optimized',
- }
-
- #
- class StrideSupport(enum.Enum):
- Strided = enum_auto()
- Unity = enum_auto()
-
- #
- StrideSupportTag = {
- StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided',
- StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity',
- }
-
- StrideSupportNames = {
- StrideSupport.Strided: '',
- StrideSupport.Unity: 'unity_stride',
- }
-
- class SpecialOptimizeDesc(enum.Enum):
- NoneSpecialOpt = enum_auto()
- ConvFilterUnity = enum_auto()
- DeconvDoubleUpsampling = enum_auto()
-
- SpecialOptimizeDescNames = {
- SpecialOptimizeDesc.NoneSpecialOpt: 'none',
- SpecialOptimizeDesc.ConvFilterUnity: 'conv_filter_unity',
- SpecialOptimizeDesc.DeconvDoubleUpsampling: 'deconv_double_upsampling',
- }
-
- SpecialOptimizeDescTag = {
- SpecialOptimizeDesc.NoneSpecialOpt: 'cutlass::conv::SpecialOptimizeDesc::NONE',
- SpecialOptimizeDesc.ConvFilterUnity: 'cutlass::conv::SpecialOptimizeDesc::CONV_FILTER_UNITY',
- SpecialOptimizeDesc.DeconvDoubleUpsampling: 'cutlass::conv::SpecialOptimizeDesc::DECONV_DOUBLE_UPSAMPLING',
- }
-
-
- class ImplicitGemmMode(enum.Enum):
- GemmNT = enum_auto()
- GemmTN = enum_auto()
-
- ImplicitGemmModeNames = {
- ImplicitGemmMode.GemmNT: 'gemm_nt',
- ImplicitGemmMode.GemmTN: 'gemm_tn',
- }
-
- ImplicitGemmModeTag = {
- ImplicitGemmMode.GemmNT: 'cutlass::conv::ImplicitGemmMode::GEMM_NT',
- ImplicitGemmMode.GemmTN: 'cutlass::conv::ImplicitGemmMode::GEMM_TN',
- }
-
- ###################################################################################################
-
- #
- class MathInstruction:
- def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation = MathOperation.multiply_add):
- self.instruction_shape = instruction_shape
- self.element_a = element_a
- self.element_b = element_b
- self.element_accumulator = element_accumulator
- self.opcode_class = opcode_class
- self.math_operation = math_operation
-
-
- #
- class TileDescription:
-
- def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute):
- self.threadblock_shape = threadblock_shape
- self.stages = stages
- self.warp_count = warp_count
- self.math_instruction = math_instruction
- self.minimum_compute_capability = min_compute
- self.maximum_compute_capability = max_compute
-
- def procedural_name(self):
- return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
-
- #
- class TensorDescription:
- def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none):
- self.element = element
- self.layout = layout
- self.alignment = alignment
- self.complex_transform = complex_transform
-
- ###################################################################################################
-
- class GlobalCnt:
- cnt = 0
|