diff --git a/imperative/python/src/helper.cpp b/imperative/python/src/helper.cpp index d4f32536..c17ff1d6 100644 --- a/imperative/python/src/helper.cpp +++ b/imperative/python/src/helper.cpp @@ -172,6 +172,7 @@ int to_mgb_supported_dtype_raw(int dtype) { #define FOREACH_NPY_DTYPE_PAIR(cb) \ cb(Uint8, NPY_UINT8) \ cb(Int8, NPY_INT8) \ + cb(Uint16, NPY_UINT16) \ cb(Int16, NPY_INT16) \ cb(Int32, NPY_INT32) \ cb(Float16, NPY_FLOAT16) \ diff --git a/imperative/python/test/unit/core/test_raw_tensor.py b/imperative/python/test/unit/core/test_raw_tensor.py index c8a88453..6ef599cc 100644 --- a/imperative/python/test/unit/core/test_raw_tensor.py +++ b/imperative/python/test/unit/core/test_raw_tensor.py @@ -28,3 +28,10 @@ def test_as_raw_tensor_from_int64(): assert xx.dtype == np.float32 assert xx.device == "xpux" np.testing.assert_almost_equal(yy, x.astype("float32") + 1) + + +def test_as_raw_tensor_uint16(): + x = np.arange(6, dtype="uint16").reshape(2, 3) + xx = Tensor(x, device="xpux") + assert xx.dtype == np.uint16 + assert xx.device == "xpux"