# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections from .core import Tensor as _Tensor from .core.ops.builtin import Copy from .core.tensor.core import apply from .core.tensor.raw_tensor import as_device from .device import _valid_device, get_default_device from .utils.deprecation import deprecated class Tensor(_Tensor): grad = None dmap_callback = None def __init__(self, data, dtype=None, device=None): if device is None: device = get_default_device() self.q_dict = {"mode": None, "scale": None, "zero_point": None} super().__init__(data, dtype=dtype, device=device) @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") def set_value(self, value): self._reset(value) @deprecated(version="1.0", reason="use *= 0 instead") def reset_zero(self): self *= 0 def to(self, device): if isinstance(device, str) and not _valid_device(device): raise ValueError( "invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format( device ) ) cn = as_device(device).to_c() return apply(Copy(comp_node=cn), self)[0] @property def requires_grad(self): raise AttributeError("requires_grad is reserved for future use") @requires_grad.setter def requires_grad(self, value): raise AttributeError("requires_grad is reserved for future use") @requires_grad.deleter def requires_grad(self): raise AttributeError("requires_grad is reserved for future use") def __hash__(self): return id(self) def __getstate__(self): r""" __getstate__ will be called for pickle serialization or deep copy """ state = { "data": self.numpy(), "device": self.device.logical_name, "dtype": self.dtype, "qdict": self.q_dict, } return state def __setstate__(self, state): data = state.pop("data") logical_device = state.pop("device") if self.dmap_callback is not None: assert isinstance(logical_device, str) logical_device = self.dmap_callback(logical_device) dtype = state.pop("dtype") self.q_dict = state.pop("qdict") super().__init__(data, dtype=dtype, device=logical_device) def detach(self): r""" Returns a new tensor which is treated as constant during backward gradient calcuation, i.e. its gradient is zero. :param inp: input tensor """ Wrapper = type(self) Tensor = type(self.__wrapped__) return Wrapper(Tensor(self.__wrapped__._data)) tensor = Tensor class Parameter(Tensor): r"""A kind of Tensor that is to be considered a module parameter. """