# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. """used for creating a megbrain operator from python""" import copy import itertools from abc import ABCMeta, abstractmethod, abstractproperty from . import helper as _helper from . import mgb as _mgb class _CraniotomeBaseMeta(ABCMeta): _base_created = False def __init__(cls, name, bases, member_dict): if _CraniotomeBaseMeta._base_created: assert "__init__" not in member_dict, ( "Craniotome operators should not overwrite __init__ method; " "use setup() instead." ) forbidden = set( k for k in dir(CraniotomeBase) if k[0] == "_" and k[1] != "_" ) forbidden.add("get_io_vars") check_key = member_dict.get("__check_key__", True) whitelist = ["__classcell__"] for k in member_dict.keys(): assert k not in forbidden, "{} could not be overwritten".format(k) if ( check_key and k.startswith("__") and k.endswith("__") and k not in whitelist and not hasattr(CraniotomeBase, k) ): raise KeyError( "name {} in class {} does not exist in the baseclass".format( k, name ) ) else: _CraniotomeBaseMeta._base_created = True super().__init__(name, bases, member_dict) class CraniotomeBase(_mgb.CraniotomeDesc, metaclass=_CraniotomeBaseMeta): """base class used for extending megbrain core operators in python Note: all names starting and ending with two underscores in the subclasses would be checked and KeyError would be raised if the name does not exist in the base class. This behavor can be disabled by setting ``__check_key__`` to ``False`` (see the testcase for more details) """ # methods and attributes to be overwritten by subclasses __expand_single_outputs__ = True """if :attr:`__nr_outputs__` is 1, whether to return a single :class:`.SymbolVar` instead of a tuple in :meth:`make`""" __is_dynamic_output_shape__ = False """whether output shape could not be inferred from input shape. If value of this attribute is ``False``, :meth:`infer_shape` must be implemented. If this attribute is ``True`` but the operator has no inputs, then :meth:`infer_shape` would also be called to infer output shape before operator execution. """ __disable_sys_mem_alloc__ = False """whether to disable system memory allocator. This is used when :attr:`__is_dynamic_output_shape__` is ``False`` but the output memory should not be managed by megbrain system (so it can be forwarded from external buffer)""" __allow_duplicate__ = True """whether this operator can be duplicated (e.g. used in sublinear memory)""" __allow_empty_out__ = False """whether empty output shape is allowed; if it is set as ``False``, then an exception would be raised if output var is empty to prevent erroneously forgetting initializing output vars""" @abstractproperty def __nr_inputs__(self): """number of input vars""" @abstractproperty def __nr_outputs__(self): """number of output vars""" @abstractmethod def execute(self, inputs, outputs): """execute the operator, read values from *inputs* by calling :meth:`.CompGraphCallbackValueProxy.get_value` and write results into *outputs* by calling :meth:`.SharedND.set_value` :param inputs: values for each input var :type inputs: tuple of :class:`.CompGraphCallbackValueProxy` :param outputs: values for each output var :type outputs: tuple of :class:`.SharedND` """ def setup(self): """overwritten by subclass to accept kwargs passed to :meth:`make` to setup the operator""" def infer_shape(self, inp_shapes): """infer output shape from input shapes :type inp_shapes: tuple of tuple of ints :param inp_shapes: input shapes for each input var :rtype: tuple of tuple of ints :return: output shapes for each output var """ raise NotImplementedError( "{}: infer_shape() not implemented; for operators with dynamic " "output shape, __is_dynamic_output_shape__ should be set to True".format( self ) ) def grad(self, wrt_idx, inputs, outputs, out_grad): """compute symbolic gradient; should be overwritten by differentiable subclasses :type wrt_idx: int :param wrt_idx: the input var with respect to which the gradient should be computed; please also see the notes below :type inputs: tuple of :class:`.SymbolVar` :param inputs: input symbol vars :type outputs: tuple of :class:`.SymbolVar` :param outputs: output symbol vars :type out_grad: tuple of (:class:`.SymbolVar` or None) :param out_grad: gradients of loss with respect to each output var .. note:: In case when loss does not depend on some var (i.e. zero grad), the corresponding value in *out_grad* would be ``None``. It is guaranteed that at least one element in *out_grad* is not ``None``. .. note:: This function can return either of the following: 1. Gradient of the input specified by ``wrt_idx`` 2. A list containing gradients of all inputs. In this case, ``wrt_idx`` can be ignored. And the so called gradient can be either one of: 1. A :class:`.SymbolVar` representing the symbolic gradient value 2. ``0`` representing zero gradient """ raise NotImplementedError("grad for {} not implemented".format(self)) def init_output_dtype(self, input_dtypes): """infer output dtypes from input dtypes; return None to use default infer function in megbrain. .. note:: This method must be implemented if there is no input var :param input_dtypes: input dtypes :type input_dtypes: list of :class:`numpy.dtype` :rtype: None or list of :class:`numpy.dtype`-compatible """ def get_serialize_params(self): """get params for megbrain graph serialization. This function should return a list or tuple, containing one or two elements: the first element must be a string, representing the name passed to ``opr_loader_maker`` during deserializing; the second element, if exists, must be convertible to ``bytes`` and is used for dumping any extra opr params, which can be retrieved by ``load_buf_with_len`` during deserializing. """ raise NotImplementedError( "get_serialize_params() for {} not implemented".format(self) ) def copy(self): """copy this craniotome descriptor; the default implementation creates a new object, and copies object ``__dict__``""" ret = type(self)() d0 = self.__dict__.copy() d0.pop("this") ret.__dict__.update(copy.deepcopy(d0)) return ret def on_graph_compiled(self, used_outputs): """a callback that would be invoked when the graph is compiled; it would always have a matching :meth:`on_compiled_func_deleted` call :param used_outputs: indices of outputs that are needed for the computation :type used_outputs: ``tuple of int`` """ def on_compiled_func_deleted(self): """a callback that would be invoked when the compiled function is destructed; it would always have a matching :meth:`on_graph_compiled` call""" def get_io_vars(self): """get input vars, comp order dep vars and output vars :return: a dict with keys ``'input'``, ``'output'`` and ``'comp_order'`` that maps to corresponding list of vars """ all_vars = list(self._get_all_io_vars()) nr_inp = self.__nr_inputs__ nr_out = self.__nr_outputs__ nr_comp_order = self._get_nr_dev_comp_order_deps() s0 = nr_inp + nr_comp_order return dict( input=all_vars[:nr_inp], comp_order=all_vars[nr_inp:s0], output=all_vars[s0:], ) @property def owner_opr_id(self): """ID of the operator that owns this descriptor""" return self._get_opr_id() @property def comp_node(self): """comp node on which this operator runs""" return self._get_comp_node() # below are methods that should not be changed def _hash(self): return int(hash(self)) % (1 << 64) def _setup_self(self, dst): dst.append(self) def _is_same(self, rhs): return bool(self == rhs) def _node_flag(self): return ( (int(bool(self.__is_dynamic_output_shape__)) << 0) | (int(not self.__allow_duplicate__) << 1) | (int(bool(self.__allow_empty_out__)) << 2) | (int(bool(self.__disable_sys_mem_alloc__)) << 3) ) def _get_opr_type_name(self): return str(self.__class__.__name__) def _get_nr_outputs(self): return int(self.__nr_outputs__) def _execute(self, inputs, outputs): inputs = tuple(inputs) outputs = tuple(outputs) if not self.__is_dynamic_output_shape__: out_shapes = [i.shape for i in outputs] self.execute(inputs, outputs) if not self.__is_dynamic_output_shape__: new_shapes = [i.shape for i in outputs] assert ( out_shapes == new_shapes ), "output shape changed after executing {}: before={} after={}".format( self, out_shapes, new_shapes ) def _infer_shape(self, inp_shapes): inp_shapes = tuple(tuple(map(int, i)) for i in inp_shapes) oshp_get = self.infer_shape(inp_shapes) assert ( len(oshp_get) == self.__nr_outputs__ ), "{}: expect {} outputs; got {}(val: {}) from infer_shape".format( self, self.__nr_outputs__, len(oshp_get), oshp_get ) return _helper.cvt_to_vector_of_shape(oshp_get) def _grad(self, wrt_idx, inputs, outputs, out_grad): og = [] for i in out_grad: if i.valid: og.append(i) else: og.append(None) rst = self.grad(int(wrt_idx), tuple(inputs), tuple(outputs), tuple(og)) if not isinstance(rst, (list, tuple)): rst = [rst] else: assert len(rst) == len( inputs ), "{}: opr has {} inputs but {} grads are returned".format( self, len(inputs), len(rst) ) for i in range(len(rst)): cur = rst[i] if cur is 0: rst[i] = _mgb.SymbolVar() else: assert isinstance(cur, _mgb.SymbolVar), ( "{}: invalid grad result; it should be either " "0 or a SymbolVar, got {!r} instead".format(self, cur) ) return rst def _get_nr_dev_comp_order_deps(self): return 0 def _init_output_dtype(self, input_dtypes, ret): get = self.init_output_dtype(input_dtypes) if get is not None: assert isinstance(ret, (list, tuple)) and len(get) == len(ret) ret[:] = get return True assert self.__nr_inputs__, ( "{}: init_output_dtype must be implemented " "if there is no input var".format(self) ) return False def _setup_serialize_params(self, output): val = list(self.get_serialize_params()) assert len(val) in [1, 2] name = val[0] assert isinstance(name, str) output.append(name) if len(val) == 2: output.append(bytes(val[1])) def _copy(self): ret = self.copy() assert type(ret) is type( self ), "copy() returned different type: src={} copied={}".format( type(self), type(ret) ) assert ret is not self ret.__disown__() self._set_copy_result(ret) def _on_graph_compile_or_func_del(self, used_outputs): if used_outputs: self.on_graph_compiled(used_outputs) else: self.on_compiled_func_deleted() def __repr__(self): return "cranoiotome:{}".format(self.__class__.__name__) @classmethod def make( cls, *inputs, comp_graph=None, name=None, comp_node=None, config=None, dev_comp_order_deps=[], **kwargs ): """apply this operator on some input vars and return corresponding output vars :type inputs: tuple of :class:`.SymbolVar` :param inputs: input symvars; immediate values could also be accepted, as long as there is symvar to infer comp node and comp graph :param comp_graph: if there is no input vars, *comp_graph* must be provided to specify which computing graph to insert this operator :param dev_comp_order_deps: vars that must have been computed before executing this operator :param kwargs: extra keyword arguments to be passed to :meth:`setup` of this class :param name: name of the resulting operator :rtype: tuple of :class:`.SymbolVar` :return: output symvars """ if not inputs and not dev_comp_order_deps: assert isinstance( comp_graph, _mgb.CompGraph ), "{}: comp_graph must be given if no inputs provided".format(self) desc = cls() desc.setup(**kwargs) assert ( len(inputs) == desc.__nr_inputs__ ), "{}: expected {} inputs, got {}".format( desc, desc.__nr_inputs__, len(inputs) ) config = _helper.gen_config(name, comp_node, config) # get inp_vec inp_vec = _mgb._VectorSymbolVar() for i in _helper.canonize_input_vars( itertools.chain(inputs, dev_comp_order_deps), comp_graph=comp_graph, config=config, ): inp_vec.push_back(i) desc._get_nr_dev_comp_order_deps = lambda *, val=len(dev_comp_order_deps): val if comp_graph is not None: desc._get_comp_graph = lambda: comp_graph expand_single_outputs = desc.__expand_single_outputs__ desc.__disown__() rst = _mgb.make_opr_from_craniotome_desc(desc, inp_vec, config) if expand_single_outputs and len(rst) == 1: return rst[0] return tuple(rst) def make_opr(cls): """decorator used to wrap a :class:`.CraniotomeBase` subclass and return its :meth:`~.CraniotomeBase.make` method """ assert issubclass(cls, CraniotomeBase) return cls.make