Browse Source

feat(mge): make F.eye numpy compatible

GitOrigin-RevId: fee32537b4
release-1.1
Megvii Engine Team 4 years ago
parent
commit
aa62672629
2 changed files with 21 additions and 5 deletions
  1. +12
    -4
      imperative/python/megengine/functional/tensor.py
  2. +9
    -1
      imperative/python/test/unit/functional/test_tensor.py

+ 12
- 4
imperative/python/megengine/functional/tensor.py View File

@@ -57,7 +57,7 @@ __all__ = [
]


def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
"""Returns a 2D tensor with ones on the diagonal and zeros elsewhere.

:param shape: expected shape of output tensor.
@@ -72,8 +72,7 @@ def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
import numpy as np
import megengine.functional as F

data_shape = (4, 6)
out = F.eye(data_shape, dtype=np.float32)
out = F.eye(4, 6, dtype=np.float32)
print(out.numpy())

Outputs:
@@ -86,8 +85,17 @@ def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
[0. 0. 0. 1. 0. 0.]]

"""
if M is not None:
if isinstance(N, Tensor) or isinstance(M, Tensor):
shape = astensor1d((N, M))
else:
shape = Tensor([N, M], dtype="int32", device=device)
elif isinstance(N, Tensor):
shape = N
else:
shape = Tensor(N, dtype="int32", device=device)
op = builtin.Eye(k=0, dtype=dtype, comp_node=device)
(result,) = apply(op, Tensor(shape, dtype="int32", device=device))
(result,) = apply(op, shape)
return result




+ 9
- 1
imperative/python/test/unit/functional/test_tensor.py View File

@@ -22,12 +22,20 @@ from megengine.distributed.helper import get_device_count_by_fork

def test_eye():
dtype = np.float32
cases = [{"input": [10, 20]}, {"input": [20, 30]}]
cases = [{"input": [10, 20]}, {"input": [30]}]
for case in cases:
np.testing.assert_allclose(
F.eye(case["input"], dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
np.testing.assert_allclose(
F.eye(*case["input"], dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
np.testing.assert_allclose(
F.eye(tensor(case["input"]), dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)


def test_concat():


Loading…
Cancel
Save