You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import threading
  11. import megengine._internal as mgb
  12. from .device import get_default_device
  13. class _DefaultGraph(threading.local):
  14. r"""
  15. An implicit thread-local graph
  16. """
  17. def __init__(self):
  18. super(_DefaultGraph, self).__init__()
  19. self._default_graph = None
  20. def get_default(self):
  21. r"""Returns a default Graph object for eager evaluation.
  22. """
  23. if self._default_graph is None:
  24. self._default_graph = Graph()
  25. return self._default_graph
  26. _default_graph = _DefaultGraph()
  27. class Graph(mgb.CompGraph):
  28. r"""
  29. A computing graph that supporting context management.
  30. :param check_env_var: whether to check environment vars including ``MGB_COMP_GRAPH_OPT``.
  31. :param eager_evaluation: use dynamic graph(``True``) or static graph(``False``).
  32. Examples:
  33. .. testcode::
  34. import numpy as np
  35. from megengine import tensor
  36. from megengine.core import Graph
  37. with Graph(eager_evaluation=True):
  38. x = tensor([1, 2])
  39. print(x)
  40. Outputs:
  41. .. testoutput::
  42. Tensor([1 2], dtype=int32)
  43. """
  44. __saved_graph = None
  45. def __new__(
  46. cls, *, check_env_var: bool = True, eager_evaluation: bool = True, **kwargs
  47. ):
  48. kwargs.update(eager_evaluation=eager_evaluation)
  49. self = mgb.comp_graph(extra_opts=kwargs, check_env_var=check_env_var)
  50. self.__class__ = cls
  51. return self
  52. def __init__(
  53. self, *, check_env_var: bool = True, eager_evaluation: bool = True, **kwargs
  54. ):
  55. # pylint: disable=super-init-not-called
  56. pass
  57. def __enter__(self):
  58. self.__saved_graph = _default_graph._default_graph
  59. _default_graph._default_graph = self
  60. return self
  61. def __exit__(self, type, value, traceback):
  62. _default_graph._default_graph = self.__saved_graph
  63. del self.__saved_graph
  64. def _use_default_if_none(device, comp_graph):
  65. if device is None:
  66. device = get_default_device()
  67. if comp_graph is None:
  68. comp_graph = get_default_graph()
  69. return device, comp_graph
  70. def dump(outputs, fpath, optimize_options=None, **kwargs):
  71. r"""
  72. Serializes this computing graph and writes it to a file.
  73. :type outputs: ``Tensor`` or a collection of ``Tensor``
  74. :param outputs: output variables that need to be retrieved when
  75. deserializing
  76. :type fpath: ``str``
  77. :param fpath: path for the output file
  78. :type optimize_options: ``list``
  79. :param optimize_options: ``['f16_io_f32_comp', 'f16_io_comp', 'use_nhwcd4', 'fuse_conv_bias_nonlinearity']`` , four elements are optional, it can be an empty list, None or a list containing any of them.
  80. .. note::
  81. ``f16_io_f32_comp`` – whether to use float16 for I/O between oprs and use float32 as internal computation precision. Note the output var would be changed to float16;
  82. ``f16_io_comp`` – whether to use float16 for both I/O and computation precision;
  83. ``use_nhwcd4`` – whether to use NHWCD4 data format. This is faster on some OpenCL devices;
  84. ``fuse_conv_bias_nonlinearity`` – whether to fuse conv+bias+nonlinearty into one opr. This is supported only when ``use_nhwcd4`` is set.
  85. """
  86. from .tensor import Tensor
  87. assert optimize_options is None or isinstance(
  88. optimize_options, list
  89. ), "optimize_options must be a list"
  90. if isinstance(outputs, Tensor):
  91. outputs = [outputs]
  92. else:
  93. assert isinstance(outputs, collections.Iterable), "{} not iterable".format(
  94. outputs
  95. )
  96. outputs = list(outputs)
  97. for output in outputs:
  98. assert isinstance(output, Tensor), "All outputs must be Tensors."
  99. outputs = [o._symvar for o in outputs]
  100. if optimize_options:
  101. opt_dict = dict.fromkeys(optimize_options, True)
  102. mgb.optimize_for_inference(outputs, **opt_dict)
  103. mgb.serialize_comp_graph_to_file(fpath, outputs, **kwargs)
  104. def set_default_graph(default_graph):
  105. r"""
  106. Sets a global default Graph object.
  107. """
  108. global _default_graph # pylint: disable=global-statement
  109. _default_graph._default_graph = default_graph
  110. def get_default_graph():
  111. r"""
  112. Returns a default Graph object, most probably for eager evaluation.
  113. """
  114. return _default_graph.get_default()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台