|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import numpy as np
-
- from megengine.core.tensor.tensor_wrapper import TensorWrapper
-
-
- def test_basic():
- x_np = np.random.rand(10).astype("float32")
- x = TensorWrapper(x_np)
- y = x * x
- y_np = y.numpy()
- np.testing.assert_almost_equal(y_np, x_np * x_np)
-
-
- def test_literal_arith():
- x_np = np.random.rand(10).astype("float32")
- x = TensorWrapper(x_np)
- y = x * 2
- y_np = y.numpy()
- np.testing.assert_almost_equal(y_np, x_np * 2)
-
-
- def test_matmul():
- A = TensorWrapper(np.random.rand(5, 7).astype("float32"))
- B = TensorWrapper(np.random.rand(7, 10).astype("float32"))
- C = A @ B
- np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6)
-
-
- def test_reduce():
- for m in ["sum", "prod", "min", "max", "mean"]:
- x_np = np.random.rand(10).astype("float32")
- x = TensorWrapper(x_np)
- y = getattr(x, m)(-1)
- np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6)
-
-
- def test_set_subtensor():
- x = TensorWrapper([1, 2, 3])
- x[:] = [1, 1, 1]
- np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6)
- x[[0, 2]] = [3, 2]
- np.testing.assert_almost_equal(x.numpy(), [3, 1, 2], decimal=6)
- x[1:3] = [4, 5]
- np.testing.assert_almost_equal(x.numpy(), [3, 4, 5], decimal=6)
|