Browse Source

fix(mge/tensor): implement abtract method to fix lint errors

GitOrigin-RevId: d53f2eac6a
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
412d1f0cdc
1 changed files with 12 additions and 4 deletions
  1. +12
    -4
      imperative/python/megengine/tensor.py

+ 12
- 4
imperative/python/megengine/tensor.py View File

@@ -6,9 +6,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


import collections
from typing import Union


import numpy as np import numpy as np


@@ -53,7 +51,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
return obj return obj


@property @property
def shape(self):
def shape(self) -> Union[tuple, "Tensor"]:
shape = super().shape shape = super().shape
if shape == () or not use_symbolic_shape(): if shape == () or not use_symbolic_shape():
return shape return shape
@@ -63,6 +61,16 @@ class Tensor(_Tensor, ArrayMethodMixin):
def _tuple_shape(self): def _tuple_shape(self):
return super().shape return super().shape


@property
def dtype(self) -> np.dtype:
return super().dtype

def numpy(self) -> np.ndarray:
return super().numpy()

def _reset(self, other):
super()._reset(other)

def __repr__(self): def __repr__(self):
piece = "Tensor(" piece = "Tensor("
with np.printoptions(precision=4, suppress=True): with np.printoptions(precision=4, suppress=True):


Loading…
Cancel
Save