GitOrigin-RevId: 5684e5ea43
dev-support-lite-fork-debug-mode
@@ -8,7 +8,7 @@ | |||||
import enum | import enum | ||||
import os.path | import os.path | ||||
import shutil | import shutil | ||||
from typing import Tuple, List | |||||
from typing import List, Tuple | |||||
from library import * | from library import * | ||||
@@ -5,14 +5,13 @@ | |||||
# | # | ||||
import enum | import enum | ||||
import os.path | |||||
import shutil | |||||
import functools | import functools | ||||
import operator | import operator | ||||
import os.path | |||||
import shutil | |||||
from library import * | from library import * | ||||
################################################################################################### | ################################################################################################### | ||||
# | # | ||||
# Data structure modeling a GEMM operation | # Data structure modeling a GEMM operation | ||||
@@ -1,11 +1,11 @@ | |||||
from generator import ( | |||||
GenerateGemmOperations, | |||||
GenerateGemvOperations, | |||||
from generator import ( # isort: skip; isort: skip | |||||
GenerateConv2dOperations, | GenerateConv2dOperations, | ||||
GenerateDeconvOperations, | GenerateDeconvOperations, | ||||
GenerateDwconv2dFpropOperations, | |||||
GenerateDwconv2dDgradOperations, | GenerateDwconv2dDgradOperations, | ||||
GenerateDwconv2dFpropOperations, | |||||
GenerateDwconv2dWgradOperations, | GenerateDwconv2dWgradOperations, | ||||
GenerateGemmOperations, | |||||
GenerateGemvOperations, | |||||
) | ) | ||||
@@ -35,12 +35,14 @@ def write_op_list(f, gen_op, gen_type): | |||||
if gen_op != "gemv": | if gen_op != "gemv": | ||||
f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type)) | f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type)) | ||||
# Write down a list of merged filenames | # Write down a list of merged filenames | ||||
def write_merge_file_name(f, gen_op, gen_type, split_number): | def write_merge_file_name(f, gen_op, gen_type, split_number): | ||||
for i in range(0, split_number): | for i in range(0, split_number): | ||||
f.write(' "{}_{}_{}.cu",\n'.format(gen_op,gen_type,i)) | |||||
f.write(' "{}_{}_{}.cu",\n'.format(gen_op, gen_type, i)) | |||||
if gen_op != "gemv": | if gen_op != "gemv": | ||||
f.write(' "all_{}_{}_operations.cu",\n'.format(gen_op,gen_type)) | |||||
f.write(' "all_{}_{}_operations.cu",\n'.format(gen_op, gen_type)) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
with open("list.bzl", "w") as f: | with open("list.bzl", "w") as f: | ||||
@@ -4,12 +4,12 @@ | |||||
# \brief Generates the CUTLASS Library's instances | # \brief Generates the CUTLASS Library's instances | ||||
# | # | ||||
import argparse | |||||
import enum | import enum | ||||
import os.path | import os.path | ||||
import shutil | |||||
import argparse | |||||
import platform | import platform | ||||
import string | import string | ||||
from library import * | from library import * | ||||
from manifest import * | from manifest import * | ||||
@@ -899,9 +899,12 @@ def GenerateGemm_Simt(args): | |||||
warpShapes.append([warp0, warp1]) | warpShapes.append([warp0, warp1]) | ||||
# sgemm | # sgemm | ||||
precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[ | |||||
"s" | |||||
] | |||||
( | |||||
precisionType, | |||||
precisionBits, | |||||
threadblockMaxElements, | |||||
threadblockTilesL0, | |||||
) = precisions["s"] | |||||
layouts = [ | layouts = [ | ||||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn | (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn | ||||
@@ -1091,9 +1094,12 @@ def GenerateDwconv2d_Simt(args, conv_kind): | |||||
warpShapes.append([warp0, warp1]) | warpShapes.append([warp0, warp1]) | ||||
# sgemm | # sgemm | ||||
precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[ | |||||
"s" | |||||
] | |||||
( | |||||
precisionType, | |||||
precisionBits, | |||||
threadblockMaxElements, | |||||
threadblockTilesL0, | |||||
) = precisions["s"] | |||||
layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)] | layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)] | ||||
@@ -1304,7 +1310,7 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind): | |||||
for dst_type, dst_layout in zip(dst_types, dst_layouts): | for dst_type, dst_layout in zip(dst_types, dst_layouts): | ||||
for alignment_src in alignment_constraints: | for alignment_src in alignment_constraints: | ||||
if conv_kind == ConvKind.Wgrad: | if conv_kind == ConvKind.Wgrad: | ||||
# skip io16xc16 | |||||
# skip io16xc16 | |||||
if math_inst.element_accumulator == DataType.f16: | if math_inst.element_accumulator == DataType.f16: | ||||
continue | continue | ||||
for alignment_diff in alignment_constraints: | for alignment_diff in alignment_constraints: | ||||
@@ -1319,7 +1325,7 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind): | |||||
min_cc, | min_cc, | ||||
alignment_src, | alignment_src, | ||||
alignment_diff, | alignment_diff, | ||||
32, # always f32 output | |||||
32, # always f32 output | |||||
SpecialOptimizeDesc.NoneSpecialOpt, | SpecialOptimizeDesc.NoneSpecialOpt, | ||||
ImplicitGemmMode.GemmNT, | ImplicitGemmMode.GemmNT, | ||||
False, | False, | ||||
@@ -1656,6 +1662,7 @@ def GenerateGemvOperations(args): | |||||
) | ) | ||||
return GenerateGemv_Simt(args) | return GenerateGemv_Simt(args) | ||||
################################################################################ | ################################################################################ | ||||
# parameters | # parameters | ||||
# split_number - the concated file will be divided into split_number parts | # split_number - the concated file will be divided into split_number parts | ||||
@@ -1668,10 +1675,21 @@ def GenerateGemvOperations(args): | |||||
# epilogue - the epilogue in the file | # epilogue - the epilogue in the file | ||||
# wrapper_path - wrapper path | # wrapper_path - wrapper path | ||||
################################################################################ | ################################################################################ | ||||
def ConcatFile(split_number:int, file_path:str,operations:str,type:str,head:str,required_cuda_ver_major:str, required_cuda_ver_minor:str, epilogue:str, wrapper_path = None): | |||||
def ConcatFile( | |||||
split_number: int, | |||||
file_path: str, | |||||
operations: str, | |||||
type: str, | |||||
head: str, | |||||
required_cuda_ver_major: str, | |||||
required_cuda_ver_minor: str, | |||||
epilogue: str, | |||||
wrapper_path=None, | |||||
): | |||||
import os | import os | ||||
meragefiledir = file_path | meragefiledir = file_path | ||||
filenames=os.listdir(meragefiledir) | |||||
filenames = os.listdir(meragefiledir) | |||||
# filter file | # filter file | ||||
if "tensorop" in type: | if "tensorop" in type: | ||||
sub_string_1 = "tensorop" | sub_string_1 = "tensorop" | ||||
@@ -1679,197 +1697,183 @@ def ConcatFile(split_number:int, file_path:str,operations:str,type:str,head:str, | |||||
else: | else: | ||||
sub_string_1 = sub_string_2 = "simt" | sub_string_1 = sub_string_2 = "simt" | ||||
if "dwconv2d_" in operations: | if "dwconv2d_" in operations: | ||||
filtered_operations = operations[:2]+operations[9:] | |||||
filtered_operations = operations[:2] + operations[9:] | |||||
elif ("conv2d" in operations) or ("deconv" in operations): | elif ("conv2d" in operations) or ("deconv" in operations): | ||||
filtered_operations = "cutlass" | filtered_operations = "cutlass" | ||||
else: | else: | ||||
filtered_operations = operations | filtered_operations = operations | ||||
#get the file list number | |||||
# get the file list number | |||||
file_list = {} | file_list = {} | ||||
file_list[operations + type] = 0 | file_list[operations + type] = 0 | ||||
for filename in filenames: | for filename in filenames: | ||||
if (filtered_operations in filename) and (sub_string_1 in filename) and (sub_string_2 in filename) and ("all_" not in filename): | |||||
if ( | |||||
(filtered_operations in filename) | |||||
and (sub_string_1 in filename) | |||||
and (sub_string_2 in filename) | |||||
and ("all_" not in filename) | |||||
): | |||||
file_list[operations + type] += 1 | file_list[operations + type] += 1 | ||||
#concat file for linux | |||||
# concat file for linux | |||||
flag_1 = 0 | flag_1 = 0 | ||||
flag_2 = 0 | flag_2 = 0 | ||||
for filename in filenames: | for filename in filenames: | ||||
if (filtered_operations in filename) and (sub_string_1 in filename) and (sub_string_2 in filename) and ("all_" not in filename): | |||||
if ( | |||||
(filtered_operations in filename) | |||||
and (sub_string_1 in filename) | |||||
and (sub_string_2 in filename) | |||||
and ("all_" not in filename) | |||||
): | |||||
flag_1 += 1 | flag_1 += 1 | ||||
filepath=meragefiledir+'/'+filename | |||||
if (flag_1 >= flag_2 * (file_list[operations + type]/split_number)) and (flag_1 <= (flag_2 + 1) * (file_list[operations + type]/split_number)): | |||||
file =open(file_path + '/{}_{}_{}.cu'.format(operations,type, flag_2),'a') | |||||
#write Template at the head | |||||
filepath = meragefiledir + "/" + filename | |||||
if (flag_1 >= flag_2 * (file_list[operations + type] / split_number)) and ( | |||||
flag_1 <= (flag_2 + 1) * (file_list[operations + type] / split_number) | |||||
): | |||||
file = open( | |||||
file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a" | |||||
) | |||||
# write Template at the head | |||||
if wrapper_path is None: | if wrapper_path is None: | ||||
file.write( | file.write( | ||||
SubstituteTemplate( | SubstituteTemplate( | ||||
head, | head, | ||||
{ | { | ||||
"required_cuda_ver_major": str( | |||||
required_cuda_ver_major | |||||
), | |||||
"required_cuda_ver_minor": str( | |||||
required_cuda_ver_minor | |||||
), | |||||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
}, | }, | ||||
) | ) | ||||
) | ) | ||||
else: | else: | ||||
file.write( | file.write( | ||||
SubstituteTemplate( | |||||
head, | |||||
{ | |||||
"wrapper_path": wrapper_path, | |||||
"required_cuda_ver_major": str( | |||||
required_cuda_ver_major | |||||
), | |||||
"required_cuda_ver_minor": str( | |||||
required_cuda_ver_minor | |||||
), | |||||
}, | |||||
) | |||||
SubstituteTemplate( | |||||
head, | |||||
{ | |||||
"wrapper_path": wrapper_path, | |||||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
}, | |||||
) | ) | ||||
) | |||||
# concat all the remaining files | # concat all the remaining files | ||||
if flag_2 == (split_number - 1): | if flag_2 == (split_number - 1): | ||||
for line in open(filepath): | for line in open(filepath): | ||||
file.writelines(line) | file.writelines(line) | ||||
os.remove(filepath) | os.remove(filepath) | ||||
file.write('\n') | |||||
file.write("\n") | |||||
file.write(epilogue) | file.write(epilogue) | ||||
continue | continue | ||||
for line in open(filepath): | for line in open(filepath): | ||||
file.writelines(line) | file.writelines(line) | ||||
os.remove(filepath) | os.remove(filepath) | ||||
file.write('\n') | |||||
file.write("\n") | |||||
file.write(epilogue) | file.write(epilogue) | ||||
else: | else: | ||||
#write Template at the head | |||||
# write Template at the head | |||||
if wrapper_path is None: | if wrapper_path is None: | ||||
file.write( | file.write( | ||||
SubstituteTemplate( | SubstituteTemplate( | ||||
head, | head, | ||||
{ | { | ||||
"required_cuda_ver_major": str( | |||||
required_cuda_ver_major | |||||
), | |||||
"required_cuda_ver_minor": str( | |||||
required_cuda_ver_minor | |||||
), | |||||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
}, | }, | ||||
) | ) | ||||
) | ) | ||||
else: | else: | ||||
file.write( | file.write( | ||||
SubstituteTemplate( | |||||
head, | |||||
{ | |||||
"wrapper_path": wrapper_path, | |||||
"required_cuda_ver_major": str( | |||||
required_cuda_ver_major | |||||
), | |||||
"required_cuda_ver_minor": str( | |||||
required_cuda_ver_minor | |||||
), | |||||
}, | |||||
) | |||||
SubstituteTemplate( | |||||
head, | |||||
{ | |||||
"wrapper_path": wrapper_path, | |||||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
}, | |||||
) | ) | ||||
) | |||||
for line in open(filepath): | for line in open(filepath): | ||||
file.writelines(line) | file.writelines(line) | ||||
os.remove(filepath) | os.remove(filepath) | ||||
file.write('\n') | |||||
file.write("\n") | |||||
file.write(epilogue) | file.write(epilogue) | ||||
file.close() | file.close() | ||||
flag_2 += 1 | flag_2 += 1 | ||||
#concat file for windows | |||||
# concat file for windows | |||||
elif filename[0].isdigit() and ("all_" not in filename): | elif filename[0].isdigit() and ("all_" not in filename): | ||||
flag_1 += 1 | flag_1 += 1 | ||||
filepath=meragefiledir+'/'+filename | |||||
if (flag_1 >= flag_2 * (len(filenames)/split_number)) and (flag_1 <= (flag_2 + 1) * (len(filenames)/split_number)): | |||||
file =open(file_path + '/{}_{}_{}.cu'.format(operations,type, flag_2),'a') | |||||
#write Template at the head | |||||
filepath = meragefiledir + "/" + filename | |||||
if (flag_1 >= flag_2 * (len(filenames) / split_number)) and ( | |||||
flag_1 <= (flag_2 + 1) * (len(filenames) / split_number) | |||||
): | |||||
file = open( | |||||
file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a" | |||||
) | |||||
# write Template at the head | |||||
if wrapper_path is None: | if wrapper_path is None: | ||||
file.write( | file.write( | ||||
SubstituteTemplate( | SubstituteTemplate( | ||||
head, | head, | ||||
{ | { | ||||
"required_cuda_ver_major": str( | |||||
required_cuda_ver_major | |||||
), | |||||
"required_cuda_ver_minor": str( | |||||
required_cuda_ver_minor | |||||
), | |||||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
}, | }, | ||||
) | ) | ||||
) | ) | ||||
else: | else: | ||||
file.write( | file.write( | ||||
SubstituteTemplate( | |||||
head, | |||||
{ | |||||
"wrapper_path": wrapper_path, | |||||
"required_cuda_ver_major": str( | |||||
required_cuda_ver_major | |||||
), | |||||
"required_cuda_ver_minor": str( | |||||
required_cuda_ver_minor | |||||
), | |||||
}, | |||||
) | |||||
SubstituteTemplate( | |||||
head, | |||||
{ | |||||
"wrapper_path": wrapper_path, | |||||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
}, | |||||
) | ) | ||||
) | |||||
# concat all the remaining files | # concat all the remaining files | ||||
if flag_2 == (split_number - 1): | if flag_2 == (split_number - 1): | ||||
for line in open(filepath): | for line in open(filepath): | ||||
file.writelines(line) | file.writelines(line) | ||||
os.remove(filepath) | os.remove(filepath) | ||||
file.write('\n') | |||||
file.write("\n") | |||||
file.write(epilogue) | file.write(epilogue) | ||||
continue | continue | ||||
for line in open(filepath): | for line in open(filepath): | ||||
file.writelines(line) | file.writelines(line) | ||||
os.remove(filepath) | os.remove(filepath) | ||||
file.write('\n') | |||||
file.write("\n") | |||||
file.write(epilogue) | file.write(epilogue) | ||||
else: | else: | ||||
#write Template at the head | |||||
# write Template at the head | |||||
if wrapper_path is None: | if wrapper_path is None: | ||||
file.write( | file.write( | ||||
SubstituteTemplate( | SubstituteTemplate( | ||||
head, | head, | ||||
{ | { | ||||
"required_cuda_ver_major": str( | |||||
required_cuda_ver_major | |||||
), | |||||
"required_cuda_ver_minor": str( | |||||
required_cuda_ver_minor | |||||
), | |||||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
}, | }, | ||||
) | ) | ||||
) | ) | ||||
else: | else: | ||||
file.write( | file.write( | ||||
SubstituteTemplate( | |||||
head, | |||||
{ | |||||
"wrapper_path": wrapper_path, | |||||
"required_cuda_ver_major": str( | |||||
required_cuda_ver_major | |||||
), | |||||
"required_cuda_ver_minor": str( | |||||
required_cuda_ver_minor | |||||
), | |||||
}, | |||||
) | |||||
SubstituteTemplate( | |||||
head, | |||||
{ | |||||
"wrapper_path": wrapper_path, | |||||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||||
}, | |||||
) | ) | ||||
) | |||||
for line in open(filepath): | for line in open(filepath): | ||||
file.writelines(line) | file.writelines(line) | ||||
os.remove(filepath) | os.remove(filepath) | ||||
file.write('\n') | |||||
file.write("\n") | |||||
file.write(epilogue) | file.write(epilogue) | ||||
file.close() | file.close() | ||||
flag_2 += 1 | flag_2 += 1 | ||||
################################################################################################### | ################################################################################################### | ||||
################################################################################################### | ################################################################################################### | ||||
@@ -1940,39 +1944,97 @@ if __name__ == "__main__": | |||||
args.output, operation, short_path | args.output, operation, short_path | ||||
) as emitter: | ) as emitter: | ||||
emitter.emit() | emitter.emit() | ||||
head = EmitConvSingleKernelWrapper(args.output, operations[0], short_path).header_template | |||||
head = EmitConvSingleKernelWrapper( | |||||
args.output, operations[0], short_path | |||||
).header_template | |||||
required_cuda_ver_major = operations[0].required_cuda_ver_major | required_cuda_ver_major = operations[0].required_cuda_ver_major | ||||
required_cuda_ver_minor = operations[0].required_cuda_ver_minor | required_cuda_ver_minor = operations[0].required_cuda_ver_minor | ||||
epilogue = EmitConvSingleKernelWrapper(args.output, operations[0], short_path).epilogue_template | |||||
epilogue = EmitConvSingleKernelWrapper( | |||||
args.output, operations[0], short_path | |||||
).epilogue_template | |||||
if "tensorop" in args.type: | if "tensorop" in args.type: | ||||
ConcatFile(4, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||||
ConcatFile( | |||||
4, | |||||
args.output, | |||||
args.operations, | |||||
args.type, | |||||
head, | |||||
required_cuda_ver_major, | |||||
required_cuda_ver_minor, | |||||
epilogue, | |||||
) | |||||
else: | else: | ||||
ConcatFile(2, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||||
ConcatFile( | |||||
2, | |||||
args.output, | |||||
args.operations, | |||||
args.type, | |||||
head, | |||||
required_cuda_ver_major, | |||||
required_cuda_ver_minor, | |||||
epilogue, | |||||
) | |||||
elif args.operations == "gemm": | elif args.operations == "gemm": | ||||
for operation in operations: | for operation in operations: | ||||
with EmitGemmSingleKernelWrapper( | with EmitGemmSingleKernelWrapper( | ||||
args.output, operation, short_path | args.output, operation, short_path | ||||
) as emitter: | ) as emitter: | ||||
emitter.emit() | emitter.emit() | ||||
head = EmitGemmSingleKernelWrapper(args.output, operations[0], short_path).header_template | |||||
head = EmitGemmSingleKernelWrapper( | |||||
args.output, operations[0], short_path | |||||
).header_template | |||||
required_cuda_ver_major = operations[0].required_cuda_ver_major | required_cuda_ver_major = operations[0].required_cuda_ver_major | ||||
required_cuda_ver_minor = operations[0].required_cuda_ver_minor | required_cuda_ver_minor = operations[0].required_cuda_ver_minor | ||||
epilogue = EmitGemmSingleKernelWrapper(args.output, operations[0], short_path).epilogue_template | |||||
epilogue = EmitGemmSingleKernelWrapper( | |||||
args.output, operations[0], short_path | |||||
).epilogue_template | |||||
if args.type == "tensorop884": | if args.type == "tensorop884": | ||||
ConcatFile(30, args.output, args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||||
ConcatFile( | |||||
30, | |||||
args.output, | |||||
args.operations, | |||||
args.type, | |||||
head, | |||||
required_cuda_ver_major, | |||||
required_cuda_ver_minor, | |||||
epilogue, | |||||
) | |||||
else: | else: | ||||
ConcatFile(2, args.output, args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||||
ConcatFile( | |||||
2, | |||||
args.output, | |||||
args.operations, | |||||
args.type, | |||||
head, | |||||
required_cuda_ver_major, | |||||
required_cuda_ver_minor, | |||||
epilogue, | |||||
) | |||||
elif args.operations == "gemv": | elif args.operations == "gemv": | ||||
for operation in operations: | for operation in operations: | ||||
with EmitGemvSingleKernelWrapper( | with EmitGemvSingleKernelWrapper( | ||||
args.output, operation, gemv_wrapper_path, short_path | args.output, operation, gemv_wrapper_path, short_path | ||||
) as emitter: | ) as emitter: | ||||
emitter.emit() | emitter.emit() | ||||
head = EmitGemvSingleKernelWrapper(args.output, operations[0], gemv_wrapper_path, short_path).header_template | |||||
head = EmitGemvSingleKernelWrapper( | |||||
args.output, operations[0], gemv_wrapper_path, short_path | |||||
).header_template | |||||
required_cuda_ver_major = operations[0].required_cuda_ver_major | required_cuda_ver_major = operations[0].required_cuda_ver_major | ||||
required_cuda_ver_minor = operations[0].required_cuda_ver_minor | required_cuda_ver_minor = operations[0].required_cuda_ver_minor | ||||
epilogue = EmitGemvSingleKernelWrapper(args.output, operations[0], gemv_wrapper_path, short_path).epilogue_template | |||||
ConcatFile(2, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue, wrapper_path = gemv_wrapper_path) | |||||
epilogue = EmitGemvSingleKernelWrapper( | |||||
args.output, operations[0], gemv_wrapper_path, short_path | |||||
).epilogue_template | |||||
ConcatFile( | |||||
2, | |||||
args.output, | |||||
args.operations, | |||||
args.type, | |||||
head, | |||||
required_cuda_ver_major, | |||||
required_cuda_ver_minor, | |||||
epilogue, | |||||
wrapper_path=gemv_wrapper_path, | |||||
) | |||||
if args.operations != "gemv": | if args.operations != "gemv": | ||||
GenerateManifest(args, operations, args.output) | GenerateManifest(args, operations, args.output) | ||||
@@ -4,11 +4,11 @@ | |||||
# \brief Generates the CUTLASS Library's instances | # \brief Generates the CUTLASS Library's instances | ||||
# | # | ||||
import enum | |||||
import re | import re | ||||
################################################################################################### | ################################################################################################### | ||||
import enum | |||||
# The following block implements enum.auto() for Python 3.5 variants that don't include it such | # The following block implements enum.auto() for Python 3.5 variants that don't include it such | ||||
# as the default 3.5.2 on Ubuntu 16.04. | # as the default 3.5.2 on Ubuntu 16.04. | ||||
@@ -8,9 +8,9 @@ import enum | |||||
import os.path | import os.path | ||||
import shutil | import shutil | ||||
from library import * | |||||
from gemm_operation import * | |||||
from conv2d_operation import * | from conv2d_operation import * | ||||
from gemm_operation import * | |||||
from library import * | |||||
################################################################################################### | ################################################################################################### | ||||
@@ -1,59 +1,67 @@ | |||||
#!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
import os | |||||
import argparse | import argparse | ||||
import os | |||||
from gen_elemwise_utils import DTYPES | from gen_elemwise_utils import DTYPES | ||||
def main(): | def main(): | ||||
parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
description='generate elemwise impl files', | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
parser.add_argument('--type', type=str, choices=['cuda'], | |||||
default='cuda', | |||||
help='generate cuda cond take kernel file') | |||||
parser.add_argument('output', help='output directory') | |||||
description="generate elemwise impl files", | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
) | |||||
parser.add_argument( | |||||
"--type", | |||||
type=str, | |||||
choices=["cuda"], | |||||
default="cuda", | |||||
help="generate cuda cond take kernel file", | |||||
) | |||||
parser.add_argument("output", help="output directory") | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if not os.path.isdir(args.output): | if not os.path.isdir(args.output): | ||||
os.makedirs(args.output) | os.makedirs(args.output) | ||||
assert args.type =='cuda' | |||||
cpp_ext = 'cu' | |||||
assert args.type == "cuda" | |||||
cpp_ext = "cu" | |||||
for dtype in DTYPES.keys(): | for dtype in DTYPES.keys(): | ||||
fname = '{}.{}'.format(dtype, cpp_ext) | |||||
fname = "{}.{}".format(dtype, cpp_ext) | |||||
fname = os.path.join(args.output, fname) | fname = os.path.join(args.output, fname) | ||||
with open(fname, 'w') as fout: | |||||
with open(fname, "w") as fout: | |||||
w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
w('// generated by gen_cond_take_kern_impls.py') | |||||
w("// generated by gen_cond_take_kern_impls.py") | |||||
w('#include "../kern.inl"') | w('#include "../kern.inl"') | ||||
w('') | |||||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||||
w('#if !MEGDNN_DISABLE_FLOAT16') | |||||
w('namespace megdnn {') | |||||
w('namespace cuda {') | |||||
w('namespace cond_take {') | |||||
w('') | |||||
w('inst_genidx(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||||
w('#undef inst_genidx') | |||||
w('') | |||||
w('inst_copy(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||||
w('#undef inst_copy') | |||||
w('#undef inst_copy_') | |||||
w('') | |||||
w('} // cond_take') | |||||
w('} // cuda') | |||||
w('} // megdnn') | |||||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||||
w('#endif') | |||||
print('generated {}'.format(fname)) | |||||
w("") | |||||
if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||||
w("#if !MEGDNN_DISABLE_FLOAT16") | |||||
w("namespace megdnn {") | |||||
w("namespace cuda {") | |||||
w("namespace cond_take {") | |||||
w("") | |||||
w("inst_genidx(::megdnn::dtype::{})".format(DTYPES[dtype][0])) | |||||
w("#undef inst_genidx") | |||||
w("") | |||||
w("inst_copy(::megdnn::dtype::{})".format(DTYPES[dtype][0])) | |||||
w("#undef inst_copy") | |||||
w("#undef inst_copy_") | |||||
w("") | |||||
w("} // cond_take") | |||||
w("} // cuda") | |||||
w("} // megdnn") | |||||
if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||||
w("#endif") | |||||
print("generated {}".format(fname)) | |||||
os.utime(args.output) | os.utime(args.output) | ||||
if __name__ == '__main__': | |||||
if __name__ == "__main__": | |||||
main() | main() |
@@ -1,37 +1,47 @@ | |||||
#!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
import os | |||||
import argparse | import argparse | ||||
import itertools | import itertools | ||||
import os | |||||
PREFIXES = { | |||||
"dp4a": [ | |||||
("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True), | |||||
("batch_conv_bias_int8_gemm_ncdiv4hw4", False), | |||||
("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False), | |||||
] | |||||
} | |||||
PREFIXES = {"dp4a": [("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True), ("batch_conv_bias_int8_gemm_ncdiv4hw4", False), ("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False)]} | |||||
ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")} | |||||
ACTIVATIONS = {1: ("IDENTITY", "_id"), | |||||
2: ("RELU", "_relu"), | |||||
3: ("H_SWISH", "_hswish")} | |||||
BIASES = { | |||||
1: ("PerElementBiasVisitor", "_per_elem"), | |||||
2: ("PerChannelBiasVisitor", "_per_chan"), | |||||
} | |||||
BIASES = {1: ("PerElementBiasVisitor", "_per_elem"), | |||||
2: ("PerChannelBiasVisitor", "_per_chan")} | |||||
SUFFIXES = {"dp4a": [""], "imma": [""]} | |||||
SUFFIXES = {"dp4a": [""], | |||||
"imma": [""]} | |||||
def main(): | def main(): | ||||
parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
description='generate cuda batch conv bias (dp4a/imma) kern impl files', | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
parser.add_argument('--type', type=str, choices=['dp4a', | |||||
'imma'], | |||||
default='dp4a', help='generate cuda conv bias kernel file') | |||||
parser.add_argument('output', help='output directory') | |||||
description="generate cuda batch conv bias (dp4a/imma) kern impl files", | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
) | |||||
parser.add_argument( | |||||
"--type", | |||||
type=str, | |||||
choices=["dp4a", "imma"], | |||||
default="dp4a", | |||||
help="generate cuda conv bias kernel file", | |||||
) | |||||
parser.add_argument("output", help="output directory") | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if not os.path.isdir(args.output): | if not os.path.isdir(args.output): | ||||
os.makedirs(args.output) | os.makedirs(args.output) | ||||
inst = ''' | |||||
inst = """ | |||||
template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, | template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, | ||||
IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>>>( | IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>>>( | ||||
const int8_t* d_src, | const int8_t* d_src, | ||||
@@ -41,7 +51,7 @@ template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, | |||||
const ConvParam& param, | const ConvParam& param, | ||||
float alpha, | float alpha, | ||||
float beta, | float beta, | ||||
cudaStream_t stream);''' | |||||
cudaStream_t stream);""" | |||||
for prefix in PREFIXES[args.type]: | for prefix in PREFIXES[args.type]: | ||||
for suffix in SUFFIXES[args.type]: | for suffix in SUFFIXES[args.type]: | ||||
@@ -52,17 +62,23 @@ template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, | |||||
fname = os.path.join(args.output, fname) | fname = os.path.join(args.output, fname) | ||||
with open(fname, "w") as fout: | with open(fname, "w") as fout: | ||||
w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
w('// generated by gen_batch_cuda_conv_bias_kern_impls.py') | |||||
cur_inst = inst.replace("PREFIX", prefix[0]).replace("SUFFIX", suffix).replace("BIAS", bias[0]).replace("ACTIVATION", act[0]) | |||||
w("// generated by gen_batch_cuda_conv_bias_kern_impls.py") | |||||
cur_inst = ( | |||||
inst.replace("PREFIX", prefix[0]) | |||||
.replace("SUFFIX", suffix) | |||||
.replace("BIAS", bias[0]) | |||||
.replace("ACTIVATION", act[0]) | |||||
) | |||||
if has_workspace: | if has_workspace: | ||||
cur_inst = cur_inst.replace("WORKSPACE", "\nint* d_workspace, ") | cur_inst = cur_inst.replace("WORKSPACE", "\nint* d_workspace, ") | ||||
else: | else: | ||||
cur_inst = cur_inst.replace("WORKSPACE", "") | |||||
cur_inst = cur_inst.replace("WORKSPACE", "") | |||||
w('#include "../{}{}.cuinl"'.format(prefix[0], suffix)) | w('#include "../{}{}.cuinl"'.format(prefix[0], suffix)) | ||||
w(cur_inst) | w(cur_inst) | ||||
print('generated {}'.format(fname)) | |||||
print("generated {}".format(fname)) | |||||
os.utime(args.output) | os.utime(args.output) | ||||
if __name__ == '__main__': | |||||
if __name__ == "__main__": | |||||
main() | main() |
@@ -1,39 +1,57 @@ | |||||
#!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
import os | |||||
import argparse | import argparse | ||||
import itertools | import itertools | ||||
import os | |||||
PREFIXES = { | |||||
"dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4", | |||||
"imma": "conv_bias_int8_implicit_gemm", | |||||
} | |||||
PREFIXES = {"dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4", "imma": "conv_bias_int8_implicit_gemm"} | |||||
ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")} | |||||
ACTIVATIONS = {1: ("IDENTITY", "_id"), | |||||
2: ("RELU", "_relu"), | |||||
3: ("H_SWISH", "_hswish")} | |||||
BIASES = { | |||||
1: ("PerElementBiasVisitor", "_per_elem"), | |||||
2: ("PerChannelBiasVisitor", "_per_chan"), | |||||
} | |||||
BIASES = {1: ("PerElementBiasVisitor", "_per_elem"), | |||||
2: ("PerChannelBiasVisitor", "_per_chan")} | |||||
SUFFIXES = { | |||||
"dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"], | |||||
"imma": [ | |||||
"_imma16x16x16_cdiv4hwn4", | |||||
"_imma8x32x16_cdiv4hwn4", | |||||
"_imma32x8x16_cdiv4hwn4", | |||||
"_imma16x16x16_cdiv4hwn4_reorder_filter", | |||||
"_imma8x32x16_cdiv4hwn4_reorder_filter", | |||||
"_imma32x8x16_cdiv4hwn4_reorder_filter", | |||||
"_imma16x16x16_cdiv4hwn4_unroll_width", | |||||
"_imma8x32x16_cdiv4hwn4_unroll_width", | |||||
"_imma32x8x16_cdiv4hwn4_unroll_width", | |||||
], | |||||
} | |||||
SUFFIXES = {"dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"], | |||||
"imma": ["_imma16x16x16_cdiv4hwn4", "_imma8x32x16_cdiv4hwn4", "_imma32x8x16_cdiv4hwn4", | |||||
"_imma16x16x16_cdiv4hwn4_reorder_filter", "_imma8x32x16_cdiv4hwn4_reorder_filter", "_imma32x8x16_cdiv4hwn4_reorder_filter", | |||||
"_imma16x16x16_cdiv4hwn4_unroll_width", "_imma8x32x16_cdiv4hwn4_unroll_width", "_imma32x8x16_cdiv4hwn4_unroll_width"]} | |||||
def main(): | def main(): | ||||
parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
description='generate cuda conv bias (dp4a/imma) kern impl files', | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
parser.add_argument('--type', type=str, choices=['dp4a', | |||||
'imma'], | |||||
default='dp4a', help='generate cuda conv bias kernel file') | |||||
parser.add_argument('output', help='output directory') | |||||
description="generate cuda conv bias (dp4a/imma) kern impl files", | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
) | |||||
parser.add_argument( | |||||
"--type", | |||||
type=str, | |||||
choices=["dp4a", "imma"], | |||||
default="dp4a", | |||||
help="generate cuda conv bias kernel file", | |||||
) | |||||
parser.add_argument("output", help="output directory") | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if not os.path.isdir(args.output): | if not os.path.isdir(args.output): | ||||
os.makedirs(args.output) | os.makedirs(args.output) | ||||
inst = ''' | |||||
inst = """ | |||||
template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, | template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, | ||||
IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>>>( | IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>>>( | ||||
const int8_t* d_src, | const int8_t* d_src, | ||||
@@ -43,7 +61,7 @@ template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, | |||||
const ConvParam& param, | const ConvParam& param, | ||||
float alpha, | float alpha, | ||||
float beta, | float beta, | ||||
cudaStream_t stream);''' | |||||
cudaStream_t stream);""" | |||||
for suffix in SUFFIXES[args.type]: | for suffix in SUFFIXES[args.type]: | ||||
for _, act in ACTIVATIONS.items(): | for _, act in ACTIVATIONS.items(): | ||||
@@ -53,13 +71,19 @@ template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, | |||||
fname = os.path.join(args.output, fname) | fname = os.path.join(args.output, fname) | ||||
with open(fname, "w") as fout: | with open(fname, "w") as fout: | ||||
w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
w('// generated by gen_cuda_conv_bias_kern_impls.py') | |||||
cur_inst = inst.replace("PREFIX", prefix).replace("SUFFIX", suffix).replace("BIAS", bias[0]).replace("ACTIVATION", act[0]) | |||||
w("// generated by gen_cuda_conv_bias_kern_impls.py") | |||||
cur_inst = ( | |||||
inst.replace("PREFIX", prefix) | |||||
.replace("SUFFIX", suffix) | |||||
.replace("BIAS", bias[0]) | |||||
.replace("ACTIVATION", act[0]) | |||||
) | |||||
w('#include "../{}{}.cuinl"'.format(prefix, suffix)) | w('#include "../{}{}.cuinl"'.format(prefix, suffix)) | ||||
w(cur_inst) | w(cur_inst) | ||||
print('generated {}'.format(fname)) | |||||
print("generated {}".format(fname)) | |||||
os.utime(args.output) | os.utime(args.output) | ||||
if __name__ == '__main__': | |||||
if __name__ == "__main__": | |||||
main() | main() |
@@ -1,34 +1,39 @@ | |||||
#!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
import os | |||||
import argparse | import argparse | ||||
import os | |||||
from gen_elemwise_utils import ARITIES, MODES | from gen_elemwise_utils import ARITIES, MODES | ||||
def main(): | def main(): | ||||
parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
description='generate elemwise each mode', | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
description="generate elemwise each mode", | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
) | |||||
parser.add_argument('output', help='output directory') | |||||
parser.add_argument("output", help="output directory") | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
with open(args.output, 'w') as fout: | |||||
with open(args.output, "w") as fout: | |||||
w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
w('// generated by gen_elemwise_each_mode.py') | |||||
w("// generated by gen_elemwise_each_mode.py") | |||||
keys = list(MODES.keys()) | keys = list(MODES.keys()) | ||||
keys.sort() | keys.sort() | ||||
for (anum, ctype) in keys: | for (anum, ctype) in keys: | ||||
w('#define MEGDNN_FOREACH_ELEMWISE_MODE_{}_{}(cb) \\'.format( | |||||
ARITIES[anum], ctype)) | |||||
w( | |||||
"#define MEGDNN_FOREACH_ELEMWISE_MODE_{}_{}(cb) \\".format( | |||||
ARITIES[anum], ctype | |||||
) | |||||
) | |||||
for mode in MODES[(anum, ctype)]: | for mode in MODES[(anum, ctype)]: | ||||
w(' MEGDNN_ELEMWISE_MODE_ENABLE({}, cb) \\'.format(mode)) | |||||
w('') | |||||
w(" MEGDNN_ELEMWISE_MODE_ENABLE({}, cb) \\".format(mode)) | |||||
w("") | |||||
print('generated each_mode.inl') | |||||
print("generated each_mode.inl") | |||||
os.utime(args.output) | os.utime(args.output) | ||||
if __name__ == '__main__': | |||||
if __name__ == "__main__": | |||||
main() | main() |
@@ -1,56 +1,63 @@ | |||||
#!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
import os | |||||
import argparse | import argparse | ||||
import itertools | import itertools | ||||
import os | |||||
from gen_elemwise_utils import ARITIES, DTYPES, MODES | from gen_elemwise_utils import ARITIES, DTYPES, MODES | ||||
def main(): | def main(): | ||||
parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
description='generate elemwise impl files', | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
parser.add_argument('--type', type=str, choices=['cuda', | |||||
'hip', | |||||
'cpp'], | |||||
default='cpp', help='generate cuda/hip kernel file') | |||||
parser.add_argument('output', help='output directory') | |||||
description="generate elemwise impl files", | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
) | |||||
parser.add_argument( | |||||
"--type", | |||||
type=str, | |||||
choices=["cuda", "hip", "cpp"], | |||||
default="cpp", | |||||
help="generate cuda/hip kernel file", | |||||
) | |||||
parser.add_argument("output", help="output directory") | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if not os.path.isdir(args.output): | if not os.path.isdir(args.output): | ||||
os.makedirs(args.output) | os.makedirs(args.output) | ||||
if args.type == 'cuda': | |||||
cpp_ext = 'cu' | |||||
elif args.type == 'hip': | |||||
cpp_ext = 'cpp.hip' | |||||
if args.type == "cuda": | |||||
cpp_ext = "cu" | |||||
elif args.type == "hip": | |||||
cpp_ext = "cpp.hip" | |||||
else: | else: | ||||
assert args.type == 'cpp' | |||||
cpp_ext = 'cpp' | |||||
assert args.type == "cpp" | |||||
cpp_ext = "cpp" | |||||
for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()): | for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()): | ||||
for mode in MODES[(anum, DTYPES[ctype][1])]: | for mode in MODES[(anum, DTYPES[ctype][1])]: | ||||
formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode) | |||||
fname = '{}_{}.{}'.format(mode, ctype, cpp_ext) | |||||
formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode) | |||||
fname = "{}_{}.{}".format(mode, ctype, cpp_ext) | |||||
fname = os.path.join(args.output, fname) | fname = os.path.join(args.output, fname) | ||||
with open(fname, 'w') as fout: | |||||
with open(fname, "w") as fout: | |||||
w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
w('// generated by gen_elemwise_kern_impls.py') | |||||
w("// generated by gen_elemwise_kern_impls.py") | |||||
if ctype == 'dt_float16' or ctype == 'dt_bfloat16': | |||||
w('#if !MEGDNN_DISABLE_FLOAT16') | |||||
if ctype == "dt_float16" or ctype == "dt_bfloat16": | |||||
w("#if !MEGDNN_DISABLE_FLOAT16") | |||||
w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) | |||||
w('#define KERN_IMPL_ARITY {}'.format(anum)) | |||||
w('#define KERN_IMPL_CTYPE {}'.format(ctype)) | |||||
w("#define KERN_IMPL_MODE(cb) {}".format(formode)) | |||||
w("#define KERN_IMPL_ARITY {}".format(anum)) | |||||
w("#define KERN_IMPL_CTYPE {}".format(ctype)) | |||||
w('#include "../kern_impl.inl"') | w('#include "../kern_impl.inl"') | ||||
if ctype == 'dt_float16' or ctype == 'dt_bfloat16': | |||||
w('#endif') | |||||
if ctype == "dt_float16" or ctype == "dt_bfloat16": | |||||
w("#endif") | |||||
print('generated {}'.format(fname)) | |||||
print("generated {}".format(fname)) | |||||
os.utime(args.output) | os.utime(args.output) | ||||
if __name__ == '__main__': | |||||
if __name__ == "__main__": | |||||
main() | main() |
@@ -1,52 +1,66 @@ | |||||
#!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
import os | |||||
import argparse | import argparse | ||||
import itertools | import itertools | ||||
from gen_elemwise_multi_type_utils import SUPPORT_DTYPES, MODES, SUPPORT_QINT32_DTYPES, QINT32_MODES | |||||
import os | |||||
from gen_elemwise_multi_type_utils import ( # isort: skip; isort: skip | |||||
MODES, | |||||
QINT32_MODES, | |||||
SUPPORT_DTYPES, | |||||
SUPPORT_QINT32_DTYPES, | |||||
) | |||||
def generate(modes, support_dtypes, output, cpp_ext): | def generate(modes, support_dtypes, output, cpp_ext): | ||||
for anum, ctype in itertools.product(modes.keys(), support_dtypes): | for anum, ctype in itertools.product(modes.keys(), support_dtypes): | ||||
print('{} : {}'.format(anum, ctype)) | |||||
print("{} : {}".format(anum, ctype)) | |||||
src_ctype = ctype[0] | src_ctype = ctype[0] | ||||
dst_ctype = ctype[1] | dst_ctype = ctype[1] | ||||
for mode in modes[anum]: | for mode in modes[anum]: | ||||
formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode) | |||||
fname = '{}_{}_{}.{}'.format(mode, src_ctype, dst_ctype, cpp_ext) | |||||
formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode) | |||||
fname = "{}_{}_{}.{}".format(mode, src_ctype, dst_ctype, cpp_ext) | |||||
fname = os.path.join(output, fname) | fname = os.path.join(output, fname) | ||||
with open(fname, 'w') as fout: | |||||
with open(fname, "w") as fout: | |||||
w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
w('// generated by gen_elemwise_multi_type_kern_impls.py') | |||||
w("// generated by gen_elemwise_multi_type_kern_impls.py") | |||||
w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) | |||||
w('#define KERN_IMPL_ARITY {}'.format(anum)) | |||||
w('#define KERN_IMPL_STYPE {}'.format(src_ctype)) | |||||
w('#define KERN_IMPL_DTYPE {}'.format(dst_ctype)) | |||||
w("#define KERN_IMPL_MODE(cb) {}".format(formode)) | |||||
w("#define KERN_IMPL_ARITY {}".format(anum)) | |||||
w("#define KERN_IMPL_STYPE {}".format(src_ctype)) | |||||
w("#define KERN_IMPL_DTYPE {}".format(dst_ctype)) | |||||
w('#include "../kern_impl.inl"') | w('#include "../kern_impl.inl"') | ||||
print('generated {}'.format(fname)) | |||||
print("generated {}".format(fname)) | |||||
def main(): | def main(): | ||||
parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
description='generate elemwise impl files', | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
parser.add_argument('--type', type=str, choices=['cuda'], | |||||
default='cuda', help='generate cuda kernel file') | |||||
parser.add_argument('output', help='output directory') | |||||
description="generate elemwise impl files", | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
) | |||||
parser.add_argument( | |||||
"--type", | |||||
type=str, | |||||
choices=["cuda"], | |||||
default="cuda", | |||||
help="generate cuda kernel file", | |||||
) | |||||
parser.add_argument("output", help="output directory") | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if not os.path.isdir(args.output): | if not os.path.isdir(args.output): | ||||
os.makedirs(args.output) | os.makedirs(args.output) | ||||
assert args.type == 'cuda' | |||||
if args.type == 'cuda': | |||||
cpp_ext = 'cu' | |||||
assert args.type == "cuda" | |||||
if args.type == "cuda": | |||||
cpp_ext = "cu" | |||||
generate(MODES, SUPPORT_DTYPES, args.output, cpp_ext) | generate(MODES, SUPPORT_DTYPES, args.output, cpp_ext) | ||||
generate(QINT32_MODES, SUPPORT_QINT32_DTYPES, args.output, cpp_ext) | generate(QINT32_MODES, SUPPORT_QINT32_DTYPES, args.output, cpp_ext) | ||||
os.utime(args.output) | os.utime(args.output) | ||||
if __name__ == '__main__': | |||||
if __name__ == "__main__": | |||||
main() | main() |
@@ -1,48 +1,131 @@ | |||||
# As cuda currently do not support quint8, so we just ignore it. | # As cuda currently do not support quint8, so we just ignore it. | ||||
SUPPORT_DTYPES = [('dt_qint8', 'dt_qint8')] | |||||
SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32'), | |||||
('dt_qint4', 'dt_qint32'), ('dt_quint4', 'dt_qint32')] | |||||
SUPPORT_DTYPES = [("dt_qint8", "dt_qint8")] | |||||
SUPPORT_QINT32_DTYPES = [ | |||||
("dt_qint32", "dt_qint8"), | |||||
("dt_qint8", "dt_qint32"), | |||||
("dt_qint4", "dt_qint32"), | |||||
("dt_quint4", "dt_qint32"), | |||||
] | |||||
SUPPORT_DTYPES_Q4 = [('dt_qint4', 'dt_qint4'), ('dt_quint4', 'dt_quint4')] | |||||
SUPPORT_QINT32_DTYPES_Q4 = [('dt_qint32', 'dt_qint4'), ('dt_qint32', 'dt_quint4')] | |||||
SUPPORT_DTYPES_Q4 = [("dt_qint4", "dt_qint4"), ("dt_quint4", "dt_quint4")] | |||||
SUPPORT_QINT32_DTYPES_Q4 = [("dt_qint32", "dt_qint4"), ("dt_qint32", "dt_quint4")] | |||||
SUPPORT_ARRITY2_DTYPES = ['dt_int32', 'dt_uint8', 'dt_int8', 'dt_int16', 'dt_bool', 'dt_float32', | |||||
'dt_float16', 'dt_bfloat16'] | |||||
SUPPORT_ARRITY1_DTYPES = ['dt_float32','dt_float16', 'dt_bfloat16'] | |||||
SUPPORT_ARRITY2_DTYPES = [ | |||||
"dt_int32", | |||||
"dt_uint8", | |||||
"dt_int8", | |||||
"dt_int16", | |||||
"dt_bool", | |||||
"dt_float32", | |||||
"dt_float16", | |||||
"dt_bfloat16", | |||||
] | |||||
SUPPORT_ARRITY1_DTYPES = ["dt_float32", "dt_float16", "dt_bfloat16"] | |||||
MODES = { | MODES = { | ||||
1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | |||||
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | |||||
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', | |||||
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'], | |||||
2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | |||||
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', | |||||
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', | |||||
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | |||||
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | |||||
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | |||||
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||||
1: [ | |||||
"RELU", | |||||
"ABS", | |||||
"NEGATE", | |||||
"ACOS", | |||||
"ASIN", | |||||
"CEIL", | |||||
"COS", | |||||
"EXP", | |||||
"EXPM1", | |||||
"FLOOR", | |||||
"LOG", | |||||
"LOG1P", | |||||
"SIGMOID", | |||||
"SIN", | |||||
"TANH", | |||||
"FAST_TANH", | |||||
"ROUND", | |||||
"ERF", | |||||
"ERFINV", | |||||
"ERFC", | |||||
"ERFCINV", | |||||
"H_SWISH", | |||||
"SILU", | |||||
"GELU", | |||||
], | |||||
2: [ | |||||
"ABS_GRAD", | |||||
"ADD", | |||||
"FLOOR_DIV", | |||||
"MAX", | |||||
"MIN", | |||||
"MOD", | |||||
"MUL", | |||||
"SIGMOID_GRAD", | |||||
"SUB", | |||||
"SWITCH_GT0", | |||||
"TANH_GRAD", | |||||
"LT", | |||||
"LEQ", | |||||
"EQ", | |||||
"FUSE_ADD_RELU", | |||||
"TRUE_DIV", | |||||
"POW", | |||||
"LOG_SUM_EXP", | |||||
"FUSE_ADD_TANH", | |||||
"FAST_TANH_GRAD", | |||||
"FUSE_ADD_SIGMOID", | |||||
"ATAN2", | |||||
"H_SWISH_GRAD", | |||||
"FUSE_ADD_H_SWISH", | |||||
"SILU_GRAD", | |||||
"GELU_GRAD", | |||||
], | |||||
3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], | |||||
} | } | ||||
QINT4_MODES = { | QINT4_MODES = { | ||||
1: ['RELU', 'ABS', 'NEGATE', 'CEIL', 'FLOOR', 'SIGMOID', | |||||
'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'], | |||||
2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0', | |||||
'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH', | |||||
'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'], | |||||
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||||
1: [ | |||||
"RELU", | |||||
"ABS", | |||||
"NEGATE", | |||||
"CEIL", | |||||
"FLOOR", | |||||
"SIGMOID", | |||||
"TANH", | |||||
"FAST_TANH", | |||||
"ROUND", | |||||
"H_SWISH", | |||||
], | |||||
2: [ | |||||
"ADD", | |||||
"MAX", | |||||
"MIN", | |||||
"MUL", | |||||
"SUB", | |||||
"SWITCH_GT0", | |||||
"LT", | |||||
"LEQ", | |||||
"EQ", | |||||
"FUSE_ADD_RELU", | |||||
"FUSE_ADD_TANH", | |||||
"FUSE_ADD_SIGMOID", | |||||
"FUSE_ADD_H_SWISH", | |||||
], | |||||
3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], | |||||
} | } | ||||
QINT32_MODES = { | QINT32_MODES = { | ||||
1: ['RELU', 'SIGMOID', 'TANH', 'FAST_TANH', 'H_SWISH'], | |||||
2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID', | |||||
'FUSE_ADD_TANH', 'FUSE_ADD_H_SWISH'] | |||||
1: ["RELU", "SIGMOID", "TANH", "FAST_TANH", "H_SWISH"], | |||||
2: [ | |||||
"ADD", | |||||
"FUSE_ADD_RELU", | |||||
"FUSE_ADD_SIGMOID", | |||||
"FUSE_ADD_TANH", | |||||
"FUSE_ADD_H_SWISH", | |||||
], | |||||
} | } | ||||
ARRITY1_BOOL_MODES = { | ARRITY1_BOOL_MODES = { | ||||
1: ['ISINF','ISNAN'], | |||||
1: ["ISINF", "ISNAN"], | |||||
} | } | ||||
ARRITY2_BOOL_MODES = { | ARRITY2_BOOL_MODES = { | ||||
2: ['EQ','LEQ','NEQ','LT'], | |||||
2: ["EQ", "LEQ", "NEQ", "LT"], | |||||
} | } |
@@ -1,52 +1,57 @@ | |||||
#!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
import os | |||||
import argparse | import argparse | ||||
import os | |||||
from gen_elemwise_utils import DTYPES | from gen_elemwise_utils import DTYPES | ||||
def main(): | def main(): | ||||
parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
description='generate elemwise impl files', | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||||
parser.add_argument('--type', type=str, choices=[ | |||||
'cuda', | |||||
'hip' | |||||
], | |||||
default='cuda', | |||||
help='generate cuda/hip elemwise special kernel file') | |||||
parser.add_argument('output', help='output directory') | |||||
description="generate elemwise impl files", | |||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||||
) | |||||
parser.add_argument( | |||||
"--type", | |||||
type=str, | |||||
choices=["cuda", "hip"], | |||||
default="cuda", | |||||
help="generate cuda/hip elemwise special kernel file", | |||||
) | |||||
parser.add_argument("output", help="output directory") | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if not os.path.isdir(args.output): | if not os.path.isdir(args.output): | ||||
os.makedirs(args.output) | os.makedirs(args.output) | ||||
if args.type == 'cuda': | |||||
cpp_ext = 'cu' | |||||
if args.type == "cuda": | |||||
cpp_ext = "cu" | |||||
else: | else: | ||||
assert args.type =='hip' | |||||
cpp_ext = 'cpp.hip' | |||||
assert args.type == "hip" | |||||
cpp_ext = "cpp.hip" | |||||
for dtype in DTYPES.keys(): | for dtype in DTYPES.keys(): | ||||
fname = 'special_{}.{}'.format(dtype, cpp_ext) | |||||
fname = "special_{}.{}".format(dtype, cpp_ext) | |||||
fname = os.path.join(args.output, fname) | fname = os.path.join(args.output, fname) | ||||
with open(fname, 'w') as fout: | |||||
with open(fname, "w") as fout: | |||||
w = lambda s: print(s, file=fout) | w = lambda s: print(s, file=fout) | ||||
w('// generated by gen_elemwise_special_kern_impls.py') | |||||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||||
w('#if !MEGDNN_DISABLE_FLOAT16') | |||||
w("// generated by gen_elemwise_special_kern_impls.py") | |||||
if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||||
w("#if !MEGDNN_DISABLE_FLOAT16") | |||||
w('#include "../special_kerns.inl"') | w('#include "../special_kerns.inl"') | ||||
w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||||
w('#undef INST') | |||||
w('}') | |||||
w('}') | |||||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||||
w('#endif') | |||||
w("INST(::megdnn::dtype::{})".format(DTYPES[dtype][0])) | |||||
w("#undef INST") | |||||
w("}") | |||||
w("}") | |||||
if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||||
w("#endif") | |||||
print('generated {}'.format(fname)) | |||||
print("generated {}".format(fname)) | |||||
os.utime(args.output) | os.utime(args.output) | ||||
if __name__ == '__main__': | |||||
if __name__ == "__main__": | |||||
main() | main() |
@@ -1,35 +1,95 @@ | |||||
ARITIES = {1: "UNARY", 2: "BINARY", 3: "TERNARY"} | |||||
ARITIES = {1: 'UNARY', 2: 'BINARY', 3: 'TERNARY'} | |||||
DTYPES = {'dt_int32': ('Int32', 'INT'), | |||||
'dt_uint8': ('Uint8', 'INT'), | |||||
'dt_int8': ('Int8', 'INT'), | |||||
'dt_int16': ('Int16', 'INT'), | |||||
'dt_bool': ('Bool', 'BOOL'), | |||||
'dt_float32': ('Float32', 'FLOAT'), | |||||
'dt_float16': ('Float16', 'FLOAT'), | |||||
'dt_bfloat16': ('BFloat16', 'FLOAT') | |||||
} | |||||
DTYPES = { | |||||
"dt_int32": ("Int32", "INT"), | |||||
"dt_uint8": ("Uint8", "INT"), | |||||
"dt_int8": ("Int8", "INT"), | |||||
"dt_int16": ("Int16", "INT"), | |||||
"dt_bool": ("Bool", "BOOL"), | |||||
"dt_float32": ("Float32", "FLOAT"), | |||||
"dt_float16": ("Float16", "FLOAT"), | |||||
"dt_bfloat16": ("BFloat16", "FLOAT"), | |||||
} | |||||
MODES = { | MODES = { | ||||
(1, 'INT'): ['RELU', 'ABS', 'NEGATE'], | |||||
(2, 'INT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | |||||
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ', | |||||
'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH'], | |||||
(3, 'INT'): ['COND_LEQ_MOV', 'COND_LT_MOV'], | |||||
(1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | |||||
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | |||||
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', | |||||
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'], | |||||
(2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | |||||
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', | |||||
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', | |||||
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | |||||
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | |||||
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | |||||
(3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||||
(1, 'BOOL'): ['NOT'], | |||||
(2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], | |||||
(3, 'BOOL'): [] | |||||
(1, "INT"): ["RELU", "ABS", "NEGATE"], | |||||
(2, "INT"): [ | |||||
"ABS_GRAD", | |||||
"ADD", | |||||
"FLOOR_DIV", | |||||
"MAX", | |||||
"MIN", | |||||
"MOD", | |||||
"MUL", | |||||
"SIGMOID_GRAD", | |||||
"SUB", | |||||
"SWITCH_GT0", | |||||
"TANH_GRAD", | |||||
"LT", | |||||
"LEQ", | |||||
"EQ", | |||||
"FUSE_ADD_RELU", | |||||
"SHL", | |||||
"SHR", | |||||
"RMULH", | |||||
], | |||||
(3, "INT"): ["COND_LEQ_MOV", "COND_LT_MOV"], | |||||
(1, "FLOAT"): [ | |||||
"RELU", | |||||
"ABS", | |||||
"NEGATE", | |||||
"ACOS", | |||||
"ASIN", | |||||
"CEIL", | |||||
"COS", | |||||
"EXP", | |||||
"EXPM1", | |||||
"FLOOR", | |||||
"LOG", | |||||
"LOG1P", | |||||
"SIGMOID", | |||||
"SIN", | |||||
"TANH", | |||||
"FAST_TANH", | |||||
"ROUND", | |||||
"ERF", | |||||
"ERFINV", | |||||
"ERFC", | |||||
"ERFCINV", | |||||
"H_SWISH", | |||||
"SILU", | |||||
"GELU", | |||||
], | |||||
(2, "FLOAT"): [ | |||||
"ABS_GRAD", | |||||
"ADD", | |||||
"FLOOR_DIV", | |||||
"MAX", | |||||
"MIN", | |||||
"MOD", | |||||
"MUL", | |||||
"SIGMOID_GRAD", | |||||
"SUB", | |||||
"SWITCH_GT0", | |||||
"TANH_GRAD", | |||||
"LT", | |||||
"LEQ", | |||||
"EQ", | |||||
"FUSE_ADD_RELU", | |||||
"TRUE_DIV", | |||||
"POW", | |||||
"LOG_SUM_EXP", | |||||
"FUSE_ADD_TANH", | |||||
"FAST_TANH_GRAD", | |||||
"FUSE_ADD_SIGMOID", | |||||
"ATAN2", | |||||
"H_SWISH_GRAD", | |||||
"FUSE_ADD_H_SWISH", | |||||
"SILU_GRAD", | |||||
"GELU_GRAD", | |||||
], | |||||
(3, "FLOAT"): ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], | |||||
(1, "BOOL"): ["NOT"], | |||||
(2, "BOOL"): ["AND", "OR", "XOR", "LT", "LEQ", "EQ"], | |||||
(3, "BOOL"): [], | |||||
} | } |
@@ -3,13 +3,14 @@ | |||||
import argparse | import argparse | ||||
import collections | import collections | ||||
import textwrap | |||||
import os | |||||
import hashlib | import hashlib | ||||
import struct | |||||
import io | import io | ||||
import os | |||||
import struct | |||||
import textwrap | |||||
from gen_param_defs import IndentWriterBase, ParamDef, member_defs | |||||
from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||||
class ConverterWriter(IndentWriterBase): | class ConverterWriter(IndentWriterBase): | ||||
_skip_current_param = False | _skip_current_param = False | ||||
@@ -20,7 +21,7 @@ class ConverterWriter(IndentWriterBase): | |||||
def __call__(self, fout, defs): | def __call__(self, fout, defs): | ||||
super().__call__(fout) | super().__call__(fout) | ||||
self._write("// %s", self._get_header()) | self._write("// %s", self._get_header()) | ||||
self._write('#include <flatbuffers/flatbuffers.h>') | |||||
self._write("#include <flatbuffers/flatbuffers.h>") | |||||
self._write("namespace mgb {") | self._write("namespace mgb {") | ||||
self._write("namespace serialization {") | self._write("namespace serialization {") | ||||
self._write("namespace fbs {") | self._write("namespace fbs {") | ||||
@@ -33,8 +34,9 @@ class ConverterWriter(IndentWriterBase): | |||||
self._last_param = p | self._last_param = p | ||||
self._param_fields = [] | self._param_fields = [] | ||||
self._fb_fields = ["builder"] | self._fb_fields = ["builder"] | ||||
self._write("template<>\nstruct ParamConverter<megdnn::param::%s> {", | |||||
p.name, indent=1) | |||||
self._write( | |||||
"template<>\nstruct ParamConverter<megdnn::param::%s> {", p.name, indent=1 | |||||
) | |||||
self._write("using MegDNNType = megdnn::param::%s;", p.name) | self._write("using MegDNNType = megdnn::param::%s;", p.name) | ||||
self._write("using FlatBufferType = fbs::param::%s;\n", p.name) | self._write("using FlatBufferType = fbs::param::%s;\n", p.name) | ||||
@@ -42,22 +44,22 @@ class ConverterWriter(IndentWriterBase): | |||||
if self._skip_current_param: | if self._skip_current_param: | ||||
self._skip_current_param = False | self._skip_current_param = False | ||||
return | return | ||||
self._write("static MegDNNType to_param(const FlatBufferType* fb) {", | |||||
indent=1) | |||||
line = 'return {' | |||||
line += ', '.join(self._param_fields) | |||||
line += '};' | |||||
self._write("static MegDNNType to_param(const FlatBufferType* fb) {", indent=1) | |||||
line = "return {" | |||||
line += ", ".join(self._param_fields) | |||||
line += "};" | |||||
self._write(line) | self._write(line) | ||||
self._write("}\n", indent=-1) | self._write("}\n", indent=-1) | ||||
self._write( | self._write( | ||||
"static flatbuffers::Offset<FlatBufferType> to_flatbuffer(flatbuffers::FlatBufferBuilder& builder, const MegDNNType& param) {", | "static flatbuffers::Offset<FlatBufferType> to_flatbuffer(flatbuffers::FlatBufferBuilder& builder, const MegDNNType& param) {", | ||||
indent=1) | |||||
line = 'return fbs::param::Create{}('.format(str(p.name)) | |||||
line += ', '.join(self._fb_fields) | |||||
line += ');' | |||||
indent=1, | |||||
) | |||||
line = "return fbs::param::Create{}(".format(str(p.name)) | |||||
line += ", ".join(self._fb_fields) | |||||
line += ");" | |||||
self._write(line) | self._write(line) | ||||
self._write('}', indent=-1) | |||||
self._write("}", indent=-1) | |||||
self._write("};\n", indent=-1) | self._write("};\n", indent=-1) | ||||
@@ -68,18 +70,23 @@ class ConverterWriter(IndentWriterBase): | |||||
return | return | ||||
self._param_fields.append( | self._param_fields.append( | ||||
"static_cast<megdnn::param::{}::{}>(fb->{}())".format( | "static_cast<megdnn::param::{}::{}>(fb->{}())".format( | ||||
str(p.name), str(e.name), e.name_field)) | |||||
self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format( | |||||
key, e.name_field)) | |||||
str(p.name), str(e.name), e.name_field | |||||
) | |||||
) | |||||
self._fb_fields.append( | |||||
"static_cast<fbs::param::{}>(param.{})".format(key, e.name_field) | |||||
) | |||||
def _on_member_field(self, f): | def _on_member_field(self, f): | ||||
if self._skip_current_param: | if self._skip_current_param: | ||||
return | return | ||||
if f.dtype.cname == 'DTypeEnum': | |||||
if f.dtype.cname == "DTypeEnum": | |||||
self._param_fields.append( | self._param_fields.append( | ||||
"intl::convert_dtype_to_megdnn(fb->{}())".format(f.name)) | |||||
"intl::convert_dtype_to_megdnn(fb->{}())".format(f.name) | |||||
) | |||||
self._fb_fields.append( | self._fb_fields.append( | ||||
"intl::convert_dtype_to_fbs(param.{})".format(f.name)) | |||||
"intl::convert_dtype_to_fbs(param.{})".format(f.name) | |||||
) | |||||
else: | else: | ||||
self._param_fields.append("fb->{}()".format(f.name)) | self._param_fields.append("fb->{}()".format(f.name)) | ||||
self._fb_fields.append("param.{}".format(f.name)) | self._fb_fields.append("param.{}".format(f.name)) | ||||
@@ -93,28 +100,33 @@ class ConverterWriter(IndentWriterBase): | |||||
enum_name = e.src_class + e.src_name | enum_name = e.src_class + e.src_name | ||||
self._param_fields.append( | self._param_fields.append( | ||||
"static_cast<megdnn::param::{}::{}>(fb->{}())".format( | "static_cast<megdnn::param::{}::{}>(fb->{}())".format( | ||||
e.src_class, e.src_name, e.name_field)) | |||||
self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format( | |||||
enum_name, e.name_field)) | |||||
e.src_class, e.src_name, e.name_field | |||||
) | |||||
) | |||||
self._fb_fields.append( | |||||
"static_cast<fbs::param::{}>(param.{})".format(enum_name, e.name_field) | |||||
) | |||||
def main(): | def main(): | ||||
parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
'generate convert functions between FlatBuffers type and MegBrain type') | |||||
parser.add_argument('input') | |||||
parser.add_argument('output') | |||||
"generate convert functions between FlatBuffers type and MegBrain type" | |||||
) | |||||
parser.add_argument("input") | |||||
parser.add_argument("output") | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
with open(args.input) as fin: | with open(args.input) as fin: | ||||
inputs = fin.read() | inputs = fin.read() | ||||
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||||
exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) | |||||
input_hash = hashlib.sha256() | input_hash = hashlib.sha256() | ||||
input_hash.update(inputs.encode(encoding='UTF-8')) | |||||
input_hash.update(inputs.encode(encoding="UTF-8")) | |||||
input_hash = input_hash.hexdigest() | input_hash = input_hash.hexdigest() | ||||
writer = ConverterWriter() | writer = ConverterWriter() | ||||
with open(args.output, 'w') as fout: | |||||
with open(args.output, "w") as fout: | |||||
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
main() | main() |
@@ -3,13 +3,14 @@ | |||||
import argparse | import argparse | ||||
import collections | import collections | ||||
import textwrap | |||||
import os | |||||
import hashlib | import hashlib | ||||
import struct | |||||
import io | import io | ||||
import os | |||||
import struct | |||||
import textwrap | |||||
from gen_param_defs import IndentWriterBase, ParamDef, member_defs | |||||
from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||||
def _cname_to_fbname(cname): | def _cname_to_fbname(cname): | ||||
return { | return { | ||||
@@ -22,17 +23,19 @@ def _cname_to_fbname(cname): | |||||
"bool": "bool", | "bool": "bool", | ||||
}[cname] | }[cname] | ||||
def scramble_enum_member_name(name): | def scramble_enum_member_name(name): | ||||
s = name.find('<<') | |||||
s = name.find("<<") | |||||
if s != -1: | if s != -1: | ||||
name = name[0:name.find('=') + 1] + ' ' + name[s+2:] | |||||
name = name[0 : name.find("=") + 1] + " " + name[s + 2 :] | |||||
if name in ("MIN", "MAX"): | if name in ("MIN", "MAX"): | ||||
return name + "_" | return name + "_" | ||||
o_name = name.split(' ')[0].split('=')[0] | |||||
o_name = name.split(" ")[0].split("=")[0] | |||||
if o_name in ("MIN", "MAX"): | if o_name in ("MIN", "MAX"): | ||||
return name.replace(o_name, o_name + "_") | return name.replace(o_name, o_name + "_") | ||||
return name | return name | ||||
class FlatBuffersWriter(IndentWriterBase): | class FlatBuffersWriter(IndentWriterBase): | ||||
_skip_current_param = False | _skip_current_param = False | ||||
_last_param = None | _last_param = None | ||||
@@ -66,12 +69,13 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
self._write("}\n", indent=-1) | self._write("}\n", indent=-1) | ||||
def _write_doc(self, doc): | def _write_doc(self, doc): | ||||
if not isinstance(doc, member_defs.Doc) or not doc.doc: return | |||||
if not isinstance(doc, member_defs.Doc) or not doc.doc: | |||||
return | |||||
doc_lines = [] | doc_lines = [] | ||||
if doc.no_reformat: | if doc.no_reformat: | ||||
doc_lines = doc.raw_lines | doc_lines = doc.raw_lines | ||||
else: | else: | ||||
doc = doc.doc.replace('\n', ' ') | |||||
doc = doc.doc.replace("\n", " ") | |||||
text_width = 80 - len(self._cur_indent) - 4 | text_width = 80 - len(self._cur_indent) - 4 | ||||
doc_lines = textwrap.wrap(doc, text_width) | doc_lines = textwrap.wrap(doc, text_width) | ||||
for line in doc_lines: | for line in doc_lines: | ||||
@@ -101,7 +105,8 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
default = e.compose_combined_enum(e.default) | default = e.compose_combined_enum(e.default) | ||||
else: | else: | ||||
default = scramble_enum_member_name( | default = scramble_enum_member_name( | ||||
str(e.members[e.default]).split(' ')[0].split('=')[0]) | |||||
str(e.members[e.default]).split(" ")[0].split("=")[0] | |||||
) | |||||
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default) | self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default) | ||||
def _resolve_const(self, v): | def _resolve_const(self, v): | ||||
@@ -113,8 +118,12 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
if self._skip_current_param: | if self._skip_current_param: | ||||
return | return | ||||
self._write_doc(f.name) | self._write_doc(f.name) | ||||
self._write("%s:%s = %s;", f.name, _cname_to_fbname(f.dtype.cname), | |||||
self._get_fb_default(self._resolve_const(f.default))) | |||||
self._write( | |||||
"%s:%s = %s;", | |||||
f.name, | |||||
_cname_to_fbname(f.dtype.cname), | |||||
self._get_fb_default(self._resolve_const(f.default)), | |||||
) | |||||
def _on_const_field(self, f): | def _on_const_field(self, f): | ||||
self._cur_const_val[str(f.name)] = str(f.default) | self._cur_const_val[str(f.name)] = str(f.default) | ||||
@@ -129,7 +138,8 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
default = s.compose_combined_enum(e.get_default()) | default = s.compose_combined_enum(e.get_default()) | ||||
else: | else: | ||||
default = scramble_enum_member_name( | default = scramble_enum_member_name( | ||||
str(s.members[e.get_default()]).split(' ')[0].split('=')[0]) | |||||
str(s.members[e.get_default()]).split(" ")[0].split("=")[0] | |||||
) | |||||
self._write("%s:%s = %s;", e.name_field, enum_name, default) | self._write("%s:%s = %s;", e.name_field, enum_name, default) | ||||
def _get_fb_default(self, cppdefault): | def _get_fb_default(self, cppdefault): | ||||
@@ -137,9 +147,9 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
return cppdefault | return cppdefault | ||||
d = cppdefault | d = cppdefault | ||||
if d.endswith('f'): # 1.f | |||||
if d.endswith("f"): # 1.f | |||||
return d[:-1] | return d[:-1] | ||||
if d.endswith('ull'): | |||||
if d.endswith("ull"): | |||||
return d[:-3] | return d[:-3] | ||||
if d.startswith("DTypeEnum::"): | if d.startswith("DTypeEnum::"): | ||||
return d[11:] | return d[11:] | ||||
@@ -148,21 +158,23 @@ class FlatBuffersWriter(IndentWriterBase): | |||||
def main(): | def main(): | ||||
parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
'generate FlatBuffers schema of operator param from description file') | |||||
parser.add_argument('input') | |||||
parser.add_argument('output') | |||||
"generate FlatBuffers schema of operator param from description file" | |||||
) | |||||
parser.add_argument("input") | |||||
parser.add_argument("output") | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
with open(args.input) as fin: | with open(args.input) as fin: | ||||
inputs = fin.read() | inputs = fin.read() | ||||
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||||
exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) | |||||
input_hash = hashlib.sha256() | input_hash = hashlib.sha256() | ||||
input_hash.update(inputs.encode(encoding='UTF-8')) | |||||
input_hash.update(inputs.encode(encoding="UTF-8")) | |||||
input_hash = input_hash.hexdigest() | input_hash = input_hash.hexdigest() | ||||
writer = FlatBuffersWriter() | writer = FlatBuffersWriter() | ||||
with open(args.output, 'w') as fout: | |||||
with open(args.output, "w") as fout: | |||||
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
main() | main() |
@@ -1,14 +1,16 @@ | |||||
#! /usr/local/env python3 | #! /usr/local/env python3 | ||||
import pickle | |||||
import numpy as np | |||||
import os | |||||
import argparse | import argparse | ||||
import re | |||||
import collections | import collections | ||||
import os | |||||
import pickle | |||||
import re | |||||
import numpy as np | |||||
def define_template(**kwargs): | def define_template(**kwargs): | ||||
template = ''' | |||||
template = """ | |||||
float cuda{cuda_arch}_{conv_type}_time_pred[{out_dim}] = {{0.0f}}; | float cuda{cuda_arch}_{conv_type}_time_pred[{out_dim}] = {{0.0f}}; | ||||
float cuda{cuda_arch}_{conv_type}_mask[{out_dim}] = {{0.0f}}; | float cuda{cuda_arch}_{conv_type}_mask[{out_dim}] = {{0.0f}}; | ||||
float cuda{cuda_arch}_{conv_type}_hidden_units[{hidden_num}] = {{0.0f}}; | float cuda{cuda_arch}_{conv_type}_hidden_units[{hidden_num}] = {{0.0f}}; | ||||
@@ -17,21 +19,23 @@ def define_template(**kwargs): | |||||
const static float cuda{cuda_arch}_{conv_type}_biases[{biases_dim}] = {{{biases}}}; | const static float cuda{cuda_arch}_{conv_type}_biases[{biases_dim}] = {{{biases}}}; | ||||
const static float cuda{cuda_arch}_{conv_type}_alpha[{out_dim}] = {{{alpha}}}; | const static float cuda{cuda_arch}_{conv_type}_alpha[{out_dim}] = {{{alpha}}}; | ||||
const static float cuda{cuda_arch}_{conv_type}_beta[{out_dim}] = {{{beta}}}; | const static float cuda{cuda_arch}_{conv_type}_beta[{out_dim}] = {{{beta}}}; | ||||
''' | |||||
""" | |||||
return template.format(**kwargs) | return template.format(**kwargs) | ||||
def cudnn_slt_template(**kwargs): | def cudnn_slt_template(**kwargs): | ||||
template = ("#if CUDNN_MAJOR == {cudnn_major} && CUDNN_MINOR == {cudnn_minor}\n" + | |||||
" {define_cmd}\n" + | |||||
" {select_cmd}\n" + | |||||
" return true;\n" + | |||||
"#endif\n" | |||||
) | |||||
template = ( | |||||
"#if CUDNN_MAJOR == {cudnn_major} && CUDNN_MINOR == {cudnn_minor}\n" | |||||
+ " {define_cmd}\n" | |||||
+ " {select_cmd}\n" | |||||
+ " return true;\n" | |||||
+ "#endif\n" | |||||
) | |||||
return template.format(**kwargs) | return template.format(**kwargs) | ||||
def select_template(**kwargs): | def select_template(**kwargs): | ||||
template = \ | |||||
'''if (conv_type == ConvolutionType::{conv_type} && cuda_major == {cuda_major} && | |||||
template = """if (conv_type == ConvolutionType::{conv_type} && cuda_major == {cuda_major} && | |||||
cuda_minor == {cuda_minor}) {{ | cuda_minor == {cuda_minor}) {{ | ||||
*layer_num_p = {layer_num}; | *layer_num_p = {layer_num}; | ||||
*hidden_units_p = cuda{cuda_arch}_{conv_type}_hidden_units; | *hidden_units_p = cuda{cuda_arch}_{conv_type}_hidden_units; | ||||
@@ -42,7 +46,7 @@ def select_template(**kwargs): | |||||
*beta_p = cuda{cuda_arch}_{conv_type}_beta; | *beta_p = cuda{cuda_arch}_{conv_type}_beta; | ||||
*time_pred_p = cuda{cuda_arch}_{conv_type}_time_pred; | *time_pred_p = cuda{cuda_arch}_{conv_type}_time_pred; | ||||
*mask_p = cuda{cuda_arch}_{conv_type}_mask; | *mask_p = cuda{cuda_arch}_{conv_type}_mask; | ||||
}} else ''' | |||||
}} else """ | |||||
return template.format(**kwargs) | return template.format(**kwargs) | ||||
@@ -58,48 +62,48 @@ def fill_src(): | |||||
if len(matrix_files) == 0: | if len(matrix_files) == 0: | ||||
print("Warning: no param files detected.") | print("Warning: no param files detected.") | ||||
for fpath in matrix_files: | for fpath in matrix_files: | ||||
cudnn_version = re.findall('cudnn([\d.]+)',fpath)[0] | |||||
cudnn_version = re.findall("cudnn([\d.]+)", fpath)[0] | |||||
gen_list[cudnn_version].append(fpath) | gen_list[cudnn_version].append(fpath) | ||||
for cudnn in gen_list: | for cudnn in gen_list: | ||||
select_cmd = ("{\n" + | |||||
" " * 8 + "return false;\n" + | |||||
" " * 4 + "}") | |||||
select_cmd = "{\n" + " " * 8 + "return false;\n" + " " * 4 + "}" | |||||
define_cmd = "" | define_cmd = "" | ||||
cudnn_major, cudnn_minor = cudnn.split('.') | |||||
cudnn_major, cudnn_minor = cudnn.split(".") | |||||
for fpath in gen_list[cudnn]: | for fpath in gen_list[cudnn]: | ||||
cuda_arch = fpath.split("-")[1].replace(".", "_") | cuda_arch = fpath.split("-")[1].replace(".", "_") | ||||
print('cudnn_version: {}, cuda_arch: {}'.format(cudnn,cuda_arch)) | |||||
print("cudnn_version: {}, cuda_arch: {}".format(cudnn, cuda_arch)) | |||||
conv_type = fpath.split("-")[2].split(".")[0] | conv_type = fpath.split("-")[2].split(".")[0] | ||||
with open(os.path.join(home, "params/{}".format(fpath)), "rb") as pobj: | with open(os.path.join(home, "params/{}".format(fpath)), "rb") as pobj: | ||||
params = pickle.load(pobj) | params = pickle.load(pobj) | ||||
crt_define_cmd, crt_select_cmd = gen_cmds( | |||||
cuda_arch, conv_type, params) | |||||
crt_define_cmd, crt_select_cmd = gen_cmds(cuda_arch, conv_type, params) | |||||
select_cmd = crt_select_cmd + select_cmd | select_cmd = crt_select_cmd + select_cmd | ||||
define_cmd = crt_define_cmd + define_cmd | define_cmd = crt_define_cmd + define_cmd | ||||
cudnn_slt_cmd += cudnn_slt_template(cudnn_major=cudnn_major, | |||||
cudnn_minor=cudnn_minor, | |||||
select_cmd=select_cmd, | |||||
define_cmd=define_cmd) | |||||
cudnn_slt_cmd += cudnn_slt_template( | |||||
cudnn_major=cudnn_major, | |||||
cudnn_minor=cudnn_minor, | |||||
select_cmd=select_cmd, | |||||
define_cmd=define_cmd, | |||||
) | |||||
#select_cmd = select_cmd | |||||
# select_cmd = select_cmd | |||||
with open(os.path.join(home, "get_params.template"), "r") as srcf: | with open(os.path.join(home, "get_params.template"), "r") as srcf: | ||||
src = srcf.read() | src = srcf.read() | ||||
dst = src.replace("{cudnn_select}", cudnn_slt_cmd) | dst = src.replace("{cudnn_select}", cudnn_slt_cmd) | ||||
MegDNN_path = os.path.join(home, "../..") | MegDNN_path = os.path.join(home, "../..") | ||||
with open(os.path.join(MegDNN_path, | |||||
"src/cuda/convolution/get_params.cpp"), "w") as dstf: | |||||
with open( | |||||
os.path.join(MegDNN_path, "src/cuda/convolution/get_params.cpp"), "w" | |||||
) as dstf: | |||||
dstf.write(dst) | dstf.write(dst) | ||||
def gen_cmds(cuda_arch, conv_type, params): | def gen_cmds(cuda_arch, conv_type, params): | ||||
cuda_major, cuda_minor = cuda_arch.split("_") | cuda_major, cuda_minor = cuda_arch.split("_") | ||||
alphastr = format_array(params['alpha']).rstrip()[:-1] | |||||
betastr = format_array(params['beta']).rstrip()[:-1] | |||||
W_list = params['W'] | |||||
b_list = params['b'] | |||||
Wstr = '' | |||||
bstr = '' | |||||
alphastr = format_array(params["alpha"]).rstrip()[:-1] | |||||
betastr = format_array(params["beta"]).rstrip()[:-1] | |||||
W_list = params["W"] | |||||
b_list = params["b"] | |||||
Wstr = "" | |||||
bstr = "" | |||||
layer_num = str(len(b_list) + 1) | layer_num = str(len(b_list) + 1) | ||||
layers_dim = [W_list[0].shape[1]] | layers_dim = [W_list[0].shape[1]] | ||||
matrices_dim = 0 | matrices_dim = 0 | ||||
@@ -118,16 +122,27 @@ def gen_cmds(cuda_arch, conv_type, params): | |||||
out_dim = layers_dim[-1] | out_dim = layers_dim[-1] | ||||
layers_dim_str = format_array(np.array(layers_dim)).rstrip()[:-1] | layers_dim_str = format_array(np.array(layers_dim)).rstrip()[:-1] | ||||
select_cmd = select_template(conv_type=conv_type.upper(), cuda_major=cuda_major, | |||||
cuda_minor=cuda_minor, layer_num=layer_num, | |||||
cuda_arch=cuda_arch) | |||||
define_cmd = define_template(cuda_arch=cuda_arch, conv_type=conv_type.upper(), | |||||
hidden_num=hidden_num, | |||||
layer_num=layer_num, out_dim=out_dim, | |||||
layers_dim=layers_dim_str, | |||||
matrices_dim=matrices_dim, matrices=Wstr, | |||||
biases_dim=biases_dim, biases=bstr, | |||||
alpha=alphastr, beta=betastr) | |||||
select_cmd = select_template( | |||||
conv_type=conv_type.upper(), | |||||
cuda_major=cuda_major, | |||||
cuda_minor=cuda_minor, | |||||
layer_num=layer_num, | |||||
cuda_arch=cuda_arch, | |||||
) | |||||
define_cmd = define_template( | |||||
cuda_arch=cuda_arch, | |||||
conv_type=conv_type.upper(), | |||||
hidden_num=hidden_num, | |||||
layer_num=layer_num, | |||||
out_dim=out_dim, | |||||
layers_dim=layers_dim_str, | |||||
matrices_dim=matrices_dim, | |||||
matrices=Wstr, | |||||
biases_dim=biases_dim, | |||||
biases=bstr, | |||||
alpha=alphastr, | |||||
beta=betastr, | |||||
) | |||||
return (define_cmd, select_cmd) | return (define_cmd, select_cmd) | ||||
@@ -153,8 +168,9 @@ def format_array(array): | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
description="Generate cuDNN heuristic code by neural network into" | description="Generate cuDNN heuristic code by neural network into" | ||||
" {MEGDNN_ROOT}/src/cuda/convolution/get_params.cpp," | |||||
" using parameter value from pickle files in" | |||||
" {MEGDNN_ROOT}/scripts/gen_heuristic/params/") | |||||
" {MEGDNN_ROOT}/src/cuda/convolution/get_params.cpp," | |||||
" using parameter value from pickle files in" | |||||
" {MEGDNN_ROOT}/scripts/gen_heuristic/params/" | |||||
) | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
main() | main() |
@@ -3,19 +3,17 @@ | |||||
import argparse | import argparse | ||||
import collections | import collections | ||||
import textwrap | |||||
import os | |||||
import hashlib | import hashlib | ||||
import struct | |||||
import io | import io | ||||
import os | |||||
import struct | |||||
import textwrap | |||||
from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||||
from gen_param_defs import IndentWriterBase, ParamDef, member_defs | |||||
# FIXME: move supportToString flag definition into the param def source file | # FIXME: move supportToString flag definition into the param def source file | ||||
ENUM_TO_STRING_SPECIAL_RULES = [ | |||||
("Elemwise", "Mode"), | |||||
("ElemwiseMultiType", "Mode") | |||||
] | |||||
ENUM_TO_STRING_SPECIAL_RULES = [("Elemwise", "Mode"), ("ElemwiseMultiType", "Mode")] | |||||
class ConverterWriter(IndentWriterBase): | class ConverterWriter(IndentWriterBase): | ||||
_skip_current_param = False | _skip_current_param = False | ||||
@@ -33,21 +31,21 @@ class ConverterWriter(IndentWriterBase): | |||||
self._write("#endif // MGB_PARAM") | self._write("#endif // MGB_PARAM") | ||||
def _ctype2attr(self, ctype, value): | def _ctype2attr(self, ctype, value): | ||||
if ctype == 'uint32_t': | |||||
return 'MgbUI32Attr', value | |||||
if ctype == 'uint64_t': | |||||
return 'MgbUI64Attr', value | |||||
if ctype == 'int32_t': | |||||
return 'MgbI32Attr', value | |||||
if ctype == 'float': | |||||
return 'MgbF32Attr', value | |||||
if ctype == 'double': | |||||
return 'MgbF64Attr', value | |||||
if ctype == 'bool': | |||||
return 'MgbBoolAttr', value | |||||
if ctype == 'DTypeEnum': | |||||
if ctype == "uint32_t": | |||||
return "MgbUI32Attr", value | |||||
if ctype == "uint64_t": | |||||
return "MgbUI64Attr", value | |||||
if ctype == "int32_t": | |||||
return "MgbI32Attr", value | |||||
if ctype == "float": | |||||
return "MgbF32Attr", value | |||||
if ctype == "double": | |||||
return "MgbF64Attr", value | |||||
if ctype == "bool": | |||||
return "MgbBoolAttr", value | |||||
if ctype == "DTypeEnum": | |||||
self._packed = False | self._packed = False | ||||
return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value) | |||||
return "MgbDTypeAttr", "megdnn::DType::from_enum(megdnn::{})".format(value) | |||||
raise RuntimeError("unknown ctype") | raise RuntimeError("unknown ctype") | ||||
def _on_param_begin(self, p): | def _on_param_begin(self, p): | ||||
@@ -61,21 +59,26 @@ class ConverterWriter(IndentWriterBase): | |||||
self._skip_current_param = False | self._skip_current_param = False | ||||
return | return | ||||
if self._packed: | if self._packed: | ||||
self._write("class {0}ParamBase<string accessor> : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1) | |||||
self._write( | |||||
'class {0}ParamBase<string accessor> : MgbPackedParamBase<"{0}", accessor> {{'.format( | |||||
p.name | |||||
), | |||||
indent=1, | |||||
) | |||||
else: | else: | ||||
self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1) | |||||
self._write('def {0}Param: MgbParamBase<"{0}"> {{'.format(p.name), indent=1) | |||||
self._write("let fields = (ins", indent=1) | self._write("let fields = (ins", indent=1) | ||||
self._write(",\n{}".format(self._cur_indent).join(self._current_tparams)) | self._write(",\n{}".format(self._cur_indent).join(self._current_tparams)) | ||||
self._write(");", indent=-1) | self._write(");", indent=-1) | ||||
self._write("}\n", indent=-1) | self._write("}\n", indent=-1) | ||||
if self._packed: | if self._packed: | ||||
self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name)) | |||||
self._write('def {0}Param : {0}ParamBase<"param">;\n'.format(p.name)) | |||||
self._current_tparams = None | self._current_tparams = None | ||||
self._packed = None | self._packed = None | ||||
self._const = None | self._const = None | ||||
def _wrapped_with_default_value(self, attr, default): | def _wrapped_with_default_value(self, attr, default): | ||||
return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default) | |||||
return 'MgbDefaultValuedAttr<{}, "{}">'.format(attr, default) | |||||
def _on_member_enum(self, e): | def _on_member_enum(self, e): | ||||
p = self._last_param | p = self._last_param | ||||
@@ -84,10 +87,12 @@ class ConverterWriter(IndentWriterBase): | |||||
# directly used by any operator, or other enum couldn't alias to this enum | # directly used by any operator, or other enum couldn't alias to this enum | ||||
td_class = "{}{}".format(p.name, e.name) | td_class = "{}{}".format(p.name, e.name) | ||||
fullname = "::megdnn::param::{}".format(p.name) | fullname = "::megdnn::param::{}".format(p.name) | ||||
enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name) | |||||
enum_def = 'MgbEnumAttr<"{}", "{}", ['.format(fullname, e.name) | |||||
def format(v): | def format(v): | ||||
return '\"{}\"'.format(str(v).split(' ')[0].split('=')[0]) | |||||
enum_def += ','.join(format(i) for i in e.members) | |||||
return '"{}"'.format(str(v).split(" ")[0].split("=")[0]) | |||||
enum_def += ",".join(format(i) for i in e.members) | |||||
if e.combined: | if e.combined: | ||||
enum_def += "], 1" | enum_def += "], 1" | ||||
@@ -95,7 +100,7 @@ class ConverterWriter(IndentWriterBase): | |||||
enum_def += "], 0" | enum_def += "], 0" | ||||
if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): | if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): | ||||
enum_def += ", 1" # whether generate ToStringTrait | |||||
enum_def += ", 1" # whether generate ToStringTrait | |||||
enum_def += ">" | enum_def += ">" | ||||
self._write("def {} : {};".format(td_class, enum_def)) | self._write("def {} : {};".format(td_class, enum_def)) | ||||
@@ -105,10 +110,12 @@ class ConverterWriter(IndentWriterBase): | |||||
# wrapped with default value | # wrapped with default value | ||||
if e.combined: | if e.combined: | ||||
default_val = "static_cast<{}::{}>({})".format( | default_val = "static_cast<{}::{}>({})".format( | ||||
fullname, e.name, e.compose_combined_enum(e.default)) | |||||
fullname, e.name, e.compose_combined_enum(e.default) | |||||
) | |||||
else: | else: | ||||
default_val = "{}::{}::{}".format( | default_val = "{}::{}::{}".format( | ||||
fullname, e.name, str(e.members[e.default]).split(' ')[0].split('=')[0]) | |||||
fullname, e.name, str(e.members[e.default]).split(" ")[0].split("=")[0] | |||||
) | |||||
wrapped = self._wrapped_with_default_value(td_class, default_val) | wrapped = self._wrapped_with_default_value(td_class, default_val) | ||||
@@ -123,51 +130,58 @@ class ConverterWriter(IndentWriterBase): | |||||
td_class = "{}{}".format(p.name, e.name) | td_class = "{}{}".format(p.name, e.name) | ||||
fullname = "::megdnn::param::{}".format(p.name) | fullname = "::megdnn::param::{}".format(p.name) | ||||
base_td_class = "{}{}".format(e.src_class, e.src_name) | base_td_class = "{}{}".format(e.src_class, e.src_name) | ||||
enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class) | |||||
enum_def = 'MgbEnumAliasAttr<"{}", "{}", {}>'.format( | |||||
fullname, e.name, base_td_class | |||||
) | |||||
self._write("def {} : {};".format(td_class, enum_def)) | self._write("def {} : {};".format(td_class, enum_def)) | ||||
# wrapped with default value | # wrapped with default value | ||||
s = e.src_enum | s = e.src_enum | ||||
if s.combined: | if s.combined: | ||||
default_val = "static_cast<{}::{}>({})".format( | default_val = "static_cast<{}::{}>({})".format( | ||||
fullname, e.name, s.compose_combined_enum(e.get_default())) | |||||
fullname, e.name, s.compose_combined_enum(e.get_default()) | |||||
) | |||||
else: | else: | ||||
default_val = "{}::{}::{}".format(fullname, e.name, str( | |||||
s.members[e.get_default()]).split(' ')[0].split('=')[0]) | |||||
default_val = "{}::{}::{}".format( | |||||
fullname, | |||||
e.name, | |||||
str(s.members[e.get_default()]).split(" ")[0].split("=")[0], | |||||
) | |||||
wrapped = self._wrapped_with_default_value(td_class, default_val) | wrapped = self._wrapped_with_default_value(td_class, default_val) | ||||
self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | ||||
def _on_member_field(self, f): | def _on_member_field(self, f): | ||||
if self._skip_current_param: | if self._skip_current_param: | ||||
return | return | ||||
attr, value = self._ctype2attr(f.dtype.cname, str(f.default)) | attr, value = self._ctype2attr(f.dtype.cname, str(f.default)) | ||||
if str(value) in self._const: | if str(value) in self._const: | ||||
value = '::megdnn::param::{}::{}'.format(self._last_param.name, value) | |||||
value = "::megdnn::param::{}::{}".format(self._last_param.name, value) | |||||
wrapped = self._wrapped_with_default_value(attr, value) | wrapped = self._wrapped_with_default_value(attr, value) | ||||
self._current_tparams.append("{}:${}".format(wrapped, f.name)) | self._current_tparams.append("{}:${}".format(wrapped, f.name)) | ||||
def _on_const_field(self, f): | def _on_const_field(self, f): | ||||
self._const.add(str(f.name)) | self._const.add(str(f.name)) | ||||
def main(): | def main(): | ||||
parser = argparse.ArgumentParser('generate op param tablegen file') | |||||
parser.add_argument('input') | |||||
parser.add_argument('output') | |||||
parser = argparse.ArgumentParser("generate op param tablegen file") | |||||
parser.add_argument("input") | |||||
parser.add_argument("output") | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
with open(args.input) as fin: | with open(args.input) as fin: | ||||
inputs = fin.read() | inputs = fin.read() | ||||
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||||
exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) | |||||
input_hash = hashlib.sha256() | input_hash = hashlib.sha256() | ||||
input_hash.update(inputs.encode(encoding='UTF-8')) | |||||
input_hash.update(inputs.encode(encoding="UTF-8")) | |||||
input_hash = input_hash.hexdigest() | input_hash = input_hash.hexdigest() | ||||
writer = ConverterWriter() | writer = ConverterWriter() | ||||
with open(args.output, 'w') as fout: | |||||
with open(args.output, "w") as fout: | |||||
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
main() | main() |
@@ -19,6 +19,7 @@ device = { | |||||
"thread_number": 3, | "thread_number": 3, | ||||
} | } | ||||
class SshConnector: | class SshConnector: | ||||
"""imp ssh control master connector""" | """imp ssh control master connector""" | ||||
@@ -83,17 +84,17 @@ def main(): | |||||
model_file = args.model_file | model_file = args.model_file | ||||
# copy model file | # copy model file | ||||
ssh.copy([args.model_file], workspace) | ssh.copy([args.model_file], workspace) | ||||
m = model_file.split('\\')[-1] | |||||
m = model_file.split("\\")[-1] | |||||
# run single thread | # run single thread | ||||
result = [] | result = [] | ||||
thread_number = [1, 2, 4] | thread_number = [1, 2, 4] | ||||
for b in thread_number : | |||||
for b in thread_number: | |||||
cmd = [] | cmd = [] | ||||
cmd1 = "cd {} && ./load_and_run {} -multithread {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format( | cmd1 = "cd {} && ./load_and_run {} -multithread {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format( | ||||
workspace, m, b | |||||
workspace, m, b | |||||
) | ) | ||||
cmd2 = "cd {} && ./load_and_run {} -multithread {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess ".format( | cmd2 = "cd {} && ./load_and_run {} -multithread {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess ".format( | ||||
workspace, m, b | |||||
workspace, m, b | |||||
) | ) | ||||
cmd.append(cmd1) | cmd.append(cmd1) | ||||
cmd.append(cmd2) | cmd.append(cmd2) | ||||
@@ -103,12 +104,20 @@ def main(): | |||||
logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret)) | logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret)) | ||||
result.append(ret) | result.append(ret) | ||||
thread_2 = result[0]/result[1] | |||||
thread_4 = result[0]/result[2] | |||||
thread_2 = result[0] / result[1] | |||||
thread_4 = result[0] / result[2] | |||||
if thread_2 > 1.6 or thread_4 > 3.0: | if thread_2 > 1.6 or thread_4 > 3.0: | ||||
print("model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4)) | |||||
print( | |||||
"model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format( | |||||
m, thread_2, thread_4 | |||||
) | |||||
) | |||||
else: | else: | ||||
print("model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4)) | |||||
print( | |||||
"model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format( | |||||
m, thread_2, thread_4 | |||||
) | |||||
) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
@@ -20,8 +20,12 @@ failed_files = Manager().list() | |||||
def process_file(file, clang_format, write): | def process_file(file, clang_format, write): | ||||
original_source = open(file, "r").read() | original_source = open(file, "r").read() | ||||
source = original_source | source = original_source | ||||
source = re.sub(r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source) | |||||
source, count = re.subn(r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source) | |||||
source = re.sub( | |||||
r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source | |||||
) | |||||
source, count = re.subn( | |||||
r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source | |||||
) | |||||
result = subprocess.check_output( | result = subprocess.check_output( | ||||
[ | [ | ||||
@@ -36,7 +40,9 @@ def process_file(file, clang_format, write): | |||||
result = result.decode("utf-8") | result = result.decode("utf-8") | ||||
if count: | if count: | ||||
result = re.sub(r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result) | |||||
result = re.sub( | |||||
r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result | |||||
) | |||||
result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result) | result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result) | ||||
if write and original_source != result: | if write and original_source != result: | ||||
@@ -109,19 +115,17 @@ def main(): | |||||
raise ValueError("Invalid path {}".format(path)) | raise ValueError("Invalid path {}".format(path)) | ||||
# check version, we only support 12.0.1 now | # check version, we only support 12.0.1 now | ||||
version = subprocess.check_output( | |||||
[ | |||||
args.clang_format, | |||||
"--version", | |||||
], | |||||
) | |||||
version = subprocess.check_output([args.clang_format, "--version",],) | |||||
version = version.decode("utf-8") | version = version.decode("utf-8") | ||||
need_version = '12.0.1' | |||||
need_version = "12.0.1" | |||||
if version.find(need_version) < 0: | if version.find(need_version) < 0: | ||||
print('We only support {} now, please install {} version, find version: {}' | |||||
.format(need_version, need_version, version)) | |||||
raise RuntimeError('clang-format version not equal {}'.format(need_version)) | |||||
print( | |||||
"We only support {} now, please install {} version, find version: {}".format( | |||||
need_version, need_version, version | |||||
) | |||||
) | |||||
raise RuntimeError("clang-format version not equal {}".format(need_version)) | |||||
process_map( | process_map( | ||||
partial(process_file, clang_format=args.clang_format, write=args.write,), | partial(process_file, clang_format=args.clang_format, write=args.write,), | ||||
@@ -20,6 +20,7 @@ device = { | |||||
"thread_number": 3, | "thread_number": 3, | ||||
} | } | ||||
class SshConnector: | class SshConnector: | ||||
"""imp ssh control master connector""" | """imp ssh control master connector""" | ||||
@@ -54,6 +55,7 @@ class SshConnector: | |||||
except: | except: | ||||
raise | raise | ||||
def main(): | def main(): | ||||
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) | ||||
parser.add_argument("--model_file", help="megengine model", required=True) | parser.add_argument("--model_file", help="megengine model", required=True) | ||||
@@ -78,10 +80,10 @@ def main(): | |||||
model_file = args.model_file | model_file = args.model_file | ||||
# copy model file | # copy model file | ||||
ssh.copy([model_file], workspace) | ssh.copy([model_file], workspace) | ||||
m = model_file.split('\\')[-1] | |||||
m = model_file.split("\\")[-1] | |||||
# run single thread | # run single thread | ||||
cmd = "cd {} && ./load_and_run {} --fast-run --record-comp-seq --iter 1 --warmup-iter 1".format( | cmd = "cd {} && ./load_and_run {} --fast-run --record-comp-seq --iter 1 --warmup-iter 1".format( | ||||
workspace, m | |||||
workspace, m | |||||
) | ) | ||||
try: | try: | ||||
raw_log = ssh.cmd([cmd]) | raw_log = ssh.cmd([cmd]) | ||||
@@ -91,6 +93,7 @@ def main(): | |||||
print("model: {} is static model.".format(m)) | print("model: {} is static model.".format(m)) | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" | LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" | ||||
DATE_FORMAT = "%Y/%m/%d %H:%M:%S" | DATE_FORMAT = "%Y/%m/%d %H:%M:%S" | ||||