GitOrigin-RevId: 5684e5ea43
dev-support-lite-fork-debug-mode
@@ -8,7 +8,7 @@ | |||
import enum | |||
import os.path | |||
import shutil | |||
from typing import Tuple, List | |||
from typing import List, Tuple | |||
from library import * | |||
@@ -5,14 +5,13 @@ | |||
# | |||
import enum | |||
import os.path | |||
import shutil | |||
import functools | |||
import operator | |||
import os.path | |||
import shutil | |||
from library import * | |||
################################################################################################### | |||
# | |||
# Data structure modeling a GEMM operation | |||
@@ -1,11 +1,11 @@ | |||
from generator import ( | |||
GenerateGemmOperations, | |||
GenerateGemvOperations, | |||
from generator import ( # isort: skip; isort: skip | |||
GenerateConv2dOperations, | |||
GenerateDeconvOperations, | |||
GenerateDwconv2dFpropOperations, | |||
GenerateDwconv2dDgradOperations, | |||
GenerateDwconv2dFpropOperations, | |||
GenerateDwconv2dWgradOperations, | |||
GenerateGemmOperations, | |||
GenerateGemvOperations, | |||
) | |||
@@ -35,12 +35,14 @@ def write_op_list(f, gen_op, gen_type): | |||
if gen_op != "gemv": | |||
f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type)) | |||
# Write down a list of merged filenames | |||
def write_merge_file_name(f, gen_op, gen_type, split_number): | |||
for i in range(0, split_number): | |||
f.write(' "{}_{}_{}.cu",\n'.format(gen_op,gen_type,i)) | |||
f.write(' "{}_{}_{}.cu",\n'.format(gen_op, gen_type, i)) | |||
if gen_op != "gemv": | |||
f.write(' "all_{}_{}_operations.cu",\n'.format(gen_op,gen_type)) | |||
f.write(' "all_{}_{}_operations.cu",\n'.format(gen_op, gen_type)) | |||
if __name__ == "__main__": | |||
with open("list.bzl", "w") as f: | |||
@@ -4,12 +4,12 @@ | |||
# \brief Generates the CUTLASS Library's instances | |||
# | |||
import argparse | |||
import enum | |||
import os.path | |||
import shutil | |||
import argparse | |||
import platform | |||
import string | |||
from library import * | |||
from manifest import * | |||
@@ -899,9 +899,12 @@ def GenerateGemm_Simt(args): | |||
warpShapes.append([warp0, warp1]) | |||
# sgemm | |||
precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[ | |||
"s" | |||
] | |||
( | |||
precisionType, | |||
precisionBits, | |||
threadblockMaxElements, | |||
threadblockTilesL0, | |||
) = precisions["s"] | |||
layouts = [ | |||
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn | |||
@@ -1091,9 +1094,12 @@ def GenerateDwconv2d_Simt(args, conv_kind): | |||
warpShapes.append([warp0, warp1]) | |||
# sgemm | |||
precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[ | |||
"s" | |||
] | |||
( | |||
precisionType, | |||
precisionBits, | |||
threadblockMaxElements, | |||
threadblockTilesL0, | |||
) = precisions["s"] | |||
layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)] | |||
@@ -1304,7 +1310,7 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind): | |||
for dst_type, dst_layout in zip(dst_types, dst_layouts): | |||
for alignment_src in alignment_constraints: | |||
if conv_kind == ConvKind.Wgrad: | |||
# skip io16xc16 | |||
# skip io16xc16 | |||
if math_inst.element_accumulator == DataType.f16: | |||
continue | |||
for alignment_diff in alignment_constraints: | |||
@@ -1319,7 +1325,7 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind): | |||
min_cc, | |||
alignment_src, | |||
alignment_diff, | |||
32, # always f32 output | |||
32, # always f32 output | |||
SpecialOptimizeDesc.NoneSpecialOpt, | |||
ImplicitGemmMode.GemmNT, | |||
False, | |||
@@ -1656,6 +1662,7 @@ def GenerateGemvOperations(args): | |||
) | |||
return GenerateGemv_Simt(args) | |||
################################################################################ | |||
# parameters | |||
# split_number - the concated file will be divided into split_number parts | |||
@@ -1668,10 +1675,21 @@ def GenerateGemvOperations(args): | |||
# epilogue - the epilogue in the file | |||
# wrapper_path - wrapper path | |||
################################################################################ | |||
def ConcatFile(split_number:int, file_path:str,operations:str,type:str,head:str,required_cuda_ver_major:str, required_cuda_ver_minor:str, epilogue:str, wrapper_path = None): | |||
def ConcatFile( | |||
split_number: int, | |||
file_path: str, | |||
operations: str, | |||
type: str, | |||
head: str, | |||
required_cuda_ver_major: str, | |||
required_cuda_ver_minor: str, | |||
epilogue: str, | |||
wrapper_path=None, | |||
): | |||
import os | |||
meragefiledir = file_path | |||
filenames=os.listdir(meragefiledir) | |||
filenames = os.listdir(meragefiledir) | |||
# filter file | |||
if "tensorop" in type: | |||
sub_string_1 = "tensorop" | |||
@@ -1679,197 +1697,183 @@ def ConcatFile(split_number:int, file_path:str,operations:str,type:str,head:str, | |||
else: | |||
sub_string_1 = sub_string_2 = "simt" | |||
if "dwconv2d_" in operations: | |||
filtered_operations = operations[:2]+operations[9:] | |||
filtered_operations = operations[:2] + operations[9:] | |||
elif ("conv2d" in operations) or ("deconv" in operations): | |||
filtered_operations = "cutlass" | |||
else: | |||
filtered_operations = operations | |||
#get the file list number | |||
# get the file list number | |||
file_list = {} | |||
file_list[operations + type] = 0 | |||
for filename in filenames: | |||
if (filtered_operations in filename) and (sub_string_1 in filename) and (sub_string_2 in filename) and ("all_" not in filename): | |||
if ( | |||
(filtered_operations in filename) | |||
and (sub_string_1 in filename) | |||
and (sub_string_2 in filename) | |||
and ("all_" not in filename) | |||
): | |||
file_list[operations + type] += 1 | |||
#concat file for linux | |||
# concat file for linux | |||
flag_1 = 0 | |||
flag_2 = 0 | |||
for filename in filenames: | |||
if (filtered_operations in filename) and (sub_string_1 in filename) and (sub_string_2 in filename) and ("all_" not in filename): | |||
if ( | |||
(filtered_operations in filename) | |||
and (sub_string_1 in filename) | |||
and (sub_string_2 in filename) | |||
and ("all_" not in filename) | |||
): | |||
flag_1 += 1 | |||
filepath=meragefiledir+'/'+filename | |||
if (flag_1 >= flag_2 * (file_list[operations + type]/split_number)) and (flag_1 <= (flag_2 + 1) * (file_list[operations + type]/split_number)): | |||
file =open(file_path + '/{}_{}_{}.cu'.format(operations,type, flag_2),'a') | |||
#write Template at the head | |||
filepath = meragefiledir + "/" + filename | |||
if (flag_1 >= flag_2 * (file_list[operations + type] / split_number)) and ( | |||
flag_1 <= (flag_2 + 1) * (file_list[operations + type] / split_number) | |||
): | |||
file = open( | |||
file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a" | |||
) | |||
# write Template at the head | |||
if wrapper_path is None: | |||
file.write( | |||
SubstituteTemplate( | |||
head, | |||
{ | |||
"required_cuda_ver_major": str( | |||
required_cuda_ver_major | |||
), | |||
"required_cuda_ver_minor": str( | |||
required_cuda_ver_minor | |||
), | |||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
}, | |||
) | |||
) | |||
else: | |||
file.write( | |||
SubstituteTemplate( | |||
head, | |||
{ | |||
"wrapper_path": wrapper_path, | |||
"required_cuda_ver_major": str( | |||
required_cuda_ver_major | |||
), | |||
"required_cuda_ver_minor": str( | |||
required_cuda_ver_minor | |||
), | |||
}, | |||
) | |||
SubstituteTemplate( | |||
head, | |||
{ | |||
"wrapper_path": wrapper_path, | |||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
}, | |||
) | |||
) | |||
# concat all the remaining files | |||
if flag_2 == (split_number - 1): | |||
for line in open(filepath): | |||
file.writelines(line) | |||
os.remove(filepath) | |||
file.write('\n') | |||
file.write("\n") | |||
file.write(epilogue) | |||
continue | |||
for line in open(filepath): | |||
file.writelines(line) | |||
os.remove(filepath) | |||
file.write('\n') | |||
file.write("\n") | |||
file.write(epilogue) | |||
else: | |||
#write Template at the head | |||
# write Template at the head | |||
if wrapper_path is None: | |||
file.write( | |||
SubstituteTemplate( | |||
head, | |||
{ | |||
"required_cuda_ver_major": str( | |||
required_cuda_ver_major | |||
), | |||
"required_cuda_ver_minor": str( | |||
required_cuda_ver_minor | |||
), | |||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
}, | |||
) | |||
) | |||
else: | |||
file.write( | |||
SubstituteTemplate( | |||
head, | |||
{ | |||
"wrapper_path": wrapper_path, | |||
"required_cuda_ver_major": str( | |||
required_cuda_ver_major | |||
), | |||
"required_cuda_ver_minor": str( | |||
required_cuda_ver_minor | |||
), | |||
}, | |||
) | |||
SubstituteTemplate( | |||
head, | |||
{ | |||
"wrapper_path": wrapper_path, | |||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
}, | |||
) | |||
) | |||
for line in open(filepath): | |||
file.writelines(line) | |||
os.remove(filepath) | |||
file.write('\n') | |||
file.write("\n") | |||
file.write(epilogue) | |||
file.close() | |||
flag_2 += 1 | |||
#concat file for windows | |||
# concat file for windows | |||
elif filename[0].isdigit() and ("all_" not in filename): | |||
flag_1 += 1 | |||
filepath=meragefiledir+'/'+filename | |||
if (flag_1 >= flag_2 * (len(filenames)/split_number)) and (flag_1 <= (flag_2 + 1) * (len(filenames)/split_number)): | |||
file =open(file_path + '/{}_{}_{}.cu'.format(operations,type, flag_2),'a') | |||
#write Template at the head | |||
filepath = meragefiledir + "/" + filename | |||
if (flag_1 >= flag_2 * (len(filenames) / split_number)) and ( | |||
flag_1 <= (flag_2 + 1) * (len(filenames) / split_number) | |||
): | |||
file = open( | |||
file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a" | |||
) | |||
# write Template at the head | |||
if wrapper_path is None: | |||
file.write( | |||
SubstituteTemplate( | |||
head, | |||
{ | |||
"required_cuda_ver_major": str( | |||
required_cuda_ver_major | |||
), | |||
"required_cuda_ver_minor": str( | |||
required_cuda_ver_minor | |||
), | |||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
}, | |||
) | |||
) | |||
else: | |||
file.write( | |||
SubstituteTemplate( | |||
head, | |||
{ | |||
"wrapper_path": wrapper_path, | |||
"required_cuda_ver_major": str( | |||
required_cuda_ver_major | |||
), | |||
"required_cuda_ver_minor": str( | |||
required_cuda_ver_minor | |||
), | |||
}, | |||
) | |||
SubstituteTemplate( | |||
head, | |||
{ | |||
"wrapper_path": wrapper_path, | |||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
}, | |||
) | |||
) | |||
# concat all the remaining files | |||
if flag_2 == (split_number - 1): | |||
for line in open(filepath): | |||
file.writelines(line) | |||
os.remove(filepath) | |||
file.write('\n') | |||
file.write("\n") | |||
file.write(epilogue) | |||
continue | |||
for line in open(filepath): | |||
file.writelines(line) | |||
os.remove(filepath) | |||
file.write('\n') | |||
file.write("\n") | |||
file.write(epilogue) | |||
else: | |||
#write Template at the head | |||
# write Template at the head | |||
if wrapper_path is None: | |||
file.write( | |||
SubstituteTemplate( | |||
head, | |||
{ | |||
"required_cuda_ver_major": str( | |||
required_cuda_ver_major | |||
), | |||
"required_cuda_ver_minor": str( | |||
required_cuda_ver_minor | |||
), | |||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
}, | |||
) | |||
) | |||
else: | |||
file.write( | |||
SubstituteTemplate( | |||
head, | |||
{ | |||
"wrapper_path": wrapper_path, | |||
"required_cuda_ver_major": str( | |||
required_cuda_ver_major | |||
), | |||
"required_cuda_ver_minor": str( | |||
required_cuda_ver_minor | |||
), | |||
}, | |||
) | |||
SubstituteTemplate( | |||
head, | |||
{ | |||
"wrapper_path": wrapper_path, | |||
"required_cuda_ver_major": str(required_cuda_ver_major), | |||
"required_cuda_ver_minor": str(required_cuda_ver_minor), | |||
}, | |||
) | |||
) | |||
for line in open(filepath): | |||
file.writelines(line) | |||
os.remove(filepath) | |||
file.write('\n') | |||
file.write("\n") | |||
file.write(epilogue) | |||
file.close() | |||
flag_2 += 1 | |||
################################################################################################### | |||
################################################################################################### | |||
@@ -1940,39 +1944,97 @@ if __name__ == "__main__": | |||
args.output, operation, short_path | |||
) as emitter: | |||
emitter.emit() | |||
head = EmitConvSingleKernelWrapper(args.output, operations[0], short_path).header_template | |||
head = EmitConvSingleKernelWrapper( | |||
args.output, operations[0], short_path | |||
).header_template | |||
required_cuda_ver_major = operations[0].required_cuda_ver_major | |||
required_cuda_ver_minor = operations[0].required_cuda_ver_minor | |||
epilogue = EmitConvSingleKernelWrapper(args.output, operations[0], short_path).epilogue_template | |||
epilogue = EmitConvSingleKernelWrapper( | |||
args.output, operations[0], short_path | |||
).epilogue_template | |||
if "tensorop" in args.type: | |||
ConcatFile(4, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||
ConcatFile( | |||
4, | |||
args.output, | |||
args.operations, | |||
args.type, | |||
head, | |||
required_cuda_ver_major, | |||
required_cuda_ver_minor, | |||
epilogue, | |||
) | |||
else: | |||
ConcatFile(2, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||
ConcatFile( | |||
2, | |||
args.output, | |||
args.operations, | |||
args.type, | |||
head, | |||
required_cuda_ver_major, | |||
required_cuda_ver_minor, | |||
epilogue, | |||
) | |||
elif args.operations == "gemm": | |||
for operation in operations: | |||
with EmitGemmSingleKernelWrapper( | |||
args.output, operation, short_path | |||
) as emitter: | |||
emitter.emit() | |||
head = EmitGemmSingleKernelWrapper(args.output, operations[0], short_path).header_template | |||
head = EmitGemmSingleKernelWrapper( | |||
args.output, operations[0], short_path | |||
).header_template | |||
required_cuda_ver_major = operations[0].required_cuda_ver_major | |||
required_cuda_ver_minor = operations[0].required_cuda_ver_minor | |||
epilogue = EmitGemmSingleKernelWrapper(args.output, operations[0], short_path).epilogue_template | |||
epilogue = EmitGemmSingleKernelWrapper( | |||
args.output, operations[0], short_path | |||
).epilogue_template | |||
if args.type == "tensorop884": | |||
ConcatFile(30, args.output, args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||
ConcatFile( | |||
30, | |||
args.output, | |||
args.operations, | |||
args.type, | |||
head, | |||
required_cuda_ver_major, | |||
required_cuda_ver_minor, | |||
epilogue, | |||
) | |||
else: | |||
ConcatFile(2, args.output, args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) | |||
ConcatFile( | |||
2, | |||
args.output, | |||
args.operations, | |||
args.type, | |||
head, | |||
required_cuda_ver_major, | |||
required_cuda_ver_minor, | |||
epilogue, | |||
) | |||
elif args.operations == "gemv": | |||
for operation in operations: | |||
with EmitGemvSingleKernelWrapper( | |||
args.output, operation, gemv_wrapper_path, short_path | |||
) as emitter: | |||
emitter.emit() | |||
head = EmitGemvSingleKernelWrapper(args.output, operations[0], gemv_wrapper_path, short_path).header_template | |||
head = EmitGemvSingleKernelWrapper( | |||
args.output, operations[0], gemv_wrapper_path, short_path | |||
).header_template | |||
required_cuda_ver_major = operations[0].required_cuda_ver_major | |||
required_cuda_ver_minor = operations[0].required_cuda_ver_minor | |||
epilogue = EmitGemvSingleKernelWrapper(args.output, operations[0], gemv_wrapper_path, short_path).epilogue_template | |||
ConcatFile(2, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue, wrapper_path = gemv_wrapper_path) | |||
epilogue = EmitGemvSingleKernelWrapper( | |||
args.output, operations[0], gemv_wrapper_path, short_path | |||
).epilogue_template | |||
ConcatFile( | |||
2, | |||
args.output, | |||
args.operations, | |||
args.type, | |||
head, | |||
required_cuda_ver_major, | |||
required_cuda_ver_minor, | |||
epilogue, | |||
wrapper_path=gemv_wrapper_path, | |||
) | |||
if args.operations != "gemv": | |||
GenerateManifest(args, operations, args.output) | |||
@@ -4,11 +4,11 @@ | |||
# \brief Generates the CUTLASS Library's instances | |||
# | |||
import enum | |||
import re | |||
################################################################################################### | |||
import enum | |||
# The following block implements enum.auto() for Python 3.5 variants that don't include it such | |||
# as the default 3.5.2 on Ubuntu 16.04. | |||
@@ -8,9 +8,9 @@ import enum | |||
import os.path | |||
import shutil | |||
from library import * | |||
from gemm_operation import * | |||
from conv2d_operation import * | |||
from gemm_operation import * | |||
from library import * | |||
################################################################################################### | |||
@@ -1,59 +1,67 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import os | |||
import argparse | |||
import os | |||
from gen_elemwise_utils import DTYPES | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description='generate elemwise impl files', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
parser.add_argument('--type', type=str, choices=['cuda'], | |||
default='cuda', | |||
help='generate cuda cond take kernel file') | |||
parser.add_argument('output', help='output directory') | |||
description="generate elemwise impl files", | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
) | |||
parser.add_argument( | |||
"--type", | |||
type=str, | |||
choices=["cuda"], | |||
default="cuda", | |||
help="generate cuda cond take kernel file", | |||
) | |||
parser.add_argument("output", help="output directory") | |||
args = parser.parse_args() | |||
if not os.path.isdir(args.output): | |||
os.makedirs(args.output) | |||
assert args.type =='cuda' | |||
cpp_ext = 'cu' | |||
assert args.type == "cuda" | |||
cpp_ext = "cu" | |||
for dtype in DTYPES.keys(): | |||
fname = '{}.{}'.format(dtype, cpp_ext) | |||
fname = "{}.{}".format(dtype, cpp_ext) | |||
fname = os.path.join(args.output, fname) | |||
with open(fname, 'w') as fout: | |||
with open(fname, "w") as fout: | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_cond_take_kern_impls.py') | |||
w("// generated by gen_cond_take_kern_impls.py") | |||
w('#include "../kern.inl"') | |||
w('') | |||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
w('#if !MEGDNN_DISABLE_FLOAT16') | |||
w('namespace megdnn {') | |||
w('namespace cuda {') | |||
w('namespace cond_take {') | |||
w('') | |||
w('inst_genidx(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||
w('#undef inst_genidx') | |||
w('') | |||
w('inst_copy(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||
w('#undef inst_copy') | |||
w('#undef inst_copy_') | |||
w('') | |||
w('} // cond_take') | |||
w('} // cuda') | |||
w('} // megdnn') | |||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
w('#endif') | |||
print('generated {}'.format(fname)) | |||
w("") | |||
if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||
w("#if !MEGDNN_DISABLE_FLOAT16") | |||
w("namespace megdnn {") | |||
w("namespace cuda {") | |||
w("namespace cond_take {") | |||
w("") | |||
w("inst_genidx(::megdnn::dtype::{})".format(DTYPES[dtype][0])) | |||
w("#undef inst_genidx") | |||
w("") | |||
w("inst_copy(::megdnn::dtype::{})".format(DTYPES[dtype][0])) | |||
w("#undef inst_copy") | |||
w("#undef inst_copy_") | |||
w("") | |||
w("} // cond_take") | |||
w("} // cuda") | |||
w("} // megdnn") | |||
if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||
w("#endif") | |||
print("generated {}".format(fname)) | |||
os.utime(args.output) | |||
if __name__ == '__main__': | |||
if __name__ == "__main__": | |||
main() |
@@ -1,37 +1,47 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import os | |||
import argparse | |||
import itertools | |||
import os | |||
PREFIXES = { | |||
"dp4a": [ | |||
("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True), | |||
("batch_conv_bias_int8_gemm_ncdiv4hw4", False), | |||
("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False), | |||
] | |||
} | |||
PREFIXES = {"dp4a": [("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True), ("batch_conv_bias_int8_gemm_ncdiv4hw4", False), ("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False)]} | |||
ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")} | |||
ACTIVATIONS = {1: ("IDENTITY", "_id"), | |||
2: ("RELU", "_relu"), | |||
3: ("H_SWISH", "_hswish")} | |||
BIASES = { | |||
1: ("PerElementBiasVisitor", "_per_elem"), | |||
2: ("PerChannelBiasVisitor", "_per_chan"), | |||
} | |||
BIASES = {1: ("PerElementBiasVisitor", "_per_elem"), | |||
2: ("PerChannelBiasVisitor", "_per_chan")} | |||
SUFFIXES = {"dp4a": [""], "imma": [""]} | |||
SUFFIXES = {"dp4a": [""], | |||
"imma": [""]} | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description='generate cuda batch conv bias (dp4a/imma) kern impl files', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
parser.add_argument('--type', type=str, choices=['dp4a', | |||
'imma'], | |||
default='dp4a', help='generate cuda conv bias kernel file') | |||
parser.add_argument('output', help='output directory') | |||
description="generate cuda batch conv bias (dp4a/imma) kern impl files", | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
) | |||
parser.add_argument( | |||
"--type", | |||
type=str, | |||
choices=["dp4a", "imma"], | |||
default="dp4a", | |||
help="generate cuda conv bias kernel file", | |||
) | |||
parser.add_argument("output", help="output directory") | |||
args = parser.parse_args() | |||
if not os.path.isdir(args.output): | |||
os.makedirs(args.output) | |||
inst = ''' | |||
inst = """ | |||
template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, | |||
IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>>>( | |||
const int8_t* d_src, | |||
@@ -41,7 +51,7 @@ template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, | |||
const ConvParam& param, | |||
float alpha, | |||
float beta, | |||
cudaStream_t stream);''' | |||
cudaStream_t stream);""" | |||
for prefix in PREFIXES[args.type]: | |||
for suffix in SUFFIXES[args.type]: | |||
@@ -52,17 +62,23 @@ template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, | |||
fname = os.path.join(args.output, fname) | |||
with open(fname, "w") as fout: | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_batch_cuda_conv_bias_kern_impls.py') | |||
cur_inst = inst.replace("PREFIX", prefix[0]).replace("SUFFIX", suffix).replace("BIAS", bias[0]).replace("ACTIVATION", act[0]) | |||
w("// generated by gen_batch_cuda_conv_bias_kern_impls.py") | |||
cur_inst = ( | |||
inst.replace("PREFIX", prefix[0]) | |||
.replace("SUFFIX", suffix) | |||
.replace("BIAS", bias[0]) | |||
.replace("ACTIVATION", act[0]) | |||
) | |||
if has_workspace: | |||
cur_inst = cur_inst.replace("WORKSPACE", "\nint* d_workspace, ") | |||
else: | |||
cur_inst = cur_inst.replace("WORKSPACE", "") | |||
cur_inst = cur_inst.replace("WORKSPACE", "") | |||
w('#include "../{}{}.cuinl"'.format(prefix[0], suffix)) | |||
w(cur_inst) | |||
print('generated {}'.format(fname)) | |||
print("generated {}".format(fname)) | |||
os.utime(args.output) | |||
if __name__ == '__main__': | |||
if __name__ == "__main__": | |||
main() |
@@ -1,39 +1,57 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import os | |||
import argparse | |||
import itertools | |||
import os | |||
PREFIXES = { | |||
"dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4", | |||
"imma": "conv_bias_int8_implicit_gemm", | |||
} | |||
PREFIXES = {"dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4", "imma": "conv_bias_int8_implicit_gemm"} | |||
ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")} | |||
ACTIVATIONS = {1: ("IDENTITY", "_id"), | |||
2: ("RELU", "_relu"), | |||
3: ("H_SWISH", "_hswish")} | |||
BIASES = { | |||
1: ("PerElementBiasVisitor", "_per_elem"), | |||
2: ("PerChannelBiasVisitor", "_per_chan"), | |||
} | |||
BIASES = {1: ("PerElementBiasVisitor", "_per_elem"), | |||
2: ("PerChannelBiasVisitor", "_per_chan")} | |||
SUFFIXES = { | |||
"dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"], | |||
"imma": [ | |||
"_imma16x16x16_cdiv4hwn4", | |||
"_imma8x32x16_cdiv4hwn4", | |||
"_imma32x8x16_cdiv4hwn4", | |||
"_imma16x16x16_cdiv4hwn4_reorder_filter", | |||
"_imma8x32x16_cdiv4hwn4_reorder_filter", | |||
"_imma32x8x16_cdiv4hwn4_reorder_filter", | |||
"_imma16x16x16_cdiv4hwn4_unroll_width", | |||
"_imma8x32x16_cdiv4hwn4_unroll_width", | |||
"_imma32x8x16_cdiv4hwn4_unroll_width", | |||
], | |||
} | |||
SUFFIXES = {"dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"], | |||
"imma": ["_imma16x16x16_cdiv4hwn4", "_imma8x32x16_cdiv4hwn4", "_imma32x8x16_cdiv4hwn4", | |||
"_imma16x16x16_cdiv4hwn4_reorder_filter", "_imma8x32x16_cdiv4hwn4_reorder_filter", "_imma32x8x16_cdiv4hwn4_reorder_filter", | |||
"_imma16x16x16_cdiv4hwn4_unroll_width", "_imma8x32x16_cdiv4hwn4_unroll_width", "_imma32x8x16_cdiv4hwn4_unroll_width"]} | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description='generate cuda conv bias (dp4a/imma) kern impl files', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
parser.add_argument('--type', type=str, choices=['dp4a', | |||
'imma'], | |||
default='dp4a', help='generate cuda conv bias kernel file') | |||
parser.add_argument('output', help='output directory') | |||
description="generate cuda conv bias (dp4a/imma) kern impl files", | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
) | |||
parser.add_argument( | |||
"--type", | |||
type=str, | |||
choices=["dp4a", "imma"], | |||
default="dp4a", | |||
help="generate cuda conv bias kernel file", | |||
) | |||
parser.add_argument("output", help="output directory") | |||
args = parser.parse_args() | |||
if not os.path.isdir(args.output): | |||
os.makedirs(args.output) | |||
inst = ''' | |||
inst = """ | |||
template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, | |||
IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>>>( | |||
const int8_t* d_src, | |||
@@ -43,7 +61,7 @@ template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, | |||
const ConvParam& param, | |||
float alpha, | |||
float beta, | |||
cudaStream_t stream);''' | |||
cudaStream_t stream);""" | |||
for suffix in SUFFIXES[args.type]: | |||
for _, act in ACTIVATIONS.items(): | |||
@@ -53,13 +71,19 @@ template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, | |||
fname = os.path.join(args.output, fname) | |||
with open(fname, "w") as fout: | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_cuda_conv_bias_kern_impls.py') | |||
cur_inst = inst.replace("PREFIX", prefix).replace("SUFFIX", suffix).replace("BIAS", bias[0]).replace("ACTIVATION", act[0]) | |||
w("// generated by gen_cuda_conv_bias_kern_impls.py") | |||
cur_inst = ( | |||
inst.replace("PREFIX", prefix) | |||
.replace("SUFFIX", suffix) | |||
.replace("BIAS", bias[0]) | |||
.replace("ACTIVATION", act[0]) | |||
) | |||
w('#include "../{}{}.cuinl"'.format(prefix, suffix)) | |||
w(cur_inst) | |||
print('generated {}'.format(fname)) | |||
print("generated {}".format(fname)) | |||
os.utime(args.output) | |||
if __name__ == '__main__': | |||
if __name__ == "__main__": | |||
main() |
@@ -1,34 +1,39 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import os | |||
import argparse | |||
import os | |||
from gen_elemwise_utils import ARITIES, MODES | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description='generate elemwise each mode', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
description="generate elemwise each mode", | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
) | |||
parser.add_argument('output', help='output directory') | |||
parser.add_argument("output", help="output directory") | |||
args = parser.parse_args() | |||
with open(args.output, 'w') as fout: | |||
with open(args.output, "w") as fout: | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_elemwise_each_mode.py') | |||
w("// generated by gen_elemwise_each_mode.py") | |||
keys = list(MODES.keys()) | |||
keys.sort() | |||
for (anum, ctype) in keys: | |||
w('#define MEGDNN_FOREACH_ELEMWISE_MODE_{}_{}(cb) \\'.format( | |||
ARITIES[anum], ctype)) | |||
w( | |||
"#define MEGDNN_FOREACH_ELEMWISE_MODE_{}_{}(cb) \\".format( | |||
ARITIES[anum], ctype | |||
) | |||
) | |||
for mode in MODES[(anum, ctype)]: | |||
w(' MEGDNN_ELEMWISE_MODE_ENABLE({}, cb) \\'.format(mode)) | |||
w('') | |||
w(" MEGDNN_ELEMWISE_MODE_ENABLE({}, cb) \\".format(mode)) | |||
w("") | |||
print('generated each_mode.inl') | |||
print("generated each_mode.inl") | |||
os.utime(args.output) | |||
if __name__ == '__main__': | |||
if __name__ == "__main__": | |||
main() |
@@ -1,56 +1,63 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import os | |||
import argparse | |||
import itertools | |||
import os | |||
from gen_elemwise_utils import ARITIES, DTYPES, MODES | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description='generate elemwise impl files', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
parser.add_argument('--type', type=str, choices=['cuda', | |||
'hip', | |||
'cpp'], | |||
default='cpp', help='generate cuda/hip kernel file') | |||
parser.add_argument('output', help='output directory') | |||
description="generate elemwise impl files", | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
) | |||
parser.add_argument( | |||
"--type", | |||
type=str, | |||
choices=["cuda", "hip", "cpp"], | |||
default="cpp", | |||
help="generate cuda/hip kernel file", | |||
) | |||
parser.add_argument("output", help="output directory") | |||
args = parser.parse_args() | |||
if not os.path.isdir(args.output): | |||
os.makedirs(args.output) | |||
if args.type == 'cuda': | |||
cpp_ext = 'cu' | |||
elif args.type == 'hip': | |||
cpp_ext = 'cpp.hip' | |||
if args.type == "cuda": | |||
cpp_ext = "cu" | |||
elif args.type == "hip": | |||
cpp_ext = "cpp.hip" | |||
else: | |||
assert args.type == 'cpp' | |||
cpp_ext = 'cpp' | |||
assert args.type == "cpp" | |||
cpp_ext = "cpp" | |||
for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()): | |||
for mode in MODES[(anum, DTYPES[ctype][1])]: | |||
formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode) | |||
fname = '{}_{}.{}'.format(mode, ctype, cpp_ext) | |||
formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode) | |||
fname = "{}_{}.{}".format(mode, ctype, cpp_ext) | |||
fname = os.path.join(args.output, fname) | |||
with open(fname, 'w') as fout: | |||
with open(fname, "w") as fout: | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_elemwise_kern_impls.py') | |||
w("// generated by gen_elemwise_kern_impls.py") | |||
if ctype == 'dt_float16' or ctype == 'dt_bfloat16': | |||
w('#if !MEGDNN_DISABLE_FLOAT16') | |||
if ctype == "dt_float16" or ctype == "dt_bfloat16": | |||
w("#if !MEGDNN_DISABLE_FLOAT16") | |||
w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) | |||
w('#define KERN_IMPL_ARITY {}'.format(anum)) | |||
w('#define KERN_IMPL_CTYPE {}'.format(ctype)) | |||
w("#define KERN_IMPL_MODE(cb) {}".format(formode)) | |||
w("#define KERN_IMPL_ARITY {}".format(anum)) | |||
w("#define KERN_IMPL_CTYPE {}".format(ctype)) | |||
w('#include "../kern_impl.inl"') | |||
if ctype == 'dt_float16' or ctype == 'dt_bfloat16': | |||
w('#endif') | |||
if ctype == "dt_float16" or ctype == "dt_bfloat16": | |||
w("#endif") | |||
print('generated {}'.format(fname)) | |||
print("generated {}".format(fname)) | |||
os.utime(args.output) | |||
if __name__ == '__main__': | |||
if __name__ == "__main__": | |||
main() |
@@ -1,52 +1,66 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import os | |||
import argparse | |||
import itertools | |||
from gen_elemwise_multi_type_utils import SUPPORT_DTYPES, MODES, SUPPORT_QINT32_DTYPES, QINT32_MODES | |||
import os | |||
from gen_elemwise_multi_type_utils import ( # isort: skip; isort: skip | |||
MODES, | |||
QINT32_MODES, | |||
SUPPORT_DTYPES, | |||
SUPPORT_QINT32_DTYPES, | |||
) | |||
def generate(modes, support_dtypes, output, cpp_ext): | |||
for anum, ctype in itertools.product(modes.keys(), support_dtypes): | |||
print('{} : {}'.format(anum, ctype)) | |||
print("{} : {}".format(anum, ctype)) | |||
src_ctype = ctype[0] | |||
dst_ctype = ctype[1] | |||
for mode in modes[anum]: | |||
formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode) | |||
fname = '{}_{}_{}.{}'.format(mode, src_ctype, dst_ctype, cpp_ext) | |||
formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode) | |||
fname = "{}_{}_{}.{}".format(mode, src_ctype, dst_ctype, cpp_ext) | |||
fname = os.path.join(output, fname) | |||
with open(fname, 'w') as fout: | |||
with open(fname, "w") as fout: | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_elemwise_multi_type_kern_impls.py') | |||
w("// generated by gen_elemwise_multi_type_kern_impls.py") | |||
w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) | |||
w('#define KERN_IMPL_ARITY {}'.format(anum)) | |||
w('#define KERN_IMPL_STYPE {}'.format(src_ctype)) | |||
w('#define KERN_IMPL_DTYPE {}'.format(dst_ctype)) | |||
w("#define KERN_IMPL_MODE(cb) {}".format(formode)) | |||
w("#define KERN_IMPL_ARITY {}".format(anum)) | |||
w("#define KERN_IMPL_STYPE {}".format(src_ctype)) | |||
w("#define KERN_IMPL_DTYPE {}".format(dst_ctype)) | |||
w('#include "../kern_impl.inl"') | |||
print('generated {}'.format(fname)) | |||
print("generated {}".format(fname)) | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description='generate elemwise impl files', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
parser.add_argument('--type', type=str, choices=['cuda'], | |||
default='cuda', help='generate cuda kernel file') | |||
parser.add_argument('output', help='output directory') | |||
description="generate elemwise impl files", | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
) | |||
parser.add_argument( | |||
"--type", | |||
type=str, | |||
choices=["cuda"], | |||
default="cuda", | |||
help="generate cuda kernel file", | |||
) | |||
parser.add_argument("output", help="output directory") | |||
args = parser.parse_args() | |||
if not os.path.isdir(args.output): | |||
os.makedirs(args.output) | |||
assert args.type == 'cuda' | |||
if args.type == 'cuda': | |||
cpp_ext = 'cu' | |||
assert args.type == "cuda" | |||
if args.type == "cuda": | |||
cpp_ext = "cu" | |||
generate(MODES, SUPPORT_DTYPES, args.output, cpp_ext) | |||
generate(QINT32_MODES, SUPPORT_QINT32_DTYPES, args.output, cpp_ext) | |||
os.utime(args.output) | |||
if __name__ == '__main__': | |||
if __name__ == "__main__": | |||
main() |
@@ -1,48 +1,131 @@ | |||
# As cuda currently do not support quint8, so we just ignore it. | |||
SUPPORT_DTYPES = [('dt_qint8', 'dt_qint8')] | |||
SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32'), | |||
('dt_qint4', 'dt_qint32'), ('dt_quint4', 'dt_qint32')] | |||
SUPPORT_DTYPES = [("dt_qint8", "dt_qint8")] | |||
SUPPORT_QINT32_DTYPES = [ | |||
("dt_qint32", "dt_qint8"), | |||
("dt_qint8", "dt_qint32"), | |||
("dt_qint4", "dt_qint32"), | |||
("dt_quint4", "dt_qint32"), | |||
] | |||
SUPPORT_DTYPES_Q4 = [('dt_qint4', 'dt_qint4'), ('dt_quint4', 'dt_quint4')] | |||
SUPPORT_QINT32_DTYPES_Q4 = [('dt_qint32', 'dt_qint4'), ('dt_qint32', 'dt_quint4')] | |||
SUPPORT_DTYPES_Q4 = [("dt_qint4", "dt_qint4"), ("dt_quint4", "dt_quint4")] | |||
SUPPORT_QINT32_DTYPES_Q4 = [("dt_qint32", "dt_qint4"), ("dt_qint32", "dt_quint4")] | |||
SUPPORT_ARRITY2_DTYPES = ['dt_int32', 'dt_uint8', 'dt_int8', 'dt_int16', 'dt_bool', 'dt_float32', | |||
'dt_float16', 'dt_bfloat16'] | |||
SUPPORT_ARRITY1_DTYPES = ['dt_float32','dt_float16', 'dt_bfloat16'] | |||
SUPPORT_ARRITY2_DTYPES = [ | |||
"dt_int32", | |||
"dt_uint8", | |||
"dt_int8", | |||
"dt_int16", | |||
"dt_bool", | |||
"dt_float32", | |||
"dt_float16", | |||
"dt_bfloat16", | |||
] | |||
SUPPORT_ARRITY1_DTYPES = ["dt_float32", "dt_float16", "dt_bfloat16"] | |||
MODES = { | |||
1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | |||
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | |||
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', | |||
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'], | |||
2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | |||
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', | |||
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', | |||
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | |||
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | |||
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | |||
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||
1: [ | |||
"RELU", | |||
"ABS", | |||
"NEGATE", | |||
"ACOS", | |||
"ASIN", | |||
"CEIL", | |||
"COS", | |||
"EXP", | |||
"EXPM1", | |||
"FLOOR", | |||
"LOG", | |||
"LOG1P", | |||
"SIGMOID", | |||
"SIN", | |||
"TANH", | |||
"FAST_TANH", | |||
"ROUND", | |||
"ERF", | |||
"ERFINV", | |||
"ERFC", | |||
"ERFCINV", | |||
"H_SWISH", | |||
"SILU", | |||
"GELU", | |||
], | |||
2: [ | |||
"ABS_GRAD", | |||
"ADD", | |||
"FLOOR_DIV", | |||
"MAX", | |||
"MIN", | |||
"MOD", | |||
"MUL", | |||
"SIGMOID_GRAD", | |||
"SUB", | |||
"SWITCH_GT0", | |||
"TANH_GRAD", | |||
"LT", | |||
"LEQ", | |||
"EQ", | |||
"FUSE_ADD_RELU", | |||
"TRUE_DIV", | |||
"POW", | |||
"LOG_SUM_EXP", | |||
"FUSE_ADD_TANH", | |||
"FAST_TANH_GRAD", | |||
"FUSE_ADD_SIGMOID", | |||
"ATAN2", | |||
"H_SWISH_GRAD", | |||
"FUSE_ADD_H_SWISH", | |||
"SILU_GRAD", | |||
"GELU_GRAD", | |||
], | |||
3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], | |||
} | |||
QINT4_MODES = { | |||
1: ['RELU', 'ABS', 'NEGATE', 'CEIL', 'FLOOR', 'SIGMOID', | |||
'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'], | |||
2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0', | |||
'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH', | |||
'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'], | |||
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||
1: [ | |||
"RELU", | |||
"ABS", | |||
"NEGATE", | |||
"CEIL", | |||
"FLOOR", | |||
"SIGMOID", | |||
"TANH", | |||
"FAST_TANH", | |||
"ROUND", | |||
"H_SWISH", | |||
], | |||
2: [ | |||
"ADD", | |||
"MAX", | |||
"MIN", | |||
"MUL", | |||
"SUB", | |||
"SWITCH_GT0", | |||
"LT", | |||
"LEQ", | |||
"EQ", | |||
"FUSE_ADD_RELU", | |||
"FUSE_ADD_TANH", | |||
"FUSE_ADD_SIGMOID", | |||
"FUSE_ADD_H_SWISH", | |||
], | |||
3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], | |||
} | |||
QINT32_MODES = { | |||
1: ['RELU', 'SIGMOID', 'TANH', 'FAST_TANH', 'H_SWISH'], | |||
2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID', | |||
'FUSE_ADD_TANH', 'FUSE_ADD_H_SWISH'] | |||
1: ["RELU", "SIGMOID", "TANH", "FAST_TANH", "H_SWISH"], | |||
2: [ | |||
"ADD", | |||
"FUSE_ADD_RELU", | |||
"FUSE_ADD_SIGMOID", | |||
"FUSE_ADD_TANH", | |||
"FUSE_ADD_H_SWISH", | |||
], | |||
} | |||
ARRITY1_BOOL_MODES = { | |||
1: ['ISINF','ISNAN'], | |||
1: ["ISINF", "ISNAN"], | |||
} | |||
ARRITY2_BOOL_MODES = { | |||
2: ['EQ','LEQ','NEQ','LT'], | |||
2: ["EQ", "LEQ", "NEQ", "LT"], | |||
} |
@@ -1,52 +1,57 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import os | |||
import argparse | |||
import os | |||
from gen_elemwise_utils import DTYPES | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description='generate elemwise impl files', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
parser.add_argument('--type', type=str, choices=[ | |||
'cuda', | |||
'hip' | |||
], | |||
default='cuda', | |||
help='generate cuda/hip elemwise special kernel file') | |||
parser.add_argument('output', help='output directory') | |||
description="generate elemwise impl files", | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |||
) | |||
parser.add_argument( | |||
"--type", | |||
type=str, | |||
choices=["cuda", "hip"], | |||
default="cuda", | |||
help="generate cuda/hip elemwise special kernel file", | |||
) | |||
parser.add_argument("output", help="output directory") | |||
args = parser.parse_args() | |||
if not os.path.isdir(args.output): | |||
os.makedirs(args.output) | |||
if args.type == 'cuda': | |||
cpp_ext = 'cu' | |||
if args.type == "cuda": | |||
cpp_ext = "cu" | |||
else: | |||
assert args.type =='hip' | |||
cpp_ext = 'cpp.hip' | |||
assert args.type == "hip" | |||
cpp_ext = "cpp.hip" | |||
for dtype in DTYPES.keys(): | |||
fname = 'special_{}.{}'.format(dtype, cpp_ext) | |||
fname = "special_{}.{}".format(dtype, cpp_ext) | |||
fname = os.path.join(args.output, fname) | |||
with open(fname, 'w') as fout: | |||
with open(fname, "w") as fout: | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_elemwise_special_kern_impls.py') | |||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
w('#if !MEGDNN_DISABLE_FLOAT16') | |||
w("// generated by gen_elemwise_special_kern_impls.py") | |||
if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||
w("#if !MEGDNN_DISABLE_FLOAT16") | |||
w('#include "../special_kerns.inl"') | |||
w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||
w('#undef INST') | |||
w('}') | |||
w('}') | |||
if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
w('#endif') | |||
w("INST(::megdnn::dtype::{})".format(DTYPES[dtype][0])) | |||
w("#undef INST") | |||
w("}") | |||
w("}") | |||
if dtype == "dt_float16" or dtype == "dt_bfloat16": | |||
w("#endif") | |||
print('generated {}'.format(fname)) | |||
print("generated {}".format(fname)) | |||
os.utime(args.output) | |||
if __name__ == '__main__': | |||
if __name__ == "__main__": | |||
main() |
@@ -1,35 +1,95 @@ | |||
ARITIES = {1: "UNARY", 2: "BINARY", 3: "TERNARY"} | |||
ARITIES = {1: 'UNARY', 2: 'BINARY', 3: 'TERNARY'} | |||
DTYPES = {'dt_int32': ('Int32', 'INT'), | |||
'dt_uint8': ('Uint8', 'INT'), | |||
'dt_int8': ('Int8', 'INT'), | |||
'dt_int16': ('Int16', 'INT'), | |||
'dt_bool': ('Bool', 'BOOL'), | |||
'dt_float32': ('Float32', 'FLOAT'), | |||
'dt_float16': ('Float16', 'FLOAT'), | |||
'dt_bfloat16': ('BFloat16', 'FLOAT') | |||
} | |||
DTYPES = { | |||
"dt_int32": ("Int32", "INT"), | |||
"dt_uint8": ("Uint8", "INT"), | |||
"dt_int8": ("Int8", "INT"), | |||
"dt_int16": ("Int16", "INT"), | |||
"dt_bool": ("Bool", "BOOL"), | |||
"dt_float32": ("Float32", "FLOAT"), | |||
"dt_float16": ("Float16", "FLOAT"), | |||
"dt_bfloat16": ("BFloat16", "FLOAT"), | |||
} | |||
MODES = { | |||
(1, 'INT'): ['RELU', 'ABS', 'NEGATE'], | |||
(2, 'INT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | |||
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ', | |||
'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH'], | |||
(3, 'INT'): ['COND_LEQ_MOV', 'COND_LT_MOV'], | |||
(1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | |||
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | |||
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', | |||
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'], | |||
(2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | |||
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', | |||
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', | |||
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | |||
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | |||
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | |||
(3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||
(1, 'BOOL'): ['NOT'], | |||
(2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], | |||
(3, 'BOOL'): [] | |||
(1, "INT"): ["RELU", "ABS", "NEGATE"], | |||
(2, "INT"): [ | |||
"ABS_GRAD", | |||
"ADD", | |||
"FLOOR_DIV", | |||
"MAX", | |||
"MIN", | |||
"MOD", | |||
"MUL", | |||
"SIGMOID_GRAD", | |||
"SUB", | |||
"SWITCH_GT0", | |||
"TANH_GRAD", | |||
"LT", | |||
"LEQ", | |||
"EQ", | |||
"FUSE_ADD_RELU", | |||
"SHL", | |||
"SHR", | |||
"RMULH", | |||
], | |||
(3, "INT"): ["COND_LEQ_MOV", "COND_LT_MOV"], | |||
(1, "FLOAT"): [ | |||
"RELU", | |||
"ABS", | |||
"NEGATE", | |||
"ACOS", | |||
"ASIN", | |||
"CEIL", | |||
"COS", | |||
"EXP", | |||
"EXPM1", | |||
"FLOOR", | |||
"LOG", | |||
"LOG1P", | |||
"SIGMOID", | |||
"SIN", | |||
"TANH", | |||
"FAST_TANH", | |||
"ROUND", | |||
"ERF", | |||
"ERFINV", | |||
"ERFC", | |||
"ERFCINV", | |||
"H_SWISH", | |||
"SILU", | |||
"GELU", | |||
], | |||
(2, "FLOAT"): [ | |||
"ABS_GRAD", | |||
"ADD", | |||
"FLOOR_DIV", | |||
"MAX", | |||
"MIN", | |||
"MOD", | |||
"MUL", | |||
"SIGMOID_GRAD", | |||
"SUB", | |||
"SWITCH_GT0", | |||
"TANH_GRAD", | |||
"LT", | |||
"LEQ", | |||
"EQ", | |||
"FUSE_ADD_RELU", | |||
"TRUE_DIV", | |||
"POW", | |||
"LOG_SUM_EXP", | |||
"FUSE_ADD_TANH", | |||
"FAST_TANH_GRAD", | |||
"FUSE_ADD_SIGMOID", | |||
"ATAN2", | |||
"H_SWISH_GRAD", | |||
"FUSE_ADD_H_SWISH", | |||
"SILU_GRAD", | |||
"GELU_GRAD", | |||
], | |||
(3, "FLOAT"): ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"], | |||
(1, "BOOL"): ["NOT"], | |||
(2, "BOOL"): ["AND", "OR", "XOR", "LT", "LEQ", "EQ"], | |||
(3, "BOOL"): [], | |||
} |
@@ -3,13 +3,14 @@ | |||
import argparse | |||
import collections | |||
import textwrap | |||
import os | |||
import hashlib | |||
import struct | |||
import io | |||
import os | |||
import struct | |||
import textwrap | |||
from gen_param_defs import IndentWriterBase, ParamDef, member_defs | |||
from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||
class ConverterWriter(IndentWriterBase): | |||
_skip_current_param = False | |||
@@ -20,7 +21,7 @@ class ConverterWriter(IndentWriterBase): | |||
def __call__(self, fout, defs): | |||
super().__call__(fout) | |||
self._write("// %s", self._get_header()) | |||
self._write('#include <flatbuffers/flatbuffers.h>') | |||
self._write("#include <flatbuffers/flatbuffers.h>") | |||
self._write("namespace mgb {") | |||
self._write("namespace serialization {") | |||
self._write("namespace fbs {") | |||
@@ -33,8 +34,9 @@ class ConverterWriter(IndentWriterBase): | |||
self._last_param = p | |||
self._param_fields = [] | |||
self._fb_fields = ["builder"] | |||
self._write("template<>\nstruct ParamConverter<megdnn::param::%s> {", | |||
p.name, indent=1) | |||
self._write( | |||
"template<>\nstruct ParamConverter<megdnn::param::%s> {", p.name, indent=1 | |||
) | |||
self._write("using MegDNNType = megdnn::param::%s;", p.name) | |||
self._write("using FlatBufferType = fbs::param::%s;\n", p.name) | |||
@@ -42,22 +44,22 @@ class ConverterWriter(IndentWriterBase): | |||
if self._skip_current_param: | |||
self._skip_current_param = False | |||
return | |||
self._write("static MegDNNType to_param(const FlatBufferType* fb) {", | |||
indent=1) | |||
line = 'return {' | |||
line += ', '.join(self._param_fields) | |||
line += '};' | |||
self._write("static MegDNNType to_param(const FlatBufferType* fb) {", indent=1) | |||
line = "return {" | |||
line += ", ".join(self._param_fields) | |||
line += "};" | |||
self._write(line) | |||
self._write("}\n", indent=-1) | |||
self._write( | |||
"static flatbuffers::Offset<FlatBufferType> to_flatbuffer(flatbuffers::FlatBufferBuilder& builder, const MegDNNType& param) {", | |||
indent=1) | |||
line = 'return fbs::param::Create{}('.format(str(p.name)) | |||
line += ', '.join(self._fb_fields) | |||
line += ');' | |||
indent=1, | |||
) | |||
line = "return fbs::param::Create{}(".format(str(p.name)) | |||
line += ", ".join(self._fb_fields) | |||
line += ");" | |||
self._write(line) | |||
self._write('}', indent=-1) | |||
self._write("}", indent=-1) | |||
self._write("};\n", indent=-1) | |||
@@ -68,18 +70,23 @@ class ConverterWriter(IndentWriterBase): | |||
return | |||
self._param_fields.append( | |||
"static_cast<megdnn::param::{}::{}>(fb->{}())".format( | |||
str(p.name), str(e.name), e.name_field)) | |||
self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format( | |||
key, e.name_field)) | |||
str(p.name), str(e.name), e.name_field | |||
) | |||
) | |||
self._fb_fields.append( | |||
"static_cast<fbs::param::{}>(param.{})".format(key, e.name_field) | |||
) | |||
def _on_member_field(self, f): | |||
if self._skip_current_param: | |||
return | |||
if f.dtype.cname == 'DTypeEnum': | |||
if f.dtype.cname == "DTypeEnum": | |||
self._param_fields.append( | |||
"intl::convert_dtype_to_megdnn(fb->{}())".format(f.name)) | |||
"intl::convert_dtype_to_megdnn(fb->{}())".format(f.name) | |||
) | |||
self._fb_fields.append( | |||
"intl::convert_dtype_to_fbs(param.{})".format(f.name)) | |||
"intl::convert_dtype_to_fbs(param.{})".format(f.name) | |||
) | |||
else: | |||
self._param_fields.append("fb->{}()".format(f.name)) | |||
self._fb_fields.append("param.{}".format(f.name)) | |||
@@ -93,28 +100,33 @@ class ConverterWriter(IndentWriterBase): | |||
enum_name = e.src_class + e.src_name | |||
self._param_fields.append( | |||
"static_cast<megdnn::param::{}::{}>(fb->{}())".format( | |||
e.src_class, e.src_name, e.name_field)) | |||
self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format( | |||
enum_name, e.name_field)) | |||
e.src_class, e.src_name, e.name_field | |||
) | |||
) | |||
self._fb_fields.append( | |||
"static_cast<fbs::param::{}>(param.{})".format(enum_name, e.name_field) | |||
) | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
'generate convert functions between FlatBuffers type and MegBrain type') | |||
parser.add_argument('input') | |||
parser.add_argument('output') | |||
"generate convert functions between FlatBuffers type and MegBrain type" | |||
) | |||
parser.add_argument("input") | |||
parser.add_argument("output") | |||
args = parser.parse_args() | |||
with open(args.input) as fin: | |||
inputs = fin.read() | |||
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||
exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) | |||
input_hash = hashlib.sha256() | |||
input_hash.update(inputs.encode(encoding='UTF-8')) | |||
input_hash.update(inputs.encode(encoding="UTF-8")) | |||
input_hash = input_hash.hexdigest() | |||
writer = ConverterWriter() | |||
with open(args.output, 'w') as fout: | |||
with open(args.output, "w") as fout: | |||
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | |||
if __name__ == "__main__": | |||
main() |
@@ -3,13 +3,14 @@ | |||
import argparse | |||
import collections | |||
import textwrap | |||
import os | |||
import hashlib | |||
import struct | |||
import io | |||
import os | |||
import struct | |||
import textwrap | |||
from gen_param_defs import IndentWriterBase, ParamDef, member_defs | |||
from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||
def _cname_to_fbname(cname): | |||
return { | |||
@@ -22,17 +23,19 @@ def _cname_to_fbname(cname): | |||
"bool": "bool", | |||
}[cname] | |||
def scramble_enum_member_name(name): | |||
s = name.find('<<') | |||
s = name.find("<<") | |||
if s != -1: | |||
name = name[0:name.find('=') + 1] + ' ' + name[s+2:] | |||
name = name[0 : name.find("=") + 1] + " " + name[s + 2 :] | |||
if name in ("MIN", "MAX"): | |||
return name + "_" | |||
o_name = name.split(' ')[0].split('=')[0] | |||
o_name = name.split(" ")[0].split("=")[0] | |||
if o_name in ("MIN", "MAX"): | |||
return name.replace(o_name, o_name + "_") | |||
return name | |||
class FlatBuffersWriter(IndentWriterBase): | |||
_skip_current_param = False | |||
_last_param = None | |||
@@ -66,12 +69,13 @@ class FlatBuffersWriter(IndentWriterBase): | |||
self._write("}\n", indent=-1) | |||
def _write_doc(self, doc): | |||
if not isinstance(doc, member_defs.Doc) or not doc.doc: return | |||
if not isinstance(doc, member_defs.Doc) or not doc.doc: | |||
return | |||
doc_lines = [] | |||
if doc.no_reformat: | |||
doc_lines = doc.raw_lines | |||
else: | |||
doc = doc.doc.replace('\n', ' ') | |||
doc = doc.doc.replace("\n", " ") | |||
text_width = 80 - len(self._cur_indent) - 4 | |||
doc_lines = textwrap.wrap(doc, text_width) | |||
for line in doc_lines: | |||
@@ -101,7 +105,8 @@ class FlatBuffersWriter(IndentWriterBase): | |||
default = e.compose_combined_enum(e.default) | |||
else: | |||
default = scramble_enum_member_name( | |||
str(e.members[e.default]).split(' ')[0].split('=')[0]) | |||
str(e.members[e.default]).split(" ")[0].split("=")[0] | |||
) | |||
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default) | |||
def _resolve_const(self, v): | |||
@@ -113,8 +118,12 @@ class FlatBuffersWriter(IndentWriterBase): | |||
if self._skip_current_param: | |||
return | |||
self._write_doc(f.name) | |||
self._write("%s:%s = %s;", f.name, _cname_to_fbname(f.dtype.cname), | |||
self._get_fb_default(self._resolve_const(f.default))) | |||
self._write( | |||
"%s:%s = %s;", | |||
f.name, | |||
_cname_to_fbname(f.dtype.cname), | |||
self._get_fb_default(self._resolve_const(f.default)), | |||
) | |||
def _on_const_field(self, f): | |||
self._cur_const_val[str(f.name)] = str(f.default) | |||
@@ -129,7 +138,8 @@ class FlatBuffersWriter(IndentWriterBase): | |||
default = s.compose_combined_enum(e.get_default()) | |||
else: | |||
default = scramble_enum_member_name( | |||
str(s.members[e.get_default()]).split(' ')[0].split('=')[0]) | |||
str(s.members[e.get_default()]).split(" ")[0].split("=")[0] | |||
) | |||
self._write("%s:%s = %s;", e.name_field, enum_name, default) | |||
def _get_fb_default(self, cppdefault): | |||
@@ -137,9 +147,9 @@ class FlatBuffersWriter(IndentWriterBase): | |||
return cppdefault | |||
d = cppdefault | |||
if d.endswith('f'): # 1.f | |||
if d.endswith("f"): # 1.f | |||
return d[:-1] | |||
if d.endswith('ull'): | |||
if d.endswith("ull"): | |||
return d[:-3] | |||
if d.startswith("DTypeEnum::"): | |||
return d[11:] | |||
@@ -148,21 +158,23 @@ class FlatBuffersWriter(IndentWriterBase): | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
'generate FlatBuffers schema of operator param from description file') | |||
parser.add_argument('input') | |||
parser.add_argument('output') | |||
"generate FlatBuffers schema of operator param from description file" | |||
) | |||
parser.add_argument("input") | |||
parser.add_argument("output") | |||
args = parser.parse_args() | |||
with open(args.input) as fin: | |||
inputs = fin.read() | |||
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||
exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) | |||
input_hash = hashlib.sha256() | |||
input_hash.update(inputs.encode(encoding='UTF-8')) | |||
input_hash.update(inputs.encode(encoding="UTF-8")) | |||
input_hash = input_hash.hexdigest() | |||
writer = FlatBuffersWriter() | |||
with open(args.output, 'w') as fout: | |||
with open(args.output, "w") as fout: | |||
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | |||
if __name__ == "__main__": | |||
main() |
@@ -1,14 +1,16 @@ | |||
#! /usr/local/env python3 | |||
import pickle | |||
import numpy as np | |||
import os | |||
import argparse | |||
import re | |||
import collections | |||
import os | |||
import pickle | |||
import re | |||
import numpy as np | |||
def define_template(**kwargs): | |||
template = ''' | |||
template = """ | |||
float cuda{cuda_arch}_{conv_type}_time_pred[{out_dim}] = {{0.0f}}; | |||
float cuda{cuda_arch}_{conv_type}_mask[{out_dim}] = {{0.0f}}; | |||
float cuda{cuda_arch}_{conv_type}_hidden_units[{hidden_num}] = {{0.0f}}; | |||
@@ -17,21 +19,23 @@ def define_template(**kwargs): | |||
const static float cuda{cuda_arch}_{conv_type}_biases[{biases_dim}] = {{{biases}}}; | |||
const static float cuda{cuda_arch}_{conv_type}_alpha[{out_dim}] = {{{alpha}}}; | |||
const static float cuda{cuda_arch}_{conv_type}_beta[{out_dim}] = {{{beta}}}; | |||
''' | |||
""" | |||
return template.format(**kwargs) | |||
def cudnn_slt_template(**kwargs): | |||
template = ("#if CUDNN_MAJOR == {cudnn_major} && CUDNN_MINOR == {cudnn_minor}\n" + | |||
" {define_cmd}\n" + | |||
" {select_cmd}\n" + | |||
" return true;\n" + | |||
"#endif\n" | |||
) | |||
template = ( | |||
"#if CUDNN_MAJOR == {cudnn_major} && CUDNN_MINOR == {cudnn_minor}\n" | |||
+ " {define_cmd}\n" | |||
+ " {select_cmd}\n" | |||
+ " return true;\n" | |||
+ "#endif\n" | |||
) | |||
return template.format(**kwargs) | |||
def select_template(**kwargs): | |||
template = \ | |||
'''if (conv_type == ConvolutionType::{conv_type} && cuda_major == {cuda_major} && | |||
template = """if (conv_type == ConvolutionType::{conv_type} && cuda_major == {cuda_major} && | |||
cuda_minor == {cuda_minor}) {{ | |||
*layer_num_p = {layer_num}; | |||
*hidden_units_p = cuda{cuda_arch}_{conv_type}_hidden_units; | |||
@@ -42,7 +46,7 @@ def select_template(**kwargs): | |||
*beta_p = cuda{cuda_arch}_{conv_type}_beta; | |||
*time_pred_p = cuda{cuda_arch}_{conv_type}_time_pred; | |||
*mask_p = cuda{cuda_arch}_{conv_type}_mask; | |||
}} else ''' | |||
}} else """ | |||
return template.format(**kwargs) | |||
@@ -58,48 +62,48 @@ def fill_src(): | |||
if len(matrix_files) == 0: | |||
print("Warning: no param files detected.") | |||
for fpath in matrix_files: | |||
cudnn_version = re.findall('cudnn([\d.]+)',fpath)[0] | |||
cudnn_version = re.findall("cudnn([\d.]+)", fpath)[0] | |||
gen_list[cudnn_version].append(fpath) | |||
for cudnn in gen_list: | |||
select_cmd = ("{\n" + | |||
" " * 8 + "return false;\n" + | |||
" " * 4 + "}") | |||
select_cmd = "{\n" + " " * 8 + "return false;\n" + " " * 4 + "}" | |||
define_cmd = "" | |||
cudnn_major, cudnn_minor = cudnn.split('.') | |||
cudnn_major, cudnn_minor = cudnn.split(".") | |||
for fpath in gen_list[cudnn]: | |||
cuda_arch = fpath.split("-")[1].replace(".", "_") | |||
print('cudnn_version: {}, cuda_arch: {}'.format(cudnn,cuda_arch)) | |||
print("cudnn_version: {}, cuda_arch: {}".format(cudnn, cuda_arch)) | |||
conv_type = fpath.split("-")[2].split(".")[0] | |||
with open(os.path.join(home, "params/{}".format(fpath)), "rb") as pobj: | |||
params = pickle.load(pobj) | |||
crt_define_cmd, crt_select_cmd = gen_cmds( | |||
cuda_arch, conv_type, params) | |||
crt_define_cmd, crt_select_cmd = gen_cmds(cuda_arch, conv_type, params) | |||
select_cmd = crt_select_cmd + select_cmd | |||
define_cmd = crt_define_cmd + define_cmd | |||
cudnn_slt_cmd += cudnn_slt_template(cudnn_major=cudnn_major, | |||
cudnn_minor=cudnn_minor, | |||
select_cmd=select_cmd, | |||
define_cmd=define_cmd) | |||
cudnn_slt_cmd += cudnn_slt_template( | |||
cudnn_major=cudnn_major, | |||
cudnn_minor=cudnn_minor, | |||
select_cmd=select_cmd, | |||
define_cmd=define_cmd, | |||
) | |||
#select_cmd = select_cmd | |||
# select_cmd = select_cmd | |||
with open(os.path.join(home, "get_params.template"), "r") as srcf: | |||
src = srcf.read() | |||
dst = src.replace("{cudnn_select}", cudnn_slt_cmd) | |||
MegDNN_path = os.path.join(home, "../..") | |||
with open(os.path.join(MegDNN_path, | |||
"src/cuda/convolution/get_params.cpp"), "w") as dstf: | |||
with open( | |||
os.path.join(MegDNN_path, "src/cuda/convolution/get_params.cpp"), "w" | |||
) as dstf: | |||
dstf.write(dst) | |||
def gen_cmds(cuda_arch, conv_type, params): | |||
cuda_major, cuda_minor = cuda_arch.split("_") | |||
alphastr = format_array(params['alpha']).rstrip()[:-1] | |||
betastr = format_array(params['beta']).rstrip()[:-1] | |||
W_list = params['W'] | |||
b_list = params['b'] | |||
Wstr = '' | |||
bstr = '' | |||
alphastr = format_array(params["alpha"]).rstrip()[:-1] | |||
betastr = format_array(params["beta"]).rstrip()[:-1] | |||
W_list = params["W"] | |||
b_list = params["b"] | |||
Wstr = "" | |||
bstr = "" | |||
layer_num = str(len(b_list) + 1) | |||
layers_dim = [W_list[0].shape[1]] | |||
matrices_dim = 0 | |||
@@ -118,16 +122,27 @@ def gen_cmds(cuda_arch, conv_type, params): | |||
out_dim = layers_dim[-1] | |||
layers_dim_str = format_array(np.array(layers_dim)).rstrip()[:-1] | |||
select_cmd = select_template(conv_type=conv_type.upper(), cuda_major=cuda_major, | |||
cuda_minor=cuda_minor, layer_num=layer_num, | |||
cuda_arch=cuda_arch) | |||
define_cmd = define_template(cuda_arch=cuda_arch, conv_type=conv_type.upper(), | |||
hidden_num=hidden_num, | |||
layer_num=layer_num, out_dim=out_dim, | |||
layers_dim=layers_dim_str, | |||
matrices_dim=matrices_dim, matrices=Wstr, | |||
biases_dim=biases_dim, biases=bstr, | |||
alpha=alphastr, beta=betastr) | |||
select_cmd = select_template( | |||
conv_type=conv_type.upper(), | |||
cuda_major=cuda_major, | |||
cuda_minor=cuda_minor, | |||
layer_num=layer_num, | |||
cuda_arch=cuda_arch, | |||
) | |||
define_cmd = define_template( | |||
cuda_arch=cuda_arch, | |||
conv_type=conv_type.upper(), | |||
hidden_num=hidden_num, | |||
layer_num=layer_num, | |||
out_dim=out_dim, | |||
layers_dim=layers_dim_str, | |||
matrices_dim=matrices_dim, | |||
matrices=Wstr, | |||
biases_dim=biases_dim, | |||
biases=bstr, | |||
alpha=alphastr, | |||
beta=betastr, | |||
) | |||
return (define_cmd, select_cmd) | |||
@@ -153,8 +168,9 @@ def format_array(array): | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser( | |||
description="Generate cuDNN heuristic code by neural network into" | |||
" {MEGDNN_ROOT}/src/cuda/convolution/get_params.cpp," | |||
" using parameter value from pickle files in" | |||
" {MEGDNN_ROOT}/scripts/gen_heuristic/params/") | |||
" {MEGDNN_ROOT}/src/cuda/convolution/get_params.cpp," | |||
" using parameter value from pickle files in" | |||
" {MEGDNN_ROOT}/scripts/gen_heuristic/params/" | |||
) | |||
args = parser.parse_args() | |||
main() |
@@ -3,19 +3,17 @@ | |||
import argparse | |||
import collections | |||
import textwrap | |||
import os | |||
import hashlib | |||
import struct | |||
import io | |||
import os | |||
import struct | |||
import textwrap | |||
from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||
from gen_param_defs import IndentWriterBase, ParamDef, member_defs | |||
# FIXME: move supportToString flag definition into the param def source file | |||
ENUM_TO_STRING_SPECIAL_RULES = [ | |||
("Elemwise", "Mode"), | |||
("ElemwiseMultiType", "Mode") | |||
] | |||
ENUM_TO_STRING_SPECIAL_RULES = [("Elemwise", "Mode"), ("ElemwiseMultiType", "Mode")] | |||
class ConverterWriter(IndentWriterBase): | |||
_skip_current_param = False | |||
@@ -33,21 +31,21 @@ class ConverterWriter(IndentWriterBase): | |||
self._write("#endif // MGB_PARAM") | |||
def _ctype2attr(self, ctype, value): | |||
if ctype == 'uint32_t': | |||
return 'MgbUI32Attr', value | |||
if ctype == 'uint64_t': | |||
return 'MgbUI64Attr', value | |||
if ctype == 'int32_t': | |||
return 'MgbI32Attr', value | |||
if ctype == 'float': | |||
return 'MgbF32Attr', value | |||
if ctype == 'double': | |||
return 'MgbF64Attr', value | |||
if ctype == 'bool': | |||
return 'MgbBoolAttr', value | |||
if ctype == 'DTypeEnum': | |||
if ctype == "uint32_t": | |||
return "MgbUI32Attr", value | |||
if ctype == "uint64_t": | |||
return "MgbUI64Attr", value | |||
if ctype == "int32_t": | |||
return "MgbI32Attr", value | |||
if ctype == "float": | |||
return "MgbF32Attr", value | |||
if ctype == "double": | |||
return "MgbF64Attr", value | |||
if ctype == "bool": | |||
return "MgbBoolAttr", value | |||
if ctype == "DTypeEnum": | |||
self._packed = False | |||
return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value) | |||
return "MgbDTypeAttr", "megdnn::DType::from_enum(megdnn::{})".format(value) | |||
raise RuntimeError("unknown ctype") | |||
def _on_param_begin(self, p): | |||
@@ -61,21 +59,26 @@ class ConverterWriter(IndentWriterBase): | |||
self._skip_current_param = False | |||
return | |||
if self._packed: | |||
self._write("class {0}ParamBase<string accessor> : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1) | |||
self._write( | |||
'class {0}ParamBase<string accessor> : MgbPackedParamBase<"{0}", accessor> {{'.format( | |||
p.name | |||
), | |||
indent=1, | |||
) | |||
else: | |||
self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1) | |||
self._write('def {0}Param: MgbParamBase<"{0}"> {{'.format(p.name), indent=1) | |||
self._write("let fields = (ins", indent=1) | |||
self._write(",\n{}".format(self._cur_indent).join(self._current_tparams)) | |||
self._write(");", indent=-1) | |||
self._write("}\n", indent=-1) | |||
if self._packed: | |||
self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name)) | |||
self._write('def {0}Param : {0}ParamBase<"param">;\n'.format(p.name)) | |||
self._current_tparams = None | |||
self._packed = None | |||
self._const = None | |||
def _wrapped_with_default_value(self, attr, default): | |||
return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default) | |||
return 'MgbDefaultValuedAttr<{}, "{}">'.format(attr, default) | |||
def _on_member_enum(self, e): | |||
p = self._last_param | |||
@@ -84,10 +87,12 @@ class ConverterWriter(IndentWriterBase): | |||
# directly used by any operator, or other enum couldn't alias to this enum | |||
td_class = "{}{}".format(p.name, e.name) | |||
fullname = "::megdnn::param::{}".format(p.name) | |||
enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name) | |||
enum_def = 'MgbEnumAttr<"{}", "{}", ['.format(fullname, e.name) | |||
def format(v): | |||
return '\"{}\"'.format(str(v).split(' ')[0].split('=')[0]) | |||
enum_def += ','.join(format(i) for i in e.members) | |||
return '"{}"'.format(str(v).split(" ")[0].split("=")[0]) | |||
enum_def += ",".join(format(i) for i in e.members) | |||
if e.combined: | |||
enum_def += "], 1" | |||
@@ -95,7 +100,7 @@ class ConverterWriter(IndentWriterBase): | |||
enum_def += "], 0" | |||
if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): | |||
enum_def += ", 1" # whether generate ToStringTrait | |||
enum_def += ", 1" # whether generate ToStringTrait | |||
enum_def += ">" | |||
self._write("def {} : {};".format(td_class, enum_def)) | |||
@@ -105,10 +110,12 @@ class ConverterWriter(IndentWriterBase): | |||
# wrapped with default value | |||
if e.combined: | |||
default_val = "static_cast<{}::{}>({})".format( | |||
fullname, e.name, e.compose_combined_enum(e.default)) | |||
fullname, e.name, e.compose_combined_enum(e.default) | |||
) | |||
else: | |||
default_val = "{}::{}::{}".format( | |||
fullname, e.name, str(e.members[e.default]).split(' ')[0].split('=')[0]) | |||
fullname, e.name, str(e.members[e.default]).split(" ")[0].split("=")[0] | |||
) | |||
wrapped = self._wrapped_with_default_value(td_class, default_val) | |||
@@ -123,51 +130,58 @@ class ConverterWriter(IndentWriterBase): | |||
td_class = "{}{}".format(p.name, e.name) | |||
fullname = "::megdnn::param::{}".format(p.name) | |||
base_td_class = "{}{}".format(e.src_class, e.src_name) | |||
enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class) | |||
enum_def = 'MgbEnumAliasAttr<"{}", "{}", {}>'.format( | |||
fullname, e.name, base_td_class | |||
) | |||
self._write("def {} : {};".format(td_class, enum_def)) | |||
# wrapped with default value | |||
s = e.src_enum | |||
if s.combined: | |||
default_val = "static_cast<{}::{}>({})".format( | |||
fullname, e.name, s.compose_combined_enum(e.get_default())) | |||
fullname, e.name, s.compose_combined_enum(e.get_default()) | |||
) | |||
else: | |||
default_val = "{}::{}::{}".format(fullname, e.name, str( | |||
s.members[e.get_default()]).split(' ')[0].split('=')[0]) | |||
default_val = "{}::{}::{}".format( | |||
fullname, | |||
e.name, | |||
str(s.members[e.get_default()]).split(" ")[0].split("=")[0], | |||
) | |||
wrapped = self._wrapped_with_default_value(td_class, default_val) | |||
self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) | |||
def _on_member_field(self, f): | |||
if self._skip_current_param: | |||
return | |||
attr, value = self._ctype2attr(f.dtype.cname, str(f.default)) | |||
if str(value) in self._const: | |||
value = '::megdnn::param::{}::{}'.format(self._last_param.name, value) | |||
value = "::megdnn::param::{}::{}".format(self._last_param.name, value) | |||
wrapped = self._wrapped_with_default_value(attr, value) | |||
self._current_tparams.append("{}:${}".format(wrapped, f.name)) | |||
def _on_const_field(self, f): | |||
self._const.add(str(f.name)) | |||
def main(): | |||
parser = argparse.ArgumentParser('generate op param tablegen file') | |||
parser.add_argument('input') | |||
parser.add_argument('output') | |||
parser = argparse.ArgumentParser("generate op param tablegen file") | |||
parser.add_argument("input") | |||
parser.add_argument("output") | |||
args = parser.parse_args() | |||
with open(args.input) as fin: | |||
inputs = fin.read() | |||
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||
exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) | |||
input_hash = hashlib.sha256() | |||
input_hash.update(inputs.encode(encoding='UTF-8')) | |||
input_hash.update(inputs.encode(encoding="UTF-8")) | |||
input_hash = input_hash.hexdigest() | |||
writer = ConverterWriter() | |||
with open(args.output, 'w') as fout: | |||
with open(args.output, "w") as fout: | |||
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | |||
if __name__ == "__main__": | |||
main() |
@@ -19,6 +19,7 @@ device = { | |||
"thread_number": 3, | |||
} | |||
class SshConnector: | |||
"""imp ssh control master connector""" | |||
@@ -83,17 +84,17 @@ def main(): | |||
model_file = args.model_file | |||
# copy model file | |||
ssh.copy([args.model_file], workspace) | |||
m = model_file.split('\\')[-1] | |||
m = model_file.split("\\")[-1] | |||
# run single thread | |||
result = [] | |||
thread_number = [1, 2, 4] | |||
for b in thread_number : | |||
for b in thread_number: | |||
cmd = [] | |||
cmd1 = "cd {} && ./load_and_run {} -multithread {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format( | |||
workspace, m, b | |||
workspace, m, b | |||
) | |||
cmd2 = "cd {} && ./load_and_run {} -multithread {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess ".format( | |||
workspace, m, b | |||
workspace, m, b | |||
) | |||
cmd.append(cmd1) | |||
cmd.append(cmd2) | |||
@@ -103,12 +104,20 @@ def main(): | |||
logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret)) | |||
result.append(ret) | |||
thread_2 = result[0]/result[1] | |||
thread_4 = result[0]/result[2] | |||
thread_2 = result[0] / result[1] | |||
thread_4 = result[0] / result[2] | |||
if thread_2 > 1.6 or thread_4 > 3.0: | |||
print("model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4)) | |||
print( | |||
"model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format( | |||
m, thread_2, thread_4 | |||
) | |||
) | |||
else: | |||
print("model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4)) | |||
print( | |||
"model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format( | |||
m, thread_2, thread_4 | |||
) | |||
) | |||
if __name__ == "__main__": | |||
@@ -20,8 +20,12 @@ failed_files = Manager().list() | |||
def process_file(file, clang_format, write): | |||
original_source = open(file, "r").read() | |||
source = original_source | |||
source = re.sub(r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source) | |||
source, count = re.subn(r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source) | |||
source = re.sub( | |||
r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source | |||
) | |||
source, count = re.subn( | |||
r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source | |||
) | |||
result = subprocess.check_output( | |||
[ | |||
@@ -36,7 +40,9 @@ def process_file(file, clang_format, write): | |||
result = result.decode("utf-8") | |||
if count: | |||
result = re.sub(r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result) | |||
result = re.sub( | |||
r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result | |||
) | |||
result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result) | |||
if write and original_source != result: | |||
@@ -109,19 +115,17 @@ def main(): | |||
raise ValueError("Invalid path {}".format(path)) | |||
# check version, we only support 12.0.1 now | |||
version = subprocess.check_output( | |||
[ | |||
args.clang_format, | |||
"--version", | |||
], | |||
) | |||
version = subprocess.check_output([args.clang_format, "--version",],) | |||
version = version.decode("utf-8") | |||
need_version = '12.0.1' | |||
need_version = "12.0.1" | |||
if version.find(need_version) < 0: | |||
print('We only support {} now, please install {} version, find version: {}' | |||
.format(need_version, need_version, version)) | |||
raise RuntimeError('clang-format version not equal {}'.format(need_version)) | |||
print( | |||
"We only support {} now, please install {} version, find version: {}".format( | |||
need_version, need_version, version | |||
) | |||
) | |||
raise RuntimeError("clang-format version not equal {}".format(need_version)) | |||
process_map( | |||
partial(process_file, clang_format=args.clang_format, write=args.write,), | |||
@@ -20,6 +20,7 @@ device = { | |||
"thread_number": 3, | |||
} | |||
class SshConnector: | |||
"""imp ssh control master connector""" | |||
@@ -54,6 +55,7 @@ class SshConnector: | |||
except: | |||
raise | |||
def main(): | |||
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) | |||
parser.add_argument("--model_file", help="megengine model", required=True) | |||
@@ -78,10 +80,10 @@ def main(): | |||
model_file = args.model_file | |||
# copy model file | |||
ssh.copy([model_file], workspace) | |||
m = model_file.split('\\')[-1] | |||
m = model_file.split("\\")[-1] | |||
# run single thread | |||
cmd = "cd {} && ./load_and_run {} --fast-run --record-comp-seq --iter 1 --warmup-iter 1".format( | |||
workspace, m | |||
workspace, m | |||
) | |||
try: | |||
raw_log = ssh.cmd([cmd]) | |||
@@ -91,6 +93,7 @@ def main(): | |||
print("model: {} is static model.".format(m)) | |||
if __name__ == "__main__": | |||
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" | |||
DATE_FORMAT = "%Y/%m/%d %H:%M:%S" | |||