From 323a4642e6ed265cc68dd2200435e9aaf00140b7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 16 Aug 2021 14:43:57 +0800 Subject: [PATCH] feat(dnn/rocm): add topk opr GitOrigin-RevId: 5ecb07985491359bb8063427cc142fbcec3da943 --- dnn/src/rocm/handle.cpp | 2 + dnn/src/rocm/topk/opr_impl.cpp | 133 +++++++ dnn/src/rocm/topk/opr_impl.h | 40 +++ dnn/src/rocm/topk/topk_radix.cpp.hip | 667 +++++++++++++++++++++++++++++++++++ dnn/src/rocm/topk/topk_radix.h.hip | 127 +++++++ dnn/test/rocm/topk.cpp | 38 ++ 6 files changed, 1007 insertions(+) create mode 100644 dnn/src/rocm/topk/opr_impl.cpp create mode 100644 dnn/src/rocm/topk/opr_impl.h create mode 100644 dnn/src/rocm/topk/topk_radix.cpp.hip create mode 100644 dnn/src/rocm/topk/topk_radix.h.hip create mode 100644 dnn/test/rocm/topk.cpp diff --git a/dnn/src/rocm/handle.cpp b/dnn/src/rocm/handle.cpp index d1ec445d..b58af148 100644 --- a/dnn/src/rocm/handle.cpp +++ b/dnn/src/rocm/handle.cpp @@ -24,6 +24,7 @@ #include "src/rocm/pooling/opr_impl.h" #include "src/rocm/reduce/opr_impl.h" #include "src/rocm/type_cvt/opr_impl.h" +#include "src/rocm/topk/opr_impl.h" #include "src/rocm/add_update/opr_impl.h" #include "src/rocm/matrix_mul/opr_impl.h" #include "src/rocm/batched_matrix_mul/opr_impl.h" @@ -161,6 +162,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(TopK); MEGDNN_SPECIALIZE_CREATE_OPERATOR(AddUpdateForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixMulForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward); diff --git a/dnn/src/rocm/topk/opr_impl.cpp b/dnn/src/rocm/topk/opr_impl.cpp new file mode 100644 index 00000000..19a53b21 --- /dev/null +++ b/dnn/src/rocm/topk/opr_impl.cpp @@ -0,0 +1,133 @@ +/** + * \file dnn/src/rocm/topk/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./opr_impl.h" +#include "./topk_radix.h.hip" +#include "src/common/utils.h" +#include "src/rocm/argsort/argsort.h.hip" +#include "src/rocm/utils.h" + +using namespace megdnn; +using namespace rocm; + +template +void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda, + const ctype* data, ctype* values, + int* indices, void* workspace) { + auto _handle = concrete_handle(handle()); + auto stream = _handle->stream(); + size_t grid_dim_y_limit = _handle->device_prop().maxGridSize[1]; + switch (param().mode) { + case Param::Mode::KTH_ONLY: + hip_check(topk::find_kth_radix(data, values, workspace, m, + n, lda, k, grid_dim_y_limit, + stream)); + return; + case Param::Mode::VALUE_IDX_NOSORT: { + WorkspaceBundle wk_bundle{workspace, {m * sizeof(ctype), 1}}; + auto thresh = static_cast(wk_bundle.get(0)); + auto real_wk = wk_bundle.get(1); + hip_check(topk::find_kth_radix(data, thresh, real_wk, m, n, + lda, k, grid_dim_y_limit, + stream)); + hip_check(topk::topk_select(data, thresh, values, indices, + real_wk, m, n, lda, k, + grid_dim_y_limit, stream)); + return; + } + case Param::Mode::VALUE_IDX_SORTED: { + WorkspaceBundle wk_bundle{ + workspace, + {m * sizeof(ctype), m * std::abs(k) * sizeof(ctype), + m * std::abs(k) * sizeof(int32_t), 1}}; + auto thresh = static_cast(wk_bundle.get(0)), + nosort_values = static_cast(wk_bundle.get(1)); + auto nosort_idx = static_cast(wk_bundle.get(2)); + auto real_wk = wk_bundle.get(3); + hip_check(topk::find_kth_radix(data, thresh, real_wk, m, n, + lda, k, grid_dim_y_limit, + stream)); + hip_check(topk::topk_select(data, thresh, nosort_values, + nosort_idx, real_wk, m, n, lda, + k, grid_dim_y_limit, stream)); + argsort::forward(nosort_values, values, indices, real_wk, m, + std::abs(k), k > 0, stream, nosort_idx); + return; + } + } + megdnn_throw("bad topk mode"); +} + +void TopKImpl::do_exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values, + int32_t* indices, _megdnn_workspace workspace) { + switch (data.layout.dtype.enumv()) { + case DTypeEnum::Float32: + dispatch_with_ctype(k, data.layout[0], data.layout[1], + data.layout.stride[0], data.ptr(), + values.ptr(), indices, + workspace.raw_ptr); + return; + case DTypeEnum::Int32: + dispatch_with_ctype(k, data.layout[0], data.layout[1], + data.layout.stride[0], data.ptr(), + values.ptr(), indices, + workspace.raw_ptr); + return; +// #if !MEGDNN_DISABLE_FLOAT16 +// case DTypeEnum::Float16: +// dispatch_with_ctype(k, data.layout[0], data.layout[1], +// data.layout.stride[0], data.ptr(), +// values.ptr(), indices, +// workspace.raw_ptr); +// return; +// #endif + default: + megdnn_throw( + ssprintf("only float32, int32 and float16 supported for " + "cuda topk, got: %s", + data.layout.dtype.name())); + } +} + +size_t TopKImpl::get_workspace_in_bytes(int k, const TensorLayout& data, + const TensorLayout& values, + const TensorLayout& indices) { + MEGDNN_MARK_USED_VAR(values); + MEGDNN_MARK_USED_VAR(indices); + size_t m = data[0], n = data[1]; + size_t kabs = std::abs(k); + size_t grid_dim_y_limit = + concrete_handle(handle())->device_prop().maxGridSize[1]; + megdnn_assert(std::max(m, n) <= + static_cast(std::numeric_limits::max())); + size_t kth = topk::find_kth_radix_workspace(m, n, grid_dim_y_limit), + sel = topk::topk_select_workspace(m, n); + auto ctsize = data.dtype.size(); + switch (param().mode) { + case Param::Mode::KTH_ONLY: + return kth; + case Param::Mode::VALUE_IDX_NOSORT: + return WorkspaceBundle{nullptr, {m * ctsize, std::max(kth, sel)}} + .total_size_in_bytes(); + case Param::Mode::VALUE_IDX_SORTED: + return WorkspaceBundle{ + nullptr, + {m * ctsize, m * kabs * ctsize, m * kabs * sizeof(int32_t), + std::max(std::max(kth, sel), + argsort::get_fwd_workspace_in_bytes( + m, kabs, data.dtype, k > 0, true))}} + .total_size_in_bytes(); + } + megdnn_throw("bad topk mode"); +} + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/rocm/topk/opr_impl.h b/dnn/src/rocm/topk/opr_impl.h new file mode 100644 index 00000000..a1cb2ff2 --- /dev/null +++ b/dnn/src/rocm/topk/opr_impl.h @@ -0,0 +1,40 @@ +/** + * \file dnn/src/rocm/topk/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "megdnn/oprs/general.h" + +namespace megdnn { +namespace rocm { + +class TopKImpl : public TopK { +protected: + template + void dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda, + const ctype* data, ctype* values, int* indices, + void* workspace); + + void do_exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values, + int32_t* indices, _megdnn_workspace workspace) override; + +public: + using TopK::TopK; + + size_t get_workspace_in_bytes(int k, const TensorLayout& data, + const TensorLayout& values, + const TensorLayout& indices) override; +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/rocm/topk/topk_radix.cpp.hip b/dnn/src/rocm/topk/topk_radix.cpp.hip new file mode 100644 index 00000000..2cae0bfb --- /dev/null +++ b/dnn/src/rocm/topk/topk_radix.cpp.hip @@ -0,0 +1,667 @@ +/** + * \file dnn/src/rocm/topk/topk_radix.cpp.hip + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./topk_radix.h.hip" +#include "src/rocm/utils.h.hip" + +#include +#include "hipcub/device/device_scan.hpp" + +#include +#include + +#if __CUDACC_VER_MAJOR__ < 9 +#pragma message "topk is a little slower on cuda earlier than 9.0" +// on cuda 9.0 and later, due to thread-divergent branches we should use +// __syncwarp; and I am too lazy to implement a correct legacy version, so just +// use __syncthreads instead for older cuda +#define __syncwarp __syncthreads +#endif + +using namespace megdnn; +using namespace rocm; +using namespace topk; +using namespace internal; + +namespace rocm_topk_impl { + +const uint32_t WARP_SIZE = 32; + +static __device__ __forceinline__ uint32_t u32_from_64_low(uint64_t x) { + return x; +} +static __device__ __forceinline__ uint32_t u32_from_64_high(uint64_t x) { + return x >> 32; +} + +template +struct static_log2 { + static const uint32_t val = static_log2::val + 1; +}; +template <> +struct static_log2<1> { + static const uint32_t val = 0; +}; + +template +struct DeviceScanPackedItem; + +template +struct DeviceScanPackedItem<1, T> { + __device__ __forceinline__ T load(T* data, uint32_t tid) { + return data[tid]; + } + + __device__ __forceinline__ void store(T* data, uint32_t tid, uint32_t s) { + data[tid] = s; + } +}; + +template <> +struct DeviceScanPackedItem<4, uint8_t> { + uint8_t d0, d1, d2, d3; + __device__ __forceinline__ uint32_t load(uint8_t* data, uint32_t tid) { + uint32_t item = reinterpret_cast(data)[tid]; + d3 = item >> 24; + d2 = (item >> 16) & 0xFF; + d1 = (item >> 8) & 0xFF; + d0 = item & 0xFF; + return d0 + d1 + d2 + d3; + } + + __device__ __forceinline__ void store(uint8_t* data, uint32_t tid, + uint32_t s) { + uint8_t o3 = s, o2 = o3 - d3, o1 = o2 - d2, o0 = o1 - d1; + reinterpret_cast(data)[tid] = + (o3 << 24) | (o2 << 16) | (o1 << 8) | o0; + } +}; + +//! inclusive scan within a warp using register shuffle +template +__device__ __forceinline__ uint32_t device_scan_shfl_core(uint32_t s, + uint32_t tid) { + static const uint32_t SIZE_LOG2 = static_log2::val; + + uint32_t self_lane = tid % SIZE; +#pragma unroll + for (uint32_t step_log2 = 1; step_log2 <= SIZE_LOG2; ++step_log2) { + uint32_t from_lane = (self_lane & ~((1u << step_log2) - 1)) + + ((1 << (step_log2 - 1)) - 1); + uint32_t valid_mask = (from_lane >= self_lane) - 1; + uint32_t s_below = __shfl_up(s, self_lane - from_lane, SIZE); + s += s_below & valid_mask; + } + return s; +} + +/*! + * \brief compute inplace inclusive prefix sum of \p data + * + * Note: no synchronization at the end + */ +template +__device__ __forceinline__ void device_scan(uint32_t* data, uint32_t tid, + uint32_t shard) { + const uint32_t NR_WARP = SIZE / NR_SHARD / WARP_SIZE; +#if __cplusplus > 199711L + static_assert(NR_WARP <= WARP_SIZE || (NR_WARP & (NR_WARP - 1)), + "bad params"); +#endif + + __syncthreads(); + DeviceScanPackedItem packed_item; + + uint32_t s = packed_item.load(data, tid); + s = device_scan_shfl_core(s, tid); + + // sync between warps + __shared__ uint32_t warp_sums_storage[NR_SHARD][NR_WARP]; + uint32_t warp_id = tid / WARP_SIZE; + uint32_t* warp_sums = warp_sums_storage[shard]; + if ((tid & (WARP_SIZE - 1)) == WARP_SIZE - 1) { + warp_sums[warp_id] = s; + } + __syncthreads(); + + for (uint32_t i = 0; i < warp_id; ++i) { + s += warp_sums[i]; + } + + packed_item.store(data, tid, s); +} + +template +__device__ __forceinline__ void device_scan_packed_accu32(T* data, + uint32_t tid) { + DeviceScanPackedItem scan_pack; + __syncwarp(); + uint32_t sum = scan_pack.load(data, tid); + sum = device_scan_shfl_core(sum, tid); + scan_pack.store(data, tid, sum); + __syncwarp(); +} + +namespace kth { + +const uint32_t BUCKET_BITS = 8, NR_BUCKET = 1 << BUCKET_BITS, + LOCAL_CNT_SHARD = 16, BLOCK_DIM = NR_BUCKET; + +template +struct enforce_const_u32 { + static const uint32_t val = v; +}; + +/*! + * \brief compute scattered histogram for the whole input + * + * launch config: grid(X, batch), thread(BLOCK_DIM) + * + * Keys not starting with given prefix would be treated as max + * + * \param[in] input [batch, length] + * \param[out] buckets [batch, X, NR_BUCKET] + */ +template +static __global__ void compute_histogram(const ctype* input, + uint32_t* bucket_cnt, uint32_t length, + int32_t lda, uint32_t* prefix_ptr) { + int32_t batch = blockIdx.y; + input += batch * lda; + bucket_cnt += (batch * gridDim.x + blockIdx.x) * NR_BUCKET; + + uint32_t prefix; + if (prefix_valid) { + prefix = prefix_ptr[batch]; + } + + { + // init bucket_cnt + for (uint32_t i = threadIdx.x; i < NR_BUCKET; + i += BLOCK_DIM) { + bucket_cnt[i] = 0; + } + __syncthreads(); + } + + { + // accumulate + uint32_t i = blockIdx.x * BLOCK_DIM + threadIdx.x, + stride = BLOCK_DIM * gridDim.x; + while (i < length) { + uint32_t key = RadixConverter::to_radix(input[i]); + if (prefix_valid) { + const uint32_t mask = + ((~0u) << ((prefix_valid ? shift : 0) + BUCKET_BITS)); + key |= ((key & enforce_const_u32::val) == prefix) - 1; + } + uint32_t idx = (key >> shift) & ((1 << BUCKET_BITS) - 1); + atomicAdd(bucket_cnt+idx, 1u); + i += stride; + } + } + __syncthreads(); +} + +/*! + * \brief update the values in \p prefix to k'th value in according to bucket + * count, and update \p k + * + * launch config: grid(batch), thread(NR_BUCKET) + */ +template +static __global__ void update_prefix_and_k(const uint32_t* bucket_cnt, + uint32_t* prefix, uint32_t* k, + uint32_t k_init, + uint32_t bucket_sharding_size, + ctype* result) { + __shared__ uint32_t cumsum_bucket_cnt[NR_BUCKET + 1]; + uint32_t batch = blockIdx.x; + bucket_cnt += batch * bucket_sharding_size * NR_BUCKET; + + uint32_t sum = 0; + for (uint32_t i = 0; i < bucket_sharding_size; ++i) { + sum += bucket_cnt[i * NR_BUCKET + threadIdx.x]; + } + if (!threadIdx.x) { + cumsum_bucket_cnt[0] = 0; + } + const uint32_t i = threadIdx.x + 1; + cumsum_bucket_cnt[i] = sum; + + device_scan(cumsum_bucket_cnt + 1, threadIdx.x, 0); + __syncthreads(); + + uint32_t kv = first ? k_init : k[batch]; + if ((cumsum_bucket_cnt[i] >= kv) & (cumsum_bucket_cnt[i - 1] < kv)) { + uint32_t b = (i - 1) << shift; + if (first) { + prefix[batch] = b; + } else if (last) { + result[batch] = + RadixConverter::from_radix(prefix[batch] | b); + } else { + prefix[batch] |= b; + } + if (!last) { + k[batch] = kv - cumsum_bucket_cnt[i - 1]; + } + } + + //if ((cumsum_bucket_cnt[NR_BUCKET] < kv) | + // (cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) { + // // impossible + // int* bad = 0x0; + // *bad = 23; + //} +} + +static uint32_t get_grid_dim_x(uint32_t length) { + return std::max(length / (128 * BLOCK_DIM), 1); +} +} // namespace kth + +/*! + * \brief select values smaller or larger than given threshold + * + * Note: we use register shuffle extensively to perform both reduce and scan. + */ +namespace select { + +struct LessPred { + template + __device__ __forceinline__ static bool cmp(ctype x, ctype y) { + return x < y; + } +}; +struct GreaterPred { + template + __device__ __forceinline__ static bool cmp(ctype x, ctype y) { + return x > y; + } +}; + +const uint32_t REDUCE_WARP_SIZE = 16, REDUCE_SIZE = WARP_SIZE * 4, + REDUCE_SHARD = 64; +/*! + * \brief reduce number of elements satisfying Pred in (N, M) mat to + * (N, ceil(M / REDUCE_SIZE)) + * + * launch config: grid(X, batch), + * thread(REDUCE_WARP_SIZE, REDUCE_SHARD) + * + * Each block computes REDUCE_SHARD outputs + */ +template +static __global__ void kern_reduce_block_cnt(const ctype* input_data, + const ctype* input_thresh, + uint32_t length, int32_t lda, + uint64_t* output, + uint32_t output_width) { + static const uint32_t BLOCK_DIM_X = REDUCE_WARP_SIZE, + BLOCK_DIM_Y = REDUCE_SHARD; + uint32_t batch = blockIdx.y, + out_col = blockIdx.x * BLOCK_DIM_Y + threadIdx.y, + col_begin = out_col * REDUCE_SIZE, + col_end = min(col_begin + REDUCE_SIZE, length), + tid_local = threadIdx.x; + + if (out_col >= output_width) { + return; + } + + uint32_t thresh = RadixConverter::to_radix(input_thresh[batch]); + input_data += static_cast(batch) * lda; + uint32_t sum_eq = 0, sum_lt = 0; + for (uint32_t i = col_begin + tid_local; i < col_end; i += BLOCK_DIM_X) { + uint32_t iv = RadixConverter::to_radix(input_data[i]); + sum_eq += iv == thresh; + sum_lt += Pred::cmp(iv, thresh); + } + +#pragma unroll + for (uint32_t step = REDUCE_WARP_SIZE / 2; step >= 1; step >>= 1) { + sum_eq += __shfl_down(sum_eq, step, REDUCE_WARP_SIZE); + sum_lt += __shfl_down(sum_lt, step, REDUCE_WARP_SIZE); + } + + // reduce warp results to a single scalar + if (!tid_local) { + output[batch * output_width + out_col] = + (static_cast(sum_eq) << 32) | sum_lt; + } +} + +static MEGDNN_NOINLINE hipError_t +invoke_cub_scan(const uint64_t* input, uint64_t* output, void* workspace, + size_t& workspace_size, uint32_t size, hipStream_t stream) { + return hipcub::DeviceScan::InclusiveSum(workspace, workspace_size, + input, output, size, stream); +} + +static __global__ void kern_init_zero(uint64_t* dst) { + dst[0] = 0; +} + +/*! + * \brief copy top-k values of each row from input to output + * + * launch config: grid(X, batch), + * thread(WARP_SIZE, COPY_SHARD) + */ +template +static __global__ void kern_copy(const ctype* input_data, + const ctype* input_thresh, + const uint64_t* scan, uint32_t scan_width, + ctype* output_value, int32_t* output_idx, + uint32_t length, uint32_t k, int32_t lda) { +#if __cplusplus > 199711L + static_assert(REDUCE_SIZE < 256, "local_sum_storage can not be uint8_t"); +#endif + static const uint32_t BLOCK_DIM_X = WARP_SIZE, BLOCK_DIM_Y = COPY_SHARD; + + uint32_t scan_col = blockIdx.x * BLOCK_DIM_Y + threadIdx.y; + + if (scan_col >= scan_width) { + return; + } + + uint32_t batch = blockIdx.y, + inp_col_begin = min(scan_col * REDUCE_SIZE, length), + inp_col_length = + min(inp_col_begin + REDUCE_SIZE, length) - inp_col_begin, + tid_local = threadIdx.x; + uint32_t thresh = RadixConverter::to_radix(input_thresh[batch]); + input_data += + static_cast(batch) * lda + static_cast(inp_col_begin); + __shared__ uint8_t local_sum_storage[BLOCK_DIM_Y][2][REDUCE_SIZE + 4]; + uint8_t *local_sum_eq = local_sum_storage[threadIdx.y][0], + *local_sum_lt = local_sum_storage[threadIdx.y][1]; + if (!tid_local) { + local_sum_eq[3] = 0; + local_sum_lt[3] = 0; + } + local_sum_eq += 4; + local_sum_lt += 4; + const uint32_t WORKLOAD = REDUCE_SIZE / WARP_SIZE; +#pragma unroll + for (uint32_t j = 0; j < WORKLOAD; ++j) { + uint32_t i = j * BLOCK_DIM_X + tid_local; + if (i < inp_col_length) { + uint32_t iv = RadixConverter::to_radix(input_data[i]); + local_sum_eq[i] = iv == thresh; + local_sum_lt[i] = Pred::cmp(iv, thresh); + } else { + local_sum_eq[i] = 0; + local_sum_lt[i] = 0; + } + } + + device_scan_packed_accu32(local_sum_eq, tid_local); + device_scan_packed_accu32(local_sum_lt, tid_local); + + scan += batch * scan_width; + uint64_t scan_prev_pack = scan[static_cast(scan_col) - 1], + k_offset_pack = scan_prev_pack - scan[-1], + scan_self_pack = scan[scan_col] - scan_prev_pack; +#define unpack(name) \ + uint32_t name##_eq = u32_from_64_high(name##_pack), \ + name##_lt = u32_from_64_low(name##_pack) + unpack(k_offset); + unpack(scan_self); +#undef unpack + uint32_t allowed_eq = k - min(k, (u32_from_64_low(scan[scan_width - 1]) - + u32_from_64_low(scan[-1]))), + ls_lt_max = k - min(k_offset_lt, k), + ls_eq_max = allowed_eq - min(allowed_eq, k_offset_eq); + if ((scan_self_lt && ls_lt_max) || (scan_self_eq && ls_eq_max)) { +#pragma unroll + for (uint32_t j = 0; j < WORKLOAD; ++j) { + int32_t i = j * BLOCK_DIM_X + tid_local; + uint32_t cur_lt = local_sum_lt[i], cur_eq = local_sum_eq[i]; + bool is_lt = cur_lt <= ls_lt_max && cur_lt != local_sum_lt[i - 1]; + bool is_eq = cur_eq <= ls_eq_max && cur_eq != local_sum_eq[i - 1]; + // exactly one should be true + if (is_lt || is_eq) { + uint32_t off_lt = cur_lt + k_offset_lt - 1; + uint32_t off_eq = cur_eq + k_offset_eq - 1 + (k - allowed_eq); + uint32_t ocol = is_lt ? off_lt : off_eq; + output_value[batch * k + ocol] = input_data[i]; + output_idx[batch * k + ocol] = i + inp_col_begin; + } + } + } +} + +//! get workspace for scan, aligned to uint64_t +static size_t get_scan_workspace(uint32_t size) { + size_t wk = 0; + hipError_t err = invoke_cub_scan(NULL, NULL, NULL, wk, size, NULL); + if (err != hipSuccess) { + fprintf(stderr, "topk: cub scan failed: %s (%d)\n", + hipGetErrorString(err), static_cast(err)); + megdnn_trap(); + } + return ((wk - 1) / sizeof(uint64_t) + 1) * sizeof(uint64_t); +} + +} // namespace select +} // namespace rocm_topk_impl + +uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length, + uint32_t grid_dim_y_limit) { + using namespace rocm_topk_impl::kth; + uint32_t limit = batch > grid_dim_y_limit ? grid_dim_y_limit : batch; + return (limit * get_grid_dim_x(length) * NR_BUCKET + limit * 2) * + sizeof(uint32_t); +} + +template +hipError_t topk::find_kth_radix(const ctype* input, ctype* output, + void* workspace, uint32_t batch, + uint32_t length, int32_t lda, int32_t k, + uint32_t grid_dim_y_limit, + hipStream_t stream) { + using namespace rocm_topk_impl::kth; + if (!k) { + return hipErrorInvalidValue; + } + if (k < 0) { + k = length + k + 1; + } + if (!(BUCKET_BITS == 8 && (sizeof(ctype) == 4 || sizeof(ctype) == 2))) { + // no c++11 in megdnn cuda; so we just trap instead of using static + // assert + megdnn_trap(); + } + + uint32_t batch_idx = 0; + uint32_t grid_dim_x = get_grid_dim_x(length); + uint32_t grid_dim_y = 1; + + while (batch_idx < batch) { + if (batch - batch_idx >= grid_dim_y_limit) { + grid_dim_y = grid_dim_y_limit; + } else { + grid_dim_y = batch - batch_idx; + } + + dim3 grid_dim(grid_dim_x, grid_dim_y); + uint32_t* dev_k = static_cast(workspace); + uint32_t* dev_prefix = dev_k + grid_dim_y; + uint32_t* bucket_cnt = dev_prefix + grid_dim_y; + + compute_histogram<<>>( + input + batch_idx * lda, bucket_cnt, length, lda, nullptr); + + // use float to make compiler happy; it is not used since last == false + update_prefix_and_k + <<>>( + bucket_cnt, dev_prefix, dev_k, k, grid_dim_x, nullptr); + + compute_histogram<<>>( + input + batch_idx * lda, bucket_cnt, length, lda, dev_prefix); + + update_prefix_and_k + <<>>( + bucket_cnt, dev_prefix, dev_k, k, grid_dim_x, nullptr); + + compute_histogram<<>>( + input + batch_idx * lda, bucket_cnt, length, lda, dev_prefix); + + update_prefix_and_k + <<>>( + bucket_cnt, dev_prefix, dev_k, k, grid_dim_x, nullptr); + + compute_histogram<<>>( + input + batch_idx * lda, bucket_cnt, length, lda, dev_prefix); + + update_prefix_and_k + <<>>(bucket_cnt, dev_prefix, + dev_k, k, grid_dim_x, + output + batch_idx); + + batch_idx += grid_dim_y; + } + return hipGetLastError(); +} + +template +hipError_t topk::topk_select(const ctype* input, const ctype* thresh, + ctype* output_value, int32_t* output_idx, + void* workspace, uint32_t batch, uint32_t length, + int32_t lda, int32_t k, + uint32_t batch_upper_limit, hipStream_t stream) { + using namespace rocm_topk_impl; + using namespace rocm_topk_impl::select; + + uint32_t length_split = DIVUP(length, REDUCE_SIZE); + + void (*kptr_reduce_block_cnt)(const ctype*, const ctype*, uint32_t, int32_t, + uint64_t*, uint32_t); + void (*kptr_copy)(const ctype*, const ctype*, const uint64_t*, uint32_t, + ctype*, int32_t*, uint32_t, uint32_t, int32_t); + + int kern_copy_shard; + { + int grid, block; + hipError_t err = hipOccupancyMaxPotentialBlockSize( + &grid, &block, kern_copy); + if (err) { + return err; + } + kern_copy_shard = block / (WARP_SIZE * 8) * 8; + if (!kern_copy_shard) { + fprintf(stderr, "topk: failed to launch: block=%d\n", block); + return hipErrorLaunchOutOfResources; + } + } + +#define CASE_SHARD_ON(pred, n) \ + case n: \ + kptr_copy = kern_copy; \ + break +#define CASE_SHARD(pred) \ + switch (kern_copy_shard) { \ + CASE_SHARD_ON(pred, 8); \ + CASE_SHARD_ON(pred, 16); \ + CASE_SHARD_ON(pred, 24); \ + CASE_SHARD_ON(pred, 32); \ + default: \ + fprintf(stderr, "topk: failed to launch: shard=%d\n", \ + kern_copy_shard); \ + return hipErrorLaunchOutOfResources; \ + } + + if (k < 0) { + k = -k; + kptr_reduce_block_cnt = kern_reduce_block_cnt; + CASE_SHARD(GreaterPred); + } else { + kptr_reduce_block_cnt = kern_reduce_block_cnt; + CASE_SHARD(LessPred); + } + +#undef CASE_SHARD +#undef CASE_SHARD_ON + + uint32_t batch_idx = 0; + uint32_t batch_real = 1; + + while (batch_idx < batch) { + if (batch - batch_idx >= batch_upper_limit) { + batch_real = batch_upper_limit; + } else { + batch_real = batch - batch_idx; + } + + size_t scan_size = batch_real * length_split; + size_t scan_wk = get_scan_workspace(scan_size); + uint64_t *scan_inp = static_cast(workspace) + + scan_wk / sizeof(uint64_t), + *scan_out = scan_inp + scan_size; + + // reduce to scan_inp + kptr_reduce_block_cnt<<< + dim3(DIVUP(length_split, REDUCE_SHARD), batch_real), + dim3(REDUCE_WARP_SIZE, REDUCE_SHARD), 0, stream>>>( + input + batch_idx * lda, thresh + batch_idx, length, lda, + scan_inp, length_split); + + // scan to scan_out + scan_out += 1; // set scan[-1] to 0 + hipError_t err = invoke_cub_scan(scan_inp, scan_out, workspace, + scan_wk, scan_size, stream); + if (err != hipSuccess) { + return err; + } + kern_init_zero<<<1, 1, 0, stream>>>(scan_out - 1); + + // copy result + kptr_copy<<>>( + input + batch_idx * lda, thresh + batch_idx, scan_out, + length_split, output_value + std::abs(k) * batch_idx, + output_idx + std::abs(k) * batch_idx, length, k, lda); + + batch_idx += batch_real; + } + return hipGetLastError(); +} + +uint32_t topk::topk_select_workspace(uint32_t batch, uint32_t length) { + using namespace rocm_topk_impl::select; + size_t scan_size = batch * DIVUP(length, REDUCE_SIZE); + return get_scan_workspace(scan_size) + + sizeof(uint64_t) * (scan_size * 2 + 1); +} + +namespace megdnn { +namespace rocm { +namespace topk { +#define INST(t) \ + template hipError_t find_kth_radix(const t*, t*, void*, uint32_t, \ + uint32_t, int32_t, int32_t, \ + uint32_t, hipStream_t); \ + template hipError_t topk_select(const t*, const t*, t*, int32_t*, \ + void*, uint32_t, uint32_t, int32_t, \ + int32_t, uint32_t, hipStream_t) +INST(float); +INST(int32_t); +// DNN_INC_FLOAT16(INST(dt_float16)); +#undef INST + +} // namespace topk +} // namespace rocm +} // namespace megdnn + +// vim: ft=rocm syntax=rocm.doxygen + diff --git a/dnn/src/rocm/topk/topk_radix.h.hip b/dnn/src/rocm/topk/topk_radix.h.hip new file mode 100644 index 00000000..7e5ba132 --- /dev/null +++ b/dnn/src/rocm/topk/topk_radix.h.hip @@ -0,0 +1,127 @@ +/** + * \file dnn/src/rocm/topk/topk_radix.h.hip + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once +#include "src/rocm/utils.h.hip" +#include + +namespace megdnn { +namespace rocm { +namespace topk { +namespace internal { +template +struct RadixConverter; + +template <> +struct RadixConverter { + union FIunion { + float fv; + uint32_t iv; + }; + static __forceinline__ __device__ __host__ uint32_t to_radix(float val) { + FIunion fi; + fi.fv = val; + return fi.iv ^ (((!(fi.iv >> 31u)) - 1u) | 0x80000000u); + } + static __forceinline__ __device__ __host__ float from_radix(uint32_t val) { + FIunion fi; + // do not write as to_radix() to work around a compiler bug in cuda-9.0 + uint32_t m = 0x80000000u; + fi.iv = val ^ (m | (m - !(val >> 31u))); + return fi.fv; + } +}; + +template <> +struct RadixConverter { + union SUUnion { + int32_t sv; + uint32_t uv; + }; + static __forceinline__ __device__ __host__ uint32_t to_radix(int32_t val) { + SUUnion su; + su.sv = val; + return su.uv ^ (1u << 31u); + } + static __forceinline__ __device__ __host__ int32_t + from_radix(uint32_t val) { + SUUnion su; + su.uv = val; + return su.sv ^ (1u << 31u); + } +}; + +// #if !MEGDNN_DISABLE_FLOAT16 +// template <> +// struct RadixConverter { +// union FIunion { +// FIunion() {} +// dt_float16 fv; +// uint16_t iv; +// }; +// static __forceinline__ __device__ __host__ uint16_t to_radix(dt_float16 val) { +// FIunion fi; +// fi.fv = val; +// return fi.iv ^ (((!(fi.iv >> 15u)) - 1u) | 0x8000u); +// } +// static __forceinline__ __device__ __host__ dt_float16 from_radix(uint16_t val) { +// FIunion fi; +// // do not write as to_radix() to work around a compiler bug in cuda-9.0 +// uint16_t m = 0x8000u; +// fi.iv = val ^ (m | (m - !(val >> 15u))); +// return fi.fv; +// } +// }; +// #endif + +} // namespace internal + +/*! + * \brief find the k'th values of a (batch, length) matrix along the length + * dimension + * + * \param input input matrix, shape [batch, length], contiguous + * \param lda distance of contiguous rows in \p input, measured in num of + * elements in \p ctype + * \param k if positive, return the smallest top-k; otherwise return the + * largest top-k + * \param output top-k values of each batch, shape [batch] + */ +template +hipError_t find_kth_radix(const ctype* input, ctype* output, void* workspace, + uint32_t batch, uint32_t length, int32_t lda, + int32_t k, uint32_t grid_dim_y_limit, + hipStream_t stream); + +//! get workspace in bytes +uint32_t find_kth_radix_workspace(uint32_t batch, uint32_t length, + uint32_t grid_dim_y_limit); + +/*! + * \brief select values from rows of input that compare to thresh as specified + * \param k if k > 0, select values <= thresh; otherwise select values >= + * thresh. Its absolute value specifies output width. + */ +template +hipError_t topk_select(const ctype* input, const ctype* thresh, + ctype* output_value, int32_t* output_idx, + void* workspace, uint32_t batch, uint32_t length, + int32_t lda, int32_t k, uint32_t batch_upper_limit, + hipStream_t stream); + +uint32_t topk_select_workspace(uint32_t batch, uint32_t length); + +} // namespace topk +} // namespace rocm +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen + diff --git a/dnn/test/rocm/topk.cpp b/dnn/test/rocm/topk.cpp new file mode 100644 index 00000000..072520dc --- /dev/null +++ b/dnn/test/rocm/topk.cpp @@ -0,0 +1,38 @@ +/** + * \file dnn/test/rocm/topk.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "test/common/topk.h" +#include "test/rocm/fixture.h" + +using namespace megdnn; +using namespace test; + + +/* + * !!!!!!!!!!!!!!!! IMPORTANT NOTE !!!!!!!!!!!!!!!! + * The kernels are indepedently developed and tested in the + * MegDNN/expr/cuda_topk directory. Here we only check some common cases. + */ + +TEST_F(ROCM, TOP_K) { + run_topk_test(handle_rocm()); +} +TEST_F(ROCM, TOP_K_I32) { + run_topk_test(handle_rocm()); +} +// #if !MEGDNN_DISABLE_FLOAT16 +// TEST_F(ROCM, TOP_K_F16) { +// run_topk_test(handle_rocm()); +// } +// #endif + +// vim: syntax=cpp.doxygen