Browse Source

feat(dnn/cuda): topk support fp16

GitOrigin-RevId: c6610d4cf0
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
b87af9f77f
8 changed files with 59 additions and 7 deletions
  1. +1
    -1
      dnn/src/cuda/argsort/argsort.cu
  2. +2
    -1
      dnn/src/cuda/argsort/argsort.cuh
  3. +14
    -0
      dnn/src/cuda/argsort/bitonic_sort.cu
  4. +1
    -1
      dnn/src/cuda/cub/util_type.cuh
  5. +10
    -1
      dnn/src/cuda/topk/opr_impl.cpp
  6. +2
    -1
      dnn/src/cuda/topk/topk_radix.cu
  7. +24
    -1
      dnn/src/cuda/topk/topk_radix.cuh
  8. +5
    -1
      dnn/test/cuda/topk.cpp

+ 1
- 1
dnn/src/cuda/argsort/argsort.cu View File

@@ -124,7 +124,7 @@ size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype,
ARGSORT_FOREACH_CTYPE(cb)
#undef cb
default:
megdnn_throw("argsort only supports float and int32");
megdnn_throw("argsort only supports float, int32 and float16");
}
if (!iptr_src_given) {
size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int);


+ 2
- 1
dnn/src/cuda/argsort/argsort.cuh View File

@@ -33,7 +33,8 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace,
const int* iptr_src = NULL);

//! iterate over all supported data types
#define ARGSORT_FOREACH_CTYPE(cb) cb(float) cb(int32_t)
#define ARGSORT_FOREACH_CTYPE(cb) \
cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16))

} // namespace argsort
} // namespace cuda


+ 14
- 0
dnn/src/cuda/argsort/bitonic_sort.cu View File

@@ -11,6 +11,7 @@

#include "./bitonic_sort.cuh"
#include "src/cuda/query_blocksize.cuh"
#include "megdnn/dtype.h"

#if __CUDACC_VER_MAJOR__ < 9
#pragma message "warp sync disabled due to insufficient cuda version"
@@ -82,6 +83,18 @@ struct NumTrait<int32_t> {
static __device__ __forceinline__ int32_t min() { return INT_MIN; }
};

#if !MEGDNN_DISABLE_FLOAT16
template <>
struct NumTrait<dt_float16> {
static __device__ __forceinline__ dt_float16 max() {
return std::numeric_limits<dt_float16>::max();
}
static __device__ __forceinline__ dt_float16 min() {
return std::numeric_limits<dt_float16>::lowest();
}
};
#endif

struct LessThan {
template <typename Key, typename Value>
static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1,
@@ -295,6 +308,7 @@ namespace cuda {

INST(float, int);
INST(int32_t, int);
DNN_INC_FLOAT16(INST(dt_float16, int));
#undef INST

} // namespace megdnn


+ 1
- 1
dnn/src/cuda/cub/util_type.cuh View File

@@ -1146,7 +1146,7 @@ template <> struct NumericTraits<double> : BaseTraits<FLOATING_POIN
#if (__CUDACC_VER_MAJOR__ >= 9)
template <> struct NumericTraits<__half> : BaseTraits<FLOATING_POINT, true, false, unsigned short, __half> {};
#endif
template <> struct NumericTraits<half_float::half> : BaseTraits<FLOATING_POINT, true, false, unsigned short, half_float::half> {};
template <> struct NumericTraits<bool> : BaseTraits<UNSIGNED_INTEGER, true, false, typename UnitWord<bool>::VolatileWord, bool> {};




+ 10
- 1
dnn/src/cuda/topk/opr_impl.cpp View File

@@ -81,9 +81,18 @@ void TopKImpl::do_exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values,
values.ptr<int32_t>(), indices,
workspace.raw_ptr);
return;
#if !MEGDNN_DISABLE_FLOAT16
case DTypeEnum::Float16:
dispatch_with_ctype<dt_float16>(k, data.layout[0], data.layout[1],
data.layout.stride[0], data.ptr<dt_float16>(),
values.ptr<dt_float16>(), indices,
workspace.raw_ptr);
return;
#endif
default:
megdnn_throw(
ssprintf("only float32 and int32 supported for cuda topk, got: %s",
ssprintf("only float32, int32 and float16 supported for "
"cuda topk, got: %s",
data.layout.dtype.name()));
}
}


+ 2
- 1
dnn/src/cuda/topk/topk_radix.cu View File

@@ -489,7 +489,7 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output,
if (k < 0) {
k = length + k + 1;
}
if (!(BUCKET_BITS == 8 && sizeof(ctype) == 4)) {
if (!(BUCKET_BITS == 8 && (sizeof(ctype) == 4 || sizeof(ctype) == 2))) {
// no c++11 in megdnn cuda; so we just trap instead of using static
// assert
megdnn_trap();
@@ -668,6 +668,7 @@ namespace topk {
int32_t, uint32_t, cudaStream_t)
INST(float);
INST(int32_t);
DNN_INC_FLOAT16(INST(dt_float16));
#undef INST

} // namespace topk


+ 24
- 1
dnn/src/cuda/topk/topk_radix.cuh View File

@@ -10,7 +10,7 @@
*/

#pragma once
#include "megdnn/dtype.h"
#include <cuda_runtime.h>
#include <stdint.h>

@@ -60,6 +60,29 @@ struct RadixConverter<int32_t> {
}
};

#if !MEGDNN_DISABLE_FLOAT16
template <>
struct RadixConverter<dt_float16> {
union FIunion {
FIunion() {}
dt_float16 fv;
uint16_t iv;
};
static __forceinline__ __device__ __host__ uint16_t to_radix(dt_float16 val) {
FIunion fi;
fi.fv = val;
return fi.iv ^ (((!(fi.iv >> 15u)) - 1u) | 0x8000u);
}
static __forceinline__ __device__ __host__ dt_float16 from_radix(uint16_t val) {
FIunion fi;
// do not write as to_radix() to work around a compiler bug in cuda-9.0
uint16_t m = 0x8000u;
fi.iv = val ^ (m | (m - !(val >> 15u)));
return fi.fv;
}
};
#endif

} // namespace internal

/*!


+ 5
- 1
dnn/test/cuda/topk.cpp View File

@@ -27,6 +27,10 @@ TEST_F(CUDA, TOP_K) {
TEST_F(CUDA, TOP_K_I32) {
run_topk_test<dtype::Int32>(handle_cuda());
}

#if !MEGDNN_DISABLE_FLOAT16
TEST_F(CUDA, TOP_K_F16) {
run_topk_test<dtype::Float16>(handle_cuda());
}
#endif

// vim: syntax=cpp.doxygen

Loading…
Cancel
Save