Browse Source

refactor(dnn/rocm): remove some useless includes

GitOrigin-RevId: 3d2c315a36
release-1.6
Megvii Engine Team 3 years ago
parent
commit
83cf4ee64e
5 changed files with 21 additions and 32 deletions
  1. +1
    -1
      dnn/src/rocm/argsort/argsort.cpp.hip
  2. +1
    -0
      dnn/src/rocm/argsort/argsort.h.hip
  3. +0
    -2
      dnn/src/rocm/argsort/backward.cpp.hip
  4. +13
    -17
      dnn/src/rocm/argsort/bitonic_sort.cpp.hip
  5. +6
    -12
      dnn/src/rocm/topk/topk_radix.cpp.hip

+ 1
- 1
dnn/src/rocm/argsort/argsort.cpp.hip View File

@@ -174,7 +174,7 @@ template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*, \
ARGSORT_FOREACH_CTYPE(INST_FORWARD)
INST_CUB_SORT(uint32_t)
// INST_CUB_SORT(uint64_t)
INST_CUB_SORT(uint64_t)
#undef INST_CUB_SORT
#undef INST_FORWARD
}


+ 1
- 0
dnn/src/rocm/argsort/argsort.h.hip View File

@@ -40,6 +40,7 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace,
const int* iptr_src = NULL);

//! iterate over all supported data types
// device_radix_sort does not support dt_float16 dtype(half_float::half in rocm)
#define ARGSORT_FOREACH_CTYPE(cb) \
cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16))



+ 0
- 2
dnn/src/rocm/argsort/backward.cpp.hip View File

@@ -14,8 +14,6 @@
#include "./argsort.h.hip"
#include "./backward.h.hip"

// #include "src/rocm/utils.h"

using namespace megdnn;
using namespace rocm;
using namespace argsort;


+ 13
- 17
dnn/src/rocm/argsort/bitonic_sort.cpp.hip View File

@@ -11,13 +11,9 @@
#include "hcc_detail/hcc_defs_prologue.h"

#include "./bitonic_sort.h.hip"
// #include "src/cuda/query_blocksize.cuh"
// #include "megdnn/dtype.h"
#include "megdnn/dtype.h"

// #if __CUDACC_VER_MAJOR__ < 9
// #pragma message "warp sync disabled due to insufficient cuda version"
#define __syncwarp __syncthreads
// #endif

#include <algorithm>
#include <cmath>
@@ -84,17 +80,17 @@ struct NumTrait<int32_t> {
static __device__ __forceinline__ int32_t min() { return INT_MIN; }
};

// #if !MEGDNN_DISABLE_FLOAT16
// template <>
// struct NumTrait<dt_float16> {
// static __device__ __forceinline__ dt_float16 max() {
// return std::numeric_limits<dt_float16>::max();
// }
// static __device__ __forceinline__ dt_float16 min() {
// return std::numeric_limits<dt_float16>::lowest();
// }
// };
// #endif
#if !MEGDNN_DISABLE_FLOAT16
template <>
struct NumTrait<dt_float16> {
static __device__ __forceinline__ dt_float16 max() {
return std::numeric_limits<dt_float16>::max();
}
static __device__ __forceinline__ dt_float16 min() {
return std::numeric_limits<dt_float16>::lowest();
}
};
#endif

struct LessThan {
template <typename Key, typename Value>
@@ -310,7 +306,7 @@ namespace rocm {

INST(float, int);
INST(int32_t, int);
// DNN_INC_FLOAT16(INST(dt_float16, int));
DNN_INC_FLOAT16(INST(dt_float16, int));
#undef INST

} // namespace megdnn


+ 6
- 12
dnn/src/rocm/topk/topk_radix.cpp.hip View File

@@ -18,13 +18,7 @@
#include <algorithm>
#include <cmath>

#if __CUDACC_VER_MAJOR__ < 9
#pragma message "topk is a little slower on cuda earlier than 9.0"
// on cuda 9.0 and later, due to thread-divergent branches we should use
// __syncwarp; and I am too lazy to implement a correct legacy version, so just
// use __syncthreads instead for older cuda
#define __syncwarp __syncthreads
#endif

using namespace megdnn;
using namespace rocm;
@@ -256,12 +250,12 @@ static __global__ void update_prefix_and_k(const uint32_t* bucket_cnt,
}
}

//if ((cumsum_bucket_cnt[NR_BUCKET] < kv) |
// (cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) {
// // impossible
// int* bad = 0x0;
// *bad = 23;
//}
if ((cumsum_bucket_cnt[NR_BUCKET] < kv) |
(cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) {
// impossible
int* bad = 0x0;
*bad = 23;
}
}

static uint32_t get_grid_dim_x(uint32_t length) {


Loading…
Cancel
Save