|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043 |
- #
- # \file generator.py
- #
- # \brief Generates the CUTLASS Library's instances
- #
-
- import argparse
- import enum
- import os.path
- import platform
- import string
-
- from library import *
- from manifest import *
-
- ###################################################################################################
-
- #
- def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch=0):
-
- # by default, use the latest CUDA Toolkit version
- cuda_version = [11, 0, 132]
-
- # Update cuda_version based on parsed string
- if semantic_ver_string != "":
- for i, x in enumerate([int(x) for x in semantic_ver_string.split(".")]):
- if i < len(cuda_version):
- cuda_version[i] = x
- else:
- cuda_version.append(x)
- return cuda_version >= [major, minor, patch]
-
-
- ###################################################################################################
- ###################################################################################################
-
- #
- def CreateGemmOperator(
- manifest,
- layouts,
- tile_descriptions,
- data_type,
- alignment_constraints,
- complex_transforms=None,
- epilogue_functor=EpilogueFunctor.LinearCombination,
- swizzling_functor=SwizzlingFunctor.Identity8,
- ):
-
- if complex_transforms is None:
- complex_transforms = [(ComplexTransform.none, ComplexTransform.none)]
-
- element_a, element_b, element_c, element_epilogue = data_type
-
- operations = []
-
- # by default, only generate the largest tile and largest alignment
- if manifest.args.kernels == "":
- tile_descriptions = [tile_descriptions[0]]
- alignment_constraints = [alignment_constraints[0]]
-
- for layout in layouts:
- for tile_description in tile_descriptions:
- for alignment in alignment_constraints:
- for complex_transform in complex_transforms:
-
- alignment_c = min(8, alignment)
-
- A = TensorDescription(
- element_a, layout[0], alignment, complex_transform[0]
- )
- B = TensorDescription(
- element_b, layout[1], alignment, complex_transform[1]
- )
- C = TensorDescription(element_c, layout[2], alignment_c)
-
- new_operation = GemmOperation(
- GemmKind.Universal,
- tile_description.minimum_compute_capability,
- tile_description,
- A,
- B,
- C,
- element_epilogue,
- epilogue_functor,
- swizzling_functor,
- )
-
- manifest.append(new_operation)
- operations.append(new_operation)
-
- return operations
-
-
- ###########################################################################################################
- # ConvolutionOperator support variations
- # ____________________________________________________________________
- # ConvolutionalOperator | Analytic | Optimized
- # ____________________________________________________________________
- # | Fprop | (strided) | (strided)
- # | Dgrad | (strided, unity*) | (unity)
- # | Wgrad | (strided) | (strided)
- # ____________________________________________________________________
- #
- # Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low
- ###########################################################################################################
- # Convolution for 2D operations
- def CreateConv2dOperator(
- manifest,
- layout,
- tile_descriptions,
- data_type,
- alignment,
- conv_kinds=[ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad],
- epilogue_functor=EpilogueFunctor.LinearCombination,
- ):
-
- element_a, element_b, element_c, element_epilogue = data_type
-
- # one exceptional case
- alignment_c = min(8, alignment)
-
- # iterator algorithm (analytic and optimized)
- iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
-
- # by default, only generate the largest tile size
- if manifest.args.kernels == "":
- tile_descriptions = [tile_descriptions[0]]
-
- operations = []
-
- for tile in tile_descriptions:
- for conv_kind in conv_kinds:
- for iterator_algorithm in iterator_algorithms:
- A = TensorDescription(element_a, layout[0], alignment)
- B = TensorDescription(element_b, layout[1], alignment)
- C = TensorDescription(element_c, layout[2], alignment_c)
-
- # unity stride only for Optimized Dgrad
- if (iterator_algorithm == IteratorAlgorithm.Optimized) and (
- conv_kind == ConvKind.Dgrad
- ):
- new_operation = Conv2dOperation(
- conv_kind,
- iterator_algorithm,
- tile.minimum_compute_capability,
- tile,
- A,
- B,
- C,
- element_epilogue,
- StrideSupport.Unity,
- epilogue_functor,
- )
-
- manifest.append(new_operation)
- operations.append(new_operation)
-
- # strided dgrad is not supported by Optimized Dgrad
- if (iterator_algorithm == IteratorAlgorithm.Optimized) and (
- conv_kind == ConvKind.Dgrad
- ):
- continue
-
- # strided support for Fprop (Analytic/Optimized), Dgrad (Analytic), and Wgrad (Analytic)
- new_operation = Conv2dOperation(
- conv_kind,
- iterator_algorithm,
- tile.minimum_compute_capability,
- tile,
- A,
- B,
- C,
- element_epilogue,
- StrideSupport.Strided,
- epilogue_functor,
- )
-
- manifest.append(new_operation)
- operations.append(new_operation)
-
- return operations
-
-
- ###################################################################################################
- ###################################################################################################
-
-
- def GenerateConv2d_Simt(args):
- operations = []
-
- layouts = [(LayoutType.TensorNC4HW4, LayoutType.TensorC4RSK4)]
-
- math_instructions = [
- MathInstruction(
- [1, 1, 4],
- DataType.s8,
- DataType.s8,
- DataType.s32,
- OpcodeClass.Simt,
- MathOperation.multiply_add,
- )
- ]
-
- dst_layouts = [
- LayoutType.TensorNC4HW4,
- LayoutType.TensorNC32HW32,
- LayoutType.TensorNHWC,
- LayoutType.TensorNHWC,
- LayoutType.TensorNCHW,
- ]
-
- dst_types = [DataType.s8, DataType.s8, DataType.u4, DataType.s4, DataType.f32]
-
- max_cc = 1024
-
- for math_inst in math_instructions:
- for layout in layouts:
- for dst_type, dst_layout in zip(dst_types, dst_layouts):
- if dst_type == DataType.s4 or dst_type == DataType.u4:
- min_cc = 75
- use_special_optimization = SpecialOptimizeDesc.NoneSpecialOpt
- else:
- min_cc = 61
- use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
- tile_descriptions = [
- TileDescription(
- [128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc
- ),
- ]
- for tile in tile_descriptions:
- if (
- dst_layout == LayoutType.TensorNC32HW32
- and tile.threadblock_shape[0] > 32
- ):
- continue
- if (
- dst_layout == LayoutType.TensorNCHW
- or dst_layout == LayoutType.TensorNHWC
- ) and tile.threadblock_shape[0] > 16:
- continue
- operations += GenerateConv2d(
- ConvType.Convolution,
- ConvKind.Fprop,
- [tile],
- layout[0],
- layout[1],
- dst_layout,
- dst_type,
- min_cc,
- 32,
- 32,
- 32,
- use_special_optimization,
- )
- return operations
-
-
- def GenerateConv2d_TensorOp_8816(args):
- operations = []
-
- layouts = [(LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32)]
-
- math_instructions = [
- MathInstruction(
- [8, 8, 16],
- DataType.s8,
- DataType.s8,
- DataType.s32,
- OpcodeClass.TensorOp,
- MathOperation.multiply_add_saturate,
- )
- ]
-
- dst_layouts = [LayoutType.TensorNC32HW32, LayoutType.TensorNC4HW4]
-
- dst_types = [DataType.s8, DataType.s8]
-
- use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
-
- min_cc = 75
- max_cc = 1024
-
- cuda_major = 10
- cuda_minor = 2
-
- for math_inst in math_instructions:
- for layout in layouts:
- for dst_type, dst_layout in zip(dst_types, dst_layouts):
- if dst_layout == LayoutType.TensorNC32HW32:
- tile_descriptions = [
- TileDescription(
- [128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
- ),
- ]
- operations += GenerateConv2d(
- ConvType.Convolution,
- ConvKind.Fprop,
- tile_descriptions,
- layout[0],
- layout[1],
- dst_layout,
- dst_type,
- min_cc,
- 128,
- 128,
- 64,
- use_special_optimization,
- ImplicitGemmMode.GemmTN,
- True,
- cuda_major,
- cuda_minor,
- )
- else:
- assert dst_layout == LayoutType.TensorNC4HW4
- tile_descriptions = [
- TileDescription(
- [64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc
- ),
- ]
- operations += GenerateConv2d(
- ConvType.Convolution,
- ConvKind.Fprop,
- tile_descriptions,
- layout[0],
- layout[1],
- dst_layout,
- dst_type,
- min_cc,
- 128,
- 128,
- 64,
- use_special_optimization,
- ImplicitGemmMode.GemmNT,
- False,
- cuda_major,
- cuda_minor,
- )
-
- layouts_nhwc = [
- (LayoutType.TensorNHWC, LayoutType.TensorNC4HW4, 32),
- (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 64),
- (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 128),
- ]
-
- dst_layouts_nhwc = [LayoutType.TensorNHWC]
-
- for math_inst in math_instructions:
- for layout in layouts_nhwc:
- for dst_layout in dst_layouts_nhwc:
- dst_type = math_inst.element_b
- tile_descriptions = [
- TileDescription(
- [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
- ),
- ]
- for tile in tile_descriptions:
- dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
- operations += GenerateConv2d(
- ConvType.Convolution,
- ConvKind.Fprop,
- [tile],
- layout[0],
- layout[1],
- dst_layout,
- dst_type,
- min_cc,
- layout[2],
- layout[2],
- dst_align,
- use_special_optimization,
- ImplicitGemmMode.GemmTN,
- False,
- cuda_major,
- cuda_minor,
- )
- if (
- tile.threadblock_shape[1] == 16
- or tile.threadblock_shape[1] == 32
- ):
- operations += GenerateConv2d(
- ConvType.Convolution,
- ConvKind.Fprop,
- [tile],
- layout[0],
- layout[1],
- dst_layout,
- dst_type,
- min_cc,
- layout[2],
- layout[2],
- dst_align,
- use_special_optimization,
- ImplicitGemmMode.GemmTN,
- True,
- cuda_major,
- cuda_minor,
- )
-
- out_dtypes = [DataType.s4, DataType.u4, DataType.f32]
-
- # INT8x8x4 and INT8x8x32
- for math_inst in math_instructions:
- for layout in layouts_nhwc:
- for dst_layout in dst_layouts_nhwc:
- for out_dtype in out_dtypes:
- tile_descriptions = [
- TileDescription(
- [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
- ),
- ]
- for tile in tile_descriptions:
- dst_align = (
- 4 * DataTypeSize[out_dtype]
- if tile.threadblock_shape[1] == 16
- or out_dtype == DataType.f32
- else 8 * DataTypeSize[out_dtype]
- )
- operations += GenerateConv2d(
- ConvType.Convolution,
- ConvKind.Fprop,
- [tile],
- layout[0],
- layout[1],
- dst_layout,
- out_dtype,
- min_cc,
- layout[2],
- layout[2],
- dst_align,
- use_special_optimization,
- ImplicitGemmMode.GemmTN,
- False,
- cuda_major,
- cuda_minor,
- )
- if tile.threadblock_shape[1] == 16 or (
- tile.threadblock_shape[1] == 32
- and out_dtype != DataType.f32
- ):
- operations += GenerateConv2d(
- ConvType.Convolution,
- ConvKind.Fprop,
- [tile],
- layout[0],
- layout[1],
- dst_layout,
- out_dtype,
- min_cc,
- layout[2],
- layout[2],
- dst_align,
- use_special_optimization,
- ImplicitGemmMode.GemmTN,
- True,
- cuda_major,
- cuda_minor,
- )
-
- return operations
-
-
- def GenerateConv2d_TensorOp_8832(args):
- operations = []
-
- layouts = [(LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64)]
-
- math_instructions = [
- MathInstruction(
- [8, 8, 32],
- DataType.s4,
- DataType.s4,
- DataType.s32,
- OpcodeClass.TensorOp,
- MathOperation.multiply_add_saturate,
- ),
- MathInstruction(
- [8, 8, 32],
- DataType.s4,
- DataType.u4,
- DataType.s32,
- OpcodeClass.TensorOp,
- MathOperation.multiply_add_saturate,
- ),
- ]
-
- dst_layouts = [LayoutType.TensorNC64HW64]
-
- use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
-
- min_cc = 75
- max_cc = 1024
-
- cuda_major = 10
- cuda_minor = 2
-
- for math_inst in math_instructions:
- for layout in layouts:
- for dst_layout in dst_layouts:
- dst_type = math_inst.element_b
- tile_descriptions = [
- TileDescription(
- [128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
- ),
- ]
- operations += GenerateConv2d(
- ConvType.Convolution,
- ConvKind.Fprop,
- tile_descriptions,
- layout[0],
- layout[1],
- dst_layout,
- dst_type,
- min_cc,
- 128,
- 128,
- 64,
- use_special_optimization,
- ImplicitGemmMode.GemmTN,
- True,
- cuda_major,
- cuda_minor,
- )
-
- layouts_nhwc = [
- (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32),
- (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 64),
- (LayoutType.TensorNHWC, LayoutType.TensorNC32HW32, 128),
- ]
-
- dst_layouts_nhwc = [LayoutType.TensorNHWC]
-
- for math_inst in math_instructions:
- for layout in layouts_nhwc:
- for dst_layout in dst_layouts_nhwc:
- dst_type = math_inst.element_b
- tile_descriptions = [
- TileDescription(
- [128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
- ),
- ]
- for tile in tile_descriptions:
- dst_align = 16 if tile.threadblock_shape[1] == 16 else 32
- operations += GenerateConv2d(
- ConvType.Convolution,
- ConvKind.Fprop,
- [tile],
- layout[0],
- layout[1],
- dst_layout,
- dst_type,
- min_cc,
- layout[2],
- layout[2],
- dst_align,
- use_special_optimization,
- ImplicitGemmMode.GemmTN,
- False,
- cuda_major,
- cuda_minor,
- )
- if (
- tile.threadblock_shape[1] == 32
- or tile.threadblock_shape[1] == 64
- ):
- dst_align = 32 if tile.threadblock_shape[1] == 32 else 64
- operations += GenerateConv2d(
- ConvType.Convolution,
- ConvKind.Fprop,
- [tile],
- layout[0],
- layout[1],
- dst_layout,
- dst_type,
- min_cc,
- layout[2],
- layout[2],
- dst_align,
- use_special_optimization,
- ImplicitGemmMode.GemmTN,
- True,
- cuda_major,
- cuda_minor,
- )
- # INT4x4x8
- for math_inst in math_instructions:
- for layout in layouts_nhwc:
- for dst_layout in dst_layouts_nhwc:
- tile_descriptions = [
- TileDescription(
- [128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
- ),
- ]
- for tile in tile_descriptions:
- dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
- operations += GenerateConv2d(
- ConvType.Convolution,
- ConvKind.Fprop,
- [tile],
- layout[0],
- layout[1],
- dst_layout,
- DataType.s8,
- min_cc,
- layout[2],
- layout[2],
- dst_align,
- use_special_optimization,
- ImplicitGemmMode.GemmTN,
- False,
- cuda_major,
- cuda_minor,
- )
- if (
- tile.threadblock_shape[1] == 32
- or tile.threadblock_shape[1] == 64
- ):
- dst_align = 64 if tile.threadblock_shape[1] == 32 else 128
- operations += GenerateConv2d(
- ConvType.Convolution,
- ConvKind.Fprop,
- [tile],
- layout[0],
- layout[1],
- dst_layout,
- DataType.s8,
- min_cc,
- layout[2],
- layout[2],
- dst_align,
- use_special_optimization,
- ImplicitGemmMode.GemmTN,
- True,
- cuda_major,
- cuda_minor,
- )
-
- return operations
-
-
- def GenerateDeconv_Simt(args):
- operations = []
-
- layouts = [(LayoutType.TensorNC4HW4, LayoutType.TensorK4RSC4)]
-
- math_instructions = [
- MathInstruction(
- [1, 1, 4],
- DataType.s8,
- DataType.s8,
- DataType.s32,
- OpcodeClass.Simt,
- MathOperation.multiply_add,
- )
- ]
-
- dst_layouts = [LayoutType.TensorNC4HW4]
-
- dst_types = [DataType.s8]
-
- use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling
-
- min_cc = 61
- max_cc = 1024
-
- for math_inst in math_instructions:
- for layout in layouts:
- for dst_type, dst_layout in zip(dst_types, dst_layouts):
- tile_descriptions = [
- TileDescription(
- [32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [16, 128, 16], 2, [1, 2, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc
- ),
- ]
- operations += GenerateConv2d(
- ConvType.Convolution,
- ConvKind.Dgrad,
- tile_descriptions,
- layout[0],
- layout[1],
- dst_layout,
- dst_type,
- min_cc,
- 32,
- 32,
- 32,
- use_special_optimization,
- )
- return operations
-
-
- def GenerateDeconv_TensorOp_8816(args):
- operations = []
-
- layouts = [
- (LayoutType.TensorNHWC, LayoutType.TensorCK4RS4, 32),
- (LayoutType.TensorNHWC, LayoutType.TensorCK8RS8, 64),
- (LayoutType.TensorNHWC, LayoutType.TensorCK16RS16, 128),
- ]
-
- math_instructions = [
- MathInstruction(
- [8, 8, 16],
- DataType.s8,
- DataType.s8,
- DataType.s32,
- OpcodeClass.TensorOp,
- MathOperation.multiply_add_saturate,
- )
- ]
-
- dst_layouts = [LayoutType.TensorNHWC]
-
- dst_types = [DataType.s8]
-
- use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling
-
- min_cc = 75
- max_cc = 1024
-
- cuda_major = 10
- cuda_minor = 2
-
- for math_inst in math_instructions:
- for layout in layouts:
- for dst_type, dst_layout in zip(dst_types, dst_layouts):
- tile_descriptions = [
- TileDescription(
- [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
- ),
- ]
- for tile in tile_descriptions:
- dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
- operations += GenerateConv2d(
- ConvType.Convolution,
- ConvKind.Dgrad,
- [tile],
- layout[0],
- layout[1],
- dst_layout,
- dst_type,
- min_cc,
- layout[2],
- layout[2],
- dst_align,
- use_special_optimization,
- ImplicitGemmMode.GemmTN,
- False,
- cuda_major,
- cuda_minor,
- )
- return operations
-
-
- ################################################################################
- # parameters
- # Edge - for tiles, the edges represent the length of one side
- # Ratio - the maximum ratio between 2 edges, limits the skinnyness of tiles
- # MaxEdge - maximum length of each edge
- # Min/Max - minimum/maximum of the product of edge lengths
- ################################################################################
-
- warpsPerThreadblockEdge = [1, 2, 4, 8, 16]
- warpsPerThreadblockRatio = 2
- warpsPerThreadblockMax = 16
- # NOTE 1x32 and 2x16 warp tile shapes fail validation for ~10% of cases
-
- warpShapeEdges = [8, 16, 32, 64, 128, 256]
- warpShapeRatio = 4
- warpShapeMax = 64 * 64
- warpShapeMin = 8 * 8
-
- threadblockEdgeMax = 256
-
- # char, type bits/elem, max tile, L0 threadblock tiles
- precisions = {
- "c": ["cutlass::complex<float>", 64, 64 * 128, [[64, 128], [64, 32]]],
- "d": ["double", 64, 64 * 64, [[64, 64], [32, 32]]],
- "h": ["cutlass::half_t", 16, 128 * 256, [[256, 128], [64, 128], [64, 32]]],
- "i": ["int", 32, 128 * 128, [[128, 64], [16, 32]]],
- "s": ["float", 32, 128 * 128, [[128, 256], [128, 128], [64, 64]]],
- "z": ["cutlass::complex<double>", 128, 64 * 64, [[32, 64], [16, 32]]],
- }
- # L1 will have a single kernel for every unique shape
- # L2 will have everything else
- def GenerateGemm_Simt(args):
- ################################################################################
- # warps per threadblock
- ################################################################################
- warpsPerThreadblocks = []
- for warpsPerThreadblock0 in warpsPerThreadblockEdge:
- for warpsPerThreadblock1 in warpsPerThreadblockEdge:
- if (
- warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio
- and warpsPerThreadblock1 / warpsPerThreadblock0
- <= warpsPerThreadblockRatio
- and warpsPerThreadblock0 * warpsPerThreadblock1
- <= warpsPerThreadblockMax
- ):
- warpsPerThreadblocks.append(
- [warpsPerThreadblock0, warpsPerThreadblock1]
- )
-
- ################################################################################
- # warp shapes
- ################################################################################
- warpNumThreads = 32
- warpShapes = []
- for warp0 in warpShapeEdges:
- for warp1 in warpShapeEdges:
- if (
- warp0 / warp1 <= warpShapeRatio
- and warp1 / warp0 <= warpShapeRatio
- and warp0 * warp1 <= warpShapeMax
- and warp0 * warp1 > warpShapeMin
- ):
- warpShapes.append([warp0, warp1])
-
- # sgemm
- (
- precisionType,
- precisionBits,
- threadblockMaxElements,
- threadblockTilesL0,
- ) = precisions["s"]
-
- layouts = [
- (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
- (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
- (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
- (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
- ]
-
- math_instructions = [
- MathInstruction(
- [1, 1, 1],
- DataType.f32,
- DataType.f32,
- DataType.f32,
- OpcodeClass.Simt,
- MathOperation.multiply_add,
- )
- ]
-
- min_cc = 50
- max_cc = 1024
-
- operations = []
- for math_inst in math_instructions:
- for layout in layouts:
- data_type = [
- math_inst.element_a,
- math_inst.element_b,
- math_inst.element_accumulator,
- math_inst.element_accumulator,
- ]
- tile_descriptions = [
- TileDescription([64, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
- TileDescription([256, 64, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([32, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
- TileDescription([256, 32, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
- TileDescription([32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
- TileDescription([32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
- TileDescription([32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
- TileDescription([8, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
- TileDescription([16, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
- TileDescription([16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
- TileDescription([16, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
- ]
- for warpsPerThreadblock in warpsPerThreadblocks:
- for warpShape in warpShapes:
- warpThreadsM = 0
- if warpShape[0] > warpShape[1]:
- warpThreadsM = 8
- else:
- warpThreadsM = 4
- warpThreadsN = warpNumThreads / warpThreadsM
-
- # skip shapes with conflicting rectangularity
- # they are unlikely to be fastest
- blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
- blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
- warpG = warpShape[0] > warpShape[1]
- warpL = warpShape[0] < warpShape[1]
-
- blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2
- blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1]
- warpG2 = warpShape[0] > warpShape[1] * 2
- warpL2 = warpShape[0] * 2 < warpShape[1]
-
- if blockG2 and warpL:
- continue
- if blockL2 and warpG:
- continue
- if warpG2 and blockL:
- continue
- if warpL2 and blockG:
- continue
-
- # check threadblock ratios and max
- threadblockTile = [
- warpShape[0] * warpsPerThreadblock[0],
- warpShape[1] * warpsPerThreadblock[1],
- ]
- if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements:
- continue
- if threadblockTile[0] > threadblockEdgeMax:
- continue
- if threadblockTile[1] > threadblockEdgeMax:
- continue
- totalThreads = (
- warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1]
- )
-
- # calculate unroll
- # ensure that every iteration at least a full load of A,B are done
- unrollMin = 8
- unrollMin0 = totalThreads // threadblockTile[0]
- unrollMin1 = totalThreads // threadblockTile[1]
- unroll = max(unrollMin, unrollMin0, unrollMin1)
-
- threadTileM = warpShape[0] // warpThreadsM
- threadTileN = warpShape[1] // warpThreadsN
- if threadTileM < 2 or threadTileN < 2:
- continue
- if threadTileM * threadTileN * precisionBits > 8 * 8 * 32:
- continue
-
- # epilogue currently only supports N < WarpNumThreads
- if threadblockTile[1] < warpNumThreads:
- continue
-
- # limit smem
- smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits
- smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits
- smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024
- if smemKBytes > 48:
- continue
-
- tile = TileDescription(
- [threadblockTile[0], threadblockTile[1], unroll],
- 2,
- [
- threadblockTile[0] // warpShape[0],
- threadblockTile[1] // warpShape[1],
- 1,
- ],
- math_inst,
- min_cc,
- max_cc,
- )
-
- def filter(t: TileDescription) -> bool:
- nonlocal tile
- return (
- t.threadblock_shape[0] == tile.threadblock_shape[0]
- and t.threadblock_shape[1] == tile.threadblock_shape[1]
- and t.threadblock_shape[2] == tile.threadblock_shape[2]
- and t.warp_count[0] == tile.warp_count[0]
- and t.warp_count[1] == tile.warp_count[1]
- and t.warp_count[2] == tile.warp_count[2]
- and t.stages == tile.stages
- )
-
- if not any(t for t in tile_descriptions if filter(t)):
- continue
-
- operations += GeneratesGemm(
- tile, data_type, layout[0], layout[1], layout[2], min_cc
- )
- return operations
-
-
- #
- def GenerateDwconv2d_Simt(args, conv_kind):
- ################################################################################
- # warps per threadblock
- ################################################################################
- warpsPerThreadblocks = []
- for warpsPerThreadblock0 in warpsPerThreadblockEdge:
- for warpsPerThreadblock1 in warpsPerThreadblockEdge:
- if (
- warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio
- and warpsPerThreadblock1 / warpsPerThreadblock0
- <= warpsPerThreadblockRatio
- and warpsPerThreadblock0 * warpsPerThreadblock1
- <= warpsPerThreadblockMax
- ):
- warpsPerThreadblocks.append(
- [warpsPerThreadblock0, warpsPerThreadblock1]
- )
-
- ################################################################################
- # warp shapes
- ################################################################################
- warpNumThreads = 32
- warpShapes = []
- for warp0 in warpShapeEdges:
- for warp1 in warpShapeEdges:
- if (
- warp0 / warp1 <= warpShapeRatio
- and warp1 / warp0 <= warpShapeRatio
- and warp0 * warp1 <= warpShapeMax
- and warp0 * warp1 > warpShapeMin
- ):
- warpShapes.append([warp0, warp1])
-
- # sgemm
- (
- precisionType,
- precisionBits,
- threadblockMaxElements,
- threadblockTilesL0,
- ) = precisions["s"]
-
- layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]
-
- math_instructions = [
- MathInstruction(
- [1, 1, 1],
- DataType.f32,
- DataType.f32,
- DataType.f32,
- OpcodeClass.Simt,
- MathOperation.multiply_add,
- )
- ]
-
- min_cc = 50
- max_cc = 1024
-
- dst_layouts = [LayoutType.TensorNCHW]
-
- dst_types = [DataType.f32]
-
- if conv_kind == ConvKind.Wgrad:
- alignment_constraints = [32]
- else:
- alignment_constraints = [128, 32]
-
- operations = []
- for math_inst in math_instructions:
- tile_descriptions = [
- TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
- TileDescription([32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
- TileDescription([32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
- TileDescription([32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
- ]
- for warpsPerThreadblock in warpsPerThreadblocks:
- for warpShape in warpShapes:
- warpThreadsM = 0
- if warpShape[0] > warpShape[1]:
- warpThreadsM = 8
- else:
- warpThreadsM = 4
- warpThreadsN = warpNumThreads / warpThreadsM
-
- # skip shapes with conflicting rectangularity
- # they are unlikely to be fastest
- blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
- blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
- warpG = warpShape[0] > warpShape[1]
- warpL = warpShape[0] < warpShape[1]
-
- blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2
- blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1]
- warpG2 = warpShape[0] > warpShape[1] * 2
- warpL2 = warpShape[0] * 2 < warpShape[1]
-
- if blockG2 and warpL:
- continue
- if blockL2 and warpG:
- continue
- if warpG2 and blockL:
- continue
- if warpL2 and blockG:
- continue
-
- # check threadblock ratios and max
- threadblockTile = [
- warpShape[0] * warpsPerThreadblock[0],
- warpShape[1] * warpsPerThreadblock[1],
- ]
- if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements:
- continue
- if threadblockTile[0] > threadblockEdgeMax:
- continue
- if threadblockTile[1] > threadblockEdgeMax:
- continue
- totalThreads = (
- warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1]
- )
-
- # calculate unroll
- # ensure that every iteration at least a full load of A,B are done
- unrollMin = 8
- unrollMin0 = totalThreads // threadblockTile[0]
- unrollMin1 = totalThreads // threadblockTile[1]
- unroll = max(unrollMin, unrollMin0, unrollMin1)
-
- threadTileM = warpShape[0] // warpThreadsM
- threadTileN = warpShape[1] // warpThreadsN
- if threadTileM < 2 or threadTileN < 2:
- continue
- if threadTileM * threadTileN * precisionBits > 8 * 8 * 32:
- continue
-
- # epilogue currently only supports N < WarpNumThreads
- if threadblockTile[1] < warpNumThreads:
- continue
-
- # limit smem
- smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits
- smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits
- smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024
- if smemKBytes > 48:
- continue
-
- tile = TileDescription(
- [threadblockTile[0], threadblockTile[1], unroll],
- 2,
- [
- threadblockTile[0] // warpShape[0],
- threadblockTile[1] // warpShape[1],
- 1,
- ],
- math_inst,
- min_cc,
- max_cc,
- )
-
- def filter(t: TileDescription) -> bool:
- nonlocal tile
- return (
- t.threadblock_shape[0] == tile.threadblock_shape[0]
- and t.threadblock_shape[1] == tile.threadblock_shape[1]
- and t.threadblock_shape[2] == tile.threadblock_shape[2]
- and t.warp_count[0] == tile.warp_count[0]
- and t.warp_count[1] == tile.warp_count[1]
- and t.warp_count[2] == tile.warp_count[2]
- and t.stages == tile.stages
- )
-
- if not any(t for t in tile_descriptions if filter(t)):
- continue
-
- for layout in layouts:
- for dst_type, dst_layout in zip(dst_types, dst_layouts):
- for alignment_src in alignment_constraints:
- operations += GenerateConv2d(
- ConvType.DepthwiseConvolution,
- conv_kind,
- [tile],
- layout[0],
- layout[1],
- dst_layout,
- dst_type,
- min_cc,
- alignment_src,
- 32,
- 32,
- SpecialOptimizeDesc.NoneSpecialOpt,
- ImplicitGemmMode.GemmNT
- if conv_kind == ConvKind.Wgrad
- else ImplicitGemmMode.GemmTN,
- )
- return operations
-
-
- #
- def GenerateDwconv2d_TensorOp_884(args, conv_kind):
- layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]
-
- math_instructions = [
- MathInstruction(
- [8, 8, 4],
- DataType.f16,
- DataType.f16,
- DataType.f32,
- OpcodeClass.TensorOp,
- MathOperation.multiply_add,
- ),
- MathInstruction(
- [8, 8, 4],
- DataType.f16,
- DataType.f16,
- DataType.f16,
- OpcodeClass.TensorOp,
- MathOperation.multiply_add,
- ),
- ]
-
- min_cc = 70
- max_cc = 75
-
- dst_layouts = [LayoutType.TensorNCHW]
-
- if conv_kind == ConvKind.Wgrad:
- dst_types = [DataType.f32]
- else:
- dst_types = [DataType.f16]
-
- alignment_constraints = [128, 32, 16]
- cuda_major = 10
- cuda_minor = 1
-
- operations = []
- for math_inst in math_instructions:
- tile_descriptions = [
- TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 128, 32], 2, [4, 4, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
- ]
- for layout in layouts:
- for dst_type, dst_layout in zip(dst_types, dst_layouts):
- for alignment_src in alignment_constraints:
- if conv_kind == ConvKind.Wgrad:
- # skip io16xc16
- if math_inst.element_accumulator == DataType.f16:
- continue
- for alignment_diff in alignment_constraints:
- operations += GenerateConv2d(
- ConvType.DepthwiseConvolution,
- conv_kind,
- tile_descriptions,
- layout[0],
- layout[1],
- dst_layout,
- dst_type,
- min_cc,
- alignment_src,
- alignment_diff,
- 32, # always f32 output
- SpecialOptimizeDesc.NoneSpecialOpt,
- ImplicitGemmMode.GemmNT,
- False,
- cuda_major,
- cuda_minor,
- )
- else:
- operations += GenerateConv2d(
- ConvType.DepthwiseConvolution,
- conv_kind,
- tile_descriptions,
- layout[0],
- layout[1],
- dst_layout,
- dst_type,
- min_cc,
- alignment_src,
- 16,
- 16,
- SpecialOptimizeDesc.NoneSpecialOpt,
- ImplicitGemmMode.GemmTN,
- False,
- cuda_major,
- cuda_minor,
- )
-
- return operations
-
-
- #
- def GenerateGemv_Simt(args):
- threadBlockShape_N = [128, 64, 32]
- ldgBits_A = [128, 64, 32]
- ldgBits_B = [128, 64, 32]
-
- layouts = [(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor)]
-
- math_instructions = [
- MathInstruction(
- [1, 1, 1],
- DataType.f32,
- DataType.f32,
- DataType.f32,
- OpcodeClass.Simt,
- MathOperation.multiply_add,
- )
- ]
-
- min_cc = 50
-
- operations = []
- for math_inst in math_instructions:
- for layout in layouts:
- data_type = [
- math_inst.element_a,
- math_inst.element_b,
- math_inst.element_accumulator,
- math_inst.element_accumulator,
- ]
- for threadblock_shape_n in threadBlockShape_N:
- for align_a in ldgBits_A:
- for align_b in ldgBits_B:
- ldg_elements_a = align_a // DataTypeSize[math_inst.element_a]
- ldg_elements_b = align_b // DataTypeSize[math_inst.element_b]
- threadblock_shape_k = (256 * ldg_elements_a) // (
- threadblock_shape_n // ldg_elements_b
- )
- threadblock_shape = [
- 1,
- threadblock_shape_n,
- threadblock_shape_k,
- ]
- thread_shape = [1, ldg_elements_b, ldg_elements_a]
-
- operations.append(
- GeneratesGemv(
- math_inst,
- threadblock_shape,
- thread_shape,
- data_type,
- layout[0],
- layout[1],
- layout[2],
- min_cc,
- align_a,
- align_b,
- )
- )
- return operations
-
-
- #
- def GeneratesGemm_TensorOp_1688(args):
- layouts = [
- (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
- (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
- (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
- (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
- ]
-
- math_instructions = [
- MathInstruction(
- [16, 8, 8],
- DataType.f16,
- DataType.f16,
- DataType.f32,
- OpcodeClass.TensorOp,
- MathOperation.multiply_add,
- ),
- MathInstruction(
- [16, 8, 8],
- DataType.f16,
- DataType.f16,
- DataType.f16,
- OpcodeClass.TensorOp,
- MathOperation.multiply_add,
- ),
- ]
-
- min_cc = 75
- max_cc = 1024
-
- alignment_constraints = [
- 8,
- 4,
- 2,
- # 1
- ]
- cuda_major = 10
- cuda_minor = 2
-
- operations = []
- for math_inst in math_instructions:
- for layout in layouts:
- for align in alignment_constraints:
- tile_descriptions = [
- TileDescription(
- [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
- ),
- ## comment some configuration to reduce compilation time and binary size
- # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
- # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
- # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
- ]
-
- data_type = [
- math_inst.element_a,
- math_inst.element_b,
- math_inst.element_a,
- math_inst.element_accumulator,
- ]
-
- for tile in tile_descriptions:
- operations += GeneratesGemm(
- tile,
- data_type,
- layout[0],
- layout[1],
- layout[2],
- min_cc,
- align * 16,
- align * 16,
- align * 16,
- cuda_major,
- cuda_minor,
- )
- return operations
-
-
- #
- def GeneratesGemm_TensorOp_884(args):
- layouts = [
- (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
- (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
- (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
- (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
- ]
-
- math_instructions = [
- MathInstruction(
- [8, 8, 4],
- DataType.f16,
- DataType.f16,
- DataType.f32,
- OpcodeClass.TensorOp,
- MathOperation.multiply_add,
- ),
- MathInstruction(
- [8, 8, 4],
- DataType.f16,
- DataType.f16,
- DataType.f16,
- OpcodeClass.TensorOp,
- MathOperation.multiply_add,
- ),
- ]
-
- min_cc = 70
- max_cc = 75
-
- alignment_constraints = [
- 8,
- 4,
- 2,
- # 1
- ]
- cuda_major = 10
- cuda_minor = 1
-
- operations = []
- for math_inst in math_instructions:
- for layout in layouts:
- for align in alignment_constraints:
- tile_descriptions = [
- TileDescription(
- [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
- ),
- TileDescription(
- [128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
- ),
- ## comment some configuration to reduce compilation time and binary size
- # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
- # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
- # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
- ]
-
- data_type = [
- math_inst.element_a,
- math_inst.element_b,
- math_inst.element_a,
- math_inst.element_accumulator,
- ]
-
- for tile in tile_descriptions:
- operations += GeneratesGemm(
- tile,
- data_type,
- layout[0],
- layout[1],
- layout[2],
- min_cc,
- align * 16,
- align * 16,
- align * 16,
- cuda_major,
- cuda_minor,
- )
-
- return operations
-
-
- #
- def GenerateConv2dOperations(args):
- if args.type == "simt":
- return GenerateConv2d_Simt(args)
- elif args.type == "tensorop8816":
- return GenerateConv2d_TensorOp_8816(args)
- else:
- assert args.type == "tensorop8832", (
- "operation conv2d only support"
- "simt, tensorop8816 and tensorop8832. (got:{})".format(args.type)
- )
- return GenerateConv2d_TensorOp_8832(args)
-
-
- def GenerateDeconvOperations(args):
- if args.type == "simt":
- return GenerateDeconv_Simt(args)
- else:
- assert args.type == "tensorop8816", (
- "operation deconv only support"
- "simt and tensorop8816. (got:{})".format(args.type)
- )
- return GenerateDeconv_TensorOp_8816(args)
-
-
- def GenerateDwconv2dFpropOperations(args):
- if args.type == "simt":
- return GenerateDwconv2d_Simt(args, ConvKind.Fprop)
- else:
- assert args.type == "tensorop884", (
- "operation dwconv2d fprop only support"
- "simt, tensorop884. (got:{})".format(args.type)
- )
- return GenerateDwconv2d_TensorOp_884(args, ConvKind.Fprop)
-
-
- def GenerateDwconv2dDgradOperations(args):
- if args.type == "simt":
- return GenerateDwconv2d_Simt(args, ConvKind.Dgrad)
- else:
- assert args.type == "tensorop884", (
- "operation dwconv2d fprop only support"
- "simt, tensorop884. (got:{})".format(args.type)
- )
- return GenerateDwconv2d_TensorOp_884(args, ConvKind.Dgrad)
-
-
- def GenerateDwconv2dWgradOperations(args):
- if args.type == "simt":
- return GenerateDwconv2d_Simt(args, ConvKind.Wgrad)
- else:
- assert args.type == "tensorop884", (
- "operation dwconv2d fprop only support"
- "simt, tensorop884. (got:{})".format(args.type)
- )
- return GenerateDwconv2d_TensorOp_884(args, ConvKind.Wgrad)
-
-
- def GenerateGemmOperations(args):
- if args.type == "tensorop884":
- return GeneratesGemm_TensorOp_884(args)
- elif args.type == "tensorop1688":
- return GeneratesGemm_TensorOp_1688(args)
- else:
- assert (
- args.type == "simt"
- ), "operation gemm only support" "simt. (got:{})".format(args.type)
- return GenerateGemm_Simt(args)
-
-
- def GenerateGemvOperations(args):
- assert args.type == "simt", "operation gemv only support" "simt. (got:{})".format(
- args.type
- )
- return GenerateGemv_Simt(args)
-
-
- ################################################################################
- # parameters
- # split_number - the concated file will be divided into split_number parts
- # file_path - the path of file, which is need to be concated
- # operations - args.operations
- # type - args.type
- # head - the head in the file
- # required_cuda_ver_major - required cuda major
- # required_cuda_ver_minor - required cuda minjor
- # epilogue - the epilogue in the file
- # wrapper_path - wrapper path
- ################################################################################
- def ConcatFile(
- split_number: int,
- file_path: str,
- operations: str,
- type: str,
- head: str,
- required_cuda_ver_major: str,
- required_cuda_ver_minor: str,
- epilogue: str,
- wrapper_path=None,
- ):
- import os
-
- meragefiledir = file_path
- filenames = os.listdir(meragefiledir)
- # filter file
- if "tensorop" in type:
- sub_string_1 = "tensorop"
- sub_string_2 = type[8:]
- else:
- sub_string_1 = sub_string_2 = "simt"
- if "dwconv2d_" in operations:
- filtered_operations = operations[:2] + operations[9:]
- elif ("conv2d" in operations) or ("deconv" in operations):
- filtered_operations = "cutlass"
- else:
- filtered_operations = operations
- # get the file list number
- file_list = {}
- file_list[operations + type] = 0
- for filename in filenames:
- if (
- (filtered_operations in filename)
- and (sub_string_1 in filename)
- and (sub_string_2 in filename)
- and ("all_" not in filename)
- ):
- file_list[operations + type] += 1
- # concat file for linux
- flag_1 = 0
- flag_2 = 0
- for filename in filenames:
- if (
- (filtered_operations in filename)
- and (sub_string_1 in filename)
- and (sub_string_2 in filename)
- and ("all_" not in filename)
- ):
- flag_1 += 1
- filepath = meragefiledir + "/" + filename
- if (flag_1 >= flag_2 * (file_list[operations + type] / split_number)) and (
- flag_1 <= (flag_2 + 1) * (file_list[operations + type] / split_number)
- ):
- file = open(
- file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a"
- )
- # write Template at the head
- if wrapper_path is None:
- file.write(
- SubstituteTemplate(
- head,
- {
- "required_cuda_ver_major": str(required_cuda_ver_major),
- "required_cuda_ver_minor": str(required_cuda_ver_minor),
- },
- )
- )
- else:
- file.write(
- SubstituteTemplate(
- head,
- {
- "wrapper_path": wrapper_path,
- "required_cuda_ver_major": str(required_cuda_ver_major),
- "required_cuda_ver_minor": str(required_cuda_ver_minor),
- },
- )
- )
- # concat all the remaining files
- if flag_2 == (split_number - 1):
- for line in open(filepath):
- file.writelines(line)
- os.remove(filepath)
- file.write("\n")
- file.write(epilogue)
- continue
- for line in open(filepath):
- file.writelines(line)
- os.remove(filepath)
- file.write("\n")
- file.write(epilogue)
- else:
- # write Template at the head
- if wrapper_path is None:
- file.write(
- SubstituteTemplate(
- head,
- {
- "required_cuda_ver_major": str(required_cuda_ver_major),
- "required_cuda_ver_minor": str(required_cuda_ver_minor),
- },
- )
- )
- else:
- file.write(
- SubstituteTemplate(
- head,
- {
- "wrapper_path": wrapper_path,
- "required_cuda_ver_major": str(required_cuda_ver_major),
- "required_cuda_ver_minor": str(required_cuda_ver_minor),
- },
- )
- )
- for line in open(filepath):
- file.writelines(line)
- os.remove(filepath)
- file.write("\n")
- file.write(epilogue)
- file.close()
- flag_2 += 1
-
- # concat file for windows
- elif filename[0].isdigit() and ("all_" not in filename):
- flag_1 += 1
- filepath = meragefiledir + "/" + filename
- if (flag_1 >= flag_2 * (len(filenames) / split_number)) and (
- flag_1 <= (flag_2 + 1) * (len(filenames) / split_number)
- ):
- file = open(
- file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a"
- )
- # write Template at the head
- if wrapper_path is None:
- file.write(
- SubstituteTemplate(
- head,
- {
- "required_cuda_ver_major": str(required_cuda_ver_major),
- "required_cuda_ver_minor": str(required_cuda_ver_minor),
- },
- )
- )
- else:
- file.write(
- SubstituteTemplate(
- head,
- {
- "wrapper_path": wrapper_path,
- "required_cuda_ver_major": str(required_cuda_ver_major),
- "required_cuda_ver_minor": str(required_cuda_ver_minor),
- },
- )
- )
- # concat all the remaining files
- if flag_2 == (split_number - 1):
- for line in open(filepath):
- file.writelines(line)
- os.remove(filepath)
- file.write("\n")
- file.write(epilogue)
- continue
- for line in open(filepath):
- file.writelines(line)
- os.remove(filepath)
- file.write("\n")
- file.write(epilogue)
- else:
- # write Template at the head
- if wrapper_path is None:
- file.write(
- SubstituteTemplate(
- head,
- {
- "required_cuda_ver_major": str(required_cuda_ver_major),
- "required_cuda_ver_minor": str(required_cuda_ver_minor),
- },
- )
- )
- else:
- file.write(
- SubstituteTemplate(
- head,
- {
- "wrapper_path": wrapper_path,
- "required_cuda_ver_major": str(required_cuda_ver_major),
- "required_cuda_ver_minor": str(required_cuda_ver_minor),
- },
- )
- )
- for line in open(filepath):
- file.writelines(line)
- os.remove(filepath)
- file.write("\n")
- file.write(epilogue)
- file.close()
- flag_2 += 1
-
-
- ###################################################################################################
- ###################################################################################################
-
- if __name__ == "__main__":
-
- parser = argparse.ArgumentParser(
- description="Generates device kernel registration code for CUTLASS Kernels"
- )
- parser.add_argument(
- "--operations",
- type=str,
- choices=[
- "gemm",
- "gemv",
- "conv2d",
- "deconv",
- "dwconv2d_fprop",
- "dwconv2d_dgrad",
- "dwconv2d_wgrad",
- ],
- required=True,
- help="Specifies the operation to generate (gemm, gemv, conv2d, deconv, dwconv2d_fprop, dwconv2d_dgrad, dwconv2d_wgrad)",
- )
- parser.add_argument(
- "output", type=str, help="output directory for CUTLASS kernel files"
- )
- parser.add_argument(
- "--type",
- type=str,
- choices=["simt", "tensorop8816", "tensorop8832", "tensorop884", "tensorop1688"],
- default="simt",
- help="kernel type of CUTLASS kernel generator",
- )
-
- gemv_wrapper_path = (
- "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
- )
- short_path = (
- platform.system() == "Windows" or platform.system().find("NT") >= 0
- ) and ("true" != os.getenv("CUTLASS_WITH_LONG_PATH", default="False").lower())
- args = parser.parse_args()
-
- if args.operations == "gemm":
- operations = GenerateGemmOperations(args)
- elif args.operations == "gemv":
- operations = GenerateGemvOperations(args)
- elif args.operations == "conv2d":
- operations = GenerateConv2dOperations(args)
- elif args.operations == "deconv":
- operations = GenerateDeconvOperations(args)
- elif args.operations == "dwconv2d_fprop":
- operations = GenerateDwconv2dFpropOperations(args)
- elif args.operations == "dwconv2d_dgrad":
- operations = GenerateDwconv2dDgradOperations(args)
- else:
- assert args.operations == "dwconv2d_wgrad", "invalid operation"
- operations = GenerateDwconv2dWgradOperations(args)
-
- if (
- args.operations == "conv2d"
- or args.operations == "deconv"
- or args.operations == "dwconv2d_fprop"
- or args.operations == "dwconv2d_dgrad"
- or args.operations == "dwconv2d_wgrad"
- ):
- for operation in operations:
- with EmitConvSingleKernelWrapper(
- args.output, operation, short_path
- ) as emitter:
- emitter.emit()
- head = EmitConvSingleKernelWrapper(
- args.output, operations[0], short_path
- ).header_template
- required_cuda_ver_major = operations[0].required_cuda_ver_major
- required_cuda_ver_minor = operations[0].required_cuda_ver_minor
- epilogue = EmitConvSingleKernelWrapper(
- args.output, operations[0], short_path
- ).epilogue_template
- if "tensorop" in args.type:
- ConcatFile(
- 4,
- args.output,
- args.operations,
- args.type,
- head,
- required_cuda_ver_major,
- required_cuda_ver_minor,
- epilogue,
- )
- else:
- ConcatFile(
- 2,
- args.output,
- args.operations,
- args.type,
- head,
- required_cuda_ver_major,
- required_cuda_ver_minor,
- epilogue,
- )
- elif args.operations == "gemm":
- for operation in operations:
- with EmitGemmSingleKernelWrapper(
- args.output, operation, short_path
- ) as emitter:
- emitter.emit()
- head = EmitGemmSingleKernelWrapper(
- args.output, operations[0], short_path
- ).header_template
- required_cuda_ver_major = operations[0].required_cuda_ver_major
- required_cuda_ver_minor = operations[0].required_cuda_ver_minor
- epilogue = EmitGemmSingleKernelWrapper(
- args.output, operations[0], short_path
- ).epilogue_template
- if args.type == "tensorop884":
- ConcatFile(
- 30,
- args.output,
- args.operations,
- args.type,
- head,
- required_cuda_ver_major,
- required_cuda_ver_minor,
- epilogue,
- )
- else:
- ConcatFile(
- 2,
- args.output,
- args.operations,
- args.type,
- head,
- required_cuda_ver_major,
- required_cuda_ver_minor,
- epilogue,
- )
- elif args.operations == "gemv":
- for operation in operations:
- with EmitGemvSingleKernelWrapper(
- args.output, operation, gemv_wrapper_path, short_path
- ) as emitter:
- emitter.emit()
- head = EmitGemvSingleKernelWrapper(
- args.output, operations[0], gemv_wrapper_path, short_path
- ).header_template
- required_cuda_ver_major = operations[0].required_cuda_ver_major
- required_cuda_ver_minor = operations[0].required_cuda_ver_minor
- epilogue = EmitGemvSingleKernelWrapper(
- args.output, operations[0], gemv_wrapper_path, short_path
- ).epilogue_template
- ConcatFile(
- 2,
- args.output,
- args.operations,
- args.type,
- head,
- required_cuda_ver_major,
- required_cuda_ver_minor,
- epilogue,
- wrapper_path=gemv_wrapper_path,
- )
-
- if args.operations != "gemv":
- GenerateManifest(args, operations, args.output)
-
- #
- ###################################################################################################
|