|
|
@@ -6,9 +6,7 @@ |
|
|
|
# Unless required by applicable law or agreed to in writing, |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
|
|
|
|
|
|
|
import collections |
|
|
|
from typing import Union |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
@@ -53,7 +51,7 @@ class Tensor(_Tensor, ArrayMethodMixin): |
|
|
|
return obj |
|
|
|
|
|
|
|
@property |
|
|
|
def shape(self): |
|
|
|
def shape(self) -> Union[tuple, "Tensor"]: |
|
|
|
shape = super().shape |
|
|
|
if shape == () or not use_symbolic_shape(): |
|
|
|
return shape |
|
|
@@ -63,6 +61,16 @@ class Tensor(_Tensor, ArrayMethodMixin): |
|
|
|
def _tuple_shape(self): |
|
|
|
return super().shape |
|
|
|
|
|
|
|
@property |
|
|
|
def dtype(self) -> np.dtype: |
|
|
|
return super().dtype |
|
|
|
|
|
|
|
def numpy(self) -> np.ndarray: |
|
|
|
return super().numpy() |
|
|
|
|
|
|
|
def _reset(self, other): |
|
|
|
super()._reset(other) |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
piece = "Tensor(" |
|
|
|
with np.printoptions(precision=4, suppress=True): |
|
|
|