From b6142bee9a3ffb2a13552138acd259046677c3ba Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 24 Sep 2021 19:05:57 +0800 Subject: [PATCH] feat(imperative): support tensor with uint16 date type GitOrigin-RevId: 57ba0633c7f344922a95f2ac03a7502231453e52 --- imperative/python/src/helper.cpp | 1 + imperative/python/test/unit/core/test_raw_tensor.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/imperative/python/src/helper.cpp b/imperative/python/src/helper.cpp index d4f32536..c17ff1d6 100644 --- a/imperative/python/src/helper.cpp +++ b/imperative/python/src/helper.cpp @@ -172,6 +172,7 @@ int to_mgb_supported_dtype_raw(int dtype) { #define FOREACH_NPY_DTYPE_PAIR(cb) \ cb(Uint8, NPY_UINT8) \ cb(Int8, NPY_INT8) \ + cb(Uint16, NPY_UINT16) \ cb(Int16, NPY_INT16) \ cb(Int32, NPY_INT32) \ cb(Float16, NPY_FLOAT16) \ diff --git a/imperative/python/test/unit/core/test_raw_tensor.py b/imperative/python/test/unit/core/test_raw_tensor.py index c8a88453..6ef599cc 100644 --- a/imperative/python/test/unit/core/test_raw_tensor.py +++ b/imperative/python/test/unit/core/test_raw_tensor.py @@ -28,3 +28,10 @@ def test_as_raw_tensor_from_int64(): assert xx.dtype == np.float32 assert xx.device == "xpux" np.testing.assert_almost_equal(yy, x.astype("float32") + 1) + + +def test_as_raw_tensor_uint16(): + x = np.arange(6, dtype="uint16").reshape(2, 3) + xx = Tensor(x, device="xpux") + assert xx.dtype == np.uint16 + assert xx.device == "xpux"