You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import numpy as np
  11. from .core._imperative_rt import CompNode
  12. from .core._imperative_rt.core2 import Tensor as _Tensor
  13. from .core._imperative_rt.core2 import apply
  14. from .core._trace_option import use_symbolic_shape
  15. from .core._wrap import device as as_device
  16. from .core.ops.builtin import Copy, GetVarShape
  17. from .core.tensor.tensor_wrapper import ArrayMethodMixin
  18. from .device import _valid_device, get_default_device
  19. from .utils.deprecation import deprecated
  20. class Tensor(_Tensor, ArrayMethodMixin):
  21. grad = None
  22. dmap_callback = None
  23. q_dict = {"mode": None, "scale": None, "zero_point": None}
  24. def __new__(cls, data, dtype=None, device=None, is_const=False):
  25. if device is None:
  26. cn = get_default_device()
  27. elif isinstance(device, str):
  28. if cls.dmap_callback is not None:
  29. cn = CompNode(cls.dmap_callback(device))
  30. else:
  31. cn = CompNode(device)
  32. else:
  33. if isinstance(device, CompNode):
  34. cn = device
  35. else:
  36. cn = device._cn
  37. # import pdb; pdb.set_trace()
  38. if isinstance(data, _Tensor):
  39. obj = _Tensor.__new__(cls, data)
  40. else:
  41. if isinstance(data, np.ndarray):
  42. if 0 in data.strides:
  43. data = data.squeeze().reshape(data.shape)
  44. obj = _Tensor.__new__(cls, data, dtype, cn, is_const)
  45. return obj
  46. @property
  47. def shape(self):
  48. shape = super().shape
  49. if shape == () or not use_symbolic_shape():
  50. return shape
  51. return apply(GetVarShape(), self)[0]
  52. @property
  53. def _tuple_shape(self):
  54. return super().shape
  55. def __repr__(self):
  56. piece = "Tensor("
  57. with np.printoptions(precision=4, suppress=True):
  58. piece += "{}".format(str(self.numpy()))
  59. if self.dtype != np.float32:
  60. piece += ", dtype={}".format(np.dtype(self.dtype).name)
  61. piece += ", device={}".format(self.device) + ")"
  62. return piece
  63. @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
  64. def set_value(self, value):
  65. if not isinstance(value, _Tensor):
  66. value = Tensor(value, dtype=self.dtype, device=self.device)
  67. self._reset(value)
  68. @deprecated(version="1.0", reason="use *= 0 instead")
  69. def reset_zero(self):
  70. self *= 0
  71. def to(self, device):
  72. if isinstance(device, str) and not _valid_device(device):
  73. raise ValueError(
  74. "invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format(
  75. device
  76. )
  77. )
  78. cn = as_device(device).to_c()
  79. return apply(Copy(comp_node=cn), self)[0]
  80. @property
  81. def requires_grad(self):
  82. raise AttributeError("requires_grad is reserved for future use")
  83. @requires_grad.setter
  84. def requires_grad(self, value):
  85. raise AttributeError("requires_grad is reserved for future use")
  86. @requires_grad.deleter
  87. def requires_grad(self):
  88. raise AttributeError("requires_grad is reserved for future use")
  89. def __hash__(self):
  90. return id(self)
  91. def __getnewargs__(self):
  92. r""" __getnewargs__ will be called for pickle serialization or deep copy
  93. """
  94. return (self.numpy(), self.dtype, self.device.logical_name)
  95. def __getstate__(self):
  96. r""" __getstate__ will be called for pickle serialization or deep copy
  97. """
  98. state = {
  99. "qdict": self.q_dict,
  100. }
  101. return state
  102. def __setstate__(self, state):
  103. self.q_dict = state.pop("qdict")
  104. tensor = Tensor
  105. class Parameter(Tensor):
  106. r"""
  107. A kind of Tensor that is to be considered a module parameter.
  108. """

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台