# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np import megengine.functional as F from megengine.core.tensor.raw_tensor import as_raw_tensor def test_as_raw_tensor(): x = np.arange(6, dtype="float32").reshape(2, 3) xx = as_raw_tensor(x, device="xpux") yy = F.add(xx, 1).numpy() assert xx.dtype == np.float32 assert xx.device == "xpux" np.testing.assert_almost_equal(yy, x + 1) def test_as_raw_tensor_from_int64(): x = np.arange(6, dtype="int64").reshape(2, 3) xx = as_raw_tensor(x, dtype="float32", device="xpux") yy = F.add(xx, 1).numpy() assert xx.dtype == np.float32 assert xx.device == "xpux" np.testing.assert_almost_equal(yy, x.astype("float32") + 1)