You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

comp_graph_tools.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. """tools for graph manipulation"""
  10. import collections
  11. from . import mgb as _mgb
  12. def get_dep_vars(var, var_type=None):
  13. """return :class:`.SymbolVar` of type ``var_type`` that input ``var``
  14. depands on. If ``var_type`` is None, return all types.
  15. :type var: an instance or iterable of :class:`.SymbolVar`
  16. :type var_type: ``str`` or an iterable of ``str``
  17. "rtype: list of :class:`.SymbolVar`
  18. """
  19. outputs = []
  20. memo = set()
  21. if isinstance(var, _mgb.SymbolVar):
  22. var = [var]
  23. if isinstance(var_type, str):
  24. var_type = [var_type]
  25. q = list(var)
  26. while q:
  27. v = q.pop()
  28. if v in memo:
  29. continue
  30. memo.add(v)
  31. q.extend(get_inputs(v))
  32. if var_type is not None:
  33. if get_type(v) in var_type:
  34. outputs.append(v)
  35. else:
  36. outputs.append(v)
  37. return outputs
  38. def get_inputs(var):
  39. """get the inputs of owner opr of a variable
  40. :type var: :class:`.SymbolVar`
  41. :rtype: list of :class:`.SymbolVar`
  42. """
  43. assert isinstance(var, _mgb.SymbolVar)
  44. return _mgb._get_owner_opr_inputs(var)
  45. def get_type(var):
  46. """get the type of owner opr of a variable
  47. :type var: :class:`.SymbolVar`
  48. :rtype: ``str``
  49. """
  50. assert isinstance(var, _mgb.SymbolVar)
  51. return _mgb._get_owner_opr_type(var)
  52. def replace_vars(dst, varmap):
  53. """replace vars in the graph
  54. :param dst: target vars representing the graph
  55. :type dst: list of :class:`.SymbolVar`
  56. :param varmap: the map that specifies how to replace the vars
  57. :type varmap: dict that maps from src var to dst var
  58. :return: new vars that correspond to ``dst`` with all the dependencies
  59. replaced
  60. :rtype: list of :class:`.SymbolVar`
  61. """
  62. dst_vec = _mgb._VectorSymbolVar()
  63. repl_src_vec = _mgb._VectorSymbolVar()
  64. repl_dst_vec = _mgb._VectorSymbolVar()
  65. for i in dst:
  66. assert isinstance(i, _mgb.SymbolVar)
  67. dst_vec.push_back(i)
  68. for i, j in getattr(varmap, "items", lambda: varmap)():
  69. assert isinstance(i, _mgb.SymbolVar)
  70. assert isinstance(j, _mgb.SymbolVar)
  71. repl_src_vec.push_back(i)
  72. repl_dst_vec.push_back(j)
  73. return _mgb._replace_vars(repl_src_vec, repl_dst_vec, dst_vec)
  74. def replace_oprs(dst, oprmap):
  75. """Replace operators in the graph. Roughly equivalent to
  76. :param dst: target vars representing the graph
  77. :type dst: list of :class:`.SymbolVar`
  78. :param oprmap: the map that specifies how to replace the operators
  79. :type oprmap: dict that maps from src operator to dst operator
  80. :return: new vars that correspond to ``dst`` with all the dependencies
  81. replaced
  82. :rtype: list of :class:`.SymbolVar`
  83. """
  84. dst_vec = _mgb._VectorSymbolVar()
  85. repl_src_vec = _mgb._VectorOperator()
  86. repl_dst_vec = _mgb._VectorOperator()
  87. for i in dst:
  88. assert isinstance(i, _mgb.SymbolVar)
  89. dst_vec.push_back(i)
  90. for i, j in getattr(oprmap, "items", lambda: oprmap)():
  91. assert isinstance(i, _mgb.Operator)
  92. assert isinstance(j, _mgb.Operator)
  93. repl_src_vec.push_back(i)
  94. repl_dst_vec.push_back(j)
  95. return _mgb._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台