@@ -124,7 +124,7 @@ size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, | |||||
ARGSORT_FOREACH_CTYPE(cb) | ARGSORT_FOREACH_CTYPE(cb) | ||||
#undef cb | #undef cb | ||||
default: | default: | ||||
megdnn_throw("argsort only supports float and int32"); | |||||
megdnn_throw("argsort only supports float, int32 and float16"); | |||||
} | } | ||||
if (!iptr_src_given) { | if (!iptr_src_given) { | ||||
size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int); | size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int); | ||||
@@ -33,7 +33,8 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace, | |||||
const int* iptr_src = NULL); | const int* iptr_src = NULL); | ||||
//! iterate over all supported data types | //! iterate over all supported data types | ||||
#define ARGSORT_FOREACH_CTYPE(cb) cb(float) cb(int32_t) | |||||
#define ARGSORT_FOREACH_CTYPE(cb) \ | |||||
cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16)) | |||||
} // namespace argsort | } // namespace argsort | ||||
} // namespace cuda | } // namespace cuda | ||||
@@ -11,6 +11,7 @@ | |||||
#include "./bitonic_sort.cuh" | #include "./bitonic_sort.cuh" | ||||
#include "src/cuda/query_blocksize.cuh" | #include "src/cuda/query_blocksize.cuh" | ||||
#include "megdnn/dtype.h" | |||||
#if __CUDACC_VER_MAJOR__ < 9 | #if __CUDACC_VER_MAJOR__ < 9 | ||||
#pragma message "warp sync disabled due to insufficient cuda version" | #pragma message "warp sync disabled due to insufficient cuda version" | ||||
@@ -82,6 +83,18 @@ struct NumTrait<int32_t> { | |||||
static __device__ __forceinline__ int32_t min() { return INT_MIN; } | static __device__ __forceinline__ int32_t min() { return INT_MIN; } | ||||
}; | }; | ||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
template <> | |||||
struct NumTrait<dt_float16> { | |||||
static __device__ __forceinline__ dt_float16 max() { | |||||
return std::numeric_limits<dt_float16>::max(); | |||||
} | |||||
static __device__ __forceinline__ dt_float16 min() { | |||||
return std::numeric_limits<dt_float16>::lowest(); | |||||
} | |||||
}; | |||||
#endif | |||||
struct LessThan { | struct LessThan { | ||||
template <typename Key, typename Value> | template <typename Key, typename Value> | ||||
static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1, | static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1, | ||||
@@ -295,6 +308,7 @@ namespace cuda { | |||||
INST(float, int); | INST(float, int); | ||||
INST(int32_t, int); | INST(int32_t, int); | ||||
DNN_INC_FLOAT16(INST(dt_float16, int)); | |||||
#undef INST | #undef INST | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -1146,7 +1146,7 @@ template <> struct NumericTraits<double> : BaseTraits<FLOATING_POIN | |||||
#if (__CUDACC_VER_MAJOR__ >= 9) | #if (__CUDACC_VER_MAJOR__ >= 9) | ||||
template <> struct NumericTraits<__half> : BaseTraits<FLOATING_POINT, true, false, unsigned short, __half> {}; | template <> struct NumericTraits<__half> : BaseTraits<FLOATING_POINT, true, false, unsigned short, __half> {}; | ||||
#endif | #endif | ||||
template <> struct NumericTraits<half_float::half> : BaseTraits<FLOATING_POINT, true, false, unsigned short, half_float::half> {}; | |||||
template <> struct NumericTraits<bool> : BaseTraits<UNSIGNED_INTEGER, true, false, typename UnitWord<bool>::VolatileWord, bool> {}; | template <> struct NumericTraits<bool> : BaseTraits<UNSIGNED_INTEGER, true, false, typename UnitWord<bool>::VolatileWord, bool> {}; | ||||
@@ -81,9 +81,18 @@ void TopKImpl::do_exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values, | |||||
values.ptr<int32_t>(), indices, | values.ptr<int32_t>(), indices, | ||||
workspace.raw_ptr); | workspace.raw_ptr); | ||||
return; | return; | ||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
case DTypeEnum::Float16: | |||||
dispatch_with_ctype<dt_float16>(k, data.layout[0], data.layout[1], | |||||
data.layout.stride[0], data.ptr<dt_float16>(), | |||||
values.ptr<dt_float16>(), indices, | |||||
workspace.raw_ptr); | |||||
return; | |||||
#endif | |||||
default: | default: | ||||
megdnn_throw( | megdnn_throw( | ||||
ssprintf("only float32 and int32 supported for cuda topk, got: %s", | |||||
ssprintf("only float32, int32 and float16 supported for " | |||||
"cuda topk, got: %s", | |||||
data.layout.dtype.name())); | data.layout.dtype.name())); | ||||
} | } | ||||
} | } | ||||
@@ -489,7 +489,7 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output, | |||||
if (k < 0) { | if (k < 0) { | ||||
k = length + k + 1; | k = length + k + 1; | ||||
} | } | ||||
if (!(BUCKET_BITS == 8 && sizeof(ctype) == 4)) { | |||||
if (!(BUCKET_BITS == 8 && (sizeof(ctype) == 4 || sizeof(ctype) == 2))) { | |||||
// no c++11 in megdnn cuda; so we just trap instead of using static | // no c++11 in megdnn cuda; so we just trap instead of using static | ||||
// assert | // assert | ||||
megdnn_trap(); | megdnn_trap(); | ||||
@@ -668,6 +668,7 @@ namespace topk { | |||||
int32_t, uint32_t, cudaStream_t) | int32_t, uint32_t, cudaStream_t) | ||||
INST(float); | INST(float); | ||||
INST(int32_t); | INST(int32_t); | ||||
DNN_INC_FLOAT16(INST(dt_float16)); | |||||
#undef INST | #undef INST | ||||
} // namespace topk | } // namespace topk | ||||
@@ -10,7 +10,7 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/dtype.h" | |||||
#include <cuda_runtime.h> | #include <cuda_runtime.h> | ||||
#include <stdint.h> | #include <stdint.h> | ||||
@@ -60,6 +60,29 @@ struct RadixConverter<int32_t> { | |||||
} | } | ||||
}; | }; | ||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
template <> | |||||
struct RadixConverter<dt_float16> { | |||||
union FIunion { | |||||
FIunion() {} | |||||
dt_float16 fv; | |||||
uint16_t iv; | |||||
}; | |||||
static __forceinline__ __device__ __host__ uint16_t to_radix(dt_float16 val) { | |||||
FIunion fi; | |||||
fi.fv = val; | |||||
return fi.iv ^ (((!(fi.iv >> 15u)) - 1u) | 0x8000u); | |||||
} | |||||
static __forceinline__ __device__ __host__ dt_float16 from_radix(uint16_t val) { | |||||
FIunion fi; | |||||
// do not write as to_radix() to work around a compiler bug in cuda-9.0 | |||||
uint16_t m = 0x8000u; | |||||
fi.iv = val ^ (m | (m - !(val >> 15u))); | |||||
return fi.fv; | |||||
} | |||||
}; | |||||
#endif | |||||
} // namespace internal | } // namespace internal | ||||
/*! | /*! | ||||
@@ -27,6 +27,10 @@ TEST_F(CUDA, TOP_K) { | |||||
TEST_F(CUDA, TOP_K_I32) { | TEST_F(CUDA, TOP_K_I32) { | ||||
run_topk_test<dtype::Int32>(handle_cuda()); | run_topk_test<dtype::Int32>(handle_cuda()); | ||||
} | } | ||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
TEST_F(CUDA, TOP_K_F16) { | |||||
run_topk_test<dtype::Float16>(handle_cuda()); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |