|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-
- """helper utils for the core mgb module"""
-
- import collections
- import inspect
- import json
- import threading
- from abc import ABCMeta, abstractmethod
-
-
- class callback_lazycopy:
- """wraps around a callable to be passed to :meth:`.CompGraph.compile`.
-
- This is used to disable eager copy, so we could get rid of an h2d copy and
- a d2h if values are to be passed from one callback to another
- :class:`.SharedND`.
- """
-
- def __init__(self, func):
- assert isinstance(func, collections.Callable)
- self.__func = func
-
- @property
- def func(self):
- return self.__func
-
-
- class SharedNDLazyInitializer(metaclass=ABCMeta):
- """lazy initialization policy for :class:`.SharedND`"""
-
- @abstractmethod
- def get_shape(self):
- """get shape, without loading value"""
-
- @abstractmethod
- def get_value(self):
- """get value as numpy ndarray"""
-
-
- class copy_output:
- """wraps a :class:`.SymbolVar` in outspec for :meth:`.CompGraph.compile`,
- to copy the output to function return value"""
-
- symvar = None
- borrow_mem = None
-
- def __init__(self, symvar, *, borrow_mem=False):
- """
-
- :param borrow_mem: see :meth:`.CompGraphCallbackValueProxy.get_value`
- """
- from .mgb import SymbolVar
-
- assert isinstance(
- symvar, SymbolVar
- ), "copy_output expects an SymbolVar, got {} instead".format(symvar)
- self.symvar = symvar
- self.borrow_mem = borrow_mem
-
-
- class FuncOutputSaver:
- """instance could be used as callbacks for :meth:`.CompGraph.compile` to
- copy output to host buffer
- """
-
- _value = None
- _borrow_mem = None
-
- def __init__(self, borrow_mem=False):
- self._borrow_mem = borrow_mem
-
- def __call__(self, v):
- self._value = v.get_value(borrow_mem=self._borrow_mem)
-
- def get(self):
- assert (
- self._value is not None
- ), "{} not called; maybe due to unwaited async func".format(self)
- return self._value
|