# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections import threading import megengine._internal as mgb from .device import get_default_device class _DefaultGraph(threading.local): r""" An implicit thread-local graph """ def __init__(self): super(_DefaultGraph, self).__init__() self._default_graph = None def get_default(self): r"""Returns a default Graph object for eager evaluation. """ if self._default_graph is None: self._default_graph = Graph() return self._default_graph _default_graph = _DefaultGraph() class Graph(mgb.CompGraph): r""" A computing graph that supporting context management. :param check_env_var: whether to check environment vars including ``MGB_COMP_GRAPH_OPT``. :param eager_evaluation: use dynamic graph(``True``) or static graph(``False``). Examples: .. testcode:: import numpy as np from megengine import tensor from megengine.core import Graph with Graph(eager_evaluation=True): x = tensor([1, 2]) print(x) Outputs: .. testoutput:: Tensor([1 2], dtype=int32) """ __saved_graph = None def __new__( cls, *, check_env_var: bool = True, eager_evaluation: bool = True, **kwargs ): kwargs.update(eager_evaluation=eager_evaluation) self = mgb.comp_graph(extra_opts=kwargs, check_env_var=check_env_var) self.__class__ = cls return self def __init__( self, *, check_env_var: bool = True, eager_evaluation: bool = True, **kwargs ): # pylint: disable=super-init-not-called pass def __enter__(self): self.__saved_graph = _default_graph._default_graph _default_graph._default_graph = self return self def __exit__(self, type, value, traceback): _default_graph._default_graph = self.__saved_graph del self.__saved_graph def _use_default_if_none(device, comp_graph): if device is None: device = get_default_device() if comp_graph is None: comp_graph = get_default_graph() return device, comp_graph def dump(outputs, fpath, optimize_options=None, **kwargs): r""" Serializes this computing graph and writes it to a file. :type outputs: ``Tensor`` or a collection of ``Tensor`` :param outputs: output variables that need to be retrieved when deserializing :type fpath: ``str`` :param fpath: path for the output file :type optimize_options: ``list`` :param optimize_options: ``['f16_io_f32_comp', 'f16_io_comp', 'use_nhwcd4', 'fuse_conv_bias_nonlinearity']`` , four elements are optional, it can be an empty list, None or a list containing any of them. .. note:: ``f16_io_f32_comp`` – whether to use float16 for I/O between oprs and use float32 as internal computation precision. Note the output var would be changed to float16; ``f16_io_comp`` – whether to use float16 for both I/O and computation precision; ``use_nhwcd4`` – whether to use NHWCD4 data format. This is faster on some OpenCL devices; ``fuse_conv_bias_nonlinearity`` – whether to fuse conv+bias+nonlinearty into one opr. This is supported only when ``use_nhwcd4`` is set. """ from .tensor import Tensor assert optimize_options is None or isinstance( optimize_options, list ), "optimize_options must be a list" if isinstance(outputs, Tensor): outputs = [outputs] else: assert isinstance(outputs, collections.Iterable), "{} not iterable".format( outputs ) outputs = list(outputs) for output in outputs: assert isinstance(output, Tensor), "All outputs must be Tensors." outputs = [o._symvar for o in outputs] if optimize_options: opt_dict = dict.fromkeys(optimize_options, True) mgb.optimize_for_inference(outputs, **opt_dict) mgb.serialize_comp_graph_to_file(fpath, outputs, **kwargs) def set_default_graph(default_graph): r""" Sets a global default Graph object. """ global _default_graph # pylint: disable=global-statement _default_graph._default_graph = default_graph def get_default_graph(): r""" Returns a default Graph object, most probably for eager evaluation. """ return _default_graph.get_default()