|
|
@@ -185,7 +185,7 @@ class HeaderGen: |
|
|
|
).stdout.decode('utf-8') |
|
|
|
self._fout.write('// midout \n') |
|
|
|
self._fout.write(cvt) |
|
|
|
if cvt.find(" half,"): |
|
|
|
if cvt.find(" half,") > 0: |
|
|
|
change = open(self._fout.name).read().replace(" half,", " __fp16,") |
|
|
|
with open("fix_fp16_bin_reduce.h", "w") as fix_fp16: |
|
|
|
fix_fp16.write(change) |
|
|
|