Browse Source

feat(imperative): support tensor with uint16 date type

GitOrigin-RevId: 57ba0633c7
release-1.7
Megvii Engine Team 3 years ago
parent
commit
b6142bee9a
2 changed files with 8 additions and 0 deletions
  1. +1
    -0
      imperative/python/src/helper.cpp
  2. +7
    -0
      imperative/python/test/unit/core/test_raw_tensor.py

+ 1
- 0
imperative/python/src/helper.cpp View File

@@ -172,6 +172,7 @@ int to_mgb_supported_dtype_raw(int dtype) {
#define FOREACH_NPY_DTYPE_PAIR(cb) \
cb(Uint8, NPY_UINT8) \
cb(Int8, NPY_INT8) \
cb(Uint16, NPY_UINT16) \
cb(Int16, NPY_INT16) \
cb(Int32, NPY_INT32) \
cb(Float16, NPY_FLOAT16) \


+ 7
- 0
imperative/python/test/unit/core/test_raw_tensor.py View File

@@ -28,3 +28,10 @@ def test_as_raw_tensor_from_int64():
assert xx.dtype == np.float32
assert xx.device == "xpux"
np.testing.assert_almost_equal(yy, x.astype("float32") + 1)


def test_as_raw_tensor_uint16():
x = np.arange(6, dtype="uint16").reshape(2, 3)
xx = Tensor(x, device="xpux")
assert xx.dtype == np.uint16
assert xx.device == "xpux"

Loading…
Cancel
Save