|
|
@@ -65,9 +65,9 @@ def quantize(module: Module, inplace=True): |
|
|
|
def is_qat(mod: Module): |
|
|
|
return isinstance(mod, qat_modules) |
|
|
|
|
|
|
|
# no need to pass prefix and get pure key of parent Module. |
|
|
|
for key, submodule, parent in module._flatten( |
|
|
|
with_key=True, with_parent=True, predicate=is_qat |
|
|
|
# must use list to avoid replacement influencing successor modules |
|
|
|
for key, submodule, parent in list( |
|
|
|
module._flatten(with_key=True, with_parent=True, predicate=is_qat) |
|
|
|
): |
|
|
|
new_mod = _qat2quantized_dict[type(submodule)].from_qat_module(submodule) |
|
|
|
if isinstance(parent, Float.Sequential): |
|
|
@@ -100,9 +100,9 @@ def quantize_qat( |
|
|
|
def is_quantable(mod: Module): |
|
|
|
return isinstance(mod, quantable_modules) |
|
|
|
|
|
|
|
# no need to pass prefix and get pure key of parent Module. |
|
|
|
for key, submodule, parent in module._flatten( |
|
|
|
with_key=True, with_parent=True, predicate=is_quantable |
|
|
|
# must use list to avoid replacement influencing successor modules |
|
|
|
for key, submodule, parent in list( |
|
|
|
module._flatten(with_key=True, with_parent=True, predicate=is_quantable) |
|
|
|
): |
|
|
|
# only convert top quantable module. |
|
|
|
if is_quantable(parent): |
|
|
|