You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimization.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from copy import deepcopy
  9. from typing import List, Set
  10. from ...logger import get_logger
  11. from ..traced_module import TracedModule
  12. from .pass_base import get_default_pass_context, get_registered_pass
  13. logger = get_logger(__name__)
  14. def optimize(
  15. module: TracedModule, enabled_pass: List[str] = ["FuseConvBn"],
  16. ) -> TracedModule:
  17. r"""Performs a set of optimization passes to optimize a `TracedModule` for inference.
  18. The following passes are currently supported:
  19. * FuseConvBn: fuse BN layers into to conv2d
  20. * FuseAddMul: fold adjacent const add or mul binary operations
  21. * BackwardFoldScale: backward fold const scaling into weights of conv2d
  22. Args:
  23. module: the :class:`TracedModule` to be optimized.
  24. enabled_pass: optimization passes to be enabled during optimization.
  25. Default: ["FuseConvBn"]
  26. Returns:
  27. the optimized :class:`TracedModule`.
  28. """
  29. defalut_passes_list = [
  30. "FuseConvBn",
  31. "FuseAddMul",
  32. ]
  33. if isinstance(enabled_pass, str):
  34. enabled_pass = [enabled_pass]
  35. if "BackwardFoldScale" in enabled_pass:
  36. if "FuseConvBn" not in enabled_pass:
  37. logger.warning(
  38. "Since BackwardFoldScale requires FuseConvBn"
  39. ", FuseConvBn will be enabled."
  40. )
  41. enabled_pass.append("FuseConvBn")
  42. defalut_passes_list.extend(
  43. ["BackwardFoldScale", "FuseAddMul",]
  44. )
  45. pass_ctx = get_default_pass_context()
  46. def run_pass(mod: TracedModule):
  47. for pass_name in defalut_passes_list:
  48. if pass_name in enabled_pass:
  49. pass_func = get_registered_pass(pass_name)()
  50. mod = pass_func(mod, pass_ctx)
  51. return mod
  52. module = deepcopy(module)
  53. module = run_pass(module)
  54. return module

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台