|
|
@@ -32,6 +32,11 @@ namespace { |
|
|
|
for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) { |
|
|
|
ptr[i] = i; |
|
|
|
} |
|
|
|
} else if (tensor.layout.dtype == dtype::Float16()) { |
|
|
|
auto ptr = tensor.ptr<dt_float16>() + span.low_elem; |
|
|
|
for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) { |
|
|
|
ptr[i] = i; |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto ptr = tensor.ptr<int>() + span.low_elem; |
|
|
|
for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) { |
|
|
@@ -135,6 +140,19 @@ TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) { |
|
|
|
|
|
|
|
TEST_F(CUDA, INDEXING_INCR_MULTI_AXIS_VEC) { |
|
|
|
run_check<IndexingIncrMultiAxisVec>(handle_cuda()); |
|
|
|
Checker<IndexingIncrMultiAxisVec> checker(handle_cuda()); |
|
|
|
OrderedRNG rng; |
|
|
|
checker. |
|
|
|
set_dtype(0, dtype::Float16()). // data |
|
|
|
set_dtype(1, dtype::Float16()). // value |
|
|
|
set_dtype(2, dtype::Int32()). // idx0 |
|
|
|
set_rng(0, &rng). |
|
|
|
set_rng(1, &rng). |
|
|
|
set_rng(2, &rng); |
|
|
|
|
|
|
|
checker. |
|
|
|
set_proxy({{1}}). |
|
|
|
execs({{5, 8, 3}, {5, 2, 3}, {2}}); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(CUDA, INDEXING_SET_MULTI_AXIS_VEC) { |
|
|
|