GitOrigin-RevId: 3d2c315a36
release-1.6
@@ -174,7 +174,7 @@ template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*, \ | |||||
ARGSORT_FOREACH_CTYPE(INST_FORWARD) | ARGSORT_FOREACH_CTYPE(INST_FORWARD) | ||||
INST_CUB_SORT(uint32_t) | INST_CUB_SORT(uint32_t) | ||||
// INST_CUB_SORT(uint64_t) | |||||
INST_CUB_SORT(uint64_t) | |||||
#undef INST_CUB_SORT | #undef INST_CUB_SORT | ||||
#undef INST_FORWARD | #undef INST_FORWARD | ||||
} | } | ||||
@@ -40,6 +40,7 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace, | |||||
const int* iptr_src = NULL); | const int* iptr_src = NULL); | ||||
//! iterate over all supported data types | //! iterate over all supported data types | ||||
// device_radix_sort does not support dt_float16 dtype(half_float::half in rocm) | |||||
#define ARGSORT_FOREACH_CTYPE(cb) \ | #define ARGSORT_FOREACH_CTYPE(cb) \ | ||||
cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16)) | cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16)) | ||||
@@ -14,8 +14,6 @@ | |||||
#include "./argsort.h.hip" | #include "./argsort.h.hip" | ||||
#include "./backward.h.hip" | #include "./backward.h.hip" | ||||
// #include "src/rocm/utils.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace rocm; | using namespace rocm; | ||||
using namespace argsort; | using namespace argsort; | ||||
@@ -11,13 +11,9 @@ | |||||
#include "hcc_detail/hcc_defs_prologue.h" | #include "hcc_detail/hcc_defs_prologue.h" | ||||
#include "./bitonic_sort.h.hip" | #include "./bitonic_sort.h.hip" | ||||
// #include "src/cuda/query_blocksize.cuh" | |||||
// #include "megdnn/dtype.h" | |||||
#include "megdnn/dtype.h" | |||||
// #if __CUDACC_VER_MAJOR__ < 9 | |||||
// #pragma message "warp sync disabled due to insufficient cuda version" | |||||
#define __syncwarp __syncthreads | #define __syncwarp __syncthreads | ||||
// #endif | |||||
#include <algorithm> | #include <algorithm> | ||||
#include <cmath> | #include <cmath> | ||||
@@ -84,17 +80,17 @@ struct NumTrait<int32_t> { | |||||
static __device__ __forceinline__ int32_t min() { return INT_MIN; } | static __device__ __forceinline__ int32_t min() { return INT_MIN; } | ||||
}; | }; | ||||
// #if !MEGDNN_DISABLE_FLOAT16 | |||||
// template <> | |||||
// struct NumTrait<dt_float16> { | |||||
// static __device__ __forceinline__ dt_float16 max() { | |||||
// return std::numeric_limits<dt_float16>::max(); | |||||
// } | |||||
// static __device__ __forceinline__ dt_float16 min() { | |||||
// return std::numeric_limits<dt_float16>::lowest(); | |||||
// } | |||||
// }; | |||||
// #endif | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
template <> | |||||
struct NumTrait<dt_float16> { | |||||
static __device__ __forceinline__ dt_float16 max() { | |||||
return std::numeric_limits<dt_float16>::max(); | |||||
} | |||||
static __device__ __forceinline__ dt_float16 min() { | |||||
return std::numeric_limits<dt_float16>::lowest(); | |||||
} | |||||
}; | |||||
#endif | |||||
struct LessThan { | struct LessThan { | ||||
template <typename Key, typename Value> | template <typename Key, typename Value> | ||||
@@ -310,7 +306,7 @@ namespace rocm { | |||||
INST(float, int); | INST(float, int); | ||||
INST(int32_t, int); | INST(int32_t, int); | ||||
// DNN_INC_FLOAT16(INST(dt_float16, int)); | |||||
DNN_INC_FLOAT16(INST(dt_float16, int)); | |||||
#undef INST | #undef INST | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -18,13 +18,7 @@ | |||||
#include <algorithm> | #include <algorithm> | ||||
#include <cmath> | #include <cmath> | ||||
#if __CUDACC_VER_MAJOR__ < 9 | |||||
#pragma message "topk is a little slower on cuda earlier than 9.0" | |||||
// on cuda 9.0 and later, due to thread-divergent branches we should use | |||||
// __syncwarp; and I am too lazy to implement a correct legacy version, so just | |||||
// use __syncthreads instead for older cuda | |||||
#define __syncwarp __syncthreads | #define __syncwarp __syncthreads | ||||
#endif | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace rocm; | using namespace rocm; | ||||
@@ -256,12 +250,12 @@ static __global__ void update_prefix_and_k(const uint32_t* bucket_cnt, | |||||
} | } | ||||
} | } | ||||
//if ((cumsum_bucket_cnt[NR_BUCKET] < kv) | | |||||
// (cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) { | |||||
// // impossible | |||||
// int* bad = 0x0; | |||||
// *bad = 23; | |||||
//} | |||||
if ((cumsum_bucket_cnt[NR_BUCKET] < kv) | | |||||
(cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) { | |||||
// impossible | |||||
int* bad = 0x0; | |||||
*bad = 23; | |||||
} | |||||
} | } | ||||
static uint32_t get_grid_dim_x(uint32_t length) { | static uint32_t get_grid_dim_x(uint32_t length) { | ||||