|
|
@@ -57,7 +57,7 @@ __all__ = [ |
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: |
|
|
|
def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: |
|
|
|
"""Returns a 2D tensor with ones on the diagonal and zeros elsewhere. |
|
|
|
|
|
|
|
:param shape: expected shape of output tensor. |
|
|
@@ -72,8 +72,7 @@ def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: |
|
|
|
import numpy as np |
|
|
|
import megengine.functional as F |
|
|
|
|
|
|
|
data_shape = (4, 6) |
|
|
|
out = F.eye(data_shape, dtype=np.float32) |
|
|
|
out = F.eye(4, 6, dtype=np.float32) |
|
|
|
print(out.numpy()) |
|
|
|
|
|
|
|
Outputs: |
|
|
@@ -86,8 +85,17 @@ def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: |
|
|
|
[0. 0. 0. 1. 0. 0.]] |
|
|
|
|
|
|
|
""" |
|
|
|
if M is not None: |
|
|
|
if isinstance(N, Tensor) or isinstance(M, Tensor): |
|
|
|
shape = astensor1d((N, M)) |
|
|
|
else: |
|
|
|
shape = Tensor([N, M], dtype="int32", device=device) |
|
|
|
elif isinstance(N, Tensor): |
|
|
|
shape = N |
|
|
|
else: |
|
|
|
shape = Tensor(N, dtype="int32", device=device) |
|
|
|
op = builtin.Eye(k=0, dtype=dtype, comp_node=device) |
|
|
|
(result,) = apply(op, Tensor(shape, dtype="int32", device=device)) |
|
|
|
(result,) = apply(op, shape) |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|