You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimization.py 1.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from copy import deepcopy
  2. from typing import List, Set
  3. from ...logger import get_logger
  4. from ..traced_module import TracedModule
  5. from .pass_base import get_default_pass_context, get_registered_pass
  6. logger = get_logger(__name__)
  7. def optimize(
  8. module: TracedModule, enabled_pass: List[str] = ["FuseConvBn"],
  9. ) -> TracedModule:
  10. r"""Performs a set of optimization passes to optimize a `TracedModule` for inference.
  11. The following passes are currently supported:
  12. * FuseConvBn: fuse BN layers into to conv2d
  13. * FuseAddMul: fold adjacent const add or mul binary operations
  14. * BackwardFoldScale: backward fold const scaling into weights of conv2d
  15. Args:
  16. module: the :class:`TracedModule` to be optimized.
  17. enabled_pass: optimization passes to be enabled during optimization.
  18. Default: ["FuseConvBn"]
  19. Returns:
  20. the optimized :class:`TracedModule`.
  21. """
  22. defalut_passes_list = [
  23. "FuseConvBn",
  24. "FuseAddMul",
  25. ]
  26. if isinstance(enabled_pass, str):
  27. enabled_pass = [enabled_pass]
  28. if "BackwardFoldScale" in enabled_pass:
  29. if "FuseConvBn" not in enabled_pass:
  30. logger.warning(
  31. "Since BackwardFoldScale requires FuseConvBn"
  32. ", FuseConvBn will be enabled."
  33. )
  34. enabled_pass.append("FuseConvBn")
  35. defalut_passes_list.extend(
  36. ["BackwardFoldScale", "FuseAddMul",]
  37. )
  38. pass_ctx = get_default_pass_context()
  39. def run_pass(mod: TracedModule):
  40. for pass_name in defalut_passes_list:
  41. if pass_name in enabled_pass:
  42. pass_func = get_registered_pass(pass_name)()
  43. mod = pass_func(mod, pass_ctx)
  44. return mod
  45. module = deepcopy(module)
  46. module = run_pass(module)
  47. return module