|
|
@@ -185,6 +185,18 @@ class HeaderGen: |
|
|
|
).stdout.decode('utf-8') |
|
|
|
self._fout.write('// midout \n') |
|
|
|
self._fout.write(cvt) |
|
|
|
if cvt.find(" half,"): |
|
|
|
change = open(self._fout.name).read().replace(" half,", " __fp16,") |
|
|
|
with open("fix_fp16_bin_reduce.h", "w") as fix_fp16: |
|
|
|
fix_fp16.write(change) |
|
|
|
msg = ( |
|
|
|
"WARNING:\n" |
|
|
|
"hit half in trace, try use fix_fp16_bin_reduce.h when build failed with bin_reduce.h\n" |
|
|
|
"which caused by LLVM mangle issue on __fp16 dtype, if you find msg 'error: use of undeclared identifier 'half'\n" |
|
|
|
"then try use fix_fp16_bin_reduce.h, if build failed again, submit a issue to Engine team!!!" |
|
|
|
) |
|
|
|
print(msg) |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
parser = argparse.ArgumentParser( |
|
|
|