GitOrigin-RevId: 83a43fdf87
tags/v1.6.0-rc1
@@ -4,12 +4,12 @@ genrule( | |||||
name = "cutlass_kimpls", | name = "cutlass_kimpls", | ||||
outs = cutlass_gen_list, | outs = cutlass_gen_list, | ||||
cmd = """GEN=$(location //brain/megbrain/dnn/scripts/cutlass_generator:generator.py) | cmd = """GEN=$(location //brain/megbrain/dnn/scripts/cutlass_generator:generator.py) | ||||
python3 $$GEN --operations gemm --type simt $(@D) | |||||
python3 $$GEN --operations gemv --type simt $(@D) | |||||
python3 $$GEN --operations deconv --type simt $(@D) | |||||
python3 $$GEN --operations conv2d --type simt $(@D) | |||||
python3 $$GEN --operations conv2d --type tensorop8816 $(@D) | |||||
python3 $$GEN --operations conv2d --type tensorop8832 $(@D) | |||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type simt $(@D) | |||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D) | |||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D) | |||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D) | |||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8816 $(@D) | |||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8832 $(@D) | |||||
""", | """, | ||||
tools = ["//brain/megbrain/dnn/scripts/cutlass_generator:generator.py"], | tools = ["//brain/megbrain/dnn/scripts/cutlass_generator:generator.py"], | ||||
visibility = ["//visibility:public"], | visibility = ["//visibility:public"], | ||||
@@ -531,9 +531,10 @@ void initialize_${configuration_name}(Manifest &manifest) { | |||||
################################################################################################### | ################################################################################################### | ||||
class EmitConvSingleKernelWrapper(): | class EmitConvSingleKernelWrapper(): | ||||
def __init__(self, kernel_path, operation): | |||||
def __init__(self, kernel_path, operation, short_path=False): | |||||
self.kernel_path = kernel_path | self.kernel_path = kernel_path | ||||
self.operation = operation | self.operation = operation | ||||
self.short_path = short_path | |||||
if self.operation.conv_kind == ConvKind.Fprop: | if self.operation.conv_kind == ConvKind.Fprop: | ||||
self.instance_emitter = EmitConv2dInstance() | self.instance_emitter = EmitConv2dInstance() | ||||
@@ -582,7 +583,11 @@ void initialize_${operation_name}(Manifest &manifest) { | |||||
# | # | ||||
def __enter__(self): | def __enter__(self): | ||||
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | |||||
if self.short_path: | |||||
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt) | |||||
GlobalCnt.cnt += 1 | |||||
else: | |||||
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | |||||
self.kernel_file = open(self.kernel_path, "w") | self.kernel_file = open(self.kernel_path, "w") | ||||
self.kernel_file.write(self.header_template) | self.kernel_file.write(self.header_template) | ||||
return self | return self | ||||
@@ -994,7 +994,8 @@ void initialize_${configuration_name}(Manifest &manifest) { | |||||
################################################################################################### | ################################################################################################### | ||||
class EmitGemmSingleKernelWrapper: | class EmitGemmSingleKernelWrapper: | ||||
def __init__(self, kernel_path, gemm_operation): | |||||
def __init__(self, kernel_path, gemm_operation, short_path=False): | |||||
self.short_path = short_path | |||||
self.kernel_path = kernel_path | self.kernel_path = kernel_path | ||||
self.operation = gemm_operation | self.operation = gemm_operation | ||||
@@ -1070,10 +1071,11 @@ void initialize_${operation_name}(Manifest &manifest) { | |||||
################################################################################################### | ################################################################################################### | ||||
class EmitGemvSingleKernelWrapper: | class EmitGemvSingleKernelWrapper: | ||||
def __init__(self, kernel_path, gemm_operation, wrapper_path): | |||||
def __init__(self, kernel_path, gemm_operation, wrapper_path, short_path=False): | |||||
self.kernel_path = kernel_path | self.kernel_path = kernel_path | ||||
self.wrapper_path = wrapper_path | self.wrapper_path = wrapper_path | ||||
self.operation = gemm_operation | self.operation = gemm_operation | ||||
self.short_path = short_path | |||||
self.wrapper_template = """ | self.wrapper_template = """ | ||||
template void megdnn::cuda::cutlass_wrapper:: | template void megdnn::cuda::cutlass_wrapper:: | ||||
@@ -1107,7 +1109,11 @@ ${operation_instance} | |||||
""" | """ | ||||
# | # | ||||
def __enter__(self): | def __enter__(self): | ||||
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | |||||
if self.short_path: | |||||
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt) | |||||
GlobalCnt.cnt += 1 | |||||
else: | |||||
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) | |||||
self.kernel_file = open(self.kernel_path, "w") | self.kernel_file = open(self.kernel_path, "w") | ||||
self.kernel_file.write(SubstituteTemplate(self.header_template, { | self.kernel_file.write(SubstituteTemplate(self.header_template, { | ||||
'wrapper_path': self.wrapper_path, | 'wrapper_path': self.wrapper_path, | ||||
@@ -8,6 +8,7 @@ import enum | |||||
import os.path | import os.path | ||||
import shutil | import shutil | ||||
import argparse | import argparse | ||||
import platform | |||||
from library import * | from library import * | ||||
from manifest import * | from manifest import * | ||||
@@ -634,7 +635,7 @@ if __name__ == "__main__": | |||||
default='simt', help="kernel type of CUTLASS kernel generator") | default='simt', help="kernel type of CUTLASS kernel generator") | ||||
gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" | gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl" | ||||
short_path = (platform.system() == "Windows" or platform.system().find('NT') >= 0) and ('true'!= os.getenv("CUTLASS_WITH_LONG_PATH", default='False').lower()) | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if args.operations == "gemm": | if args.operations == "gemm": | ||||
@@ -648,15 +649,15 @@ if __name__ == "__main__": | |||||
if args.operations == "conv2d" or args.operations == "deconv": | if args.operations == "conv2d" or args.operations == "deconv": | ||||
for operation in operations: | for operation in operations: | ||||
with EmitConvSingleKernelWrapper(args.output, operation) as emitter: | |||||
with EmitConvSingleKernelWrapper(args.output, operation, short_path) as emitter: | |||||
emitter.emit() | emitter.emit() | ||||
elif args.operations == "gemm": | elif args.operations == "gemm": | ||||
for operation in operations: | for operation in operations: | ||||
with EmitGemmSingleKernelWrapper(args.output, operation) as emitter: | |||||
with EmitGemmSingleKernelWrapper(args.output, operation, short_path) as emitter: | |||||
emitter.emit() | emitter.emit() | ||||
elif args.operations == "gemv": | elif args.operations == "gemv": | ||||
for operation in operations: | for operation in operations: | ||||
with EmitGemvSingleKernelWrapper(args.output, operation, gemv_wrapper_path) as emitter: | |||||
with EmitGemvSingleKernelWrapper(args.output, operation, gemv_wrapper_path, short_path) as emitter: | |||||
emitter.emit() | emitter.emit() | ||||
if args.operations != "gemv": | if args.operations != "gemv": | ||||
@@ -612,3 +612,6 @@ class TensorDescription: | |||||
self.complex_transform = complex_transform | self.complex_transform = complex_transform | ||||
################################################################################################### | ################################################################################################### | ||||
class GlobalCnt: | |||||
cnt = 0 |