|
|
@@ -24,21 +24,22 @@ from megengine.utils.network_node import VarNode |
|
|
|
|
|
|
|
|
|
|
|
def test_eye(): |
|
|
|
dtype = np.float32 |
|
|
|
dtypes = [np.float32, np.bool] |
|
|
|
cases = [{"input": [10, 20]}, {"input": [30]}] |
|
|
|
for case in cases: |
|
|
|
np.testing.assert_allclose( |
|
|
|
F.eye(case["input"], dtype=dtype).numpy(), |
|
|
|
np.eye(*case["input"]).astype(dtype), |
|
|
|
) |
|
|
|
np.testing.assert_allclose( |
|
|
|
F.eye(*case["input"], dtype=dtype).numpy(), |
|
|
|
np.eye(*case["input"]).astype(dtype), |
|
|
|
) |
|
|
|
np.testing.assert_allclose( |
|
|
|
F.eye(tensor(case["input"]), dtype=dtype).numpy(), |
|
|
|
np.eye(*case["input"]).astype(dtype), |
|
|
|
) |
|
|
|
for dtype in dtypes: |
|
|
|
for case in cases: |
|
|
|
np.testing.assert_allclose( |
|
|
|
F.eye(case["input"], dtype=dtype).numpy(), |
|
|
|
np.eye(*case["input"]).astype(dtype), |
|
|
|
) |
|
|
|
np.testing.assert_allclose( |
|
|
|
F.eye(*case["input"], dtype=dtype).numpy(), |
|
|
|
np.eye(*case["input"]).astype(dtype), |
|
|
|
) |
|
|
|
np.testing.assert_allclose( |
|
|
|
F.eye(tensor(case["input"]), dtype=dtype).numpy(), |
|
|
|
np.eye(*case["input"]).astype(dtype), |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def test_full(): |
|
|
|