|
|
@@ -994,7 +994,8 @@ void initialize_${configuration_name}(Manifest &manifest) { |
|
|
|
################################################################################################### |
|
|
|
|
|
|
|
class EmitGemmSingleKernelWrapper: |
|
|
|
def __init__(self, kernel_path, gemm_operation): |
|
|
|
def __init__(self, kernel_path, gemm_operation, short_path=False): |
|
|
|
self.short_path = short_path |
|
|
|
self.kernel_path = kernel_path |
|
|
|
self.operation = gemm_operation |
|
|
|
|
|
|
@@ -1070,10 +1071,11 @@ void initialize_${operation_name}(Manifest &manifest) { |
|
|
|
################################################################################################### |
|
|
|
|
|
|
|
class EmitGemvSingleKernelWrapper: |
|
|
|
def __init__(self, kernel_path, gemm_operation, wrapper_path): |
|
|
|
def __init__(self, kernel_path, gemm_operation, wrapper_path, short_path=False): |
|
|
|
self.kernel_path = kernel_path |
|
|
|
self.wrapper_path = wrapper_path |
|
|
|
self.operation = gemm_operation |
|
|
|
self.short_path = short_path |
|
|
|
|
|
|
|
self.wrapper_template = """ |
|
|
|
template void megdnn::cuda::cutlass_wrapper:: |
|
|
@@ -1107,7 +1109,11 @@ ${operation_instance} |
|
|
|
""" |
|
|
|
# |
|
|
|
def __enter__(self): |
|
|
|
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) |
|
|
|
if self.short_path: |
|
|
|
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt) |
|
|
|
GlobalCnt.cnt += 1 |
|
|
|
else: |
|
|
|
self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name()) |
|
|
|
self.kernel_file = open(self.kernel_path, "w") |
|
|
|
self.kernel_file.write(SubstituteTemplate(self.header_template, { |
|
|
|
'wrapper_path': self.wrapper_path, |
|
|
|