Browse Source

style(mgb/tools): add format for tools, dnn and ci

GitOrigin-RevId: 5684e5ea43
dev-support-lite-fork-debug-mode
Megvii Engine Team 3 years ago
parent
commit
421bcfd3d8
23 changed files with 1295 additions and 843 deletions
  1. +1
    -1
      dnn/scripts/cutlass_generator/conv2d_operation.py
  2. +2
    -3
      dnn/scripts/cutlass_generator/gemm_operation.py
  3. +8
    -6
      dnn/scripts/cutlass_generator/gen_list.py
  4. +180
    -118
      dnn/scripts/cutlass_generator/generator.py
  5. +1
    -1
      dnn/scripts/cutlass_generator/library.py
  6. +2
    -2
      dnn/scripts/cutlass_generator/manifest.py
  7. +44
    -36
      dnn/scripts/gen_cond_take_kern_impls.py
  8. +40
    -24
      dnn/scripts/gen_cuda_batch_conv_bias_kern_impls.py
  9. +49
    -25
      dnn/scripts/gen_cuda_conv_bias_kern_impls.py
  10. +18
    -13
      dnn/scripts/gen_elemwise_each_mode.py
  11. +34
    -27
      dnn/scripts/gen_elemwise_kern_impls.py
  12. +35
    -21
      dnn/scripts/gen_elemwise_multi_type_kern_impls.py
  13. +113
    -30
      dnn/scripts/gen_elemwise_multi_type_utils.py
  14. +32
    -27
      dnn/scripts/gen_elemwise_special_kern_impls.py
  15. +91
    -31
      dnn/scripts/gen_elemwise_utils.py
  16. +44
    -32
      dnn/scripts/gen_flatbuffers_converter.py
  17. +33
    -21
      dnn/scripts/gen_flatbuffers_schema.py
  18. +65
    -49
      dnn/scripts/gen_heuristic/gen_heuristic.py
  19. +406
    -309
      dnn/scripts/gen_param_defs.py
  20. +58
    -44
      dnn/scripts/gen_tablegen.py
  21. +17
    -8
      tools/evaluation_model_parallelism.py
  22. +17
    -13
      tools/format.py
  23. +5
    -2
      tools/test_model_static.py

+ 1
- 1
dnn/scripts/cutlass_generator/conv2d_operation.py View File

@@ -8,7 +8,7 @@
import enum import enum
import os.path import os.path
import shutil import shutil
from typing import Tuple, List
from typing import List, Tuple


from library import * from library import *




+ 2
- 3
dnn/scripts/cutlass_generator/gemm_operation.py View File

@@ -5,14 +5,13 @@
# #


import enum import enum
import os.path
import shutil
import functools import functools
import operator import operator
import os.path
import shutil


from library import * from library import *



################################################################################################### ###################################################################################################
# #
# Data structure modeling a GEMM operation # Data structure modeling a GEMM operation


+ 8
- 6
dnn/scripts/cutlass_generator/gen_list.py View File

@@ -1,11 +1,11 @@
from generator import (
GenerateGemmOperations,
GenerateGemvOperations,
from generator import ( # isort: skip; isort: skip
GenerateConv2dOperations, GenerateConv2dOperations,
GenerateDeconvOperations, GenerateDeconvOperations,
GenerateDwconv2dFpropOperations,
GenerateDwconv2dDgradOperations, GenerateDwconv2dDgradOperations,
GenerateDwconv2dFpropOperations,
GenerateDwconv2dWgradOperations, GenerateDwconv2dWgradOperations,
GenerateGemmOperations,
GenerateGemvOperations,
) )




@@ -35,12 +35,14 @@ def write_op_list(f, gen_op, gen_type):
if gen_op != "gemv": if gen_op != "gemv":
f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type)) f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type))



# Write down a list of merged filenames # Write down a list of merged filenames
def write_merge_file_name(f, gen_op, gen_type, split_number): def write_merge_file_name(f, gen_op, gen_type, split_number):
for i in range(0, split_number): for i in range(0, split_number):
f.write(' "{}_{}_{}.cu",\n'.format(gen_op,gen_type,i))
f.write(' "{}_{}_{}.cu",\n'.format(gen_op, gen_type, i))
if gen_op != "gemv": if gen_op != "gemv":
f.write(' "all_{}_{}_operations.cu",\n'.format(gen_op,gen_type))
f.write(' "all_{}_{}_operations.cu",\n'.format(gen_op, gen_type))



if __name__ == "__main__": if __name__ == "__main__":
with open("list.bzl", "w") as f: with open("list.bzl", "w") as f:


+ 180
- 118
dnn/scripts/cutlass_generator/generator.py View File

@@ -4,12 +4,12 @@
# \brief Generates the CUTLASS Library's instances # \brief Generates the CUTLASS Library's instances
# #


import argparse
import enum import enum
import os.path import os.path
import shutil
import argparse
import platform import platform
import string import string

from library import * from library import *
from manifest import * from manifest import *


@@ -899,9 +899,12 @@ def GenerateGemm_Simt(args):
warpShapes.append([warp0, warp1]) warpShapes.append([warp0, warp1])


# sgemm # sgemm
precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[
"s"
]
(
precisionType,
precisionBits,
threadblockMaxElements,
threadblockTilesL0,
) = precisions["s"]


layouts = [ layouts = [
(LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
@@ -1091,9 +1094,12 @@ def GenerateDwconv2d_Simt(args, conv_kind):
warpShapes.append([warp0, warp1]) warpShapes.append([warp0, warp1])


# sgemm # sgemm
precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[
"s"
]
(
precisionType,
precisionBits,
threadblockMaxElements,
threadblockTilesL0,
) = precisions["s"]


layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)] layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]


@@ -1304,7 +1310,7 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind):
for dst_type, dst_layout in zip(dst_types, dst_layouts): for dst_type, dst_layout in zip(dst_types, dst_layouts):
for alignment_src in alignment_constraints: for alignment_src in alignment_constraints:
if conv_kind == ConvKind.Wgrad: if conv_kind == ConvKind.Wgrad:
# skip io16xc16
# skip io16xc16
if math_inst.element_accumulator == DataType.f16: if math_inst.element_accumulator == DataType.f16:
continue continue
for alignment_diff in alignment_constraints: for alignment_diff in alignment_constraints:
@@ -1319,7 +1325,7 @@ def GenerateDwconv2d_TensorOp_884(args, conv_kind):
min_cc, min_cc,
alignment_src, alignment_src,
alignment_diff, alignment_diff,
32, # always f32 output
32, # always f32 output
SpecialOptimizeDesc.NoneSpecialOpt, SpecialOptimizeDesc.NoneSpecialOpt,
ImplicitGemmMode.GemmNT, ImplicitGemmMode.GemmNT,
False, False,
@@ -1656,6 +1662,7 @@ def GenerateGemvOperations(args):
) )
return GenerateGemv_Simt(args) return GenerateGemv_Simt(args)



################################################################################ ################################################################################
# parameters # parameters
# split_number - the concated file will be divided into split_number parts # split_number - the concated file will be divided into split_number parts
@@ -1668,10 +1675,21 @@ def GenerateGemvOperations(args):
# epilogue - the epilogue in the file # epilogue - the epilogue in the file
# wrapper_path - wrapper path # wrapper_path - wrapper path
################################################################################ ################################################################################
def ConcatFile(split_number:int, file_path:str,operations:str,type:str,head:str,required_cuda_ver_major:str, required_cuda_ver_minor:str, epilogue:str, wrapper_path = None):
def ConcatFile(
split_number: int,
file_path: str,
operations: str,
type: str,
head: str,
required_cuda_ver_major: str,
required_cuda_ver_minor: str,
epilogue: str,
wrapper_path=None,
):
import os import os

meragefiledir = file_path meragefiledir = file_path
filenames=os.listdir(meragefiledir)
filenames = os.listdir(meragefiledir)
# filter file # filter file
if "tensorop" in type: if "tensorop" in type:
sub_string_1 = "tensorop" sub_string_1 = "tensorop"
@@ -1679,197 +1697,183 @@ def ConcatFile(split_number:int, file_path:str,operations:str,type:str,head:str,
else: else:
sub_string_1 = sub_string_2 = "simt" sub_string_1 = sub_string_2 = "simt"
if "dwconv2d_" in operations: if "dwconv2d_" in operations:
filtered_operations = operations[:2]+operations[9:]
filtered_operations = operations[:2] + operations[9:]
elif ("conv2d" in operations) or ("deconv" in operations): elif ("conv2d" in operations) or ("deconv" in operations):
filtered_operations = "cutlass" filtered_operations = "cutlass"
else: else:
filtered_operations = operations filtered_operations = operations
#get the file list number
# get the file list number
file_list = {} file_list = {}
file_list[operations + type] = 0 file_list[operations + type] = 0
for filename in filenames: for filename in filenames:
if (filtered_operations in filename) and (sub_string_1 in filename) and (sub_string_2 in filename) and ("all_" not in filename):
if (
(filtered_operations in filename)
and (sub_string_1 in filename)
and (sub_string_2 in filename)
and ("all_" not in filename)
):
file_list[operations + type] += 1 file_list[operations + type] += 1
#concat file for linux
# concat file for linux
flag_1 = 0 flag_1 = 0
flag_2 = 0 flag_2 = 0
for filename in filenames: for filename in filenames:
if (filtered_operations in filename) and (sub_string_1 in filename) and (sub_string_2 in filename) and ("all_" not in filename):
if (
(filtered_operations in filename)
and (sub_string_1 in filename)
and (sub_string_2 in filename)
and ("all_" not in filename)
):
flag_1 += 1 flag_1 += 1
filepath=meragefiledir+'/'+filename
if (flag_1 >= flag_2 * (file_list[operations + type]/split_number)) and (flag_1 <= (flag_2 + 1) * (file_list[operations + type]/split_number)):
file =open(file_path + '/{}_{}_{}.cu'.format(operations,type, flag_2),'a')
#write Template at the head
filepath = meragefiledir + "/" + filename
if (flag_1 >= flag_2 * (file_list[operations + type] / split_number)) and (
flag_1 <= (flag_2 + 1) * (file_list[operations + type] / split_number)
):
file = open(
file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a"
)
# write Template at the head
if wrapper_path is None: if wrapper_path is None:
file.write( file.write(
SubstituteTemplate( SubstituteTemplate(
head, head,
{ {
"required_cuda_ver_major": str(
required_cuda_ver_major
),
"required_cuda_ver_minor": str(
required_cuda_ver_minor
),
"required_cuda_ver_major": str(required_cuda_ver_major),
"required_cuda_ver_minor": str(required_cuda_ver_minor),
}, },
) )
) )
else: else:
file.write( file.write(
SubstituteTemplate(
head,
{
"wrapper_path": wrapper_path,
"required_cuda_ver_major": str(
required_cuda_ver_major
),
"required_cuda_ver_minor": str(
required_cuda_ver_minor
),
},
)
SubstituteTemplate(
head,
{
"wrapper_path": wrapper_path,
"required_cuda_ver_major": str(required_cuda_ver_major),
"required_cuda_ver_minor": str(required_cuda_ver_minor),
},
) )
)
# concat all the remaining files # concat all the remaining files
if flag_2 == (split_number - 1): if flag_2 == (split_number - 1):
for line in open(filepath): for line in open(filepath):
file.writelines(line) file.writelines(line)
os.remove(filepath) os.remove(filepath)
file.write('\n')
file.write("\n")
file.write(epilogue) file.write(epilogue)
continue continue
for line in open(filepath): for line in open(filepath):
file.writelines(line) file.writelines(line)
os.remove(filepath) os.remove(filepath)
file.write('\n')
file.write("\n")
file.write(epilogue) file.write(epilogue)
else: else:
#write Template at the head
# write Template at the head
if wrapper_path is None: if wrapper_path is None:
file.write( file.write(
SubstituteTemplate( SubstituteTemplate(
head, head,
{ {
"required_cuda_ver_major": str(
required_cuda_ver_major
),
"required_cuda_ver_minor": str(
required_cuda_ver_minor
),
"required_cuda_ver_major": str(required_cuda_ver_major),
"required_cuda_ver_minor": str(required_cuda_ver_minor),
}, },
) )
) )
else: else:
file.write( file.write(
SubstituteTemplate(
head,
{
"wrapper_path": wrapper_path,
"required_cuda_ver_major": str(
required_cuda_ver_major
),
"required_cuda_ver_minor": str(
required_cuda_ver_minor
),
},
)
SubstituteTemplate(
head,
{
"wrapper_path": wrapper_path,
"required_cuda_ver_major": str(required_cuda_ver_major),
"required_cuda_ver_minor": str(required_cuda_ver_minor),
},
) )
)
for line in open(filepath): for line in open(filepath):
file.writelines(line) file.writelines(line)
os.remove(filepath) os.remove(filepath)
file.write('\n')
file.write("\n")
file.write(epilogue) file.write(epilogue)
file.close() file.close()
flag_2 += 1 flag_2 += 1



#concat file for windows
# concat file for windows
elif filename[0].isdigit() and ("all_" not in filename): elif filename[0].isdigit() and ("all_" not in filename):
flag_1 += 1 flag_1 += 1
filepath=meragefiledir+'/'+filename
if (flag_1 >= flag_2 * (len(filenames)/split_number)) and (flag_1 <= (flag_2 + 1) * (len(filenames)/split_number)):
file =open(file_path + '/{}_{}_{}.cu'.format(operations,type, flag_2),'a')
#write Template at the head
filepath = meragefiledir + "/" + filename
if (flag_1 >= flag_2 * (len(filenames) / split_number)) and (
flag_1 <= (flag_2 + 1) * (len(filenames) / split_number)
):
file = open(
file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a"
)
# write Template at the head
if wrapper_path is None: if wrapper_path is None:
file.write( file.write(
SubstituteTemplate( SubstituteTemplate(
head, head,
{ {
"required_cuda_ver_major": str(
required_cuda_ver_major
),
"required_cuda_ver_minor": str(
required_cuda_ver_minor
),
"required_cuda_ver_major": str(required_cuda_ver_major),
"required_cuda_ver_minor": str(required_cuda_ver_minor),
}, },
) )
) )
else: else:
file.write( file.write(
SubstituteTemplate(
head,
{
"wrapper_path": wrapper_path,
"required_cuda_ver_major": str(
required_cuda_ver_major
),
"required_cuda_ver_minor": str(
required_cuda_ver_minor
),
},
)
SubstituteTemplate(
head,
{
"wrapper_path": wrapper_path,
"required_cuda_ver_major": str(required_cuda_ver_major),
"required_cuda_ver_minor": str(required_cuda_ver_minor),
},
) )
)
# concat all the remaining files # concat all the remaining files
if flag_2 == (split_number - 1): if flag_2 == (split_number - 1):
for line in open(filepath): for line in open(filepath):
file.writelines(line) file.writelines(line)
os.remove(filepath) os.remove(filepath)
file.write('\n')
file.write("\n")
file.write(epilogue) file.write(epilogue)
continue continue
for line in open(filepath): for line in open(filepath):
file.writelines(line) file.writelines(line)
os.remove(filepath) os.remove(filepath)
file.write('\n')
file.write("\n")
file.write(epilogue) file.write(epilogue)
else: else:
#write Template at the head
# write Template at the head
if wrapper_path is None: if wrapper_path is None:
file.write( file.write(
SubstituteTemplate( SubstituteTemplate(
head, head,
{ {
"required_cuda_ver_major": str(
required_cuda_ver_major
),
"required_cuda_ver_minor": str(
required_cuda_ver_minor
),
"required_cuda_ver_major": str(required_cuda_ver_major),
"required_cuda_ver_minor": str(required_cuda_ver_minor),
}, },
) )
) )
else: else:
file.write( file.write(
SubstituteTemplate(
head,
{
"wrapper_path": wrapper_path,
"required_cuda_ver_major": str(
required_cuda_ver_major
),
"required_cuda_ver_minor": str(
required_cuda_ver_minor
),
},
)
SubstituteTemplate(
head,
{
"wrapper_path": wrapper_path,
"required_cuda_ver_major": str(required_cuda_ver_major),
"required_cuda_ver_minor": str(required_cuda_ver_minor),
},
) )
)
for line in open(filepath): for line in open(filepath):
file.writelines(line) file.writelines(line)
os.remove(filepath) os.remove(filepath)
file.write('\n')
file.write("\n")
file.write(epilogue) file.write(epilogue)
file.close() file.close()
flag_2 += 1 flag_2 += 1



################################################################################################### ###################################################################################################
################################################################################################### ###################################################################################################


@@ -1940,39 +1944,97 @@ if __name__ == "__main__":
args.output, operation, short_path args.output, operation, short_path
) as emitter: ) as emitter:
emitter.emit() emitter.emit()
head = EmitConvSingleKernelWrapper(args.output, operations[0], short_path).header_template
head = EmitConvSingleKernelWrapper(
args.output, operations[0], short_path
).header_template
required_cuda_ver_major = operations[0].required_cuda_ver_major required_cuda_ver_major = operations[0].required_cuda_ver_major
required_cuda_ver_minor = operations[0].required_cuda_ver_minor required_cuda_ver_minor = operations[0].required_cuda_ver_minor
epilogue = EmitConvSingleKernelWrapper(args.output, operations[0], short_path).epilogue_template
epilogue = EmitConvSingleKernelWrapper(
args.output, operations[0], short_path
).epilogue_template
if "tensorop" in args.type: if "tensorop" in args.type:
ConcatFile(4, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue)
ConcatFile(
4,
args.output,
args.operations,
args.type,
head,
required_cuda_ver_major,
required_cuda_ver_minor,
epilogue,
)
else: else:
ConcatFile(2, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue)
ConcatFile(
2,
args.output,
args.operations,
args.type,
head,
required_cuda_ver_major,
required_cuda_ver_minor,
epilogue,
)
elif args.operations == "gemm": elif args.operations == "gemm":
for operation in operations: for operation in operations:
with EmitGemmSingleKernelWrapper( with EmitGemmSingleKernelWrapper(
args.output, operation, short_path args.output, operation, short_path
) as emitter: ) as emitter:
emitter.emit() emitter.emit()
head = EmitGemmSingleKernelWrapper(args.output, operations[0], short_path).header_template
head = EmitGemmSingleKernelWrapper(
args.output, operations[0], short_path
).header_template
required_cuda_ver_major = operations[0].required_cuda_ver_major required_cuda_ver_major = operations[0].required_cuda_ver_major
required_cuda_ver_minor = operations[0].required_cuda_ver_minor required_cuda_ver_minor = operations[0].required_cuda_ver_minor
epilogue = EmitGemmSingleKernelWrapper(args.output, operations[0], short_path).epilogue_template
epilogue = EmitGemmSingleKernelWrapper(
args.output, operations[0], short_path
).epilogue_template
if args.type == "tensorop884": if args.type == "tensorop884":
ConcatFile(30, args.output, args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue)
ConcatFile(
30,
args.output,
args.operations,
args.type,
head,
required_cuda_ver_major,
required_cuda_ver_minor,
epilogue,
)
else: else:
ConcatFile(2, args.output, args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue)
ConcatFile(
2,
args.output,
args.operations,
args.type,
head,
required_cuda_ver_major,
required_cuda_ver_minor,
epilogue,
)
elif args.operations == "gemv": elif args.operations == "gemv":
for operation in operations: for operation in operations:
with EmitGemvSingleKernelWrapper( with EmitGemvSingleKernelWrapper(
args.output, operation, gemv_wrapper_path, short_path args.output, operation, gemv_wrapper_path, short_path
) as emitter: ) as emitter:
emitter.emit() emitter.emit()
head = EmitGemvSingleKernelWrapper(args.output, operations[0], gemv_wrapper_path, short_path).header_template
head = EmitGemvSingleKernelWrapper(
args.output, operations[0], gemv_wrapper_path, short_path
).header_template
required_cuda_ver_major = operations[0].required_cuda_ver_major required_cuda_ver_major = operations[0].required_cuda_ver_major
required_cuda_ver_minor = operations[0].required_cuda_ver_minor required_cuda_ver_minor = operations[0].required_cuda_ver_minor
epilogue = EmitGemvSingleKernelWrapper(args.output, operations[0], gemv_wrapper_path, short_path).epilogue_template
ConcatFile(2, args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue, wrapper_path = gemv_wrapper_path)
epilogue = EmitGemvSingleKernelWrapper(
args.output, operations[0], gemv_wrapper_path, short_path
).epilogue_template
ConcatFile(
2,
args.output,
args.operations,
args.type,
head,
required_cuda_ver_major,
required_cuda_ver_minor,
epilogue,
wrapper_path=gemv_wrapper_path,
)


if args.operations != "gemv": if args.operations != "gemv":
GenerateManifest(args, operations, args.output) GenerateManifest(args, operations, args.output)


+ 1
- 1
dnn/scripts/cutlass_generator/library.py View File

@@ -4,11 +4,11 @@
# \brief Generates the CUTLASS Library's instances # \brief Generates the CUTLASS Library's instances
# #


import enum
import re import re


################################################################################################### ###################################################################################################


import enum


# The following block implements enum.auto() for Python 3.5 variants that don't include it such # The following block implements enum.auto() for Python 3.5 variants that don't include it such
# as the default 3.5.2 on Ubuntu 16.04. # as the default 3.5.2 on Ubuntu 16.04.


+ 2
- 2
dnn/scripts/cutlass_generator/manifest.py View File

@@ -8,9 +8,9 @@ import enum
import os.path import os.path
import shutil import shutil


from library import *
from gemm_operation import *
from conv2d_operation import * from conv2d_operation import *
from gemm_operation import *
from library import *


################################################################################################### ###################################################################################################




+ 44
- 36
dnn/scripts/gen_cond_take_kern_impls.py View File

@@ -1,59 +1,67 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-


import os
import argparse import argparse
import os

from gen_elemwise_utils import DTYPES from gen_elemwise_utils import DTYPES



def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='generate elemwise impl files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--type', type=str, choices=['cuda'],
default='cuda',
help='generate cuda cond take kernel file')
parser.add_argument('output', help='output directory')
description="generate elemwise impl files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--type",
type=str,
choices=["cuda"],
default="cuda",
help="generate cuda cond take kernel file",
)
parser.add_argument("output", help="output directory")
args = parser.parse_args() args = parser.parse_args()


if not os.path.isdir(args.output): if not os.path.isdir(args.output):
os.makedirs(args.output) os.makedirs(args.output)


assert args.type =='cuda'
cpp_ext = 'cu'
assert args.type == "cuda"
cpp_ext = "cu"


for dtype in DTYPES.keys(): for dtype in DTYPES.keys():
fname = '{}.{}'.format(dtype, cpp_ext)
fname = "{}.{}".format(dtype, cpp_ext)
fname = os.path.join(args.output, fname) fname = os.path.join(args.output, fname)
with open(fname, 'w') as fout:
with open(fname, "w") as fout:
w = lambda s: print(s, file=fout) w = lambda s: print(s, file=fout)


w('// generated by gen_cond_take_kern_impls.py')
w("// generated by gen_cond_take_kern_impls.py")
w('#include "../kern.inl"') w('#include "../kern.inl"')
w('')
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#if !MEGDNN_DISABLE_FLOAT16')
w('namespace megdnn {')
w('namespace cuda {')
w('namespace cond_take {')
w('')
w('inst_genidx(::megdnn::dtype::{})'.format(DTYPES[dtype][0]))
w('#undef inst_genidx')
w('')
w('inst_copy(::megdnn::dtype::{})'.format(DTYPES[dtype][0]))
w('#undef inst_copy')
w('#undef inst_copy_')
w('')
w('} // cond_take')
w('} // cuda')
w('} // megdnn')
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#endif')
print('generated {}'.format(fname))
w("")
if dtype == "dt_float16" or dtype == "dt_bfloat16":
w("#if !MEGDNN_DISABLE_FLOAT16")
w("namespace megdnn {")
w("namespace cuda {")
w("namespace cond_take {")
w("")
w("inst_genidx(::megdnn::dtype::{})".format(DTYPES[dtype][0]))
w("#undef inst_genidx")
w("")
w("inst_copy(::megdnn::dtype::{})".format(DTYPES[dtype][0]))
w("#undef inst_copy")
w("#undef inst_copy_")
w("")
w("} // cond_take")
w("} // cuda")
w("} // megdnn")
if dtype == "dt_float16" or dtype == "dt_bfloat16":
w("#endif")
print("generated {}".format(fname))


os.utime(args.output) os.utime(args.output)


if __name__ == '__main__':

if __name__ == "__main__":
main() main()

+ 40
- 24
dnn/scripts/gen_cuda_batch_conv_bias_kern_impls.py View File

@@ -1,37 +1,47 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-


import os
import argparse import argparse
import itertools import itertools
import os

PREFIXES = {
"dp4a": [
("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True),
("batch_conv_bias_int8_gemm_ncdiv4hw4", False),
("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False),
]
}


PREFIXES = {"dp4a": [("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True), ("batch_conv_bias_int8_gemm_ncdiv4hw4", False), ("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False)]}
ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")}


ACTIVATIONS = {1: ("IDENTITY", "_id"),
2: ("RELU", "_relu"),
3: ("H_SWISH", "_hswish")}
BIASES = {
1: ("PerElementBiasVisitor", "_per_elem"),
2: ("PerChannelBiasVisitor", "_per_chan"),
}


BIASES = {1: ("PerElementBiasVisitor", "_per_elem"),
2: ("PerChannelBiasVisitor", "_per_chan")}
SUFFIXES = {"dp4a": [""], "imma": [""]}


SUFFIXES = {"dp4a": [""],
"imma": [""]}


def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='generate cuda batch conv bias (dp4a/imma) kern impl files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--type', type=str, choices=['dp4a',
'imma'],
default='dp4a', help='generate cuda conv bias kernel file')
parser.add_argument('output', help='output directory')
description="generate cuda batch conv bias (dp4a/imma) kern impl files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--type",
type=str,
choices=["dp4a", "imma"],
default="dp4a",
help="generate cuda conv bias kernel file",
)
parser.add_argument("output", help="output directory")
args = parser.parse_args() args = parser.parse_args()


if not os.path.isdir(args.output): if not os.path.isdir(args.output):
os.makedirs(args.output) os.makedirs(args.output)


inst = '''
inst = """
template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS,
IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>>>( IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>>>(
const int8_t* d_src, const int8_t* d_src,
@@ -41,7 +51,7 @@ template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS,
const ConvParam& param, const ConvParam& param,
float alpha, float alpha,
float beta, float beta,
cudaStream_t stream);'''
cudaStream_t stream);"""


for prefix in PREFIXES[args.type]: for prefix in PREFIXES[args.type]:
for suffix in SUFFIXES[args.type]: for suffix in SUFFIXES[args.type]:
@@ -52,17 +62,23 @@ template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS,
fname = os.path.join(args.output, fname) fname = os.path.join(args.output, fname)
with open(fname, "w") as fout: with open(fname, "w") as fout:
w = lambda s: print(s, file=fout) w = lambda s: print(s, file=fout)
w('// generated by gen_batch_cuda_conv_bias_kern_impls.py')
cur_inst = inst.replace("PREFIX", prefix[0]).replace("SUFFIX", suffix).replace("BIAS", bias[0]).replace("ACTIVATION", act[0])
w("// generated by gen_batch_cuda_conv_bias_kern_impls.py")
cur_inst = (
inst.replace("PREFIX", prefix[0])
.replace("SUFFIX", suffix)
.replace("BIAS", bias[0])
.replace("ACTIVATION", act[0])
)
if has_workspace: if has_workspace:
cur_inst = cur_inst.replace("WORKSPACE", "\nint* d_workspace, ") cur_inst = cur_inst.replace("WORKSPACE", "\nint* d_workspace, ")
else: else:
cur_inst = cur_inst.replace("WORKSPACE", "")
cur_inst = cur_inst.replace("WORKSPACE", "")
w('#include "../{}{}.cuinl"'.format(prefix[0], suffix)) w('#include "../{}{}.cuinl"'.format(prefix[0], suffix))
w(cur_inst) w(cur_inst)
print('generated {}'.format(fname))
print("generated {}".format(fname))
os.utime(args.output) os.utime(args.output)


if __name__ == '__main__':

if __name__ == "__main__":
main() main()

+ 49
- 25
dnn/scripts/gen_cuda_conv_bias_kern_impls.py View File

@@ -1,39 +1,57 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-


import os
import argparse import argparse
import itertools import itertools
import os

PREFIXES = {
"dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4",
"imma": "conv_bias_int8_implicit_gemm",
}


PREFIXES = {"dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4", "imma": "conv_bias_int8_implicit_gemm"}
ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")}


ACTIVATIONS = {1: ("IDENTITY", "_id"),
2: ("RELU", "_relu"),
3: ("H_SWISH", "_hswish")}
BIASES = {
1: ("PerElementBiasVisitor", "_per_elem"),
2: ("PerChannelBiasVisitor", "_per_chan"),
}


BIASES = {1: ("PerElementBiasVisitor", "_per_elem"),
2: ("PerChannelBiasVisitor", "_per_chan")}
SUFFIXES = {
"dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"],
"imma": [
"_imma16x16x16_cdiv4hwn4",
"_imma8x32x16_cdiv4hwn4",
"_imma32x8x16_cdiv4hwn4",
"_imma16x16x16_cdiv4hwn4_reorder_filter",
"_imma8x32x16_cdiv4hwn4_reorder_filter",
"_imma32x8x16_cdiv4hwn4_reorder_filter",
"_imma16x16x16_cdiv4hwn4_unroll_width",
"_imma8x32x16_cdiv4hwn4_unroll_width",
"_imma32x8x16_cdiv4hwn4_unroll_width",
],
}


SUFFIXES = {"dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"],
"imma": ["_imma16x16x16_cdiv4hwn4", "_imma8x32x16_cdiv4hwn4", "_imma32x8x16_cdiv4hwn4",
"_imma16x16x16_cdiv4hwn4_reorder_filter", "_imma8x32x16_cdiv4hwn4_reorder_filter", "_imma32x8x16_cdiv4hwn4_reorder_filter",
"_imma16x16x16_cdiv4hwn4_unroll_width", "_imma8x32x16_cdiv4hwn4_unroll_width", "_imma32x8x16_cdiv4hwn4_unroll_width"]}


def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='generate cuda conv bias (dp4a/imma) kern impl files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--type', type=str, choices=['dp4a',
'imma'],
default='dp4a', help='generate cuda conv bias kernel file')
parser.add_argument('output', help='output directory')
description="generate cuda conv bias (dp4a/imma) kern impl files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--type",
type=str,
choices=["dp4a", "imma"],
default="dp4a",
help="generate cuda conv bias kernel file",
)
parser.add_argument("output", help="output directory")
args = parser.parse_args() args = parser.parse_args()


if not os.path.isdir(args.output): if not os.path.isdir(args.output):
os.makedirs(args.output) os.makedirs(args.output)


inst = '''
inst = """
template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS,
IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>>>( IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>>>(
const int8_t* d_src, const int8_t* d_src,
@@ -43,7 +61,7 @@ template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS,
const ConvParam& param, const ConvParam& param,
float alpha, float alpha,
float beta, float beta,
cudaStream_t stream);'''
cudaStream_t stream);"""


for suffix in SUFFIXES[args.type]: for suffix in SUFFIXES[args.type]:
for _, act in ACTIVATIONS.items(): for _, act in ACTIVATIONS.items():
@@ -53,13 +71,19 @@ template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS,
fname = os.path.join(args.output, fname) fname = os.path.join(args.output, fname)
with open(fname, "w") as fout: with open(fname, "w") as fout:
w = lambda s: print(s, file=fout) w = lambda s: print(s, file=fout)
w('// generated by gen_cuda_conv_bias_kern_impls.py')
cur_inst = inst.replace("PREFIX", prefix).replace("SUFFIX", suffix).replace("BIAS", bias[0]).replace("ACTIVATION", act[0])
w("// generated by gen_cuda_conv_bias_kern_impls.py")
cur_inst = (
inst.replace("PREFIX", prefix)
.replace("SUFFIX", suffix)
.replace("BIAS", bias[0])
.replace("ACTIVATION", act[0])
)
w('#include "../{}{}.cuinl"'.format(prefix, suffix)) w('#include "../{}{}.cuinl"'.format(prefix, suffix))
w(cur_inst) w(cur_inst)
print('generated {}'.format(fname))
print("generated {}".format(fname))
os.utime(args.output) os.utime(args.output)


if __name__ == '__main__':

if __name__ == "__main__":
main() main()

+ 18
- 13
dnn/scripts/gen_elemwise_each_mode.py View File

@@ -1,34 +1,39 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-


import os
import argparse import argparse
import os


from gen_elemwise_utils import ARITIES, MODES from gen_elemwise_utils import ARITIES, MODES



def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='generate elemwise each mode',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
description="generate elemwise each mode",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)


parser.add_argument('output', help='output directory')
parser.add_argument("output", help="output directory")
args = parser.parse_args() args = parser.parse_args()



with open(args.output, 'w') as fout:
with open(args.output, "w") as fout:
w = lambda s: print(s, file=fout) w = lambda s: print(s, file=fout)
w('// generated by gen_elemwise_each_mode.py')
w("// generated by gen_elemwise_each_mode.py")
keys = list(MODES.keys()) keys = list(MODES.keys())
keys.sort() keys.sort()
for (anum, ctype) in keys: for (anum, ctype) in keys:
w('#define MEGDNN_FOREACH_ELEMWISE_MODE_{}_{}(cb) \\'.format(
ARITIES[anum], ctype))
w(
"#define MEGDNN_FOREACH_ELEMWISE_MODE_{}_{}(cb) \\".format(
ARITIES[anum], ctype
)
)
for mode in MODES[(anum, ctype)]: for mode in MODES[(anum, ctype)]:
w(' MEGDNN_ELEMWISE_MODE_ENABLE({}, cb) \\'.format(mode))
w('')
w(" MEGDNN_ELEMWISE_MODE_ENABLE({}, cb) \\".format(mode))
w("")


print('generated each_mode.inl')
print("generated each_mode.inl")
os.utime(args.output) os.utime(args.output)


if __name__ == '__main__':

if __name__ == "__main__":
main() main()

+ 34
- 27
dnn/scripts/gen_elemwise_kern_impls.py View File

@@ -1,56 +1,63 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-


import os
import argparse import argparse
import itertools import itertools
import os

from gen_elemwise_utils import ARITIES, DTYPES, MODES from gen_elemwise_utils import ARITIES, DTYPES, MODES



def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='generate elemwise impl files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--type', type=str, choices=['cuda',
'hip',
'cpp'],
default='cpp', help='generate cuda/hip kernel file')
parser.add_argument('output', help='output directory')
description="generate elemwise impl files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--type",
type=str,
choices=["cuda", "hip", "cpp"],
default="cpp",
help="generate cuda/hip kernel file",
)
parser.add_argument("output", help="output directory")
args = parser.parse_args() args = parser.parse_args()


if not os.path.isdir(args.output): if not os.path.isdir(args.output):
os.makedirs(args.output) os.makedirs(args.output)


if args.type == 'cuda':
cpp_ext = 'cu'
elif args.type == 'hip':
cpp_ext = 'cpp.hip'
if args.type == "cuda":
cpp_ext = "cu"
elif args.type == "hip":
cpp_ext = "cpp.hip"
else: else:
assert args.type == 'cpp'
cpp_ext = 'cpp'
assert args.type == "cpp"
cpp_ext = "cpp"


for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()): for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()):
for mode in MODES[(anum, DTYPES[ctype][1])]: for mode in MODES[(anum, DTYPES[ctype][1])]:
formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode)
fname = '{}_{}.{}'.format(mode, ctype, cpp_ext)
formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode)
fname = "{}_{}.{}".format(mode, ctype, cpp_ext)
fname = os.path.join(args.output, fname) fname = os.path.join(args.output, fname)
with open(fname, 'w') as fout:
with open(fname, "w") as fout:
w = lambda s: print(s, file=fout) w = lambda s: print(s, file=fout)
w('// generated by gen_elemwise_kern_impls.py')
w("// generated by gen_elemwise_kern_impls.py")


if ctype == 'dt_float16' or ctype == 'dt_bfloat16':
w('#if !MEGDNN_DISABLE_FLOAT16')
if ctype == "dt_float16" or ctype == "dt_bfloat16":
w("#if !MEGDNN_DISABLE_FLOAT16")


w('#define KERN_IMPL_MODE(cb) {}'.format(formode))
w('#define KERN_IMPL_ARITY {}'.format(anum))
w('#define KERN_IMPL_CTYPE {}'.format(ctype))
w("#define KERN_IMPL_MODE(cb) {}".format(formode))
w("#define KERN_IMPL_ARITY {}".format(anum))
w("#define KERN_IMPL_CTYPE {}".format(ctype))
w('#include "../kern_impl.inl"') w('#include "../kern_impl.inl"')


if ctype == 'dt_float16' or ctype == 'dt_bfloat16':
w('#endif')
if ctype == "dt_float16" or ctype == "dt_bfloat16":
w("#endif")


print('generated {}'.format(fname))
print("generated {}".format(fname))


os.utime(args.output) os.utime(args.output)


if __name__ == '__main__':

if __name__ == "__main__":
main() main()

+ 35
- 21
dnn/scripts/gen_elemwise_multi_type_kern_impls.py View File

@@ -1,52 +1,66 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-


import os
import argparse import argparse
import itertools import itertools
from gen_elemwise_multi_type_utils import SUPPORT_DTYPES, MODES, SUPPORT_QINT32_DTYPES, QINT32_MODES
import os

from gen_elemwise_multi_type_utils import ( # isort: skip; isort: skip
MODES,
QINT32_MODES,
SUPPORT_DTYPES,
SUPPORT_QINT32_DTYPES,
)



def generate(modes, support_dtypes, output, cpp_ext): def generate(modes, support_dtypes, output, cpp_ext):
for anum, ctype in itertools.product(modes.keys(), support_dtypes): for anum, ctype in itertools.product(modes.keys(), support_dtypes):
print('{} : {}'.format(anum, ctype))
print("{} : {}".format(anum, ctype))
src_ctype = ctype[0] src_ctype = ctype[0]
dst_ctype = ctype[1] dst_ctype = ctype[1]
for mode in modes[anum]: for mode in modes[anum]:
formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode)
fname = '{}_{}_{}.{}'.format(mode, src_ctype, dst_ctype, cpp_ext)
formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode)
fname = "{}_{}_{}.{}".format(mode, src_ctype, dst_ctype, cpp_ext)
fname = os.path.join(output, fname) fname = os.path.join(output, fname)
with open(fname, 'w') as fout:
with open(fname, "w") as fout:
w = lambda s: print(s, file=fout) w = lambda s: print(s, file=fout)
w('// generated by gen_elemwise_multi_type_kern_impls.py')
w("// generated by gen_elemwise_multi_type_kern_impls.py")


w('#define KERN_IMPL_MODE(cb) {}'.format(formode))
w('#define KERN_IMPL_ARITY {}'.format(anum))
w('#define KERN_IMPL_STYPE {}'.format(src_ctype))
w('#define KERN_IMPL_DTYPE {}'.format(dst_ctype))
w("#define KERN_IMPL_MODE(cb) {}".format(formode))
w("#define KERN_IMPL_ARITY {}".format(anum))
w("#define KERN_IMPL_STYPE {}".format(src_ctype))
w("#define KERN_IMPL_DTYPE {}".format(dst_ctype))
w('#include "../kern_impl.inl"') w('#include "../kern_impl.inl"')


print('generated {}'.format(fname))
print("generated {}".format(fname))




def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='generate elemwise impl files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--type', type=str, choices=['cuda'],
default='cuda', help='generate cuda kernel file')
parser.add_argument('output', help='output directory')
description="generate elemwise impl files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--type",
type=str,
choices=["cuda"],
default="cuda",
help="generate cuda kernel file",
)
parser.add_argument("output", help="output directory")
args = parser.parse_args() args = parser.parse_args()


if not os.path.isdir(args.output): if not os.path.isdir(args.output):
os.makedirs(args.output) os.makedirs(args.output)


assert args.type == 'cuda'
if args.type == 'cuda':
cpp_ext = 'cu'
assert args.type == "cuda"
if args.type == "cuda":
cpp_ext = "cu"


generate(MODES, SUPPORT_DTYPES, args.output, cpp_ext) generate(MODES, SUPPORT_DTYPES, args.output, cpp_ext)
generate(QINT32_MODES, SUPPORT_QINT32_DTYPES, args.output, cpp_ext) generate(QINT32_MODES, SUPPORT_QINT32_DTYPES, args.output, cpp_ext)
os.utime(args.output) os.utime(args.output)


if __name__ == '__main__':

if __name__ == "__main__":
main() main()

+ 113
- 30
dnn/scripts/gen_elemwise_multi_type_utils.py View File

@@ -1,48 +1,131 @@
# As cuda currently do not support quint8, so we just ignore it. # As cuda currently do not support quint8, so we just ignore it.
SUPPORT_DTYPES = [('dt_qint8', 'dt_qint8')]
SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32'),
('dt_qint4', 'dt_qint32'), ('dt_quint4', 'dt_qint32')]
SUPPORT_DTYPES = [("dt_qint8", "dt_qint8")]
SUPPORT_QINT32_DTYPES = [
("dt_qint32", "dt_qint8"),
("dt_qint8", "dt_qint32"),
("dt_qint4", "dt_qint32"),
("dt_quint4", "dt_qint32"),
]


SUPPORT_DTYPES_Q4 = [('dt_qint4', 'dt_qint4'), ('dt_quint4', 'dt_quint4')]
SUPPORT_QINT32_DTYPES_Q4 = [('dt_qint32', 'dt_qint4'), ('dt_qint32', 'dt_quint4')]
SUPPORT_DTYPES_Q4 = [("dt_qint4", "dt_qint4"), ("dt_quint4", "dt_quint4")]
SUPPORT_QINT32_DTYPES_Q4 = [("dt_qint32", "dt_qint4"), ("dt_qint32", "dt_quint4")]


SUPPORT_ARRITY2_DTYPES = ['dt_int32', 'dt_uint8', 'dt_int8', 'dt_int16', 'dt_bool', 'dt_float32',
'dt_float16', 'dt_bfloat16']
SUPPORT_ARRITY1_DTYPES = ['dt_float32','dt_float16', 'dt_bfloat16']
SUPPORT_ARRITY2_DTYPES = [
"dt_int32",
"dt_uint8",
"dt_int8",
"dt_int16",
"dt_bool",
"dt_float32",
"dt_float16",
"dt_bfloat16",
]
SUPPORT_ARRITY1_DTYPES = ["dt_float32", "dt_float16", "dt_bfloat16"]


MODES = { MODES = {
1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS',
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN',
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC',
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'],
2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT',
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW',
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD',
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD',
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'],
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'],
1: [
"RELU",
"ABS",
"NEGATE",
"ACOS",
"ASIN",
"CEIL",
"COS",
"EXP",
"EXPM1",
"FLOOR",
"LOG",
"LOG1P",
"SIGMOID",
"SIN",
"TANH",
"FAST_TANH",
"ROUND",
"ERF",
"ERFINV",
"ERFC",
"ERFCINV",
"H_SWISH",
"SILU",
"GELU",
],
2: [
"ABS_GRAD",
"ADD",
"FLOOR_DIV",
"MAX",
"MIN",
"MOD",
"MUL",
"SIGMOID_GRAD",
"SUB",
"SWITCH_GT0",
"TANH_GRAD",
"LT",
"LEQ",
"EQ",
"FUSE_ADD_RELU",
"TRUE_DIV",
"POW",
"LOG_SUM_EXP",
"FUSE_ADD_TANH",
"FAST_TANH_GRAD",
"FUSE_ADD_SIGMOID",
"ATAN2",
"H_SWISH_GRAD",
"FUSE_ADD_H_SWISH",
"SILU_GRAD",
"GELU_GRAD",
],
3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"],
} }


QINT4_MODES = { QINT4_MODES = {
1: ['RELU', 'ABS', 'NEGATE', 'CEIL', 'FLOOR', 'SIGMOID',
'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'],
2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0',
'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH',
'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'],
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'],
1: [
"RELU",
"ABS",
"NEGATE",
"CEIL",
"FLOOR",
"SIGMOID",
"TANH",
"FAST_TANH",
"ROUND",
"H_SWISH",
],
2: [
"ADD",
"MAX",
"MIN",
"MUL",
"SUB",
"SWITCH_GT0",
"LT",
"LEQ",
"EQ",
"FUSE_ADD_RELU",
"FUSE_ADD_TANH",
"FUSE_ADD_SIGMOID",
"FUSE_ADD_H_SWISH",
],
3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"],
} }


QINT32_MODES = { QINT32_MODES = {
1: ['RELU', 'SIGMOID', 'TANH', 'FAST_TANH', 'H_SWISH'],
2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID',
'FUSE_ADD_TANH', 'FUSE_ADD_H_SWISH']
1: ["RELU", "SIGMOID", "TANH", "FAST_TANH", "H_SWISH"],
2: [
"ADD",
"FUSE_ADD_RELU",
"FUSE_ADD_SIGMOID",
"FUSE_ADD_TANH",
"FUSE_ADD_H_SWISH",
],
} }


ARRITY1_BOOL_MODES = { ARRITY1_BOOL_MODES = {
1: ['ISINF','ISNAN'],
1: ["ISINF", "ISNAN"],
} }


ARRITY2_BOOL_MODES = { ARRITY2_BOOL_MODES = {
2: ['EQ','LEQ','NEQ','LT'],
2: ["EQ", "LEQ", "NEQ", "LT"],
} }

+ 32
- 27
dnn/scripts/gen_elemwise_special_kern_impls.py View File

@@ -1,52 +1,57 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-


import os
import argparse import argparse
import os

from gen_elemwise_utils import DTYPES from gen_elemwise_utils import DTYPES



def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='generate elemwise impl files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--type', type=str, choices=[
'cuda',
'hip'
],
default='cuda',
help='generate cuda/hip elemwise special kernel file')
parser.add_argument('output', help='output directory')
description="generate elemwise impl files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--type",
type=str,
choices=["cuda", "hip"],
default="cuda",
help="generate cuda/hip elemwise special kernel file",
)
parser.add_argument("output", help="output directory")
args = parser.parse_args() args = parser.parse_args()


if not os.path.isdir(args.output): if not os.path.isdir(args.output):
os.makedirs(args.output) os.makedirs(args.output)


if args.type == 'cuda':
cpp_ext = 'cu'
if args.type == "cuda":
cpp_ext = "cu"
else: else:
assert args.type =='hip'
cpp_ext = 'cpp.hip'
assert args.type == "hip"
cpp_ext = "cpp.hip"


for dtype in DTYPES.keys(): for dtype in DTYPES.keys():
fname = 'special_{}.{}'.format(dtype, cpp_ext)
fname = "special_{}.{}".format(dtype, cpp_ext)
fname = os.path.join(args.output, fname) fname = os.path.join(args.output, fname)
with open(fname, 'w') as fout:
with open(fname, "w") as fout:
w = lambda s: print(s, file=fout) w = lambda s: print(s, file=fout)


w('// generated by gen_elemwise_special_kern_impls.py')
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#if !MEGDNN_DISABLE_FLOAT16')
w("// generated by gen_elemwise_special_kern_impls.py")
if dtype == "dt_float16" or dtype == "dt_bfloat16":
w("#if !MEGDNN_DISABLE_FLOAT16")
w('#include "../special_kerns.inl"') w('#include "../special_kerns.inl"')
w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0]))
w('#undef INST')
w('}')
w('}')
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#endif')
w("INST(::megdnn::dtype::{})".format(DTYPES[dtype][0]))
w("#undef INST")
w("}")
w("}")
if dtype == "dt_float16" or dtype == "dt_bfloat16":
w("#endif")


print('generated {}'.format(fname))
print("generated {}".format(fname))


os.utime(args.output) os.utime(args.output)


if __name__ == '__main__':

if __name__ == "__main__":
main() main()

+ 91
- 31
dnn/scripts/gen_elemwise_utils.py View File

@@ -1,35 +1,95 @@
ARITIES = {1: "UNARY", 2: "BINARY", 3: "TERNARY"}


ARITIES = {1: 'UNARY', 2: 'BINARY', 3: 'TERNARY'}

DTYPES = {'dt_int32': ('Int32', 'INT'),
'dt_uint8': ('Uint8', 'INT'),
'dt_int8': ('Int8', 'INT'),
'dt_int16': ('Int16', 'INT'),
'dt_bool': ('Bool', 'BOOL'),
'dt_float32': ('Float32', 'FLOAT'),
'dt_float16': ('Float16', 'FLOAT'),
'dt_bfloat16': ('BFloat16', 'FLOAT')
}
DTYPES = {
"dt_int32": ("Int32", "INT"),
"dt_uint8": ("Uint8", "INT"),
"dt_int8": ("Int8", "INT"),
"dt_int16": ("Int16", "INT"),
"dt_bool": ("Bool", "BOOL"),
"dt_float32": ("Float32", "FLOAT"),
"dt_float16": ("Float16", "FLOAT"),
"dt_bfloat16": ("BFloat16", "FLOAT"),
}


MODES = { MODES = {
(1, 'INT'): ['RELU', 'ABS', 'NEGATE'],
(2, 'INT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ',
'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH'],
(3, 'INT'): ['COND_LEQ_MOV', 'COND_LT_MOV'],

(1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS',
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN',
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC',
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'],
(2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT',
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW',
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD',
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD',
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'],
(3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'],
(1, 'BOOL'): ['NOT'],
(2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'],
(3, 'BOOL'): []
(1, "INT"): ["RELU", "ABS", "NEGATE"],
(2, "INT"): [
"ABS_GRAD",
"ADD",
"FLOOR_DIV",
"MAX",
"MIN",
"MOD",
"MUL",
"SIGMOID_GRAD",
"SUB",
"SWITCH_GT0",
"TANH_GRAD",
"LT",
"LEQ",
"EQ",
"FUSE_ADD_RELU",
"SHL",
"SHR",
"RMULH",
],
(3, "INT"): ["COND_LEQ_MOV", "COND_LT_MOV"],
(1, "FLOAT"): [
"RELU",
"ABS",
"NEGATE",
"ACOS",
"ASIN",
"CEIL",
"COS",
"EXP",
"EXPM1",
"FLOOR",
"LOG",
"LOG1P",
"SIGMOID",
"SIN",
"TANH",
"FAST_TANH",
"ROUND",
"ERF",
"ERFINV",
"ERFC",
"ERFCINV",
"H_SWISH",
"SILU",
"GELU",
],
(2, "FLOAT"): [
"ABS_GRAD",
"ADD",
"FLOOR_DIV",
"MAX",
"MIN",
"MOD",
"MUL",
"SIGMOID_GRAD",
"SUB",
"SWITCH_GT0",
"TANH_GRAD",
"LT",
"LEQ",
"EQ",
"FUSE_ADD_RELU",
"TRUE_DIV",
"POW",
"LOG_SUM_EXP",
"FUSE_ADD_TANH",
"FAST_TANH_GRAD",
"FUSE_ADD_SIGMOID",
"ATAN2",
"H_SWISH_GRAD",
"FUSE_ADD_H_SWISH",
"SILU_GRAD",
"GELU_GRAD",
],
(3, "FLOAT"): ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"],
(1, "BOOL"): ["NOT"],
(2, "BOOL"): ["AND", "OR", "XOR", "LT", "LEQ", "EQ"],
(3, "BOOL"): [],
} }

+ 44
- 32
dnn/scripts/gen_flatbuffers_converter.py View File

@@ -3,13 +3,14 @@


import argparse import argparse
import collections import collections
import textwrap
import os
import hashlib import hashlib
import struct
import io import io
import os
import struct
import textwrap

from gen_param_defs import IndentWriterBase, ParamDef, member_defs


from gen_param_defs import member_defs, ParamDef, IndentWriterBase


class ConverterWriter(IndentWriterBase): class ConverterWriter(IndentWriterBase):
_skip_current_param = False _skip_current_param = False
@@ -20,7 +21,7 @@ class ConverterWriter(IndentWriterBase):
def __call__(self, fout, defs): def __call__(self, fout, defs):
super().__call__(fout) super().__call__(fout)
self._write("// %s", self._get_header()) self._write("// %s", self._get_header())
self._write('#include <flatbuffers/flatbuffers.h>')
self._write("#include <flatbuffers/flatbuffers.h>")
self._write("namespace mgb {") self._write("namespace mgb {")
self._write("namespace serialization {") self._write("namespace serialization {")
self._write("namespace fbs {") self._write("namespace fbs {")
@@ -33,8 +34,9 @@ class ConverterWriter(IndentWriterBase):
self._last_param = p self._last_param = p
self._param_fields = [] self._param_fields = []
self._fb_fields = ["builder"] self._fb_fields = ["builder"]
self._write("template<>\nstruct ParamConverter<megdnn::param::%s> {",
p.name, indent=1)
self._write(
"template<>\nstruct ParamConverter<megdnn::param::%s> {", p.name, indent=1
)
self._write("using MegDNNType = megdnn::param::%s;", p.name) self._write("using MegDNNType = megdnn::param::%s;", p.name)
self._write("using FlatBufferType = fbs::param::%s;\n", p.name) self._write("using FlatBufferType = fbs::param::%s;\n", p.name)


@@ -42,22 +44,22 @@ class ConverterWriter(IndentWriterBase):
if self._skip_current_param: if self._skip_current_param:
self._skip_current_param = False self._skip_current_param = False
return return
self._write("static MegDNNType to_param(const FlatBufferType* fb) {",
indent=1)
line = 'return {'
line += ', '.join(self._param_fields)
line += '};'
self._write("static MegDNNType to_param(const FlatBufferType* fb) {", indent=1)
line = "return {"
line += ", ".join(self._param_fields)
line += "};"
self._write(line) self._write(line)
self._write("}\n", indent=-1) self._write("}\n", indent=-1)


self._write( self._write(
"static flatbuffers::Offset<FlatBufferType> to_flatbuffer(flatbuffers::FlatBufferBuilder& builder, const MegDNNType& param) {", "static flatbuffers::Offset<FlatBufferType> to_flatbuffer(flatbuffers::FlatBufferBuilder& builder, const MegDNNType& param) {",
indent=1)
line = 'return fbs::param::Create{}('.format(str(p.name))
line += ', '.join(self._fb_fields)
line += ');'
indent=1,
)
line = "return fbs::param::Create{}(".format(str(p.name))
line += ", ".join(self._fb_fields)
line += ");"
self._write(line) self._write(line)
self._write('}', indent=-1)
self._write("}", indent=-1)


self._write("};\n", indent=-1) self._write("};\n", indent=-1)


@@ -68,18 +70,23 @@ class ConverterWriter(IndentWriterBase):
return return
self._param_fields.append( self._param_fields.append(
"static_cast<megdnn::param::{}::{}>(fb->{}())".format( "static_cast<megdnn::param::{}::{}>(fb->{}())".format(
str(p.name), str(e.name), e.name_field))
self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format(
key, e.name_field))
str(p.name), str(e.name), e.name_field
)
)
self._fb_fields.append(
"static_cast<fbs::param::{}>(param.{})".format(key, e.name_field)
)


def _on_member_field(self, f): def _on_member_field(self, f):
if self._skip_current_param: if self._skip_current_param:
return return
if f.dtype.cname == 'DTypeEnum':
if f.dtype.cname == "DTypeEnum":
self._param_fields.append( self._param_fields.append(
"intl::convert_dtype_to_megdnn(fb->{}())".format(f.name))
"intl::convert_dtype_to_megdnn(fb->{}())".format(f.name)
)
self._fb_fields.append( self._fb_fields.append(
"intl::convert_dtype_to_fbs(param.{})".format(f.name))
"intl::convert_dtype_to_fbs(param.{})".format(f.name)
)
else: else:
self._param_fields.append("fb->{}()".format(f.name)) self._param_fields.append("fb->{}()".format(f.name))
self._fb_fields.append("param.{}".format(f.name)) self._fb_fields.append("param.{}".format(f.name))
@@ -93,28 +100,33 @@ class ConverterWriter(IndentWriterBase):
enum_name = e.src_class + e.src_name enum_name = e.src_class + e.src_name
self._param_fields.append( self._param_fields.append(
"static_cast<megdnn::param::{}::{}>(fb->{}())".format( "static_cast<megdnn::param::{}::{}>(fb->{}())".format(
e.src_class, e.src_name, e.name_field))
self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format(
enum_name, e.name_field))
e.src_class, e.src_name, e.name_field
)
)
self._fb_fields.append(
"static_cast<fbs::param::{}>(param.{})".format(enum_name, e.name_field)
)




def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
'generate convert functions between FlatBuffers type and MegBrain type')
parser.add_argument('input')
parser.add_argument('output')
"generate convert functions between FlatBuffers type and MegBrain type"
)
parser.add_argument("input")
parser.add_argument("output")
args = parser.parse_args() args = parser.parse_args()


with open(args.input) as fin: with open(args.input) as fin:
inputs = fin.read() inputs = fin.read()
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc})
input_hash = hashlib.sha256() input_hash = hashlib.sha256()
input_hash.update(inputs.encode(encoding='UTF-8'))
input_hash.update(inputs.encode(encoding="UTF-8"))
input_hash = input_hash.hexdigest() input_hash = input_hash.hexdigest()


writer = ConverterWriter() writer = ConverterWriter()
with open(args.output, 'w') as fout:
with open(args.output, "w") as fout:
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)



if __name__ == "__main__": if __name__ == "__main__":
main() main()

+ 33
- 21
dnn/scripts/gen_flatbuffers_schema.py View File

@@ -3,13 +3,14 @@


import argparse import argparse
import collections import collections
import textwrap
import os
import hashlib import hashlib
import struct
import io import io
import os
import struct
import textwrap

from gen_param_defs import IndentWriterBase, ParamDef, member_defs


from gen_param_defs import member_defs, ParamDef, IndentWriterBase


def _cname_to_fbname(cname): def _cname_to_fbname(cname):
return { return {
@@ -22,17 +23,19 @@ def _cname_to_fbname(cname):
"bool": "bool", "bool": "bool",
}[cname] }[cname]



def scramble_enum_member_name(name): def scramble_enum_member_name(name):
s = name.find('<<')
s = name.find("<<")
if s != -1: if s != -1:
name = name[0:name.find('=') + 1] + ' ' + name[s+2:]
name = name[0 : name.find("=") + 1] + " " + name[s + 2 :]
if name in ("MIN", "MAX"): if name in ("MIN", "MAX"):
return name + "_" return name + "_"
o_name = name.split(' ')[0].split('=')[0]
o_name = name.split(" ")[0].split("=")[0]
if o_name in ("MIN", "MAX"): if o_name in ("MIN", "MAX"):
return name.replace(o_name, o_name + "_") return name.replace(o_name, o_name + "_")
return name return name



class FlatBuffersWriter(IndentWriterBase): class FlatBuffersWriter(IndentWriterBase):
_skip_current_param = False _skip_current_param = False
_last_param = None _last_param = None
@@ -66,12 +69,13 @@ class FlatBuffersWriter(IndentWriterBase):
self._write("}\n", indent=-1) self._write("}\n", indent=-1)


def _write_doc(self, doc): def _write_doc(self, doc):
if not isinstance(doc, member_defs.Doc) or not doc.doc: return
if not isinstance(doc, member_defs.Doc) or not doc.doc:
return
doc_lines = [] doc_lines = []
if doc.no_reformat: if doc.no_reformat:
doc_lines = doc.raw_lines doc_lines = doc.raw_lines
else: else:
doc = doc.doc.replace('\n', ' ')
doc = doc.doc.replace("\n", " ")
text_width = 80 - len(self._cur_indent) - 4 text_width = 80 - len(self._cur_indent) - 4
doc_lines = textwrap.wrap(doc, text_width) doc_lines = textwrap.wrap(doc, text_width)
for line in doc_lines: for line in doc_lines:
@@ -101,7 +105,8 @@ class FlatBuffersWriter(IndentWriterBase):
default = e.compose_combined_enum(e.default) default = e.compose_combined_enum(e.default)
else: else:
default = scramble_enum_member_name( default = scramble_enum_member_name(
str(e.members[e.default]).split(' ')[0].split('=')[0])
str(e.members[e.default]).split(" ")[0].split("=")[0]
)
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default) self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default)


def _resolve_const(self, v): def _resolve_const(self, v):
@@ -113,8 +118,12 @@ class FlatBuffersWriter(IndentWriterBase):
if self._skip_current_param: if self._skip_current_param:
return return
self._write_doc(f.name) self._write_doc(f.name)
self._write("%s:%s = %s;", f.name, _cname_to_fbname(f.dtype.cname),
self._get_fb_default(self._resolve_const(f.default)))
self._write(
"%s:%s = %s;",
f.name,
_cname_to_fbname(f.dtype.cname),
self._get_fb_default(self._resolve_const(f.default)),
)


def _on_const_field(self, f): def _on_const_field(self, f):
self._cur_const_val[str(f.name)] = str(f.default) self._cur_const_val[str(f.name)] = str(f.default)
@@ -129,7 +138,8 @@ class FlatBuffersWriter(IndentWriterBase):
default = s.compose_combined_enum(e.get_default()) default = s.compose_combined_enum(e.get_default())
else: else:
default = scramble_enum_member_name( default = scramble_enum_member_name(
str(s.members[e.get_default()]).split(' ')[0].split('=')[0])
str(s.members[e.get_default()]).split(" ")[0].split("=")[0]
)
self._write("%s:%s = %s;", e.name_field, enum_name, default) self._write("%s:%s = %s;", e.name_field, enum_name, default)


def _get_fb_default(self, cppdefault): def _get_fb_default(self, cppdefault):
@@ -137,9 +147,9 @@ class FlatBuffersWriter(IndentWriterBase):
return cppdefault return cppdefault


d = cppdefault d = cppdefault
if d.endswith('f'): # 1.f
if d.endswith("f"): # 1.f
return d[:-1] return d[:-1]
if d.endswith('ull'):
if d.endswith("ull"):
return d[:-3] return d[:-3]
if d.startswith("DTypeEnum::"): if d.startswith("DTypeEnum::"):
return d[11:] return d[11:]
@@ -148,21 +158,23 @@ class FlatBuffersWriter(IndentWriterBase):


def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
'generate FlatBuffers schema of operator param from description file')
parser.add_argument('input')
parser.add_argument('output')
"generate FlatBuffers schema of operator param from description file"
)
parser.add_argument("input")
parser.add_argument("output")
args = parser.parse_args() args = parser.parse_args()


with open(args.input) as fin: with open(args.input) as fin:
inputs = fin.read() inputs = fin.read()
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc})
input_hash = hashlib.sha256() input_hash = hashlib.sha256()
input_hash.update(inputs.encode(encoding='UTF-8'))
input_hash.update(inputs.encode(encoding="UTF-8"))
input_hash = input_hash.hexdigest() input_hash = input_hash.hexdigest()


writer = FlatBuffersWriter() writer = FlatBuffersWriter()
with open(args.output, 'w') as fout:
with open(args.output, "w") as fout:
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)



if __name__ == "__main__": if __name__ == "__main__":
main() main()

+ 65
- 49
dnn/scripts/gen_heuristic/gen_heuristic.py View File

@@ -1,14 +1,16 @@
#! /usr/local/env python3 #! /usr/local/env python3


import pickle
import numpy as np
import os
import argparse import argparse
import re
import collections import collections
import os
import pickle
import re

import numpy as np



def define_template(**kwargs): def define_template(**kwargs):
template = '''
template = """
float cuda{cuda_arch}_{conv_type}_time_pred[{out_dim}] = {{0.0f}}; float cuda{cuda_arch}_{conv_type}_time_pred[{out_dim}] = {{0.0f}};
float cuda{cuda_arch}_{conv_type}_mask[{out_dim}] = {{0.0f}}; float cuda{cuda_arch}_{conv_type}_mask[{out_dim}] = {{0.0f}};
float cuda{cuda_arch}_{conv_type}_hidden_units[{hidden_num}] = {{0.0f}}; float cuda{cuda_arch}_{conv_type}_hidden_units[{hidden_num}] = {{0.0f}};
@@ -17,21 +19,23 @@ def define_template(**kwargs):
const static float cuda{cuda_arch}_{conv_type}_biases[{biases_dim}] = {{{biases}}}; const static float cuda{cuda_arch}_{conv_type}_biases[{biases_dim}] = {{{biases}}};
const static float cuda{cuda_arch}_{conv_type}_alpha[{out_dim}] = {{{alpha}}}; const static float cuda{cuda_arch}_{conv_type}_alpha[{out_dim}] = {{{alpha}}};
const static float cuda{cuda_arch}_{conv_type}_beta[{out_dim}] = {{{beta}}}; const static float cuda{cuda_arch}_{conv_type}_beta[{out_dim}] = {{{beta}}};
'''
"""
return template.format(**kwargs) return template.format(**kwargs)



def cudnn_slt_template(**kwargs): def cudnn_slt_template(**kwargs):
template = ("#if CUDNN_MAJOR == {cudnn_major} && CUDNN_MINOR == {cudnn_minor}\n" +
" {define_cmd}\n" +
" {select_cmd}\n" +
" return true;\n" +
"#endif\n"
)
template = (
"#if CUDNN_MAJOR == {cudnn_major} && CUDNN_MINOR == {cudnn_minor}\n"
+ " {define_cmd}\n"
+ " {select_cmd}\n"
+ " return true;\n"
+ "#endif\n"
)
return template.format(**kwargs) return template.format(**kwargs)



def select_template(**kwargs): def select_template(**kwargs):
template = \
'''if (conv_type == ConvolutionType::{conv_type} && cuda_major == {cuda_major} &&
template = """if (conv_type == ConvolutionType::{conv_type} && cuda_major == {cuda_major} &&
cuda_minor == {cuda_minor}) {{ cuda_minor == {cuda_minor}) {{
*layer_num_p = {layer_num}; *layer_num_p = {layer_num};
*hidden_units_p = cuda{cuda_arch}_{conv_type}_hidden_units; *hidden_units_p = cuda{cuda_arch}_{conv_type}_hidden_units;
@@ -42,7 +46,7 @@ def select_template(**kwargs):
*beta_p = cuda{cuda_arch}_{conv_type}_beta; *beta_p = cuda{cuda_arch}_{conv_type}_beta;
*time_pred_p = cuda{cuda_arch}_{conv_type}_time_pred; *time_pred_p = cuda{cuda_arch}_{conv_type}_time_pred;
*mask_p = cuda{cuda_arch}_{conv_type}_mask; *mask_p = cuda{cuda_arch}_{conv_type}_mask;
}} else '''
}} else """
return template.format(**kwargs) return template.format(**kwargs)




@@ -58,48 +62,48 @@ def fill_src():
if len(matrix_files) == 0: if len(matrix_files) == 0:
print("Warning: no param files detected.") print("Warning: no param files detected.")
for fpath in matrix_files: for fpath in matrix_files:
cudnn_version = re.findall('cudnn([\d.]+)',fpath)[0]
cudnn_version = re.findall("cudnn([\d.]+)", fpath)[0]
gen_list[cudnn_version].append(fpath) gen_list[cudnn_version].append(fpath)
for cudnn in gen_list: for cudnn in gen_list:
select_cmd = ("{\n" +
" " * 8 + "return false;\n" +
" " * 4 + "}")
select_cmd = "{\n" + " " * 8 + "return false;\n" + " " * 4 + "}"
define_cmd = "" define_cmd = ""
cudnn_major, cudnn_minor = cudnn.split('.')
cudnn_major, cudnn_minor = cudnn.split(".")
for fpath in gen_list[cudnn]: for fpath in gen_list[cudnn]:
cuda_arch = fpath.split("-")[1].replace(".", "_") cuda_arch = fpath.split("-")[1].replace(".", "_")
print('cudnn_version: {}, cuda_arch: {}'.format(cudnn,cuda_arch))
print("cudnn_version: {}, cuda_arch: {}".format(cudnn, cuda_arch))
conv_type = fpath.split("-")[2].split(".")[0] conv_type = fpath.split("-")[2].split(".")[0]
with open(os.path.join(home, "params/{}".format(fpath)), "rb") as pobj: with open(os.path.join(home, "params/{}".format(fpath)), "rb") as pobj:
params = pickle.load(pobj) params = pickle.load(pobj)
crt_define_cmd, crt_select_cmd = gen_cmds(
cuda_arch, conv_type, params)
crt_define_cmd, crt_select_cmd = gen_cmds(cuda_arch, conv_type, params)
select_cmd = crt_select_cmd + select_cmd select_cmd = crt_select_cmd + select_cmd
define_cmd = crt_define_cmd + define_cmd define_cmd = crt_define_cmd + define_cmd


cudnn_slt_cmd += cudnn_slt_template(cudnn_major=cudnn_major,
cudnn_minor=cudnn_minor,
select_cmd=select_cmd,
define_cmd=define_cmd)
cudnn_slt_cmd += cudnn_slt_template(
cudnn_major=cudnn_major,
cudnn_minor=cudnn_minor,
select_cmd=select_cmd,
define_cmd=define_cmd,
)


#select_cmd = select_cmd
# select_cmd = select_cmd
with open(os.path.join(home, "get_params.template"), "r") as srcf: with open(os.path.join(home, "get_params.template"), "r") as srcf:
src = srcf.read() src = srcf.read()
dst = src.replace("{cudnn_select}", cudnn_slt_cmd) dst = src.replace("{cudnn_select}", cudnn_slt_cmd)
MegDNN_path = os.path.join(home, "../..") MegDNN_path = os.path.join(home, "../..")
with open(os.path.join(MegDNN_path,
"src/cuda/convolution/get_params.cpp"), "w") as dstf:
with open(
os.path.join(MegDNN_path, "src/cuda/convolution/get_params.cpp"), "w"
) as dstf:
dstf.write(dst) dstf.write(dst)




def gen_cmds(cuda_arch, conv_type, params): def gen_cmds(cuda_arch, conv_type, params):
cuda_major, cuda_minor = cuda_arch.split("_") cuda_major, cuda_minor = cuda_arch.split("_")
alphastr = format_array(params['alpha']).rstrip()[:-1]
betastr = format_array(params['beta']).rstrip()[:-1]
W_list = params['W']
b_list = params['b']
Wstr = ''
bstr = ''
alphastr = format_array(params["alpha"]).rstrip()[:-1]
betastr = format_array(params["beta"]).rstrip()[:-1]
W_list = params["W"]
b_list = params["b"]
Wstr = ""
bstr = ""
layer_num = str(len(b_list) + 1) layer_num = str(len(b_list) + 1)
layers_dim = [W_list[0].shape[1]] layers_dim = [W_list[0].shape[1]]
matrices_dim = 0 matrices_dim = 0
@@ -118,16 +122,27 @@ def gen_cmds(cuda_arch, conv_type, params):
out_dim = layers_dim[-1] out_dim = layers_dim[-1]
layers_dim_str = format_array(np.array(layers_dim)).rstrip()[:-1] layers_dim_str = format_array(np.array(layers_dim)).rstrip()[:-1]


select_cmd = select_template(conv_type=conv_type.upper(), cuda_major=cuda_major,
cuda_minor=cuda_minor, layer_num=layer_num,
cuda_arch=cuda_arch)
define_cmd = define_template(cuda_arch=cuda_arch, conv_type=conv_type.upper(),
hidden_num=hidden_num,
layer_num=layer_num, out_dim=out_dim,
layers_dim=layers_dim_str,
matrices_dim=matrices_dim, matrices=Wstr,
biases_dim=biases_dim, biases=bstr,
alpha=alphastr, beta=betastr)
select_cmd = select_template(
conv_type=conv_type.upper(),
cuda_major=cuda_major,
cuda_minor=cuda_minor,
layer_num=layer_num,
cuda_arch=cuda_arch,
)
define_cmd = define_template(
cuda_arch=cuda_arch,
conv_type=conv_type.upper(),
hidden_num=hidden_num,
layer_num=layer_num,
out_dim=out_dim,
layers_dim=layers_dim_str,
matrices_dim=matrices_dim,
matrices=Wstr,
biases_dim=biases_dim,
biases=bstr,
alpha=alphastr,
beta=betastr,
)
return (define_cmd, select_cmd) return (define_cmd, select_cmd)




@@ -153,8 +168,9 @@ def format_array(array):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Generate cuDNN heuristic code by neural network into" description="Generate cuDNN heuristic code by neural network into"
" {MEGDNN_ROOT}/src/cuda/convolution/get_params.cpp,"
" using parameter value from pickle files in"
" {MEGDNN_ROOT}/scripts/gen_heuristic/params/")
" {MEGDNN_ROOT}/src/cuda/convolution/get_params.cpp,"
" using parameter value from pickle files in"
" {MEGDNN_ROOT}/scripts/gen_heuristic/params/"
)
args = parser.parse_args() args = parser.parse_args()
main() main()

+ 406
- 309
dnn/scripts/gen_param_defs.py
File diff suppressed because it is too large
View File


+ 58
- 44
dnn/scripts/gen_tablegen.py View File

@@ -3,19 +3,17 @@


import argparse import argparse
import collections import collections
import textwrap
import os
import hashlib import hashlib
import struct
import io import io
import os
import struct
import textwrap


from gen_param_defs import member_defs, ParamDef, IndentWriterBase
from gen_param_defs import IndentWriterBase, ParamDef, member_defs


# FIXME: move supportToString flag definition into the param def source file # FIXME: move supportToString flag definition into the param def source file
ENUM_TO_STRING_SPECIAL_RULES = [
("Elemwise", "Mode"),
("ElemwiseMultiType", "Mode")
]
ENUM_TO_STRING_SPECIAL_RULES = [("Elemwise", "Mode"), ("ElemwiseMultiType", "Mode")]



class ConverterWriter(IndentWriterBase): class ConverterWriter(IndentWriterBase):
_skip_current_param = False _skip_current_param = False
@@ -33,21 +31,21 @@ class ConverterWriter(IndentWriterBase):
self._write("#endif // MGB_PARAM") self._write("#endif // MGB_PARAM")


def _ctype2attr(self, ctype, value): def _ctype2attr(self, ctype, value):
if ctype == 'uint32_t':
return 'MgbUI32Attr', value
if ctype == 'uint64_t':
return 'MgbUI64Attr', value
if ctype == 'int32_t':
return 'MgbI32Attr', value
if ctype == 'float':
return 'MgbF32Attr', value
if ctype == 'double':
return 'MgbF64Attr', value
if ctype == 'bool':
return 'MgbBoolAttr', value
if ctype == 'DTypeEnum':
if ctype == "uint32_t":
return "MgbUI32Attr", value
if ctype == "uint64_t":
return "MgbUI64Attr", value
if ctype == "int32_t":
return "MgbI32Attr", value
if ctype == "float":
return "MgbF32Attr", value
if ctype == "double":
return "MgbF64Attr", value
if ctype == "bool":
return "MgbBoolAttr", value
if ctype == "DTypeEnum":
self._packed = False self._packed = False
return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value)
return "MgbDTypeAttr", "megdnn::DType::from_enum(megdnn::{})".format(value)
raise RuntimeError("unknown ctype") raise RuntimeError("unknown ctype")


def _on_param_begin(self, p): def _on_param_begin(self, p):
@@ -61,21 +59,26 @@ class ConverterWriter(IndentWriterBase):
self._skip_current_param = False self._skip_current_param = False
return return
if self._packed: if self._packed:
self._write("class {0}ParamBase<string accessor> : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1)
self._write(
'class {0}ParamBase<string accessor> : MgbPackedParamBase<"{0}", accessor> {{'.format(
p.name
),
indent=1,
)
else: else:
self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1)
self._write('def {0}Param: MgbParamBase<"{0}"> {{'.format(p.name), indent=1)
self._write("let fields = (ins", indent=1) self._write("let fields = (ins", indent=1)
self._write(",\n{}".format(self._cur_indent).join(self._current_tparams)) self._write(",\n{}".format(self._cur_indent).join(self._current_tparams))
self._write(");", indent=-1) self._write(");", indent=-1)
self._write("}\n", indent=-1) self._write("}\n", indent=-1)
if self._packed: if self._packed:
self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name))
self._write('def {0}Param : {0}ParamBase<"param">;\n'.format(p.name))
self._current_tparams = None self._current_tparams = None
self._packed = None self._packed = None
self._const = None self._const = None


def _wrapped_with_default_value(self, attr, default): def _wrapped_with_default_value(self, attr, default):
return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default)
return 'MgbDefaultValuedAttr<{}, "{}">'.format(attr, default)


def _on_member_enum(self, e): def _on_member_enum(self, e):
p = self._last_param p = self._last_param
@@ -84,10 +87,12 @@ class ConverterWriter(IndentWriterBase):
# directly used by any operator, or other enum couldn't alias to this enum # directly used by any operator, or other enum couldn't alias to this enum
td_class = "{}{}".format(p.name, e.name) td_class = "{}{}".format(p.name, e.name)
fullname = "::megdnn::param::{}".format(p.name) fullname = "::megdnn::param::{}".format(p.name)
enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name)
enum_def = 'MgbEnumAttr<"{}", "{}", ['.format(fullname, e.name)

def format(v): def format(v):
return '\"{}\"'.format(str(v).split(' ')[0].split('=')[0])
enum_def += ','.join(format(i) for i in e.members)
return '"{}"'.format(str(v).split(" ")[0].split("=")[0])

enum_def += ",".join(format(i) for i in e.members)


if e.combined: if e.combined:
enum_def += "], 1" enum_def += "], 1"
@@ -95,7 +100,7 @@ class ConverterWriter(IndentWriterBase):
enum_def += "], 0" enum_def += "], 0"


if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)): if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)):
enum_def += ", 1" # whether generate ToStringTrait
enum_def += ", 1" # whether generate ToStringTrait
enum_def += ">" enum_def += ">"


self._write("def {} : {};".format(td_class, enum_def)) self._write("def {} : {};".format(td_class, enum_def))
@@ -105,10 +110,12 @@ class ConverterWriter(IndentWriterBase):
# wrapped with default value # wrapped with default value
if e.combined: if e.combined:
default_val = "static_cast<{}::{}>({})".format( default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, e.compose_combined_enum(e.default))
fullname, e.name, e.compose_combined_enum(e.default)
)
else: else:
default_val = "{}::{}::{}".format( default_val = "{}::{}::{}".format(
fullname, e.name, str(e.members[e.default]).split(' ')[0].split('=')[0])
fullname, e.name, str(e.members[e.default]).split(" ")[0].split("=")[0]
)


wrapped = self._wrapped_with_default_value(td_class, default_val) wrapped = self._wrapped_with_default_value(td_class, default_val)


@@ -123,51 +130,58 @@ class ConverterWriter(IndentWriterBase):
td_class = "{}{}".format(p.name, e.name) td_class = "{}{}".format(p.name, e.name)
fullname = "::megdnn::param::{}".format(p.name) fullname = "::megdnn::param::{}".format(p.name)
base_td_class = "{}{}".format(e.src_class, e.src_name) base_td_class = "{}{}".format(e.src_class, e.src_name)
enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class)
enum_def = 'MgbEnumAliasAttr<"{}", "{}", {}>'.format(
fullname, e.name, base_td_class
)
self._write("def {} : {};".format(td_class, enum_def)) self._write("def {} : {};".format(td_class, enum_def))


# wrapped with default value # wrapped with default value
s = e.src_enum s = e.src_enum
if s.combined: if s.combined:
default_val = "static_cast<{}::{}>({})".format( default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, s.compose_combined_enum(e.get_default()))
fullname, e.name, s.compose_combined_enum(e.get_default())
)
else: else:
default_val = "{}::{}::{}".format(fullname, e.name, str(
s.members[e.get_default()]).split(' ')[0].split('=')[0])
default_val = "{}::{}::{}".format(
fullname,
e.name,
str(s.members[e.get_default()]).split(" ")[0].split("=")[0],
)


wrapped = self._wrapped_with_default_value(td_class, default_val) wrapped = self._wrapped_with_default_value(td_class, default_val)


self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) self._current_tparams.append("{}:${}".format(wrapped, e.name_field))



def _on_member_field(self, f): def _on_member_field(self, f):
if self._skip_current_param: if self._skip_current_param:
return return
attr, value = self._ctype2attr(f.dtype.cname, str(f.default)) attr, value = self._ctype2attr(f.dtype.cname, str(f.default))
if str(value) in self._const: if str(value) in self._const:
value = '::megdnn::param::{}::{}'.format(self._last_param.name, value)
value = "::megdnn::param::{}::{}".format(self._last_param.name, value)
wrapped = self._wrapped_with_default_value(attr, value) wrapped = self._wrapped_with_default_value(attr, value)
self._current_tparams.append("{}:${}".format(wrapped, f.name)) self._current_tparams.append("{}:${}".format(wrapped, f.name))


def _on_const_field(self, f): def _on_const_field(self, f):
self._const.add(str(f.name)) self._const.add(str(f.name))



def main(): def main():
parser = argparse.ArgumentParser('generate op param tablegen file')
parser.add_argument('input')
parser.add_argument('output')
parser = argparse.ArgumentParser("generate op param tablegen file")
parser.add_argument("input")
parser.add_argument("output")
args = parser.parse_args() args = parser.parse_args()


with open(args.input) as fin: with open(args.input) as fin:
inputs = fin.read() inputs = fin.read()
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc})
input_hash = hashlib.sha256() input_hash = hashlib.sha256()
input_hash.update(inputs.encode(encoding='UTF-8'))
input_hash.update(inputs.encode(encoding="UTF-8"))
input_hash = input_hash.hexdigest() input_hash = input_hash.hexdigest()


writer = ConverterWriter() writer = ConverterWriter()
with open(args.output, 'w') as fout:
with open(args.output, "w") as fout:
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)



if __name__ == "__main__": if __name__ == "__main__":
main() main()

+ 17
- 8
tools/evaluation_model_parallelism.py View File

@@ -19,6 +19,7 @@ device = {
"thread_number": 3, "thread_number": 3,
} }



class SshConnector: class SshConnector:
"""imp ssh control master connector""" """imp ssh control master connector"""


@@ -83,17 +84,17 @@ def main():
model_file = args.model_file model_file = args.model_file
# copy model file # copy model file
ssh.copy([args.model_file], workspace) ssh.copy([args.model_file], workspace)
m = model_file.split('\\')[-1]
m = model_file.split("\\")[-1]
# run single thread # run single thread
result = [] result = []
thread_number = [1, 2, 4] thread_number = [1, 2, 4]
for b in thread_number :
for b in thread_number:
cmd = [] cmd = []
cmd1 = "cd {} && ./load_and_run {} -multithread {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format( cmd1 = "cd {} && ./load_and_run {} -multithread {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format(
workspace, m, b
workspace, m, b
) )
cmd2 = "cd {} && ./load_and_run {} -multithread {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess ".format( cmd2 = "cd {} && ./load_and_run {} -multithread {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess ".format(
workspace, m, b
workspace, m, b
) )
cmd.append(cmd1) cmd.append(cmd1)
cmd.append(cmd2) cmd.append(cmd2)
@@ -103,12 +104,20 @@ def main():
logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret)) logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret))
result.append(ret) result.append(ret)


thread_2 = result[0]/result[1]
thread_4 = result[0]/result[2]
thread_2 = result[0] / result[1]
thread_4 = result[0] / result[2]
if thread_2 > 1.6 or thread_4 > 3.0: if thread_2 > 1.6 or thread_4 > 3.0:
print("model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4))
print(
"model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format(
m, thread_2, thread_4
)
)
else: else:
print("model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4))
print(
"model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format(
m, thread_2, thread_4
)
)




if __name__ == "__main__": if __name__ == "__main__":


+ 17
- 13
tools/format.py View File

@@ -20,8 +20,12 @@ failed_files = Manager().list()
def process_file(file, clang_format, write): def process_file(file, clang_format, write):
original_source = open(file, "r").read() original_source = open(file, "r").read()
source = original_source source = original_source
source = re.sub(r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source)
source, count = re.subn(r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source)
source = re.sub(
r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source
)
source, count = re.subn(
r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source
)


result = subprocess.check_output( result = subprocess.check_output(
[ [
@@ -36,7 +40,9 @@ def process_file(file, clang_format, write):


result = result.decode("utf-8") result = result.decode("utf-8")
if count: if count:
result = re.sub(r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result)
result = re.sub(
r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result
)
result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result) result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result)


if write and original_source != result: if write and original_source != result:
@@ -109,19 +115,17 @@ def main():
raise ValueError("Invalid path {}".format(path)) raise ValueError("Invalid path {}".format(path))


# check version, we only support 12.0.1 now # check version, we only support 12.0.1 now
version = subprocess.check_output(
[
args.clang_format,
"--version",
],
)
version = subprocess.check_output([args.clang_format, "--version",],)
version = version.decode("utf-8") version = version.decode("utf-8")


need_version = '12.0.1'
need_version = "12.0.1"
if version.find(need_version) < 0: if version.find(need_version) < 0:
print('We only support {} now, please install {} version, find version: {}'
.format(need_version, need_version, version))
raise RuntimeError('clang-format version not equal {}'.format(need_version))
print(
"We only support {} now, please install {} version, find version: {}".format(
need_version, need_version, version
)
)
raise RuntimeError("clang-format version not equal {}".format(need_version))


process_map( process_map(
partial(process_file, clang_format=args.clang_format, write=args.write,), partial(process_file, clang_format=args.clang_format, write=args.write,),


+ 5
- 2
tools/test_model_static.py View File

@@ -20,6 +20,7 @@ device = {
"thread_number": 3, "thread_number": 3,
} }



class SshConnector: class SshConnector:
"""imp ssh control master connector""" """imp ssh control master connector"""


@@ -54,6 +55,7 @@ class SshConnector:
except: except:
raise raise



def main(): def main():
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model_file", help="megengine model", required=True) parser.add_argument("--model_file", help="megengine model", required=True)
@@ -78,10 +80,10 @@ def main():
model_file = args.model_file model_file = args.model_file
# copy model file # copy model file
ssh.copy([model_file], workspace) ssh.copy([model_file], workspace)
m = model_file.split('\\')[-1]
m = model_file.split("\\")[-1]
# run single thread # run single thread
cmd = "cd {} && ./load_and_run {} --fast-run --record-comp-seq --iter 1 --warmup-iter 1".format( cmd = "cd {} && ./load_and_run {} --fast-run --record-comp-seq --iter 1 --warmup-iter 1".format(
workspace, m
workspace, m
) )
try: try:
raw_log = ssh.cmd([cmd]) raw_log = ssh.cmd([cmd])
@@ -91,6 +93,7 @@ def main():


print("model: {} is static model.".format(m)) print("model: {} is static model.".format(m))



if __name__ == "__main__": if __name__ == "__main__":
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
DATE_FORMAT = "%Y/%m/%d %H:%M:%S" DATE_FORMAT = "%Y/%m/%d %H:%M:%S"


Loading…
Cancel
Save