|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import collections
- import contextlib
- import functools
- import itertools
- import json
- import os
- import pickle
- import re
- import struct
- from typing import Any
-
- import cv2
- import numpy as np
-
- from megengine.logger import get_logger
-
- from .. import tensor
- from ..core import _imperative_rt as rt
- from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata
- from ..core._imperative_rt.core2 import Tensor as RawTensor
- from ..core._imperative_rt.core2 import (
- TensorWeakRef,
- apply,
- set_tracing,
- skip_tracing,
- unset_tracing,
- )
- from ..core._imperative_rt.ops import (
- AssertEqual,
- CollectiveComm,
- ExternOpr,
- RemoteRecv,
- RemoteSend,
- )
- from ..core._trace_option import set_symbolic_shape
- from ..core._wrap import as_device
- from ..core.ops.builtin import BatchNorm, OpDef
- from ..core.tensor import megbrain_graph as G
- from ..core.tensor.utils import setscalar
- from ..utils import comp_graph_tools as cgtools
- from ..utils.naming import AutoNaming
- from ..utils.profiler import is_profiling
- from .dtr_config import DTRConfig
- from .graph_opt_config import GraphOptimizationConfig
- from .sublinear_memory_config import SublinearMemoryConfig
-
- logger = get_logger(__name__)
-
-
- def _input_node_use_static_shape():
- return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None
-
-
- class TraceMismatchError(RuntimeError):
- pass
-
-
- active_trace = None
-
-
- def is_tracing():
- if active_trace is None:
- return False
- else:
- return not skip_tracing
-
-
- @contextlib.contextmanager
- def exclude_from_trace():
- global skip_tracing
- if skip_tracing or (active_trace is None):
- yield
- return
- try:
- skip_tracing = True
- unset_tracing()
- if active_trace is not None:
- active_trace._begin_excluded_region()
- yield
- finally:
- skip_tracing = False
- set_tracing()
-
-
- class TensorInfo:
- __slots__ = (
- # collected attributes
- "name",
- "external",
- "data_read",
- "shape_read",
- "value_read",
- "exported",
- "device",
- "dtype",
- "shape",
- "is_const",
- "bound_data",
- # resources for execution
- "varnode",
- "data_setter",
- "shape_reader",
- "value_reader",
- "data_reader",
- )
-
- def __init__(self):
- self.name = None
- self.exported = None
- self.data_read = None
- self.shape_read = None
- self.value_read = None
- self.bound_data = None
-
- self.data_setter = None
- self.shape_reader = None
- self.value_reader = None
- self.data_reader = None
-
-
- _io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv}
-
-
- class trace:
- """Wraps a callable and provide:
-
- * tracing via :meth:`.trace` and :meth:`.dump`
- * accelerated evalutaion via :meth:`.__call__`
-
- Args:
- function: the function will be traced.
- symbolic: whether to apply symbolic execution for tracing. Default: False
- capture_as_const: capture global vars or closures as const value. Default: False
- record_only: if True, won't run even if call the function. Default: False
- sublinear_memory_config: configuration for sublinear memory optimization.
- If not None, it enables sublinear memory optimization with given setting.
- profiling: whether to profile compiled trace. Default: False
- opt_level: optimization level for compiling trace. Default: 2
- graph_opt_config: configuration for graph optimization. Default: None
- symbolic_shape: whether to use symbolic shape for tracing. Default: True
- """
-
- def __new__(cls, *args, **kwargs):
- if not args:
- return functools.partial(cls, **kwargs)
- return super().__new__(cls)
-
- def __init__(
- self,
- function,
- symbolic=False,
- capture_as_const=False,
- record_only=False,
- sublinear_memory_config: SublinearMemoryConfig = None,
- dtr_config: DTRConfig = None,
- profiling: bool = False,
- opt_level: int = 2,
- graph_opt_config: GraphOptimizationConfig = None,
- symbolic_shape: bool = True,
- ):
- self.__wrapped__ = function
- self._symbolic = symbolic or record_only
- self._capture_as_const = capture_as_const or record_only
- self._record_only = record_only
- self._sublinear_memory_config = sublinear_memory_config
- self._dtr_config = dtr_config
- self._profiling = profiling
- self._profiler = None
- self._profiler2 = None
- self._graph_opt_level = opt_level
- self._graph_opt_config = graph_opt_config
- self._symbolic_shape = symbolic_shape
- self._output_handles = set()
-
- self._reset()
-
- def _reset(self):
- self._untraced = True
- self._tinfo = [] # handle -> TensorInfo
- self._seq = []
- self._pc = 0
- self._graph = None
- self._need_reset_nodes = None
- self._lazy_eval_graph = None
- self._lazy_eval_tensors = set()
- self._lazy_eval_links = None
- self._active_tensors = set()
- self._tensor_remaps = None
- self._inputs_to_restore = None
- self._arg_bindings = None
- self._kwarg_bindings = None
- self._output_bindings = None
- self._output_names = None
-
- def _new_handle(self):
- handle = len(self._tinfo)
- info = TensorInfo()
- self._tinfo.append(info)
- return handle, info
-
- def _apply_op(self, op, args):
- assert not self._untraced
- # check against trace
- if self._pc >= len(self._seq):
- raise TraceMismatchError("trace should end here, but more op observed")
- record = self._seq[self._pc]
- op_, ihandles, ohandles = record
- if (isinstance(op_, str) and op_ == "Const") or (op != op_):
- raise TraceMismatchError("op different from last time")
- if len(ihandles) != len(args):
- raise TraceMismatchError("op input size different from last time")
-
- # check all inputs of crrent op
- for h, x in zip(ihandles, args):
- info = self._tinfo[h]
- if info.external:
- if (
- x._compiled_info is not None
- and not self._tinfo[x._mixin_handle].exported
- ):
- raise TraceMismatchError(
- "failed to capture: input was an external tensor "
- "last time, got an internal tensor this time"
- )
- if info.bound_data:
- if x._compiled_info is not None:
- raise TraceMismatchError(
- "const capture violated: was an external tensor "
- "last time, got an internal tensor this time"
- )
- if x._handle != info.bound_data._handle:
- if not np.array_equal(x.numpy(), info.bound_data.numpy()):
- raise TraceMismatchError(
- "const capture violated: got "
- "a different tensor this time"
- )
- else:
- if info.dtype != x.dtype:
- raise TraceMismatchError(
- "failed to capture: different dtype from last time"
- )
- if info.device != x.device:
- raise TraceMismatchError(
- "failed to capture: different device from last time"
- )
- info.data_setter.set_value(x._dev_tensor())
- else:
- if x._mixin_handle == -1:
- if x._handle not in self._tensor_remaps:
- raise TraceMismatchError(
- "unexpected capture: trying to use an external tensor as "
- "input, but that input was an internal tensor last time"
- )
- else:
- x._mixin_handle = self._tensor_remaps[
- x._handle
- ]._CompiledTensorProxy__handle
- if x._mixin_handle != h:
- raise TraceMismatchError(
- "mis-wiring: input edge to an data flow "
- "graph node is different from last time"
- )
-
- self._pc += 1
- outputs = []
- for h in ohandles:
- info = self._tinfo[h]
- # generate output tensor and create compied info
- y = RawTensor(info.varnode)
- y._compiled_info = CompiledTensorProxy(h)
- y._mixin_handle = h
- outputs += [y]
- self._active_tensors.add(TensorWeakRef(y))
- self._output_handles.update(ohandles)
- return outputs
-
- def _apply_const(self, value, dtype, device):
- assert not self._untraced
- # check against trace
- if self._pc >= len(self._seq):
- raise TraceMismatchError("trace should end here, but more op observed")
- record = self._seq[self._pc]
- op_, ihandles, ohandles = record
- # Const op is represented by a str
- assert isinstance(op_, str) and op_ == "Const"
-
- expected = self._tinfo[ohandles[0]].bound_data.numpy()
- shape = value.shape
- if shape != expected.shape or dtype != expected.dtype:
- eq = False
- elif shape == ():
- eq = expected.item() == value.item()
- elif shape == (1,):
- eq = expected[0] == value[0]
- else:
- eq = np.all(value == expected)
- if not eq:
- raise TraceMismatchError(
- "const tensor violated: got a different tensor this time"
- )
-
- self._pc += 1
- (h,) = ohandles
- outputs = [self._tinfo[h].bound_data]
- return outputs
-
- # run in first step, record information for trace
- def _record_op(self, op, inputs, outputs):
- if skip_tracing:
- for x in inputs:
- h = getattr(x, "_mixin_handle", -1)
- if h >= 0:
- self._tinfo[h].data = True
- return
-
- ihandles = []
- for x in inputs:
- h = getattr(x, "_mixin_handle", -1)
- if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
- h, info = self._new_handle()
- name = AutoNaming.gen_name(x)
- info.name = name
- info.external = True
- info.device = x.device
- info.dtype = x.dtype
- info.shape = x.shape
- if self._capture_as_const:
- info.bound_data = RawTensor(
- x.numpy(), x.dtype, x.device, False, name
- )
-
- ihandles.append(h)
-
- ohandles = []
- for x in outputs:
- h, info = self._new_handle()
- ohandles.append(h)
- info.external = False
- x._mixin_handle = h
- x._recording = True
- x._trace_mixin_info = info
- self._active_tensors.add(TensorWeakRef(x))
- if self._symbolic:
- self._lazy_eval_tensors.add(TensorWeakRef(x))
-
- self._seq.append((op, tuple(ihandles), tuple(ohandles)))
-
- def _record_const(self, outputs):
- if skip_tracing:
- (x,) = outputs
- h = getattr(x, "_mixin_handle", -1)
- if h >= 0:
- self._tinfo[h].data_read = True
- return
-
- (x,) = outputs
- h, info = self._new_handle()
- ohandles = [h]
- info.external = True
- info.device = x.device
- info.dtype = x.dtype
- info.shape = x.shape
- info.bound_data = x
- info.is_const = True
- x._mixin_handle = h
- x._recording = True
- x._trace_mixin_info = info
- if self._symbolic:
- self._lazy_eval_tensors.add(TensorWeakRef(x))
- self._seq.append(("Const", tuple(), tuple(ohandles)))
-
- def _set_active(self, active: bool):
- global active_trace
- if active:
- if active_trace:
- raise NotImplementedError("sorry, not implemented: nested trace")
- active_trace = self
- else:
- assert active_trace is self
- active_trace = None
-
- def _init_trace(self, symbolic: bool):
- if symbolic:
- self._lazy_eval_graph = G.Graph()
- self._apply_graph_options(self._lazy_eval_graph)
- self._lazy_eval_links = ()
-
- def _take_escaped_tensors(self):
- escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors))
- self._active_tensors.clear()
- return escaped_tensors
-
- def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
- lazy_eval_tensors = [x() for x in lazy_eval_tensors]
- lazy_eval_tensors = [x for x in lazy_eval_tensors if x is not None]
- readers = [G.OutputNode(x._varnode).outputs[0] for x in lazy_eval_tensors]
- self._apply_graph_options(lazy_eval_graph)
- lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
- lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers])
- lazy_eval_graph.compile(*lazy_eval_links, *readers)
- self._execute_graph(lazy_eval_graph)
- lazy_eval_graph.wait()
- for r, x in zip(readers, lazy_eval_tensors):
- # get values from lazy_eval_graph and assign to lazy_eval tensor
- x._handle = RawTensor(r.op.get_value())._handle
- x._reset_varnode()
-
- @contextlib.contextmanager
- def _setup(self):
- interrupted = False
-
- def do_enter():
- set_tracing()
- self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
- self._set_active(True)
- if self._untraced:
- self._init_trace(self._symbolic)
- else:
- if self._graph is None:
- self._compile()
- self._execute_graph(self._graph)
-
- def do_finalize():
- escaped_tensors = self._take_escaped_tensors()
- if self._untraced:
- if self._record_only:
- self._lazy_eval_graph = None
- self._lazy_eval_tensors = None
- self._lazy_eval_links = None
- else:
- for x in escaped_tensors:
- if x():
- info = self._tinfo[x()._mixin_handle]
- info.data_read = True
- x()._mixin_handle = -1
- x()._recording = False
- if self._inputs_to_restore:
- for x in self._inputs_to_restore:
- x._mixin_handle = -1
- x._recording = False
- if self._symbolic and (
- self._lazy_eval_tensors or self._lazy_eval_links
- ):
- # eval lazy eval tensors
- self._lazy_eval(
- self._lazy_eval_graph,
- self._lazy_eval_tensors,
- self._lazy_eval_links,
- )
- self._lazy_eval_graph = None
- self._lazy_eval_tensors = None
- self._lazy_eval_links = None
- self._untraced = False
- else:
- # compiled_tensor leaks
- if self._pc == len(self._seq):
- for x in escaped_tensors:
- try:
- x().__init__(RawTensor(x()._dev_tensor()))
- except RuntimeError:
- # TraceMismatchError thrown in do_exit
- pass
- self._graph.wait()
- self._reset_exec_env()
-
- # reset status
- self._pc = 0
- self._tensor_remaps = None
- self._set_active(False)
- set_symbolic_shape(self._save_symbolic_shape)
- unset_tracing()
-
- def do_exit():
- unset_tracing()
- if not self._untraced and self._pc != len(self._seq):
- raise TraceMismatchError("premature end")
- if not self._symbolic or not self._untraced:
- # reset output tensors
- for x in self._active_tensors.copy():
- strong_x = x()
- if strong_x is not None:
- strong_x._dev_tensor()
- strong_x._reset_varnode()
- strong_x._mixin_handle = -1
- strong_x._recording = False
- strong_x._trace_mixin_info = None
-
- try:
- do_enter()
- yield
- do_exit()
- except:
- interrupted = True
- raise
- finally:
- do_finalize()
- if interrupted:
- self._reset()
-
- def _begin_excluded_region(self):
- if self._capture_as_const:
- raise RuntimeError(
- "exclude_from_trace cannot be used with capture_as_const"
- )
- if self._untraced:
- # conditionally reading a compiled tensor in excluded region
- # is permitted, so we have to assume every tensor might be read
- for x in self._active_tensors:
- strong_x = x()
- if strong_x:
- info = self._tinfo[strong_x._mixin_handle]
- info.exported = True
- info.data_read = True
- else:
- for x in self._active_tensors:
- strong_x = x()
- if strong_x:
- strong_x._dev_tensor()
-
- def _apply_graph_options(self, graph):
-
- graph.options.no_force_inplace = True
- graph.options.seq_opt.enable_seq_comp_node_opt = False
- graph.options.graph_opt_level = self._graph_opt_level
- if self._dtr_config is not None:
- graph.options.enable_dtr_memory_opt = True
- graph.options.dtr_config.eviction_threshold = (
- self._dtr_config.eviction_threshold
- )
- graph.options.dtr_config.evictee_minimum_size = (
- self._dtr_config.evictee_minimum_size
- )
- graph.options.dtr_config.recomp_memory_factor = (
- self._dtr_config.recomp_memory_factor
- )
- graph.options.dtr_config.recomp_time_factor = (
- self._dtr_config.recomp_time_factor
- )
- # graph optimization
- if self._graph_opt_config is not None:
- mapping = {None: 0, False: 1, True: 2}
- jit_config = graph.options.graph_opt.jit_config
- jit_config.fuse_dimshuffle = mapping[
- self._graph_opt_config.jit_fuse_dimshuffle
- ]
- jit_config.fuse_reduce = mapping[self._graph_opt_config.jit_fuse_reduce]
- # sublinear
- if self._sublinear_memory_config is not None:
- graph.options.enable_sublinear_memory_opt = True
- sublinear_config = graph.options.sublinear_mem_config
- sublinear_config.lb_memory_mb = self._sublinear_memory_config.lb_memory_mb
- sublinear_config.genetic_nr_iter = (
- self._sublinear_memory_config.genetic_nr_iter
- )
- sublinear_config.genetic_pool_size = (
- self._sublinear_memory_config.genetic_pool_size
- )
- sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
- sublinear_config.num_worker = self._sublinear_memory_config.num_worker
- # profile
- if self._profiling:
- self._profiler = GraphProfiler(graph)
- self._profiler2 = None
- if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")):
- graph.options.var_sanity_check_first_run = False
-
- def _execute_graph(self, graph: G.Graph, *args):
- if is_profiling() and (self._profiler2 is None):
- self._profiler2 = GraphProfiler2(graph)
- elif not is_profiling() and (self._profiler2 is not None):
- self._profiler2 = None
- graph.execute(*args)
-
- def _compile(self):
- graph = self._graph = G.Graph()
- graph.options.async_exec_level = 0b100
- self._apply_graph_options(graph)
- need_reset_nodes = self._need_reset_nodes = []
- # links enforce ordering of I/O nodes
- in_out_links = ()
- io_links = ()
- readers = []
-
- if self._capture_as_const:
- for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
- info = self._tinfo[h]
- opnode = info.data_setter = G.InputNode(
- device=info.device,
- dtype=info.dtype,
- shape=info.shape or (1,),
- graph=graph,
- use_static_shape=_input_node_use_static_shape(),
- )
- need_reset_nodes.append(opnode)
- info.varnode = opnode.outputs[0]
- in_out_links += opnode.outputs[1:]
-
- for op, ihandles, ohandles in self._seq:
- if isinstance(op, str) and op == "Const":
- assert len(ihandles) == 0
- (h,) = ohandles
- info = self._tinfo[h]
- if not hasattr(info, "varnode"):
- assert info.external
- assert info.bound_data
- info.varnode = graph.make_const(
- info.bound_data.numpy(),
- info.bound_data.dtype,
- info.bound_data.device,
- )
- continue
-
- require_links = type(op) in _io_op_types
- ivars = []
- for i, h in enumerate(ihandles):
- info = self._tinfo[h]
- if not hasattr(info, "varnode"):
- assert info.external
- if info.bound_data:
- if getattr(info, "is_const", False):
- info.varnode = graph.make_const(
- info.bound_data.numpy(),
- info.bound_data.dtype,
- info.bound_data.device,
- )
- else:
- info.varnode = graph.make_const(
- info.bound_data._dev_tensor()
- # info.bound_data.numpy()
- )
- else:
- opnode = info.data_setter = G.InputNode(
- *in_out_links,
- device=info.device,
- dtype=info.dtype,
- shape=info.shape or (1,),
- graph=graph,
- use_static_shape=_input_node_use_static_shape(),
- )
- need_reset_nodes.append(opnode)
- info.varnode, *in_out_links = opnode.outputs
- if require_links and i == 0 and len(io_links) > 0:
- opnode = G.VirtualDepNode(
- [info.varnode, *io_links], str(io_links[0].device)
- )
- info.varnode = opnode.outputs[0]
- io_links = (info.varnode,)
-
- ivars.append(info.varnode)
-
- ovars = G.apply_normal_varnode(op, *ivars)
-
- if require_links and len(ovars) > 0:
- io_links = (ovars[0],)
- assert len(ovars) == len(ohandles)
- for h, v in zip(ohandles, ovars):
- info = self._tinfo[h]
- info.varnode = v
-
- def add_reader(opnode):
- nonlocal in_out_links
- need_reset_nodes.append(opnode)
- readers.append(opnode.outputs[0])
- in_out_links = opnode.outputs
-
- if info.data_read:
- # Shape can be obtained from data so doesn't need its own
- # output node. On the other hand, value is read separately
- # to leverage eager h2d copy
- info.shape_read = False
- opnode = info.data_reader = G.OutputNode(v, *in_out_links)
- add_reader(opnode)
- if info.value_read:
- opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
- add_reader(opnode)
- if info.shape_read:
- opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
- add_reader(opnode)
-
- graph.options.graph_opt_level = self._graph_opt_level
- graph._set_priority_to_id([*readers, *in_out_links, *io_links])
- graph.compile(*readers, *in_out_links, *io_links)
-
- def _reset_exec_env(self):
- for opnode in self._need_reset_nodes:
- opnode.reset()
-
- def __call__(self, *args, **kwargs):
- with self._setup():
- if self._capture_as_const:
- self._process_inputs(*args, **kwargs)
- outputs = self.__wrapped__(*args, **kwargs)
- if self._capture_as_const:
- self._process_outputs(outputs)
- return outputs
-
- def _make_feed(
- self,
- graph,
- outputs,
- input_data,
- repeat,
- silent,
- no_assert,
- maxerr,
- resize_input,
- input_transform,
- ):
- def auto_reformat_image(path, data, dst_shape):
- """reformat image to target shape
-
- :param data: image data as numpy array
- :param dst_shape: target shape
- """
- dim3_format = False # required input format does not contain batch
- hwc_format = False # required input format is NHWC
-
- if not dst_shape: # input tensor shape is not predefined
- if len(data.shape) == 2:
- chl = 1
- h = data.shape[0]
- w = data.shape[1]
- else:
- assert (
- len(data.shape) == 3
- ), "Input image must be of dimension 2 or 3"
- h, w, chl = data.shape
- dst_shape = (1, chl, h, w)
-
- if len(dst_shape) == 3:
- dst_shape = (1,) + dst_shape
- dim3_format = True
-
- assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape)
- chl = dst_shape[1]
- if chl in [1, 3]:
- n, c, h, w = dst_shape
- dst_shape = (n, h, w, c)
- else:
- chl = dst_shape[3]
- assert chl in [
- 1,
- 3,
- ], "can not infer input format from shape: {}".format(dst_shape)
- hwc_format = True
-
- # dst_shape has now been normalized to NHWC format
-
- if resize_input:
- h, w = dst_shape[1:3]
- data = cv2.resize(data, (w, h))
- logger.info("input {} resized to {}".format(path, data.shape))
-
- if chl == 1:
- data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)
- data = data[:, :, np.newaxis]
-
- assert data.ndim == 3
- data = data[np.newaxis]
- # data normalized to NHWC format
-
- if not hwc_format:
- data = np.transpose(data, (0, 3, 1, 2))
-
- if dim3_format:
- data = np.squeeze(data, 0)
-
- return data
-
- def read_input_data(dst_shape, dtype, path):
- def check_shape_equal(dst_shape, data_shape):
- if len(dst_shape):
- assert len(data_shape) == len(
- dst_shape
- ), "input/data shapes mismatch: {} vs {}".format(
- dst_shape, data_shape
- )
-
- if data_shape[1:] != dst_shape[1:]:
- logger.warning(
- "dst_shape is {}; data_shape is {}".format(
- dst_shape, data_shape
- )
- )
-
- if path.startswith("#"):
- assert not resize_input
- assert not input_transform
- spec = path
- m = re.match(
- r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec
- )
- assert m, "bad spec {}".format(spec)
-
- rng_min = float(m.group(1))
- rng_max = float(m.group(2))
- if m.group(3):
- shape_str = m.group(3)
- try:
- shape = shape_str[1:].split(",")
- if shape[-1].strip() == "...":
- shape = shape[:-1]
- shape.extend(list(dst_shape[len(shape) :]))
- data_shape = tuple(map(int, shape))
- except ValueError as e:
- raise ValueError("bad spec {}: {}".format(spec, e.args))
- else:
- data_shape = dst_shape
-
- check_shape_equal(dst_shape, data_shape)
- return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype)
-
- # try to load image
- data = cv2.imread(path, cv2.IMREAD_COLOR)
- if data is None:
- assert not resize_input
- data = np.load(path)
- assert isinstance(data, np.ndarray)
- else:
- # load image succeeds, so we expect input format is image format
- data = auto_reformat_image(path, data, dst_shape)
-
- data = np.repeat(data, repeat, axis=0)
- if repeat > 1:
- logger.info(
- "repeat input for {} times, data shape is {}".format(
- repeat, data.shape
- )
- )
-
- check_shape_equal(dst_shape, data.shape)
-
- if input_transform:
- data = eval(input_transform, {"data": data, "np": np})
-
- return data
-
- def gen_one_testcase(inputs, spec):
- paths = spec.split(";")
- if len(paths) != len(inputs):
- if len(paths) == 1 and paths[0].startswith("#"):
- paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()]
- assert len(paths) == len(
- inputs
- ), "required inputs: {}; data paths: {}".format(inputs.keys(), paths)
- if len(paths) == 1 and ":" not in paths[0]:
- paths[0] = next(iter(inputs.keys())) + ":" + paths[0]
-
- ret = {}
- for path in paths:
- var, path = path.split(":")
- ret[var] = read_input_data(inputs[var].shape, inputs[var].dtype, path)
- return ret
-
- inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy")
- inputs = {i.name: i for i in inputs}
-
- if not no_assert:
-
- replace_varmap = {}
- inp_map = {}
- # replace var use InputNode
- for name, var in inputs.items():
- inp = G.InputNode(
- device="xpux", dtype=var.dtype, shape=var.shape, graph=graph
- )
- replace_varmap[var] = inp.outputs[0]._node
- inp_map[name] = inp
-
- new = cgtools.replace_vars(outputs, replace_varmap)
- if isinstance(new, rt.VarNode):
- new = list(new)
-
- output_nodes = [G.OutputNode(var) for var in new]
- func = graph.compile(*[node.outputs[0]._node for node in output_nodes])
-
- def make_dev_tensor(value, dtype=None, device=None):
- return tensor(value, dtype=dtype, device=device)._dev_tensor()
-
- def calculate(*args, **kwargs):
- output_val = []
- # set inputs value
- for name, var in inputs.items():
- val = kwargs.pop(name, None)
- assert val is not None, "miss input name{}".format(name)
- dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux")
- inp_map[name].set_value(dev_tensor)
-
- func.execute()
-
- for res in output_nodes:
- output_val.append(res.get_value().numpy())
- return output_val
-
- def expect_name(var):
- return "{}:expect".format(var.name)
-
- testcases = []
-
- np.set_printoptions(precision=2, threshold=4, suppress=True)
-
- data_list = []
- for item in input_data:
- if item.startswith("@"):
- with open(item[1:], "r") as f:
- data_list.extend(
- [line.rstrip() for line in f if line.rstrip() != ""]
- )
- else:
- data_list.append(item)
-
- for inp_spec in data_list:
- cur_testcase = gen_one_testcase(inputs, inp_spec)
- assert len(cur_testcase) == len(
- inputs
- ), "required inputs: {}; given data: {}".format(
- inputs.keys(), cur_testcase.keys()
- )
-
- if not no_assert:
- outputs_get = calculate(**cur_testcase)
- for var, val in zip(outputs, outputs_get):
- cur_testcase[expect_name(var)] = val
- logger.info(
- "generate test groundtruth: var={} shape={} range=({}, {})"
- " mean={} var={}".format(
- var,
- val.shape,
- val.min(),
- val.max(),
- np.mean(val),
- np.var(val),
- )
- )
- testcases.append(cur_testcase)
- logger.info(
- "add testcase: \n {}".format(
- "\n ".join(
- "{}: shape={} dtype={} range=({:.2f},{:.2f}) "
- "mean={:.2f} sd={:.2f}".format(
- k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v)
- )
- for k, v in sorted(cur_testcase.items())
- )
- )
- )
-
- if not no_assert:
-
- def expect_shp(var):
- ret = var.shape
- if ret:
- return ret
- return testcases[0][expect_name(var)].shape
-
- def assert_equal(expect, real, **kwargs):
- op = AssertEqual(**kwargs)
- (res,) = G.apply_normal_varnode(op, expect, real)
- return res._node
-
- verbose = not silent
-
- outputs_new = []
- for i in outputs:
- device = rt.CompNode("xpux")
- dtype = i.dtype
- name = expect_name(i)
- shape = expect_shp(i)
- # make expect output as one input of model.
- expect_get = rt.make_h2d(graph, device, dtype, shape, name)
- # insert assert opr to check expect and real.
- outputs_new.append(
- assert_equal(expect_get, i, verbose=verbose, maxerr=maxerr,)
- )
- inputs[expect_name(i)] = expect_get
- outputs = outputs_new
-
- return {"outputs": outputs, "testcases": testcases}
-
- def dump(
- self,
- file,
- *,
- arg_names=None,
- output_names=None,
- append=False,
- keep_var_name: int = 1,
- keep_opr_name: bool = False,
- keep_param_name: bool = False,
- keep_opr_priority: bool = False,
- strip_info_file=None,
- append_json=False,
- optimize_for_inference=True,
- user_info: Any = None,
- enable_metadata: bool = True,
- input_data=None,
- repeat=1,
- silent=False,
- no_assert=False,
- maxerr=1e-4,
- resize_input=False,
- input_transform=None,
- dump_format: str = None,
- **kwargs
- ):
- r"""Serializes trace to file system.
-
- Args:
- file: output file, could be file object or filename.
- arg_names: names of the input tensors in the traced function.
- output_names: names of the output tensors in the traced function,
- use the default name if not specified.
- append: whether output is appended to ``file``.
- Only works when ``file`` is str.
- keep_var_name: level for keeping variable names:
-
- * 0: none of the names are kept
- * 1: (default)keep names of output vars
- * 2: keep names of all (output and internal) vars
-
- keep_opr_name: whether to keep operator names.
- keep_param_name: whether to keep param names, so param values can be
- easily manipulated after loading model
- keep_opr_priority: whether to keep priority setting for operators
- strip_info_file: a string for path or a file handler. if is not None,
- then the dump information for code strip would be written to ``strip_info_file``
- append_json: will be check when `strip_info_file` is not None. if set
- true, the information for code strip will be append to strip_info_file.
- if set false, will rewrite strip_info_file
- optimize_for_inference: enbale optmizations,
- will skip all optimize options if this is False. Default: True
- user_info: any type object, which will be pickled to bytes.
- enable_metadata: whether to save metadata into output file.
- input_data: input test data and current network output would be used as groundtruth.
- The format is "var0:file0;var1:file1..." to specify data files for input vars.
- It can also be "#rand(min,max,shape...)" for generating random input data, for
- example, "#rand(0,255)", "#rand(0,255,1,3,224,224)" or "#rand(0, 255, 1, ...)"
- where `...` means the remaining part of the original shape. If the shape is not
- specified, the shape of corresponding input tensors in the network will be used.
- If there is only one input var, its name can be omitted. Each data file can either
- be an image which can be loaded by opencv, or a pickled numpy.ndarray. This option
- can be given multiple times to add multiple testcases. If you start the data
- with the letter @, the rest should be a filename, and each line in the file should
- be a single datum in the format described above. *NOTE* If `input_data` is not None,
- you can only use load-and-run to run the output file.
- repeat: how many times the input image is repeated. Useful when running benchmark for
- batch size other than one. Have no effect on randomly generated input data.
- silent: whether set verbose to False in assert_equal opr.
- no_assert: whether insert assert_equal opr to check result; this option is useful for
- benchmarking.
- maxerr: max error for assert_equal check during runtime.
- resize_input: whether resize input image to fit input var shape.
- input_transform: a python expression to transform the input data.
- Example: data / np.std(data)
- dump_format: using different dump formats. the open source MegEngine defaults to the FBS
- format. internal MegEngine have a choice of FBS and internal proprietary formats
-
- Keyword Arguments:
-
- * enable_io16xc32 --
- whether to use float16 for I/O between oprs and use
- float32 as internal computation precision. Note the output var would be
- changed to float16.
- * enable_ioc16 --
- whether to use float16 for both I/O and computation
- precision.
- * enable_hwcd4 --
- whether to use NHWCD4 data layout. This is faster on some
- OpenCL backend.
- * enable_nchw88 --
- whether to use NCHW88 data layout, currently
- used in X86 AVX backend.
- * enable_nchw44 --
- whether to use NCHW44 data layout, currently
- used in arm backend.
- * enable_nchw44_dot --
- whether to use NCHW44_dot data layout, currently
- used in armv8.2+dotprod backend.
- * enable_nchw4 --
- whether to use NCHW4 data layout, currently
- used in nvidia backend(based on cudnn).
- * enable_nchw32 --
- whether to use NCHW32 data layout, currently
- used in nvidia backend with tensorcore(based on cudnn).
- * enable_chwn4 --
- whether to use CHWN4 data layout, currently
- used in nvidia backend with tensorcore.
- * enable_nchw64 --
- whether to use NCHW64 data layout, used for fast int4
- support on Nvidia GPU.
- * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
- into one opr.
- * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
- input for inference on nvidia backend(this optimization pass will
- result in mismatch of the precision of output of training and
- inference)
- * enable_fuse_preprocess: whether to fuse astype\pad_channel\dimshuffle and
- etc opr
- """
- if not self._capture_as_const:
- raise ValueError(
- "you must specify capture_as_const=True at __init__ to use dump"
- )
- if self._untraced and len(self._seq) == 0:
- raise RuntimeError("should do record first before dump")
- if self._output_names and output_names:
- raise TypeError(
- "cannot specify output_names when output is already in dict format"
- )
- if output_names and not isinstance(output_names, collections.abc.Sequence):
- output_names = (output_names,)
- if output_names and len(output_names) != len(self._output_bindings):
- raise ValueError(
- "wrong number of output_names, should be {} values".format(
- len(self._output_bindings)
- )
- )
- without_arg_names = arg_names is None
- if without_arg_names:
- arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
- if arg_names and not isinstance(arg_names, collections.abc.Sequence):
- arg_names = (arg_names,)
- if arg_names and len(arg_names) != len(self._arg_bindings):
- raise ValueError(
- "wrong number of arg_names, should be {} values".format(
- len(self._arg_bindings)
- )
- )
- output_names = output_names or self._output_names
-
- def dumped_device(info):
- device_name = info.device.logical_name
- if device_name[:3] in ("cpu", "gpu", "xpu"):
- return as_device("xpux")
- return info.device
-
- h2v = {}
- graph = G.Graph()
-
- # apply graph_opt_level in dump
- if self._graph_opt_level is not None:
- graph.options.graph_opt_level = self._graph_opt_level
- for i, h in enumerate(self._arg_bindings):
- info = self._tinfo[h]
- h2v[h] = graph.make_h2d(
- dtype=info.dtype,
- device=dumped_device(info),
- shape=info.shape or (1,),
- name=info.name if without_arg_names and info.name else arg_names[i],
- )
- for k, h in self._kwarg_bindings.items():
- info = self._tinfo[h]
- h2v[h] = graph.make_h2d(
- dtype=info.dtype,
- device=dumped_device(info),
- shape=info.shape or (1,),
- name=k,
- )
-
- for op, ihandles, ohandles in self._seq:
- if isinstance(op, str) and op == "Const":
- assert len(ihandles) == 0
- (h,) = ohandles
- info = self._tinfo[h]
- if h not in h2v:
- assert info.external
- assert info.bound_data
- h2v[h] = graph.make_const(
- info.bound_data.numpy(),
- dtype=info.dtype,
- device=dumped_device(info),
- name=info.name,
- )
- continue
- ivars = []
- for h in ihandles:
- info = self._tinfo[h]
- if h not in h2v:
- assert info.external
- assert info.bound_data
- h2v[h] = graph.make_const(
- info.bound_data.numpy(),
- dtype=info.dtype,
- device=dumped_device(info),
- name=info.name,
- )
- ivars.append(h2v[h])
- if isinstance(op, BatchNorm):
- assert (
- op.fwd_mode == BatchNorm.FwdMode.INFERENCE
- ), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?"
- ovars = G.apply_normal_varnode(op, *ivars)
-
- AutoNaming.record_opnode(ovars[0].op)
-
- assert len(ovars) == len(ohandles)
- h2v.update(zip(ohandles, ovars))
-
- for i in ohandles:
- name = AutoNaming.get_var_name(i)
- if name is not None:
- h2v[i].name = name
-
- AutoNaming.remove_duplicate_names()
-
- dest_vars = []
- for i, h in enumerate(self._output_bindings):
- v = h2v[h]
- if output_names:
- v.name = output_names[i]
- dest_vars.append(v)
-
- dest_vars = [i._node for i in dest_vars]
-
- if input_data is not None:
- feeds = self._make_feed(
- graph,
- dest_vars,
- input_data,
- repeat,
- silent,
- no_assert,
- maxerr,
- resize_input,
- input_transform,
- )
- assert (
- isinstance(feeds, dict) and feeds["testcases"]
- ), "testcases can not be empty"
- dest_vars = feeds["outputs"]
-
- if optimize_for_inference:
- dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs)
- dest_vars = [i._node for i in dest_vars]
-
- metadata = SerializationMetadata()
- if enable_metadata:
- metadata.user_info = pickle.dumps(user_info)
- metadata.is_valid = True
- metadata.graph_modified = False
- if optimize_for_inference:
- metadata.optimize_options = optimize_options
-
- if isinstance(file, str):
- permission = "wb" if append == False else "ab"
- file = open(file, permission)
-
- if keep_opr_priority:
- graph._set_priority_to_id(dest_vars)
-
- if input_data is not None:
- file.write(b"mgbtest0")
- file.write(struct.pack("I", len(feeds["testcases"])))
- dump_content, dump_info = G.dump_graph(
- dest_vars,
- keep_var_name=keep_var_name,
- keep_opr_name=keep_opr_name,
- keep_param_name=keep_param_name,
- keep_opr_priority=keep_opr_priority,
- strip_info_file=strip_info_file,
- append_json=append_json,
- metadata=metadata,
- dump_format=dump_format,
- )
- file.write(dump_content)
-
- if input_data is not None:
- inputs = cgtools.get_dep_vars(dest_vars, "Host2DeviceCopy")
- inputs = sorted((i.name, i.dtype) for i in inputs)
-
- def make_dev_tensor(value, dtype=None, device=None):
- return tensor(value, dtype=dtype, device=device)._dev_tensor()
-
- for testcase in feeds["testcases"]:
- assert isinstance(testcase, dict)
- cg = G.Graph()
- output_mgbvars = []
- for name, dtype in inputs:
- output_mgbvars.append(
- cg.make_const(
- make_dev_tensor(
- testcase.pop(name), dtype=dtype, device="cpux"
- )
- )
- )
- assert not testcase, "extra inputs provided in testcase: {}".format(
- testcase.keys()
- )
- dump_content, _ = G.dump_graph(
- output_mgbvars, strip_info_file=strip_info_file, append_json=True,
- )
- file.write(dump_content)
-
- return dump_info
-
- def _process_inputs(self, *args, **kwargs):
- if self._untraced:
- self._inputs_to_restore = []
-
- def record_input(x):
- if x is None:
- return
- h, info = self._new_handle()
- info.external = False
- info.name = x.c_name
- info.device = x.device
- info.dtype = x.dtype
- info.shape = x.numpy().shape
- x._mixin_handle = h
- x._recording = True
- x._trace_mixin_info = info
- self._inputs_to_restore.append(x)
- return h
-
- self._arg_bindings = []
- for i, x in enumerate(args):
- if not isinstance(x, RawTensor):
- raise TypeError(
- "positional arguments should all be tensor "
- "but args[%d] cannot be recognized as one" % i
- )
- self._arg_bindings.append(record_input(x))
-
- self._kwarg_bindings = {}
- for k, x in kwargs.items():
- if isinstance(x, RawTensor):
- self._kwarg_bindings[k] = record_input(x)
- else:
- if len(args) != len(self._arg_bindings):
- raise TraceMismatchError("positional argument length mismatch")
-
- self._tensor_remaps = {}
-
- for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
- if not isinstance(x, RawTensor):
- raise TypeError(
- "positional arguments should all be tensor "
- "but args[%d] cannot be recognized as one" % i
- )
- info = self._tinfo[h]
- if x.dtype != info.dtype:
- raise TypeError("args[%d].dtype different from last time" % i)
- if x.device != info.device:
- raise TypeError("args[%d].device different from last time" % i)
- info.data_setter.set_value(x._dev_tensor())
- self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
-
- kwargs_tensors = {}
- for k, x in kwargs.items():
- if isinstance(x, RawTensor):
- kwargs_tensors[k] = x
- if set(kwargs_tensors) != set(self._kwarg_bindings):
- too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
- too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
- if too_many:
- raise TraceMismatchError(
- "keyword arguments found to be tensor this time "
- "but were non-tensor previously: %s" % " ".join(too_many)
- )
- if too_few:
- raise TraceMismatchError(
- "keyword arguments found to be non-tensor this time "
- "but were tensor previously: %s" % " ".join(too_few)
- )
- for k, h in self._kwarg_bindings.items():
- x = kwargs_tensors[k]
- info = self._tinfo[h]
- if x.dtype != info.dtype:
- raise TypeError("kwargs[%s].dtype different from last time" % k)
- if x.device != info.device:
- raise TypeError("kwargs[%s].device different from last time" % k)
- info.data_setter.set_value(x._dev_tensor())
- self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
-
- def _process_outputs(self, outputs):
- output_names = None
- if isinstance(outputs, collections.abc.Mapping):
- output_names, outputs = zip(*sorted(outputs.items()))
- elif not isinstance(outputs, collections.abc.Sequence):
- outputs = (outputs,)
-
- if not self._untraced:
- if output_names != self._output_names:
- too_many = set(output_names) - set(self._output_names)
- too_few = set(self._output_names) - set(output_names)
- if too_many:
- raise TraceMismatchError(
- "output has more keys than last time: %s" % " ".join(too_many)
- )
- if too_few:
- raise TraceMismatchError(
- "output has less keys than last time: %s" % " ".join(too_few)
- )
- if len(outputs) != len(self._output_bindings):
- raise TraceMismatchError("output size differs from last time")
- else:
- self._output_names = output_names
- self._output_bindings = []
-
- for i, x in enumerate(outputs):
- if not isinstance(x, RawTensor):
- raise TypeError("every item of return value should be tensor")
- if self._untraced:
- h = x._mixin_handle
- if h < 0:
- raise RuntimeError("output is not computed from inputs")
- self._output_bindings.append(h)
- else:
- h = x._mixin_handle
- if h not in self._output_handles:
- raise RuntimeError("output is not computed from inputs")
- if h != self._output_bindings[i]:
- raise TraceMismatchError(
- "retval[%s] is a different tensor than last time"
- % (output_names and output_names[i] or i)
- )
-
- def get_profile(self):
- r"""Get profiling result for compiled trace.
-
- Return:
- a json compatible object.
- """
- if not self._profiler:
- raise RuntimeError("trace is not set with profiling=True")
- return json.loads(self._profiler.get())
-
-
- class CompiledTensorProxy:
- r"""Duck-typed RawTensor"""
-
- def __init__(self, handle):
- self.__handle = handle
- self._isscalar = False
- self.__info = active_trace._tinfo[handle]
- self.__shape = None
- self.__data = None
- self.__value = None
-
- @property
- def dtype(self):
- return self.__info.varnode.dtype
-
- @property
- def device(self):
- return self.__info.varnode.device
-
- @property
- def shape(self):
- if self._isscalar:
- return ()
- if self.__shape is None:
- if self.__info.shape_read:
- self.__shape = self.__info.shape_reader.get_value().shape
- elif self.__info.data_read:
- self.__shape = self._dev_tensor().shape
- else:
- # c++ will throw TraceReadError
- return None
- return self.__shape
-
- def numpy(self):
- if self.__value is None:
- if self.__info.value_read:
- self.__value = self.__info.value_reader.get_value()
- elif self.__info.data_read:
- self.__value = self._dev_tensor().numpy()
- else:
- # c++ will throw TraceReadError
- return None
- # c++ side will handle scalar case
- return self.__value
-
- def _dev_tensor(self):
- if self.__data is None:
- if not self.__info.data_read:
- # c++ will throw TraceReadError
- return None
- self.__data = self.__info.data_reader.get_value()
- return self.__data
-
- def __del__(self):
- if self.__info.shape_read and self.__shape is not None:
- self.__info.shape_reader.drop_value()
- if self.__info.value_read and self.__value is not None:
- self.__info.value_reader.drop_value()
- if self.__info.data_read and self.__data is not None:
- self.__info.data_reader.drop_value()
-
-
- def apply_symbolic_mode(op: OpDef, *args: RawTensor):
- graph = active_trace._lazy_eval_graph
- ivars = []
- for x in args:
- var = getattr(x, "_varnode", None)
- if var:
- ivars.append(var)
- else:
- data_setter = G.InputNode(
- device=x.device,
- dtype=x.dtype,
- shape=x.numpy().shape or (1,),
- graph=graph,
- use_static_shape=True,
- )
- var = data_setter.outputs[0]
- ivars.append(var)
- data_setter.set_value(x._dev_tensor())
-
- require_links = type(op) in _io_op_types
-
- if require_links and active_trace._lazy_eval_links:
- assert len(ivars) > 0, "op should has at least one input"
- opnode = G.VirtualDepNode(
- [ivars[0], *active_trace._lazy_eval_links],
- str(active_trace._lazy_eval_links[0].device),
- )
- ivars[0] = opnode.outputs[0]
- active_trace._lazy_eval_links = (ivars[0],)
-
- ovars = G.apply_normal_varnode(op, *ivars)
- outputs = [RawTensor(o) for o in ovars]
-
- if require_links:
- active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),)
-
- return outputs
-
-
- def apply_const_symbolic_mode(value, dtype, device, name):
- graph = active_trace._lazy_eval_graph
- # don't need to unset tracing
- # because varnode construction will ignore tracing flag
- ret = RawTensor(graph.make_const(value, dtype=dtype, device=device, name=name))
- if np.array(value).ndim == 0:
- setscalar(ret)
- return (ret,)
-
-
- def apply_compiled_mode(op: OpDef, *args: RawTensor):
- if skip_tracing:
- args = [
- RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
- for x in args
- ]
- unset_tracing()
- ret = apply(op, *args)
- set_tracing()
- return ret
- return active_trace._apply_op(op, args)
-
-
- def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name):
- if skip_tracing:
- unset_tracing()
- ret = RawTensor(value, dtype, device, False, name)
- set_tracing()
- return ret
- return active_trace._apply_const(value, dtype, device)
-
-
- def apply_with_tracing(op: OpDef, *args: RawTensor):
- if active_trace._graph:
- # if member _graph exits, then is_compiled
- return apply_compiled_mode(op, *args)
- if hasattr(op, "scope"):
- op.scope = AutoNaming.get_scope()
- if active_trace._symbolic:
- outputs = apply_symbolic_mode(op, *args)
- else:
- unset_tracing()
- outputs = apply(op, *args)
- set_tracing()
-
- active_trace._record_op(op, args, outputs)
- return list(outputs)
-
-
- def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name):
- if active_trace._graph:
- return apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name)
- if active_trace._symbolic:
- outputs = apply_const_symbolic_mode(value, dtype, device, name)
- else:
- unset_tracing()
- outputs = RawTensor(value, dtype, device, False, name)
- if np.array(value).ndim == 0:
- setscalar(outputs)
- outputs = (outputs,)
- set_tracing()
- active_trace._record_const(outputs)
- return list(outputs)
|