diff --git a/dnn/scripts/cutlass_generator/conv2d_operation.py b/dnn/scripts/cutlass_generator/conv2d_operation.py index 7f16857a..129baa62 100644 --- a/dnn/scripts/cutlass_generator/conv2d_operation.py +++ b/dnn/scripts/cutlass_generator/conv2d_operation.py @@ -8,7 +8,7 @@ import enum import os.path import shutil -from typing import Tuple, List +from typing import List, Tuple from library import * diff --git a/dnn/scripts/cutlass_generator/gemm_operation.py b/dnn/scripts/cutlass_generator/gemm_operation.py index b4587b4e..5f4490f9 100644 --- a/dnn/scripts/cutlass_generator/gemm_operation.py +++ b/dnn/scripts/cutlass_generator/gemm_operation.py @@ -5,14 +5,13 @@ # import enum -import os.path -import shutil import functools import operator +import os.path +import shutil from library import * - ################################################################################################### # # Data structure modeling a GEMM operation diff --git a/dnn/scripts/cutlass_generator/gen_list.py b/dnn/scripts/cutlass_generator/gen_list.py index 3681a1e2..f652a4a0 100644 --- a/dnn/scripts/cutlass_generator/gen_list.py +++ b/dnn/scripts/cutlass_generator/gen_list.py @@ -1,11 +1,11 @@ -from generator import ( - GenerateGemmOperations, - GenerateGemvOperations, +from generator import ( # isort: skip; isort: skip GenerateConv2dOperations, GenerateDeconvOperations, - GenerateDwconv2dFpropOperations, GenerateDwconv2dDgradOperations, + GenerateDwconv2dFpropOperations, GenerateDwconv2dWgradOperations, + GenerateGemmOperations, + GenerateGemvOperations, ) @@ -35,12 +35,14 @@ def write_op_list(f, gen_op, gen_type): if gen_op != "gemv": f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type)) + # Write down a list of merged filenames def write_merge_file_name(f, gen_op, gen_type, split_number): for i in range(0, split_number): - f.write(' "{}_{}_{}.cu",\n'.format(gen_op,gen_type,i)) + f.write(' "{}_{}_{}.cu",\n'.format(gen_op, gen_type, i)) if gen_op != "gemv": - f.write(' "all_{}_{}_operations.cu",\n'.format(gen_op,gen_type)) + f.write(' "all_{}_{}_operations.cu",\n'.format(gen_op, gen_type)) + if __name__ == "__main__": with open("list.bzl", "w") as f: diff --git a/dnn/scripts/cutlass_generator/generator.py b/dnn/scripts/cutlass_generator/generator.py index 1d1116e9..931108e2 100644 --- a/dnn/scripts/cutlass_generator/generator.py +++ b/dnn/scripts/cutlass_generator/generator.py @@ -4,12 +4,12 @@ # \brief Generates the CUTLASS Library's instances # +import argparse import enum import os.path -import shutil -import argparse import platform import string + from library import * from manifest import * @@ -899,9 +899,12 @@ def GenerateGemm_Simt(args): warpShapes.append([warp0, warp1]) # sgemm - precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[ - "s" - ] + ( + precisionType, + precisionBits, + threadblockMaxElements, + threadblockTilesL0, + ) = precisions["s"] layouts = [ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn @@ -1091,9 +1094,12 @@ def GenerateDwconv2d_Simt(args, conv_kind): warpShapes.append([warp0, warp1]) # sgemm - precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[ - "s" - ] + ( + precisionType, + precisionBits, + threadblockMaxElements, + threadblockTilesL0, + ) = precisions["s"] layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)] @@ -1304,7 +1310,7 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind): for dst_type, dst_layout in zip(dst_types, dst_layouts): for alignment_src in alignment_constraints: if conv_kind == ConvKind.Wgrad: - # skip io16xc16 + # skip io16xc16 if math_inst.element_accumulator == DataType.f16: continue for alignment_diff in alignment_constraints: @@ -1319,7 +1325,7 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind): min_cc, alignment_src, alignment_diff, - 32, # always f32 output + 32, # always f32 output SpecialOptimizeDesc.NoneSpecialOpt, ImplicitGemmMode.GemmNT, False, @@ -1656,6 +1662,7 @@ def GenerateGemvOperations(args): ) return GenerateGemv_Simt(args) + ################################################################################ # parameters # split_number - the concated file will be divided into split_number parts @@ -1668,10 +1675,21 @@ def GenerateGemvOperations(args): # epilogue - the epilogue in the file # wrapper_path - wrapper path ################################################################################ -def ConcatFile(split_number:int, file_path:str,operations:str,type:str,head:str,required_cuda_ver_major:str, required_cuda_ver_minor:str, epilogue:str, wrapper_path = None): +def ConcatFile( + split_number: int, + file_path: str, + operations: str, + type: str, + head: str, + required_cuda_ver_major: str, + required_cuda_ver_minor: str, + epilogue: str, + wrapper_path=None, +): import os + meragefiledir = file_path - filenames=os.listdir(meragefiledir) + filenames = os.listdir(meragefiledir) # filter file if "tensorop" in type: sub_string_1 = "tensorop" @@ -1679,197 +1697,183 @@ def ConcatFile(split_number:int, file_path:str,operations:str,type:str,head:str, else: sub_string_1 = sub_string_2 = "simt" if "dwconv2d_" in operations: - filtered_operations = operations[:2]+operations[9:] + filtered_operations = operations[:2] + operations[9:] elif ("conv2d" in operations) or ("deconv" in operations): filtered_operations = "cutlass" else: filtered_operations = operations - #get the file list number + # get the file list number file_list = {} file_list[operations + type] = 0 for filename in filenames: - if (filtered_operations in filename) and (sub_string_1 in filename) and (sub_string_2 in filename) and ("all_" not in filename): + if ( + (filtered_operations in filename) + and (sub_string_1 in filename) + and (sub_string_2 in filename) + and ("all_" not in filename) + ): file_list[operations + type] += 1 - #concat file for linux + # concat file for linux flag_1 = 0 flag_2 = 0 for filename in filenames: - if (filtered_operations in filename) and (sub_string_1 in filename) and (sub_string_2 in filename) and ("all_" not in filename): + if ( + (filtered_operations in filename) + and (sub_string_1 in filename) + and (sub_string_2 in filename) + and ("all_" not in filename) + ): flag_1 += 1 - filepath=meragefiledir+'/'+filename - if (flag_1 >= flag_2 * (file_list[operations + type]/split_number)) and (flag_1 <= (flag_2 + 1) * (file_list[operations + type]/split_number)): - file =open(file_path + '/{}_{}_{}.cu'.format(operations,type, flag_2),'a') - #write Template at the head + filepath = meragefiledir + "/" + filename + if (flag_1 >= flag_2 * (file_list[operations + type] / split_number)) and ( + flag_1 <= (flag_2 + 1) * (file_list[operations + type] / split_number) + ): + file = open( + file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a" + ) + # write Template at the head if wrapper_path is None: file.write( SubstituteTemplate( head, { - "required_cuda_ver_major": str( - required_cuda_ver_major - ), - "required_cuda_ver_minor": str( - required_cuda_ver_minor - ), + "required_cuda_ver_major": str(required_cuda_ver_major), + "required_cuda_ver_minor": str(required_cuda_ver_minor), }, ) ) else: file.write( - SubstituteTemplate( - head, - { - "wrapper_path": wrapper_path, - "required_cuda_ver_major": str( - required_cuda_ver_major - ), - "required_cuda_ver_minor": str( - required_cuda_ver_minor - ), - }, - ) + SubstituteTemplate( + head, + { + "wrapper_path": wrapper_path, + "required_cuda_ver_major": str(required_cuda_ver_major), + "required_cuda_ver_minor": str(required_cuda_ver_minor), + }, ) + ) # concat all the remaining files if flag_2 == (split_number - 1): for line in open(filepath): file.writelines(line) os.remove(filepath) - file.write('\n') + file.write("\n") file.write(epilogue) continue for line in open(filepath): file.writelines(line) os.remove(filepath) - file.write('\n') + file.write("\n") file.write(epilogue) else: - #write Template at the head + # write Template at the head if wrapper_path is None: file.write( SubstituteTemplate( head, { - "required_cuda_ver_major": str( - required_cuda_ver_major - ), - "required_cuda_ver_minor": str( - required_cuda_ver_minor - ), + "required_cuda_ver_major": str(required_cuda_ver_major), + "required_cuda_ver_minor": str(required_cuda_ver_minor), }, ) ) else: file.write( - SubstituteTemplate( - head, - { - "wrapper_path": wrapper_path, - "required_cuda_ver_major": str( - required_cuda_ver_major - ), - "required_cuda_ver_minor": str( - required_cuda_ver_minor - ), - }, - ) + SubstituteTemplate( + head, + { + "wrapper_path": wrapper_path, + "required_cuda_ver_major": str(required_cuda_ver_major), + "required_cuda_ver_minor": str(required_cuda_ver_minor), + }, ) + ) for line in open(filepath): file.writelines(line) os.remove(filepath) - file.write('\n') + file.write("\n") file.write(epilogue) file.close() flag_2 += 1 - - #concat file for windows + # concat file for windows elif filename[0].isdigit() and ("all_" not in filename): flag_1 += 1 - filepath=meragefiledir+'/'+filename - if (flag_1 >= flag_2 * (len(filenames)/split_number)) and (flag_1 <= (flag_2 + 1) * (len(filenames)/split_number)): - file =open(file_path + '/{}_{}_{}.cu'.format(operations,type, flag_2),'a') - #write Template at the head + filepath = meragefiledir + "/" + filename + if (flag_1 >= flag_2 * (len(filenames) / split_number)) and ( + flag_1 <= (flag_2 + 1) * (len(filenames) / split_number) + ): + file = open( + file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a" + ) + # write Template at the head if wrapper_path is None: file.write( SubstituteTemplate( head, { - "required_cuda_ver_major": str( - required_cuda_ver_major - ), - "required_cuda_ver_minor": str( - required_cuda_ver_minor - ), + "required_cuda_ver_major": str(required_cuda_ver_major), + "required_cuda_ver_minor": str(required_cuda_ver_minor), }, ) ) else: file.write( - SubstituteTemplate( - head, - { - "wrapper_path": wrapper_path, - "required_cuda_ver_major": str( - required_cuda_ver_major - ), - "required_cuda_ver_minor": str( - required_cuda_ver_minor - ), - }, - ) + SubstituteTemplate( + head, + { + "wrapper_path": wrapper_path, + "required_cuda_ver_major": str(required_cuda_ver_major), + "required_cuda_ver_minor": str(required_cuda_ver_minor), + }, ) + ) # concat all the remaining files if flag_2 == (split_number - 1): for line in open(filepath): file.writelines(line) os.remove(filepath) - file.write('\n') + file.write("\n") file.write(epilogue) continue for line in open(filepath): file.writelines(line) os.remove(filepath) - file.write('\n') + file.write("\n") file.write(epilogue) else: - #write Template at the head + # write Template at the head if wrapper_path is None: file.write( SubstituteTemplate( head, { - "required_cuda_ver_major": str( - required_cuda_ver_major - ), - "required_cuda_ver_minor": str( - required_cuda_ver_minor - ), + "required_cuda_ver_major": str(required_cuda_ver_major), + "required_cuda_ver_minor": str(required_cuda_ver_minor), }, ) ) else: file.write( - SubstituteTemplate( - head, - { - "wrapper_path": wrapper_path, - "required_cuda_ver_major": str( - required_cuda_ver_major - ), - "required_cuda_ver_minor": str( - required_cuda_ver_minor - ), - }, - ) + SubstituteTemplate( + head, + { + "wrapper_path": wrapper_path, + "required_cuda_ver_major": str(required_cuda_ver_major), + "required_cuda_ver_minor": str(required_cuda_ver_minor), + }, ) + ) for line in open(filepath): file.writelines(line) os.remove(filepath) - file.write('\n') + file.write("\n") file.write(epilogue) file.close() flag_2 += 1 + ################################################################################################### ################################################################################################### @@ -1940,39 +1944,97 @@ if __name__ == "__main__": args.output, operation, short_path ) as emitter: emitter.emit() - head = EmitConvSingleKernelWrapper(args.output, operations[0], short_path).header_template + head = EmitConvSingleKernelWrapper( + args.output, operations[0], short_path + ).header_template required_cuda_ver_major = operations[0].required_cuda_ver_major required_cuda_ver_minor = operations[0].required_cuda_ver_minor - epilogue = EmitConvSingleKernelWrapper(args.output, operations[0], short_path).epilogue_template + epilogue = EmitConvSingleKernelWrapper( + args.output, operations[0], short_path + ).epilogue_template if "tensorop" in args.type: - ConcatFile(4, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) + ConcatFile( + 4, + args.output, + args.operations, + args.type, + head, + required_cuda_ver_major, + required_cuda_ver_minor, + epilogue, + ) else: - ConcatFile(2, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) + ConcatFile( + 2, + args.output, + args.operations, + args.type, + head, + required_cuda_ver_major, + required_cuda_ver_minor, + epilogue, + ) elif args.operations == "gemm": for operation in operations: with EmitGemmSingleKernelWrapper( args.output, operation, short_path ) as emitter: emitter.emit() - head = EmitGemmSingleKernelWrapper(args.output, operations[0], short_path).header_template + head = EmitGemmSingleKernelWrapper( + args.output, operations[0], short_path + ).header_template required_cuda_ver_major = operations[0].required_cuda_ver_major required_cuda_ver_minor = operations[0].required_cuda_ver_minor - epilogue = EmitGemmSingleKernelWrapper(args.output, operations[0], short_path).epilogue_template + epilogue = EmitGemmSingleKernelWrapper( + args.output, operations[0], short_path + ).epilogue_template if args.type == "tensorop884": - ConcatFile(30, args.output, args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) + ConcatFile( + 30, + args.output, + args.operations, + args.type, + head, + required_cuda_ver_major, + required_cuda_ver_minor, + epilogue, + ) else: - ConcatFile(2, args.output, args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue) + ConcatFile( + 2, + args.output, + args.operations, + args.type, + head, + required_cuda_ver_major, + required_cuda_ver_minor, + epilogue, + ) elif args.operations == "gemv": for operation in operations: with EmitGemvSingleKernelWrapper( args.output, operation, gemv_wrapper_path, short_path ) as emitter: emitter.emit() - head = EmitGemvSingleKernelWrapper(args.output, operations[0], gemv_wrapper_path, short_path).header_template + head = EmitGemvSingleKernelWrapper( + args.output, operations[0], gemv_wrapper_path, short_path + ).header_template required_cuda_ver_major = operations[0].required_cuda_ver_major required_cuda_ver_minor = operations[0].required_cuda_ver_minor - epilogue = EmitGemvSingleKernelWrapper(args.output, operations[0], gemv_wrapper_path, short_path).epilogue_template - ConcatFile(2, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue, wrapper_path = gemv_wrapper_path) + epilogue = EmitGemvSingleKernelWrapper( + args.output, operations[0], gemv_wrapper_path, short_path + ).epilogue_template + ConcatFile( + 2, + args.output, + args.operations, + args.type, + head, + required_cuda_ver_major, + required_cuda_ver_minor, + epilogue, + wrapper_path=gemv_wrapper_path, + ) if args.operations != "gemv": GenerateManifest(args, operations, args.output) diff --git a/dnn/scripts/cutlass_generator/library.py b/dnn/scripts/cutlass_generator/library.py index b1b669d8..ea26b2ea 100644 --- a/dnn/scripts/cutlass_generator/library.py +++ b/dnn/scripts/cutlass_generator/library.py @@ -4,11 +4,11 @@ # \brief Generates the CUTLASS Library's instances # +import enum import re ################################################################################################### -import enum # The following block implements enum.auto() for Python 3.5 variants that don't include it such # as the default 3.5.2 on Ubuntu 16.04. diff --git a/dnn/scripts/cutlass_generator/manifest.py b/dnn/scripts/cutlass_generator/manifest.py index ee27e668..9dd50633 100644 --- a/dnn/scripts/cutlass_generator/manifest.py +++ b/dnn/scripts/cutlass_generator/manifest.py @@ -8,9 +8,9 @@ import enum import os.path import shutil -from library import * -from gemm_operation import * from conv2d_operation import * +from gemm_operation import * +from library import * ################################################################################################### diff --git a/dnn/scripts/gen_cond_take_kern_impls.py b/dnn/scripts/gen_cond_take_kern_impls.py index 8e6fe3e4..759299a7 100755 --- a/dnn/scripts/gen_cond_take_kern_impls.py +++ b/dnn/scripts/gen_cond_take_kern_impls.py @@ -1,59 +1,67 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import os import argparse +import os + from gen_elemwise_utils import DTYPES + def main(): parser = argparse.ArgumentParser( - description='generate elemwise impl files', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--type', type=str, choices=['cuda'], - default='cuda', - help='generate cuda cond take kernel file') - parser.add_argument('output', help='output directory') + description="generate elemwise impl files", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--type", + type=str, + choices=["cuda"], + default="cuda", + help="generate cuda cond take kernel file", + ) + parser.add_argument("output", help="output directory") args = parser.parse_args() if not os.path.isdir(args.output): os.makedirs(args.output) - assert args.type =='cuda' - cpp_ext = 'cu' + assert args.type == "cuda" + cpp_ext = "cu" for dtype in DTYPES.keys(): - fname = '{}.{}'.format(dtype, cpp_ext) + fname = "{}.{}".format(dtype, cpp_ext) fname = os.path.join(args.output, fname) - with open(fname, 'w') as fout: + with open(fname, "w") as fout: w = lambda s: print(s, file=fout) - w('// generated by gen_cond_take_kern_impls.py') + w("// generated by gen_cond_take_kern_impls.py") w('#include "../kern.inl"') - w('') - if dtype == 'dt_float16' or dtype == 'dt_bfloat16': - w('#if !MEGDNN_DISABLE_FLOAT16') - w('namespace megdnn {') - w('namespace cuda {') - w('namespace cond_take {') - w('') - - w('inst_genidx(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) - w('#undef inst_genidx') - w('') - w('inst_copy(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) - w('#undef inst_copy') - w('#undef inst_copy_') - - w('') - w('} // cond_take') - w('} // cuda') - w('} // megdnn') - if dtype == 'dt_float16' or dtype == 'dt_bfloat16': - w('#endif') - - print('generated {}'.format(fname)) + w("") + if dtype == "dt_float16" or dtype == "dt_bfloat16": + w("#if !MEGDNN_DISABLE_FLOAT16") + w("namespace megdnn {") + w("namespace cuda {") + w("namespace cond_take {") + w("") + + w("inst_genidx(::megdnn::dtype::{})".format(DTYPES[dtype][0])) + w("#undef inst_genidx") + w("") + w("inst_copy(::megdnn::dtype::{})".format(DTYPES[dtype][0])) + w("#undef inst_copy") + w("#undef inst_copy_") + + w("") + w("} // cond_take") + w("} // cuda") + w("} // megdnn") + if dtype == "dt_float16" or dtype == "dt_bfloat16": + w("#endif") + + print("generated {}".format(fname)) os.utime(args.output) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/dnn/scripts/gen_cuda_batch_conv_bias_kern_impls.py b/dnn/scripts/gen_cuda_batch_conv_bias_kern_impls.py index 2d71b02e..70fb0456 100755 --- a/dnn/scripts/gen_cuda_batch_conv_bias_kern_impls.py +++ b/dnn/scripts/gen_cuda_batch_conv_bias_kern_impls.py @@ -1,37 +1,47 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import os import argparse import itertools +import os + +PREFIXES = { + "dp4a": [ + ("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True), + ("batch_conv_bias_int8_gemm_ncdiv4hw4", False), + ("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False), + ] +} -PREFIXES = {"dp4a": [("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True), ("batch_conv_bias_int8_gemm_ncdiv4hw4", False), ("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False)]} +ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")} -ACTIVATIONS = {1: ("IDENTITY", "_id"), - 2: ("RELU", "_relu"), - 3: ("H_SWISH", "_hswish")} +BIASES = { + 1: ("PerElementBiasVisitor", "_per_elem"), + 2: ("PerChannelBiasVisitor", "_per_chan"), +} -BIASES = {1: ("PerElementBiasVisitor", "_per_elem"), - 2: ("PerChannelBiasVisitor", "_per_chan")} +SUFFIXES = {"dp4a": [""], "imma": [""]} -SUFFIXES = {"dp4a": [""], - "imma": [""]} def main(): parser = argparse.ArgumentParser( - description='generate cuda batch conv bias (dp4a/imma) kern impl files', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--type', type=str, choices=['dp4a', - 'imma'], - default='dp4a', help='generate cuda conv bias kernel file') - parser.add_argument('output', help='output directory') + description="generate cuda batch conv bias (dp4a/imma) kern impl files", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--type", + type=str, + choices=["dp4a", "imma"], + default="dp4a", + help="generate cuda conv bias kernel file", + ) + parser.add_argument("output", help="output directory") args = parser.parse_args() if not os.path.isdir(args.output): os.makedirs(args.output) - - inst = ''' + inst = """ template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX>>( const int8_t* d_src, @@ -41,7 +51,7 @@ template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX>>( const int8_t* d_src, @@ -43,7 +61,7 @@ template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX') + self._write("#include ") self._write("namespace mgb {") self._write("namespace serialization {") self._write("namespace fbs {") @@ -33,8 +34,9 @@ class ConverterWriter(IndentWriterBase): self._last_param = p self._param_fields = [] self._fb_fields = ["builder"] - self._write("template<>\nstruct ParamConverter {", - p.name, indent=1) + self._write( + "template<>\nstruct ParamConverter {", p.name, indent=1 + ) self._write("using MegDNNType = megdnn::param::%s;", p.name) self._write("using FlatBufferType = fbs::param::%s;\n", p.name) @@ -42,22 +44,22 @@ class ConverterWriter(IndentWriterBase): if self._skip_current_param: self._skip_current_param = False return - self._write("static MegDNNType to_param(const FlatBufferType* fb) {", - indent=1) - line = 'return {' - line += ', '.join(self._param_fields) - line += '};' + self._write("static MegDNNType to_param(const FlatBufferType* fb) {", indent=1) + line = "return {" + line += ", ".join(self._param_fields) + line += "};" self._write(line) self._write("}\n", indent=-1) self._write( "static flatbuffers::Offset to_flatbuffer(flatbuffers::FlatBufferBuilder& builder, const MegDNNType& param) {", - indent=1) - line = 'return fbs::param::Create{}('.format(str(p.name)) - line += ', '.join(self._fb_fields) - line += ');' + indent=1, + ) + line = "return fbs::param::Create{}(".format(str(p.name)) + line += ", ".join(self._fb_fields) + line += ");" self._write(line) - self._write('}', indent=-1) + self._write("}", indent=-1) self._write("};\n", indent=-1) @@ -68,18 +70,23 @@ class ConverterWriter(IndentWriterBase): return self._param_fields.append( "static_cast(fb->{}())".format( - str(p.name), str(e.name), e.name_field)) - self._fb_fields.append("static_cast(param.{})".format( - key, e.name_field)) + str(p.name), str(e.name), e.name_field + ) + ) + self._fb_fields.append( + "static_cast(param.{})".format(key, e.name_field) + ) def _on_member_field(self, f): if self._skip_current_param: return - if f.dtype.cname == 'DTypeEnum': + if f.dtype.cname == "DTypeEnum": self._param_fields.append( - "intl::convert_dtype_to_megdnn(fb->{}())".format(f.name)) + "intl::convert_dtype_to_megdnn(fb->{}())".format(f.name) + ) self._fb_fields.append( - "intl::convert_dtype_to_fbs(param.{})".format(f.name)) + "intl::convert_dtype_to_fbs(param.{})".format(f.name) + ) else: self._param_fields.append("fb->{}()".format(f.name)) self._fb_fields.append("param.{}".format(f.name)) @@ -93,28 +100,33 @@ class ConverterWriter(IndentWriterBase): enum_name = e.src_class + e.src_name self._param_fields.append( "static_cast(fb->{}())".format( - e.src_class, e.src_name, e.name_field)) - self._fb_fields.append("static_cast(param.{})".format( - enum_name, e.name_field)) + e.src_class, e.src_name, e.name_field + ) + ) + self._fb_fields.append( + "static_cast(param.{})".format(enum_name, e.name_field) + ) def main(): parser = argparse.ArgumentParser( - 'generate convert functions between FlatBuffers type and MegBrain type') - parser.add_argument('input') - parser.add_argument('output') + "generate convert functions between FlatBuffers type and MegBrain type" + ) + parser.add_argument("input") + parser.add_argument("output") args = parser.parse_args() with open(args.input) as fin: inputs = fin.read() - exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) + exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) input_hash = hashlib.sha256() - input_hash.update(inputs.encode(encoding='UTF-8')) + input_hash.update(inputs.encode(encoding="UTF-8")) input_hash = input_hash.hexdigest() writer = ConverterWriter() - with open(args.output, 'w') as fout: + with open(args.output, "w") as fout: writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) + if __name__ == "__main__": main() diff --git a/dnn/scripts/gen_flatbuffers_schema.py b/dnn/scripts/gen_flatbuffers_schema.py index 11805b14..2b823588 100755 --- a/dnn/scripts/gen_flatbuffers_schema.py +++ b/dnn/scripts/gen_flatbuffers_schema.py @@ -3,13 +3,14 @@ import argparse import collections -import textwrap -import os import hashlib -import struct import io +import os +import struct +import textwrap + +from gen_param_defs import IndentWriterBase, ParamDef, member_defs -from gen_param_defs import member_defs, ParamDef, IndentWriterBase def _cname_to_fbname(cname): return { @@ -22,17 +23,19 @@ def _cname_to_fbname(cname): "bool": "bool", }[cname] + def scramble_enum_member_name(name): - s = name.find('<<') + s = name.find("<<") if s != -1: - name = name[0:name.find('=') + 1] + ' ' + name[s+2:] + name = name[0 : name.find("=") + 1] + " " + name[s + 2 :] if name in ("MIN", "MAX"): return name + "_" - o_name = name.split(' ')[0].split('=')[0] + o_name = name.split(" ")[0].split("=")[0] if o_name in ("MIN", "MAX"): return name.replace(o_name, o_name + "_") return name + class FlatBuffersWriter(IndentWriterBase): _skip_current_param = False _last_param = None @@ -66,12 +69,13 @@ class FlatBuffersWriter(IndentWriterBase): self._write("}\n", indent=-1) def _write_doc(self, doc): - if not isinstance(doc, member_defs.Doc) or not doc.doc: return + if not isinstance(doc, member_defs.Doc) or not doc.doc: + return doc_lines = [] if doc.no_reformat: doc_lines = doc.raw_lines else: - doc = doc.doc.replace('\n', ' ') + doc = doc.doc.replace("\n", " ") text_width = 80 - len(self._cur_indent) - 4 doc_lines = textwrap.wrap(doc, text_width) for line in doc_lines: @@ -101,7 +105,8 @@ class FlatBuffersWriter(IndentWriterBase): default = e.compose_combined_enum(e.default) else: default = scramble_enum_member_name( - str(e.members[e.default]).split(' ')[0].split('=')[0]) + str(e.members[e.default]).split(" ")[0].split("=")[0] + ) self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default) def _resolve_const(self, v): @@ -113,8 +118,12 @@ class FlatBuffersWriter(IndentWriterBase): if self._skip_current_param: return self._write_doc(f.name) - self._write("%s:%s = %s;", f.name, _cname_to_fbname(f.dtype.cname), - self._get_fb_default(self._resolve_const(f.default))) + self._write( + "%s:%s = %s;", + f.name, + _cname_to_fbname(f.dtype.cname), + self._get_fb_default(self._resolve_const(f.default)), + ) def _on_const_field(self, f): self._cur_const_val[str(f.name)] = str(f.default) @@ -129,7 +138,8 @@ class FlatBuffersWriter(IndentWriterBase): default = s.compose_combined_enum(e.get_default()) else: default = scramble_enum_member_name( - str(s.members[e.get_default()]).split(' ')[0].split('=')[0]) + str(s.members[e.get_default()]).split(" ")[0].split("=")[0] + ) self._write("%s:%s = %s;", e.name_field, enum_name, default) def _get_fb_default(self, cppdefault): @@ -137,9 +147,9 @@ class FlatBuffersWriter(IndentWriterBase): return cppdefault d = cppdefault - if d.endswith('f'): # 1.f + if d.endswith("f"): # 1.f return d[:-1] - if d.endswith('ull'): + if d.endswith("ull"): return d[:-3] if d.startswith("DTypeEnum::"): return d[11:] @@ -148,21 +158,23 @@ class FlatBuffersWriter(IndentWriterBase): def main(): parser = argparse.ArgumentParser( - 'generate FlatBuffers schema of operator param from description file') - parser.add_argument('input') - parser.add_argument('output') + "generate FlatBuffers schema of operator param from description file" + ) + parser.add_argument("input") + parser.add_argument("output") args = parser.parse_args() with open(args.input) as fin: inputs = fin.read() - exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) + exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) input_hash = hashlib.sha256() - input_hash.update(inputs.encode(encoding='UTF-8')) + input_hash.update(inputs.encode(encoding="UTF-8")) input_hash = input_hash.hexdigest() writer = FlatBuffersWriter() - with open(args.output, 'w') as fout: + with open(args.output, "w") as fout: writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) + if __name__ == "__main__": main() diff --git a/dnn/scripts/gen_heuristic/gen_heuristic.py b/dnn/scripts/gen_heuristic/gen_heuristic.py index f5579e65..55ed7e97 100755 --- a/dnn/scripts/gen_heuristic/gen_heuristic.py +++ b/dnn/scripts/gen_heuristic/gen_heuristic.py @@ -1,14 +1,16 @@ #! /usr/local/env python3 -import pickle -import numpy as np -import os import argparse -import re import collections +import os +import pickle +import re + +import numpy as np + def define_template(**kwargs): - template = ''' + template = """ float cuda{cuda_arch}_{conv_type}_time_pred[{out_dim}] = {{0.0f}}; float cuda{cuda_arch}_{conv_type}_mask[{out_dim}] = {{0.0f}}; float cuda{cuda_arch}_{conv_type}_hidden_units[{hidden_num}] = {{0.0f}}; @@ -17,21 +19,23 @@ def define_template(**kwargs): const static float cuda{cuda_arch}_{conv_type}_biases[{biases_dim}] = {{{biases}}}; const static float cuda{cuda_arch}_{conv_type}_alpha[{out_dim}] = {{{alpha}}}; const static float cuda{cuda_arch}_{conv_type}_beta[{out_dim}] = {{{beta}}}; - ''' + """ return template.format(**kwargs) + def cudnn_slt_template(**kwargs): - template = ("#if CUDNN_MAJOR == {cudnn_major} && CUDNN_MINOR == {cudnn_minor}\n" + - " {define_cmd}\n" + - " {select_cmd}\n" + - " return true;\n" + - "#endif\n" - ) + template = ( + "#if CUDNN_MAJOR == {cudnn_major} && CUDNN_MINOR == {cudnn_minor}\n" + + " {define_cmd}\n" + + " {select_cmd}\n" + + " return true;\n" + + "#endif\n" + ) return template.format(**kwargs) + def select_template(**kwargs): - template = \ - '''if (conv_type == ConvolutionType::{conv_type} && cuda_major == {cuda_major} && + template = """if (conv_type == ConvolutionType::{conv_type} && cuda_major == {cuda_major} && cuda_minor == {cuda_minor}) {{ *layer_num_p = {layer_num}; *hidden_units_p = cuda{cuda_arch}_{conv_type}_hidden_units; @@ -42,7 +46,7 @@ def select_template(**kwargs): *beta_p = cuda{cuda_arch}_{conv_type}_beta; *time_pred_p = cuda{cuda_arch}_{conv_type}_time_pred; *mask_p = cuda{cuda_arch}_{conv_type}_mask; - }} else ''' + }} else """ return template.format(**kwargs) @@ -58,48 +62,48 @@ def fill_src(): if len(matrix_files) == 0: print("Warning: no param files detected.") for fpath in matrix_files: - cudnn_version = re.findall('cudnn([\d.]+)',fpath)[0] + cudnn_version = re.findall("cudnn([\d.]+)", fpath)[0] gen_list[cudnn_version].append(fpath) for cudnn in gen_list: - select_cmd = ("{\n" + - " " * 8 + "return false;\n" + - " " * 4 + "}") + select_cmd = "{\n" + " " * 8 + "return false;\n" + " " * 4 + "}" define_cmd = "" - cudnn_major, cudnn_minor = cudnn.split('.') + cudnn_major, cudnn_minor = cudnn.split(".") for fpath in gen_list[cudnn]: cuda_arch = fpath.split("-")[1].replace(".", "_") - print('cudnn_version: {}, cuda_arch: {}'.format(cudnn,cuda_arch)) + print("cudnn_version: {}, cuda_arch: {}".format(cudnn, cuda_arch)) conv_type = fpath.split("-")[2].split(".")[0] with open(os.path.join(home, "params/{}".format(fpath)), "rb") as pobj: params = pickle.load(pobj) - crt_define_cmd, crt_select_cmd = gen_cmds( - cuda_arch, conv_type, params) + crt_define_cmd, crt_select_cmd = gen_cmds(cuda_arch, conv_type, params) select_cmd = crt_select_cmd + select_cmd define_cmd = crt_define_cmd + define_cmd - cudnn_slt_cmd += cudnn_slt_template(cudnn_major=cudnn_major, - cudnn_minor=cudnn_minor, - select_cmd=select_cmd, - define_cmd=define_cmd) + cudnn_slt_cmd += cudnn_slt_template( + cudnn_major=cudnn_major, + cudnn_minor=cudnn_minor, + select_cmd=select_cmd, + define_cmd=define_cmd, + ) - #select_cmd = select_cmd + # select_cmd = select_cmd with open(os.path.join(home, "get_params.template"), "r") as srcf: src = srcf.read() dst = src.replace("{cudnn_select}", cudnn_slt_cmd) MegDNN_path = os.path.join(home, "../..") - with open(os.path.join(MegDNN_path, - "src/cuda/convolution/get_params.cpp"), "w") as dstf: + with open( + os.path.join(MegDNN_path, "src/cuda/convolution/get_params.cpp"), "w" + ) as dstf: dstf.write(dst) def gen_cmds(cuda_arch, conv_type, params): cuda_major, cuda_minor = cuda_arch.split("_") - alphastr = format_array(params['alpha']).rstrip()[:-1] - betastr = format_array(params['beta']).rstrip()[:-1] - W_list = params['W'] - b_list = params['b'] - Wstr = '' - bstr = '' + alphastr = format_array(params["alpha"]).rstrip()[:-1] + betastr = format_array(params["beta"]).rstrip()[:-1] + W_list = params["W"] + b_list = params["b"] + Wstr = "" + bstr = "" layer_num = str(len(b_list) + 1) layers_dim = [W_list[0].shape[1]] matrices_dim = 0 @@ -118,16 +122,27 @@ def gen_cmds(cuda_arch, conv_type, params): out_dim = layers_dim[-1] layers_dim_str = format_array(np.array(layers_dim)).rstrip()[:-1] - select_cmd = select_template(conv_type=conv_type.upper(), cuda_major=cuda_major, - cuda_minor=cuda_minor, layer_num=layer_num, - cuda_arch=cuda_arch) - define_cmd = define_template(cuda_arch=cuda_arch, conv_type=conv_type.upper(), - hidden_num=hidden_num, - layer_num=layer_num, out_dim=out_dim, - layers_dim=layers_dim_str, - matrices_dim=matrices_dim, matrices=Wstr, - biases_dim=biases_dim, biases=bstr, - alpha=alphastr, beta=betastr) + select_cmd = select_template( + conv_type=conv_type.upper(), + cuda_major=cuda_major, + cuda_minor=cuda_minor, + layer_num=layer_num, + cuda_arch=cuda_arch, + ) + define_cmd = define_template( + cuda_arch=cuda_arch, + conv_type=conv_type.upper(), + hidden_num=hidden_num, + layer_num=layer_num, + out_dim=out_dim, + layers_dim=layers_dim_str, + matrices_dim=matrices_dim, + matrices=Wstr, + biases_dim=biases_dim, + biases=bstr, + alpha=alphastr, + beta=betastr, + ) return (define_cmd, select_cmd) @@ -153,8 +168,9 @@ def format_array(array): if __name__ == "__main__": parser = argparse.ArgumentParser( description="Generate cuDNN heuristic code by neural network into" - " {MEGDNN_ROOT}/src/cuda/convolution/get_params.cpp," - " using parameter value from pickle files in" - " {MEGDNN_ROOT}/scripts/gen_heuristic/params/") + " {MEGDNN_ROOT}/src/cuda/convolution/get_params.cpp," + " using parameter value from pickle files in" + " {MEGDNN_ROOT}/scripts/gen_heuristic/params/" + ) args = parser.parse_args() main() diff --git a/dnn/scripts/gen_param_defs.py b/dnn/scripts/gen_param_defs.py index f3e1b5ef..41d18ed8 100755 --- a/dnn/scripts/gen_param_defs.py +++ b/dnn/scripts/gen_param_defs.py @@ -3,36 +3,37 @@ import argparse import collections -import textwrap -import os import hashlib +import os import struct +import textwrap + class member_defs: """contain classes to define members of an opr param""" - Dtype = collections.namedtuple('Dtype', ['cname', 'pycvt', 'pyfmt', - 'cppjson', 'cname_attr']) - Dtype.__new__.__defaults__ = ('', ) - uint32 = Dtype('uint32_t', 'int', 'I', 'NumberInt') - uint64 = Dtype('uint64_t', 'int', 'Q', 'NumberInt', - 'alignas(sizeof(uint64_t)) ') - int32 = Dtype('int32_t', 'int', 'i', 'NumberInt') - float32 = Dtype('float', 'float', 'f', 'Number') - float64 = Dtype('double', 'float', 'd', 'Number') - dtype = Dtype('DTypeEnum', '_as_dtype_num', 'I', 'Number') - bool = Dtype('bool', 'bool', '?', 'Bool') + Dtype = collections.namedtuple( + "Dtype", ["cname", "pycvt", "pyfmt", "cppjson", "cname_attr"] + ) + Dtype.__new__.__defaults__ = ("",) + uint32 = Dtype("uint32_t", "int", "I", "NumberInt") + uint64 = Dtype("uint64_t", "int", "Q", "NumberInt", "alignas(sizeof(uint64_t)) ") + int32 = Dtype("int32_t", "int", "i", "NumberInt") + float32 = Dtype("float", "float", "f", "Number") + float64 = Dtype("double", "float", "d", "Number") + dtype = Dtype("DTypeEnum", "_as_dtype_num", "I", "Number") + bool = Dtype("bool", "bool", "?", "Bool") class Base: pass - class Doc: """wrap an identifier to associate document note: if the doc starts with a linebreak, it would not be reforamtted. """ - __slots__ = ['id', 'doc'] + + __slots__ = ["id", "doc"] def __init__(self, id_, doc): assert isinstance(id_, str) and isinstance(doc, str), (id_, doc) @@ -42,12 +43,12 @@ class member_defs: @property def no_reformat(self): """whether reformat is disallowed for this doc string""" - return self.doc.startswith('\n') + return self.doc.startswith("\n") @property def raw_lines(self): """the doc lines when ``no_format`` is true""" - ret = self.doc.split('\n') + ret = self.doc.split("\n") assert not ret[0] return ret[1:] @@ -57,7 +58,7 @@ class member_defs: if isinstance(v, cls): return v assert isinstance(v, str) - return cls(v, '') + return cls(v, "") def __str__(self): return self.id @@ -65,9 +66,7 @@ class member_defs: def __eq__(self, rhs): if isinstance(rhs, str): return self.id == rhs - return (isinstance(rhs, Doc) and - (self.id, self.doc) == (rhs.id, rhs.doc)) - + return isinstance(rhs, Doc) and (self.id, self.doc) == (rhs.id, rhs.doc) class Enum(Base): """define an enum; the result would contain both an enum class def and its @@ -89,14 +88,29 @@ class member_defs: for normal enum class: list of (member, alias) pairs for bit combined class: list of (tuple of members, alias) paris """ - __slots__ = ['name', 'name_field', 'members', 'default', - 'member_alias', 'combined'] + + __slots__ = [ + "name", + "name_field", + "members", + "default", + "member_alias", + "combined", + ] all_enums = {} """(param_name, name) => enum""" - def __init__(self, param_name, name, name_field, members, default, - member_alias, combined = False): + def __init__( + self, + param_name, + name, + name_field, + members, + default, + member_alias, + combined=False, + ): name = member_defs.Doc.make(name) assert name.id[0].isupper() members = tuple(map(member_defs.Doc.make, members)) @@ -122,13 +136,13 @@ class member_defs: def normalize(v): if isinstance(v, str): for idx, m in enumerate(self.members): - m = str(m).split(' ')[0].split('=')[0] - if v == m : + m = str(m).split(" ")[0].split("=")[0] + if v == m: return idx - raise ValueError( - "enum member '{}' does not exist.".format(v)) + raise ValueError("enum member '{}' does not exist.".format(v)) assert isinstance(v, int) return v + if self.combined: if isinstance(value, int): value = self.decompose_combined_enum(value) @@ -159,7 +173,8 @@ class member_defs: class Field(Base): """define a normal data field""" - __slots__ = ['name', 'dtype', 'default'] + + __slots__ = ["name", "dtype", "default"] def __init__(self, name, dtype, default): assert isinstance(dtype, member_defs.Dtype) @@ -169,7 +184,8 @@ class member_defs: class Const(Base): """define a const data field""" - __slots__ = ['name', 'dtype', 'default'] + + __slots__ = ["name", "dtype", "default"] def __init__(self, name, dtype, default): assert isinstance(dtype, member_defs.Dtype) @@ -179,7 +195,8 @@ class member_defs: class EnumAlias(Base): """alias of enum type from another param""" - __slots__ = ['name', 'name_field', 'src_class', 'src_name', 'default'] + + __slots__ = ["name", "name_field", "src_class", "src_name", "default"] def __init__(self, name, name_field, src_class, src_name, default): self.name = name @@ -209,26 +226,28 @@ class member_defs: class ParamDef: """""" + __all_tags = set() all_param_defs = [] - __slots__ = ['name', 'members', 'tag', 'is_legacy'] + __slots__ = ["name", "members", "tag", "is_legacy"] - def __init__(self, name, doc='', *, version=0, is_legacy=False): + def __init__(self, name, doc="", *, version=0, is_legacy=False): self.members = [] self.all_param_defs.append(self) - h = hashlib.sha256(name.encode('utf-8')) + h = hashlib.sha256(name.encode("utf-8")) if version: - h.update(struct.pack(' 0: self._indent() @@ -343,7 +369,8 @@ class IndentWriterBase(WriterBase): class PyWriter(IndentWriterBase): FieldDef = collections.namedtuple( - 'FieldDef', ['name', 'cvt', 'fmt', 'default', 'type', 'doc']) + "FieldDef", ["name", "cvt", "fmt", "default", "type", "doc"] + ) # see _on_param_end() for the use of those fields _cur_param_name = None @@ -358,79 +385,75 @@ class PyWriter(IndentWriterBase): def __call__(self, fout, defs): super().__call__(fout) self._enum_member2num = [] - self._write('# %s', self._get_header()) - self._write('import struct') - self._write('from . import enum36 as enum') + self._write("# %s", self._get_header()) + self._write("import struct") + self._write("from . import enum36 as enum") self._write( - 'class _ParamDefBase:\n' - ' def serialize(self):\n' + "class _ParamDefBase:\n" + " def serialize(self):\n" ' tag = struct.pack("I", type(self).TAG)\n' - ' pdata = [getattr(self, i) for i in self.__slots__]\n' - ' for idx, v in enumerate(pdata):\n' - ' if isinstance(v, _EnumBase):\n' - ' pdata[idx] = _enum_member2num[id(v)]\n' - ' elif isinstance(v, _BitCombinedEnumBase):\n' - ' pdata[idx] = v._value_\n' - ' return tag + self._packer.pack(*pdata)\n' - '\n' + " pdata = [getattr(self, i) for i in self.__slots__]\n" + " for idx, v in enumerate(pdata):\n" + " if isinstance(v, _EnumBase):\n" + " pdata[idx] = _enum_member2num[id(v)]\n" + " elif isinstance(v, _BitCombinedEnumBase):\n" + " pdata[idx] = v._value_\n" + " return tag + self._packer.pack(*pdata)\n" + "\n" ) # it's hard to mix custom implemention into enum, just do copy-paste instead classbody = ( - ' @classmethod\n' - ' def __normalize(cls, val):\n' - ' if isinstance(val, str):\n' + " @classmethod\n" + " def __normalize(cls, val):\n" + " if isinstance(val, str):\n" ' if not hasattr(cls, "__member_upper_dict__"):\n' - ' cls.__member_upper_dict__ = {k.upper(): v\n' - ' for k, v in cls.__members__.items()}\n' - ' val = cls.__member_upper_dict__.get(val.upper(),val)\n' - ' return val\n' - ' @classmethod\n' - ' def convert(cls, val):\n' - ' val = cls.__normalize(val)\n' - ' if isinstance(val, cls):\n' - ' return val\n' - ' return cls(val)\n' - ' @classmethod\n' - ' def _missing_(cls, value):\n' - ' vnorm = cls.__normalize(value)\n' - ' if vnorm is not value:\n' - ' return cls(vnorm)\n' - ' return super()._missing_(value)\n' - '\n' - ) - self._write( - 'class _EnumBase(enum.Enum):\n' + classbody - ) - self._write( - 'class _BitCombinedEnumBase(enum.Flag):\n' + classbody + " cls.__member_upper_dict__ = {k.upper(): v\n" + " for k, v in cls.__members__.items()}\n" + " val = cls.__member_upper_dict__.get(val.upper(),val)\n" + " return val\n" + " @classmethod\n" + " def convert(cls, val):\n" + " val = cls.__normalize(val)\n" + " if isinstance(val, cls):\n" + " return val\n" + " return cls(val)\n" + " @classmethod\n" + " def _missing_(cls, value):\n" + " vnorm = cls.__normalize(value)\n" + " if vnorm is not value:\n" + " return cls(vnorm)\n" + " return super()._missing_(value)\n" + "\n" ) + self._write("class _EnumBase(enum.Enum):\n" + classbody) + self._write("class _BitCombinedEnumBase(enum.Flag):\n" + classbody) if not self._imperative: self._write( - 'def _as_dtype_num(dtype):\n' - ' import megbrain.mgb as m\n' - ' return m._get_dtype_num(dtype)\n' - '\n' + "def _as_dtype_num(dtype):\n" + " import megbrain.mgb as m\n" + " return m._get_dtype_num(dtype)\n" + "\n" ) self._write( - 'def _as_serialized_dtype(dtype):\n' - ' import megbrain.mgb as m\n' - ' return m._get_serialized_dtype(dtype)\n' - '\n' + "def _as_serialized_dtype(dtype):\n" + " import megbrain.mgb as m\n" + " return m._get_serialized_dtype(dtype)\n" + "\n" ) else: self._write( - 'def _as_dtype_num(dtype):\n' - ' import megengine.core._imperative_rt.utils as m\n' - ' return m._get_dtype_num(dtype)\n' - '\n' + "def _as_dtype_num(dtype):\n" + " import megengine.core._imperative_rt.utils as m\n" + " return m._get_dtype_num(dtype)\n" + "\n" ) self._write( - 'def _as_serialized_dtype(dtype):\n' - ' import megengine.core._imperative_rt.utils as m\n' - ' return m._get_serialized_dtype(dtype)\n' - '\n' + "def _as_serialized_dtype(dtype):\n" + " import megengine.core._imperative_rt.utils as m\n" + " return m._get_serialized_dtype(dtype)\n" + "\n" ) self._process(defs) @@ -451,8 +474,7 @@ class SerializedDType(_ParamDefBase): self.dtype = _as_serialized_dtype(dtype) ''' ) - self._write('_enum_member2num = {\n %s}', - ',\n '.join(self._enum_member2num)) + self._write("_enum_member2num = {\n %s}", ",\n ".join(self._enum_member2num)) def _write_doc(self, doc): assert isinstance(doc, member_defs.Doc) @@ -465,158 +487,171 @@ class SerializedDType(_ParamDefBase): self._write('"""') return - doc = doc.doc.replace('\n', ' ') + doc = doc.doc.replace("\n", " ") textwidth = 80 - len(self._cur_indent) self._write('"""') for i in textwrap.wrap(doc, textwidth): self._write(i) self._write('"""') - def _on_param_begin(self, p): self._cur_param_name = str(p.name) self._cur_fields = [] self._cur_enum_names = [] - self._write('class %s(_ParamDefBase):', p.name, indent=1) + self._write("class %s(_ParamDefBase):", p.name, indent=1) self._write_doc(p.name) - self._write('TAG = %d', p.tag) + self._write("TAG = %d", p.tag) def _on_param_end(self, p): # gen slots and packer - self._write('__slots__ = [%s]', ', '.join( - map('"{.name}"'.format, self._cur_fields))) - struct_fmt = ''.join(i.fmt for i in self._cur_fields) + self._write( + "__slots__ = [%s]", ", ".join(map('"{.name}"'.format, self._cur_fields)) + ) + struct_fmt = "".join(i.fmt for i in self._cur_fields) if not struct_fmt: - struct_fmt = 'x' + struct_fmt = "x" else: # add padding at end max_t = max(struct_fmt, key=struct.calcsize) - struct_fmt += '0{}'.format(max_t) + struct_fmt += "0{}".format(max_t) self._write('_packer = struct.Struct("%s")', struct_fmt) # gen __init__ signature - self._write('def __init__(%s):', - ', '.join(['self'] + - list('{}={}'.format(i.name, i.default) - for i in self._cur_fields)), - indent=1) + self._write( + "def __init__(%s):", + ", ".join( + ["self"] + + list("{}={}".format(i.name, i.default) for i in self._cur_fields) + ), + indent=1, + ) # gen __init__ doc self._write('"""') for i in self._cur_fields: - self._write(':type {}: :class:`.{}`'.format(i.name, i.type)) + self._write(":type {}: :class:`.{}`".format(i.name, i.type)) if i.doc: - self._write(':param {}: {}'.format(i.name, i.doc)) + self._write(":param {}: {}".format(i.name, i.doc)) self._write('"""') # gen cvt in __init__ for i in self._cur_fields: - self._write('self.%s = %s', i.name, i.cvt) + self._write("self.%s = %s", i.name, i.cvt) self._unindent() self._unindent() - self._write('') + self._write("") def _on_member_enum(self, e): - qualname = '{}.{}'.format(self._cur_param_name, e.name) + qualname = "{}.{}".format(self._cur_param_name, e.name) if e.combined: - self._write('class %s(_BitCombinedEnumBase):', e.name, indent=1) + self._write("class %s(_BitCombinedEnumBase):", e.name, indent=1) else: - self._write('class %s(_EnumBase):', e.name, indent=1) + self._write("class %s(_EnumBase):", e.name, indent=1) self._write_doc(e.name) for emem in e.members: if e.combined: - self._write('%s', emem) + self._write("%s", emem) self._write_doc(emem) else: - v = str(emem).split(' ')[0].split('=')[0] - n = int(str(emem).split('=')[1]) + v = str(emem).split(" ")[0].split("=")[0] + n = int(str(emem).split("=")[1]) self._write('%s = "%s"', v, v) self._write_doc(emem) - self._enum_member2num.append('id({}.{}):{}'.format( - qualname, v, n)) + self._enum_member2num.append("id({}.{}):{}".format(qualname, v, n)) for emem, emem_alias in e.member_alias: - em_a = emem_alias.split(' ')[0].split('=')[0] + em_a = emem_alias.split(" ")[0].split("=")[0] if e.combined: - self._write('%s = %s', em_a, e.compose_combined_enum(emem)) + self._write("%s = %s", em_a, e.compose_combined_enum(emem)) else: - em = str(emem).split(' ')[0].split('=')[0] - self._write('%s = %s', em_a, em) + em = str(emem).split(" ")[0].split("=")[0] + self._write("%s = %s", em_a, em) self._unindent() - self._write('') + self._write("") if e.combined: default = e.compose_combined_enum(e.default) else: - default = "'{}'".format(str(e.members[e.default]).split(' ')[0].split('=')[0]) + default = "'{}'".format( + str(e.members[e.default]).split(" ")[0].split("=")[0] + ) - self._cur_fields.append(self.FieldDef( - name=e.name_field, - cvt='{}.convert({})'.format(qualname, e.name_field), - fmt='I', - default=default, - type=qualname, - doc=None)) + self._cur_fields.append( + self.FieldDef( + name=e.name_field, + cvt="{}.convert({})".format(qualname, e.name_field), + fmt="I", + default=default, + type=qualname, + doc=None, + ) + ) def _on_member_enum_alias(self, e): - self._write('%s = %s.%s', e.name, e.src_class, e.src_name) + self._write("%s = %s.%s", e.name, e.src_class, e.src_name) s = e.src_enum - qualname = '{}.{}'.format(e.src_class, e.src_name) + qualname = "{}.{}".format(e.src_class, e.src_name) if s.combined: default = s.compose_combined_enum(e.get_default()) else: - default = "'{}'".format(str(s.members[e.get_default()]).split(' ')[0].split('=')[0]) - self._cur_fields.append(self.FieldDef( - name=e.name_field, - cvt='{}.convert({})'.format(qualname, e.name_field), - fmt='I', - default=default, - type=qualname, - doc=None)) + default = "'{}'".format( + str(s.members[e.get_default()]).split(" ")[0].split("=")[0] + ) + self._cur_fields.append( + self.FieldDef( + name=e.name_field, + cvt="{}.convert({})".format(qualname, e.name_field), + fmt="I", + default=default, + type=qualname, + doc=None, + ) + ) def _get_py_default(self, cppdefault): if not isinstance(cppdefault, str): return cppdefault d = cppdefault - if d.endswith('f'): # 1.f + if d.endswith("f"): # 1.f return d[:-1] - if d.endswith('ull'): + if d.endswith("ull"): return d[:-3] - if d == 'false': - return 'False' - if d == 'true': - return 'True' - if d.startswith('DTypeEnum::'): - return '"{}"'.format(d.split(':')[2].lower()) + if d == "false": + return "False" + if d == "true": + return "True" + if d.startswith("DTypeEnum::"): + return '"{}"'.format(d.split(":")[2].lower()) return d def _on_member_field(self, f): d = self._get_py_default(f.default) - self._cur_fields.append(self.FieldDef( - name=f.name, - cvt='{}({})'.format(f.dtype.pycvt, f.name), - fmt=f.dtype.pyfmt, - default=d, - type=f.dtype.pycvt, - doc=f.name.doc - )) + self._cur_fields.append( + self.FieldDef( + name=f.name, + cvt="{}({})".format(f.dtype.pycvt, f.name), + fmt=f.dtype.pyfmt, + default=d, + type=f.dtype.pycvt, + doc=f.name.doc, + ) + ) def _on_const_field(self, f): d = self._get_py_default(f.default) self._write_doc(f.name) - self._write('%s = %s', f.name, d) - + self._write("%s = %s", f.name, d) class CPPWriter(IndentWriterBase): - _param_namespace = 'param' + _param_namespace = "param" _ctor_args = None """list of (text in func param, var name); func param name must be var name @@ -625,18 +660,18 @@ class CPPWriter(IndentWriterBase): def __call__(self, fout, defs): super().__call__(fout) - self._write('// %s', self._get_header()) - self._write('#pragma once') + self._write("// %s", self._get_header()) + self._write("#pragma once") self._write('#include "megdnn/dtype.h"') - self._write('#include ') - if self._param_namespace == 'param': - self._write('#include ') - self._write('namespace megdnn {') - self._write('namespace %s {', self._param_namespace) + self._write("#include ") + if self._param_namespace == "param": + self._write("#include ") + self._write("namespace megdnn {") + self._write("namespace %s {", self._param_namespace) self._process(defs) - self._write('} // namespace megdnn') - self._write('} // namespace %s', self._param_namespace) - self._write('// vim: syntax=cpp.doxygen') + self._write("} // namespace megdnn") + self._write("} // namespace %s", self._param_namespace) + self._write("// vim: syntax=cpp.doxygen") def _write_doc(self, doc): assert isinstance(doc, member_defs.Doc) @@ -644,37 +679,37 @@ class CPPWriter(IndentWriterBase): return if doc.no_reformat: - self._write('/*') + self._write("/*") for i in doc.raw_lines: - self._write('* ' + i) - self._write('*/') + self._write("* " + i) + self._write("*/") return - doc = doc.doc.replace('\n', ' ') + doc = doc.doc.replace("\n", " ") textwidth = 80 - len(self._cur_indent) - 4 if len(doc) <= textwidth: - self._write('//! ' + doc) + self._write("//! " + doc) return - self._write('/*!') + self._write("/*!") for i in textwrap.wrap(doc, textwidth): - self._write(' * ' + i) - self._write(' */') + self._write(" * " + i) + self._write(" */") def _on_param_begin(self, p): self._write_doc(p.name) - self._write('struct %s {', p.name, indent=1) - self._write('static MEGDNN_CONSTEXPR uint32_t TAG = %du;', p.tag) + self._write("struct %s {", p.name, indent=1) + self._write("static MEGDNN_CONSTEXPR uint32_t TAG = %du;", p.tag) self._ctor_args = [] self._non_static_members = [] def _add_ctor_args(self, typename, default, varname): - self._ctor_args.append(( - '{} {}_={}'.format(typename, varname, default), - varname)) + self._ctor_args.append( + ("{} {}_={}".format(typename, varname, default), varname) + ) def _on_param_end(self, p): - ''' + """ MegDNN param structures are not packed and we need to initialize the structure paddings to zero or it would break MegBrain hash system. We do memset(0) in default ctor and use a trick, wrapping non-static members in a anonymous union which would @@ -683,64 +718,78 @@ class CPPWriter(IndentWriterBase): > a memberwise copy/move of its bases and members. [class.copy.ctor 14] > The implicitly-defined copy/move constructor for a union X copies the object > representation (6.9) of X. [class.copy.ctor 15] - ''' + """ if self._non_static_members: - self._write('union { struct {') + self._write("union { struct {") for i in self._non_static_members: if isinstance(i, member_defs.Field): self._write_doc(i.name) - self._write('%s%s %s;', i.dtype.cname_attr, i.dtype.cname, i.name) + self._write("%s%s %s;", i.dtype.cname_attr, i.dtype.cname, i.name) else: assert isinstance(i, (member_defs.Enum, member_defs.EnumAlias)) - self._write('%s %s;', i.name, i.name_field) - self._write('}; };') + self._write("%s %s;", i.name, i.name_field) + self._write("}; };") if self._ctor_args: pdefs, varnames = zip(*self._ctor_args) - self._write('%s(%s) {', p.name, ', '.join(pdefs), indent=1) - self._write('memset(this, 0, sizeof(*this));') + self._write("%s(%s) {", p.name, ", ".join(pdefs), indent=1) + self._write("memset(this, 0, sizeof(*this));") for var in varnames: - self._write('this->%s = %s_;', var, var) - self._write('}', indent=-1) - self._write('};\n', indent=-1) + self._write("this->%s = %s_;", var, var) + self._write("}", indent=-1) + self._write("};\n", indent=-1) def _on_member_enum(self, e): self._write_doc(e.name) - self._write('enum class %s: uint32_t {', e.name, indent=1) + self._write("enum class %s: uint32_t {", e.name, indent=1) for i in e.members: self._write_doc(i) v = str(i) if i is not e.members[-1] or e.member_alias: - v += ',' + v += "," self._write(v) for mem, alias in e.member_alias: if e.combined: - self._write('%s = %s,', alias, e.compose_combined_enum(mem)) + self._write("%s = %s,", alias, e.compose_combined_enum(mem)) else: - self._write('%s = %s,', str(alias).split(' ')[0].split('=')[0], str(mem).split(' ')[0].split('=')[0]) - self._write('};', indent=-1) + self._write( + "%s = %s,", + str(alias).split(" ")[0].split("=")[0], + str(mem).split(" ")[0].split("=")[0], + ) + self._write("};", indent=-1) self._non_static_members.append(e) - self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', - str(e.name).upper(), len(e.members)) + self._write( + "static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;", + str(e.name).upper(), + len(e.members), + ) if e.combined: - default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default)) + default = "static_cast<{}>({})".format( + e.name, e.compose_combined_enum(e.default) + ) else: value = str(e.members[e.default]) - value = value.split(' ')[0].split('=')[0] - default = '{}::{}'.format(e.name, value) + value = value.split(" ")[0].split("=")[0] + default = "{}::{}".format(e.name, value) self._add_ctor_args(e.name, default, e.name_field) def _on_member_enum_alias(self, e): s = e.src_enum - self._write('using %s = %s::%s;', e.name, e.src_class, e.src_name) + self._write("using %s = %s::%s;", e.name, e.src_class, e.src_name) self._non_static_members.append(e) - self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', - str(e.name).upper(), len(s.members)) + self._write( + "static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;", + str(e.name).upper(), + len(s.members), + ) if s.combined: - default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default)) + default = "static_cast<{}>({})".format( + e.name, s.compose_combined_enum(e.default) + ) else: value = str(s.members[e.get_default()]) - value = value.split(' ')[0].split('=')[0] - default = '{}::{}'.format(e.name, value) + value = value.split(" ")[0].split("=")[0] + default = "{}::{}".format(e.name, value) self._add_ctor_args(e.name, default, e.name_field) def _on_member_field(self, f): @@ -749,30 +798,45 @@ class CPPWriter(IndentWriterBase): def _on_const_field(self, f): self._write_doc(f.name) - if 'int' in f.dtype.cname: - self._write('static constexpr %s%s %s = %s;', f.dtype.cname_attr, f.dtype.cname, f.name, f.default) + if "int" in f.dtype.cname: + self._write( + "static constexpr %s%s %s = %s;", + f.dtype.cname_attr, + f.dtype.cname, + f.name, + f.default, + ) else: - self._write('static const %s%s %s = %s;', f.dtype.cname_attr, f.dtype.cname, f.name, f.default) - + self._write( + "static const %s%s %s = %s;", + f.dtype.cname_attr, + f.dtype.cname, + f.name, + f.default, + ) class CPPEnumValueWriter(CPPWriter): - _param_namespace = 'param_enumv' + _param_namespace = "param_enumv" def _on_member_enum(self, e): self._write_doc(e.name) - self._write('struct %s {', e.name, indent=1) + self._write("struct %s {", e.name, indent=1) for val in e.members: self._write_doc(val) v = str(val) - self._write('static const uint32_t %s;', v) + self._write("static const uint32_t %s;", v) for mem, alias in e.member_alias: - self._write('static const uint32_t %s = %s;', str(alias).split(' ')[0].split('=')[0], str(mem).split(' ')[0].split('=')[0]) - self._write('};', indent=-1) + self._write( + "static const uint32_t %s = %s;", + str(alias).split(" ")[0].split("=")[0], + str(mem).split(" ")[0].split("=")[0], + ) + self._write("};", indent=-1) def _on_member_enum_alias(self, e): s = e.src_enum - self._write('typedef %s::%s %s;', e.src_class, e.src_name, e.name) + self._write("typedef %s::%s %s;", e.src_class, e.src_name, e.name) def _on_member_field(self, f): pass @@ -787,7 +851,7 @@ class CPPEnumItemWriter(WriterBase): _enable = False def __init__(self, enum_def): - self._class_name, self._enum_name = enum_def.split(':') + self._class_name, self._enum_name = enum_def.split(":") def __call__(self, fout, defs): super().__call__(fout) @@ -799,59 +863,73 @@ class CPPEnumItemWriter(WriterBase): def _on_member_enum(self, e): if self._enable and e.name == self._enum_name: for i in e.members: - self._fout.write('{}\n'.format(i)) + self._fout.write("{}\n".format(i)) + class CPPParamJsonFuncWriter(IndentWriterBase): - _param_namespace = 'param' + _param_namespace = "param" _param_name = None _items = None + def _write_json_item(self, json_cls, field): cls2ctype = { - 'NumberInt': 'int64_t', - 'Number': 'double', - 'Bool': 'bool', + "NumberInt": "int64_t", + "Number": "double", + "Bool": "bool", } - self._items.append('{"%s", json::%s::make(static_cast<%s>(p.%s))},' % ( - field, json_cls, cls2ctype[json_cls], field)) - + self._items.append( + '{"%s", json::%s::make(static_cast<%s>(p.%s))},' + % (field, json_cls, cls2ctype[json_cls], field) + ) def __call__(self, fout, defs): super().__call__(fout) - self._write('// %s', self._get_header()) - self._write('// this file can only be included in ' - 'megbrain/src/plugin/impl/opr_footprint.cpp\n' - '// please do not include it directly') + self._write("// %s", self._get_header()) + self._write( + "// this file can only be included in " + "megbrain/src/plugin/impl/opr_footprint.cpp\n" + "// please do not include it directly" + ) self._write('#include "megdnn/opr_param_defs.h"') - self._write('#pragma once') - self._write('using namespace megdnn;') - self._write('namespace mgb {') - self._write('namespace opr {') - self._write('template') - self._write('std::shared_ptr opr_param_to_json(const OprParam ¶m);') + self._write("#pragma once") + self._write("using namespace megdnn;") + self._write("namespace mgb {") + self._write("namespace opr {") + self._write("template") + self._write( + "std::shared_ptr opr_param_to_json(const OprParam ¶m);" + ) self._process(defs) - self._write('} // namespace opr') - self._write('} // namespace mgb') - self._write('\n// vim: syntax=cpp.doxygen') + self._write("} // namespace opr") + self._write("} // namespace mgb") + self._write("\n// vim: syntax=cpp.doxygen") def _on_param_begin(self, p): - self._write('template<>', indent=0) + self._write("template<>", indent=0) self._write( - 'std::shared_ptr opr_param_to_json(const param::%s &p) {', - p.name, indent=1) - self._param_name = 'param::{}'.format(p.name) + "std::shared_ptr opr_param_to_json(const param::%s &p) {", + p.name, + indent=1, + ) + self._param_name = "param::{}".format(p.name) self._items = [] def _on_param_end(self, p): - self._write('return json::Object::make({', indent=1) + self._write("return json::Object::make({", indent=1) for i in self._items: self._write(i, indent=0) - self._write('});', indent=-1) - self._write('}', indent=-1) + self._write("});", indent=-1) + self._write("}", indent=-1) def _on_member_enum(self, e): - self._write('auto %s2str = [](const %s::%s arg) -> std::string {', - e.name, self._param_name, e.name, indent=1) - self._write('switch (arg) {', indent=1) + self._write( + "auto %s2str = [](const %s::%s arg) -> std::string {", + e.name, + self._param_name, + e.name, + indent=1, + ) + self._write("switch (arg) {", indent=1) enum2str = [] if isinstance(e, member_defs.EnumAlias): members = e.src_enum.members @@ -859,15 +937,27 @@ class CPPParamJsonFuncWriter(IndentWriterBase): members = e.members for i in members: v = str(i) - v = v.split(' ')[0].split('=')[0] - self._write('case %s::%s::%s: return "%s";', - self._param_name, e.name, v, v, indent=0) - self._write('default: mgb_throw(MegBrainError, "Invalid %s::%s:%%d", static_cast(arg));', - self._param_name, e.name, indent=0) - self._write('}', indent=-1) - self._write('};', indent=-1) - self._items.append('{"%s", json::String::make(%s2str(p.%s))},' % ( - e.name_field, e.name, e.name_field)) + v = v.split(" ")[0].split("=")[0] + self._write( + 'case %s::%s::%s: return "%s";', + self._param_name, + e.name, + v, + v, + indent=0, + ) + self._write( + 'default: mgb_throw(MegBrainError, "Invalid %s::%s:%%d", static_cast(arg));', + self._param_name, + e.name, + indent=0, + ) + self._write("}", indent=-1) + self._write("};", indent=-1) + self._items.append( + '{"%s", json::String::make(%s2str(p.%s))},' + % (e.name_field, e.name, e.name_field) + ) def _on_member_enum_alias(self, e): self._on_member_enum(e) @@ -880,51 +970,58 @@ class CPPParamJsonFuncWriter(IndentWriterBase): def main(): - parser = argparse.ArgumentParser( - 'generate opr param defs from description file') - parser.add_argument('--enumv', action='store_true', - help='generate c++03 compatible code which only ' - 'contains enum values') - parser.add_argument('-t', '--type', choices=['c++', 'py'], default='c++', - help='output type') - parser.add_argument('--write-enum-items', - help='write enum item names to output file; argument ' - 'should be given in the CLASS:ENUM format') - parser.add_argument('--write-cppjson', - help='generate megbrain json serialization implemention' - 'cpp file') - parser.add_argument('input') - parser.add_argument('output') - parser.add_argument('--imperative', action='store_true', - help='generate files for imperatvie ') + parser = argparse.ArgumentParser("generate opr param defs from description file") + parser.add_argument( + "--enumv", + action="store_true", + help="generate c++03 compatible code which only " "contains enum values", + ) + parser.add_argument( + "-t", "--type", choices=["c++", "py"], default="c++", help="output type" + ) + parser.add_argument( + "--write-enum-items", + help="write enum item names to output file; argument " + "should be given in the CLASS:ENUM format", + ) + parser.add_argument( + "--write-cppjson", + help="generate megbrain json serialization implemention" "cpp file", + ) + parser.add_argument("input") + parser.add_argument("output") + parser.add_argument( + "--imperative", action="store_true", help="generate files for imperatvie " + ) args = parser.parse_args() for_imperative = args.imperative with open(args.input) as fin: inputs = fin.read() - exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) + exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) input_hash = hashlib.sha256() - input_hash.update(inputs.encode(encoding='UTF-8')) + input_hash.update(inputs.encode(encoding="UTF-8")) input_hash = input_hash.hexdigest() - if args.type == 'py': + if args.type == "py": writer = PyWriter(for_imperative=for_imperative) else: - assert args.type == 'c++' + assert args.type == "c++" if args.enumv: writer = CPPEnumValueWriter() elif args.write_enum_items: writer = CPPEnumItemWriter(args.write_enum_items) else: writer = CPPWriter() - with open(args.output, 'w') as fout: + with open(args.output, "w") as fout: writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) if args.write_cppjson: writer = CPPParamJsonFuncWriter() - with open(args.write_cppjson, 'w') as fout: + with open(args.write_cppjson, "w") as fout: writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/dnn/scripts/gen_tablegen.py b/dnn/scripts/gen_tablegen.py index 4de6eb65..c5e72a35 100755 --- a/dnn/scripts/gen_tablegen.py +++ b/dnn/scripts/gen_tablegen.py @@ -3,19 +3,17 @@ import argparse import collections -import textwrap -import os import hashlib -import struct import io +import os +import struct +import textwrap -from gen_param_defs import member_defs, ParamDef, IndentWriterBase +from gen_param_defs import IndentWriterBase, ParamDef, member_defs # FIXME: move supportToString flag definition into the param def source file -ENUM_TO_STRING_SPECIAL_RULES = [ - ("Elemwise", "Mode"), - ("ElemwiseMultiType", "Mode") -] +ENUM_TO_STRING_SPECIAL_RULES = [("Elemwise", "Mode"), ("ElemwiseMultiType", "Mode")] + class ConverterWriter(IndentWriterBase): _skip_current_param = False @@ -33,21 +31,21 @@ class ConverterWriter(IndentWriterBase): self._write("#endif // MGB_PARAM") def _ctype2attr(self, ctype, value): - if ctype == 'uint32_t': - return 'MgbUI32Attr', value - if ctype == 'uint64_t': - return 'MgbUI64Attr', value - if ctype == 'int32_t': - return 'MgbI32Attr', value - if ctype == 'float': - return 'MgbF32Attr', value - if ctype == 'double': - return 'MgbF64Attr', value - if ctype == 'bool': - return 'MgbBoolAttr', value - if ctype == 'DTypeEnum': + if ctype == "uint32_t": + return "MgbUI32Attr", value + if ctype == "uint64_t": + return "MgbUI64Attr", value + if ctype == "int32_t": + return "MgbI32Attr", value + if ctype == "float": + return "MgbF32Attr", value + if ctype == "double": + return "MgbF64Attr", value + if ctype == "bool": + return "MgbBoolAttr", value + if ctype == "DTypeEnum": self._packed = False - return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value) + return "MgbDTypeAttr", "megdnn::DType::from_enum(megdnn::{})".format(value) raise RuntimeError("unknown ctype") def _on_param_begin(self, p): @@ -61,21 +59,26 @@ class ConverterWriter(IndentWriterBase): self._skip_current_param = False return if self._packed: - self._write("class {0}ParamBase : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1) + self._write( + 'class {0}ParamBase : MgbPackedParamBase<"{0}", accessor> {{'.format( + p.name + ), + indent=1, + ) else: - self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1) + self._write('def {0}Param: MgbParamBase<"{0}"> {{'.format(p.name), indent=1) self._write("let fields = (ins", indent=1) self._write(",\n{}".format(self._cur_indent).join(self._current_tparams)) self._write(");", indent=-1) self._write("}\n", indent=-1) if self._packed: - self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name)) + self._write('def {0}Param : {0}ParamBase<"param">;\n'.format(p.name)) self._current_tparams = None self._packed = None self._const = None def _wrapped_with_default_value(self, attr, default): - return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default) + return 'MgbDefaultValuedAttr<{}, "{}">'.format(attr, default) def _on_member_enum(self, e): p = self._last_param @@ -84,10 +87,12 @@ class ConverterWriter(IndentWriterBase): # directly used by any operator, or other enum couldn't alias to this enum td_class = "{}{}".format(p.name, e.name) fullname = "::megdnn::param::{}".format(p.name) - enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name) + enum_def = 'MgbEnumAttr<"{}", "{}", ['.format(fullname, e.name) + def format(v): - return '\"{}\"'.format(str(v).split(' ')[0].split('=')[0]) - enum_def += ','.join(format(i) for i in e.members) + return '"{}"'.format(str(v).split(" ")[0].split("=")[0]) + + enum_def += ",".join(format(i) for i in e.members) if e.combined: enum_def += "], 1" @@ -95,7 +100,7 @@ class ConverterWriter(IndentWriterBase): enum_def += "], 0" if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): - enum_def += ", 1" # whether generate ToStringTrait + enum_def += ", 1" # whether generate ToStringTrait enum_def += ">" self._write("def {} : {};".format(td_class, enum_def)) @@ -105,10 +110,12 @@ class ConverterWriter(IndentWriterBase): # wrapped with default value if e.combined: default_val = "static_cast<{}::{}>({})".format( - fullname, e.name, e.compose_combined_enum(e.default)) + fullname, e.name, e.compose_combined_enum(e.default) + ) else: default_val = "{}::{}::{}".format( - fullname, e.name, str(e.members[e.default]).split(' ')[0].split('=')[0]) + fullname, e.name, str(e.members[e.default]).split(" ")[0].split("=")[0] + ) wrapped = self._wrapped_with_default_value(td_class, default_val) @@ -123,51 +130,58 @@ class ConverterWriter(IndentWriterBase): td_class = "{}{}".format(p.name, e.name) fullname = "::megdnn::param::{}".format(p.name) base_td_class = "{}{}".format(e.src_class, e.src_name) - enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class) + enum_def = 'MgbEnumAliasAttr<"{}", "{}", {}>'.format( + fullname, e.name, base_td_class + ) self._write("def {} : {};".format(td_class, enum_def)) # wrapped with default value s = e.src_enum if s.combined: default_val = "static_cast<{}::{}>({})".format( - fullname, e.name, s.compose_combined_enum(e.get_default())) + fullname, e.name, s.compose_combined_enum(e.get_default()) + ) else: - default_val = "{}::{}::{}".format(fullname, e.name, str( - s.members[e.get_default()]).split(' ')[0].split('=')[0]) + default_val = "{}::{}::{}".format( + fullname, + e.name, + str(s.members[e.get_default()]).split(" ")[0].split("=")[0], + ) wrapped = self._wrapped_with_default_value(td_class, default_val) self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) - def _on_member_field(self, f): if self._skip_current_param: return attr, value = self._ctype2attr(f.dtype.cname, str(f.default)) if str(value) in self._const: - value = '::megdnn::param::{}::{}'.format(self._last_param.name, value) + value = "::megdnn::param::{}::{}".format(self._last_param.name, value) wrapped = self._wrapped_with_default_value(attr, value) self._current_tparams.append("{}:${}".format(wrapped, f.name)) def _on_const_field(self, f): self._const.add(str(f.name)) + def main(): - parser = argparse.ArgumentParser('generate op param tablegen file') - parser.add_argument('input') - parser.add_argument('output') + parser = argparse.ArgumentParser("generate op param tablegen file") + parser.add_argument("input") + parser.add_argument("output") args = parser.parse_args() with open(args.input) as fin: inputs = fin.read() - exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) + exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc}) input_hash = hashlib.sha256() - input_hash.update(inputs.encode(encoding='UTF-8')) + input_hash.update(inputs.encode(encoding="UTF-8")) input_hash = input_hash.hexdigest() writer = ConverterWriter() - with open(args.output, 'w') as fout: + with open(args.output, "w") as fout: writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) + if __name__ == "__main__": main() diff --git a/tools/evaluation_model_parallelism.py b/tools/evaluation_model_parallelism.py index 1f547212..53228388 100644 --- a/tools/evaluation_model_parallelism.py +++ b/tools/evaluation_model_parallelism.py @@ -19,6 +19,7 @@ device = { "thread_number": 3, } + class SshConnector: """imp ssh control master connector""" @@ -83,17 +84,17 @@ def main(): model_file = args.model_file # copy model file ssh.copy([args.model_file], workspace) - m = model_file.split('\\')[-1] + m = model_file.split("\\")[-1] # run single thread result = [] thread_number = [1, 2, 4] - for b in thread_number : + for b in thread_number: cmd = [] cmd1 = "cd {} && ./load_and_run {} -multithread {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format( - workspace, m, b + workspace, m, b ) cmd2 = "cd {} && ./load_and_run {} -multithread {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess ".format( - workspace, m, b + workspace, m, b ) cmd.append(cmd1) cmd.append(cmd2) @@ -103,12 +104,20 @@ def main(): logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret)) result.append(ret) - thread_2 = result[0]/result[1] - thread_4 = result[0]/result[2] + thread_2 = result[0] / result[1] + thread_4 = result[0] / result[2] if thread_2 > 1.6 or thread_4 > 3.0: - print("model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4)) + print( + "model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format( + m, thread_2, thread_4 + ) + ) else: - print("model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4)) + print( + "model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format( + m, thread_2, thread_4 + ) + ) if __name__ == "__main__": diff --git a/tools/format.py b/tools/format.py index 0e6ce625..a1e04d2a 100755 --- a/tools/format.py +++ b/tools/format.py @@ -20,8 +20,12 @@ failed_files = Manager().list() def process_file(file, clang_format, write): original_source = open(file, "r").read() source = original_source - source = re.sub(r"MGB_DEFINE(?P([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g{", source) - source, count = re.subn(r"(?([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g{", source + ) + source, count = re.subn( + r"(?