|
@@ -470,7 +470,17 @@ static size_t get_scan_workspace(uint32_t size) { |
|
|
|
|
|
|
|
|
uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length) { |
|
|
uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length) { |
|
|
using namespace cuda_topk_impl::kth; |
|
|
using namespace cuda_topk_impl::kth; |
|
|
return (batch * get_grid_dim_x(length) * NR_BUCKET + batch * 2) * |
|
|
|
|
|
|
|
|
int device_id; |
|
|
|
|
|
if (cudaGetDevice(&device_id) != cudaSuccess) { |
|
|
|
|
|
megdnn_trap(); |
|
|
|
|
|
} |
|
|
|
|
|
cudaDeviceProp prop; |
|
|
|
|
|
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) { |
|
|
|
|
|
megdnn_trap(); |
|
|
|
|
|
} |
|
|
|
|
|
uint32_t grid_dim_y_limit = prop.maxGridSize[1]; |
|
|
|
|
|
uint32_t limit = batch > grid_dim_y_limit ? grid_dim_y_limit : batch; |
|
|
|
|
|
return (limit * get_grid_dim_x(length) * NR_BUCKET + limit * 2) * |
|
|
sizeof(uint32_t); |
|
|
sizeof(uint32_t); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@@ -491,35 +501,65 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output, |
|
|
// assert |
|
|
// assert |
|
|
megdnn_trap(); |
|
|
megdnn_trap(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
int device_id; |
|
|
|
|
|
if (cudaGetDevice(&device_id) != cudaSuccess) { |
|
|
|
|
|
megdnn_trap(); |
|
|
|
|
|
} |
|
|
|
|
|
cudaDeviceProp prop; |
|
|
|
|
|
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) { |
|
|
|
|
|
megdnn_trap(); |
|
|
|
|
|
} |
|
|
|
|
|
uint32_t grid_dim_y_limit = prop.maxGridSize[1]; |
|
|
|
|
|
|
|
|
|
|
|
uint32_t batch_idx = 0; |
|
|
uint32_t grid_dim_x = get_grid_dim_x(length); |
|
|
uint32_t grid_dim_x = get_grid_dim_x(length); |
|
|
dim3 grid_dim(grid_dim_x, batch); |
|
|
|
|
|
uint32_t* dev_k = static_cast<uint32_t*>(workspace); |
|
|
|
|
|
uint32_t* dev_prefix = dev_k + batch; |
|
|
|
|
|
uint32_t* bucket_cnt = dev_prefix + batch; |
|
|
|
|
|
|
|
|
|
|
|
compute_histogram<ctype, false, 24><<<grid_dim, BLOCK_DIM, 0, stream>>>( |
|
|
|
|
|
input, bucket_cnt, length, lda, nullptr); |
|
|
|
|
|
// use float to make compiler happy; it is not used since last == false |
|
|
|
|
|
update_prefix_and_k<true, false, 24, float> |
|
|
|
|
|
<<<batch, NR_BUCKET, 0, stream>>>(bucket_cnt, dev_prefix, dev_k, k, |
|
|
|
|
|
grid_dim_x, nullptr); |
|
|
|
|
|
|
|
|
|
|
|
compute_histogram<ctype, true, 16><<<grid_dim, BLOCK_DIM, 0, stream>>>( |
|
|
|
|
|
input, bucket_cnt, length, lda, dev_prefix); |
|
|
|
|
|
update_prefix_and_k<false, false, 16, float> |
|
|
|
|
|
<<<batch, NR_BUCKET, 0, stream>>>(bucket_cnt, dev_prefix, dev_k, k, |
|
|
|
|
|
grid_dim_x, nullptr); |
|
|
|
|
|
compute_histogram<ctype, true, 8><<<grid_dim, BLOCK_DIM, 0, stream>>>( |
|
|
|
|
|
input, bucket_cnt, length, lda, dev_prefix); |
|
|
|
|
|
update_prefix_and_k<false, false, 8, float> |
|
|
|
|
|
<<<batch, NR_BUCKET, 0, stream>>>(bucket_cnt, dev_prefix, dev_k, k, |
|
|
|
|
|
grid_dim_x, nullptr); |
|
|
|
|
|
|
|
|
|
|
|
compute_histogram<ctype, true, 0><<<grid_dim, BLOCK_DIM, 0, stream>>>( |
|
|
|
|
|
input, bucket_cnt, length, lda, dev_prefix); |
|
|
|
|
|
update_prefix_and_k<false, true, 0, ctype><<<batch, NR_BUCKET, 0, stream>>>( |
|
|
|
|
|
bucket_cnt, dev_prefix, dev_k, k, grid_dim_x, output); |
|
|
|
|
|
|
|
|
uint32_t grid_dim_y = 1; |
|
|
|
|
|
|
|
|
|
|
|
while (batch_idx < batch) { |
|
|
|
|
|
if (batch - batch_idx >= grid_dim_y_limit) { |
|
|
|
|
|
grid_dim_y = grid_dim_y_limit; |
|
|
|
|
|
} else { |
|
|
|
|
|
grid_dim_y = batch - batch_idx; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
dim3 grid_dim(grid_dim_x, grid_dim_y); |
|
|
|
|
|
uint32_t* dev_k = static_cast<uint32_t*>(workspace); |
|
|
|
|
|
uint32_t* dev_prefix = dev_k + grid_dim_y; |
|
|
|
|
|
uint32_t* bucket_cnt = dev_prefix + grid_dim_y; |
|
|
|
|
|
|
|
|
|
|
|
compute_histogram<ctype, false, 24><<<grid_dim, BLOCK_DIM, 0, stream>>>( |
|
|
|
|
|
input + batch_idx * lda, bucket_cnt, length, lda, nullptr); |
|
|
|
|
|
|
|
|
|
|
|
// use float to make compiler happy; it is not used since last == false |
|
|
|
|
|
update_prefix_and_k<true, false, 24, float> |
|
|
|
|
|
<<<grid_dim_y, NR_BUCKET, 0, stream>>>( |
|
|
|
|
|
bucket_cnt, dev_prefix, dev_k, k, grid_dim_x, nullptr); |
|
|
|
|
|
|
|
|
|
|
|
compute_histogram<ctype, true, 16><<<grid_dim, BLOCK_DIM, 0, stream>>>( |
|
|
|
|
|
input + batch_idx * lda, bucket_cnt, length, lda, dev_prefix); |
|
|
|
|
|
|
|
|
|
|
|
update_prefix_and_k<false, false, 16, float> |
|
|
|
|
|
<<<grid_dim_y, NR_BUCKET, 0, stream>>>( |
|
|
|
|
|
bucket_cnt, dev_prefix, dev_k, k, grid_dim_x, nullptr); |
|
|
|
|
|
|
|
|
|
|
|
compute_histogram<ctype, true, 8><<<grid_dim, BLOCK_DIM, 0, stream>>>( |
|
|
|
|
|
input + batch_idx * lda, bucket_cnt, length, lda, dev_prefix); |
|
|
|
|
|
|
|
|
|
|
|
update_prefix_and_k<false, false, 8, float> |
|
|
|
|
|
<<<grid_dim_y, NR_BUCKET, 0, stream>>>( |
|
|
|
|
|
bucket_cnt, dev_prefix, dev_k, k, grid_dim_x, nullptr); |
|
|
|
|
|
|
|
|
|
|
|
compute_histogram<ctype, true, 0><<<grid_dim, BLOCK_DIM, 0, stream>>>( |
|
|
|
|
|
input + batch_idx * lda, bucket_cnt, length, lda, dev_prefix); |
|
|
|
|
|
|
|
|
|
|
|
update_prefix_and_k<false, true, 0, ctype> |
|
|
|
|
|
<<<grid_dim_y, NR_BUCKET, 0, stream>>>(bucket_cnt, dev_prefix, |
|
|
|
|
|
dev_k, k, grid_dim_x, |
|
|
|
|
|
output + batch_idx); |
|
|
|
|
|
|
|
|
|
|
|
batch_idx += grid_dim_y; |
|
|
|
|
|
} |
|
|
return cudaGetLastError(); |
|
|
return cudaGetLastError(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@@ -530,12 +570,18 @@ cudaError_t topk::topk_select(const ctype* input, const ctype* thresh, |
|
|
int32_t lda, int32_t k, cudaStream_t stream) { |
|
|
int32_t lda, int32_t k, cudaStream_t stream) { |
|
|
using namespace cuda_topk_impl; |
|
|
using namespace cuda_topk_impl; |
|
|
using namespace cuda_topk_impl::select; |
|
|
using namespace cuda_topk_impl::select; |
|
|
uint32_t length_split = DIVUP(length, REDUCE_SIZE), |
|
|
|
|
|
scan_size = batch * length_split; |
|
|
|
|
|
size_t scan_wk = get_scan_workspace(scan_size); |
|
|
|
|
|
uint64_t *scan_inp = static_cast<uint64_t*>(workspace) + |
|
|
|
|
|
scan_wk / sizeof(uint64_t), |
|
|
|
|
|
*scan_out = scan_inp + scan_size; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int device_id; |
|
|
|
|
|
if (cudaGetDevice(&device_id) != cudaSuccess) { |
|
|
|
|
|
megdnn_trap(); |
|
|
|
|
|
} |
|
|
|
|
|
cudaDeviceProp prop; |
|
|
|
|
|
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) { |
|
|
|
|
|
megdnn_trap(); |
|
|
|
|
|
} |
|
|
|
|
|
uint32_t batch_upper_limit = prop.maxGridSize[1]; |
|
|
|
|
|
|
|
|
|
|
|
uint32_t length_split = DIVUP(length, REDUCE_SIZE); |
|
|
|
|
|
|
|
|
void (*kptr_reduce_block_cnt)(const ctype*, const ctype*, uint32_t, int32_t, |
|
|
void (*kptr_reduce_block_cnt)(const ctype*, const ctype*, uint32_t, int32_t, |
|
|
uint64_t*, uint32_t); |
|
|
uint64_t*, uint32_t); |
|
@@ -585,25 +631,47 @@ cudaError_t topk::topk_select(const ctype* input, const ctype* thresh, |
|
|
#undef CASE_SHARD |
|
|
#undef CASE_SHARD |
|
|
#undef CASE_SHARD_ON |
|
|
#undef CASE_SHARD_ON |
|
|
|
|
|
|
|
|
// reduce to scan_inp |
|
|
|
|
|
kptr_reduce_block_cnt<<<dim3(DIVUP(length_split, REDUCE_SHARD), batch), |
|
|
|
|
|
dim3(REDUCE_WARP_SIZE, REDUCE_SHARD), 0, stream>>>( |
|
|
|
|
|
input, thresh, length, lda, scan_inp, length_split); |
|
|
|
|
|
|
|
|
uint32_t batch_idx = 0; |
|
|
|
|
|
uint32_t batch_real = 1; |
|
|
|
|
|
|
|
|
// scan to scan_out |
|
|
|
|
|
scan_out += 1; // set scan[-1] to 0 |
|
|
|
|
|
cudaError_t err = invoke_cub_scan(scan_inp, scan_out, workspace, scan_wk, |
|
|
|
|
|
scan_size, stream); |
|
|
|
|
|
if (err != cudaSuccess) { |
|
|
|
|
|
return err; |
|
|
|
|
|
} |
|
|
|
|
|
kern_init_zero<<<1, 1, 0, stream>>>(scan_out - 1); |
|
|
|
|
|
|
|
|
while (batch_idx < batch) { |
|
|
|
|
|
if (batch - batch_idx >= batch_upper_limit) { |
|
|
|
|
|
batch_real = batch_upper_limit; |
|
|
|
|
|
} else { |
|
|
|
|
|
batch_real = batch - batch_idx; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// copy result |
|
|
|
|
|
kptr_copy<<<dim3(DIVUP(length_split, kern_copy_shard), batch), |
|
|
|
|
|
dim3(WARP_SIZE, kern_copy_shard), 0, stream>>>( |
|
|
|
|
|
input, thresh, scan_out, length_split, output_value, output_idx, |
|
|
|
|
|
length, k, lda); |
|
|
|
|
|
|
|
|
size_t scan_size = batch_real * length_split; |
|
|
|
|
|
size_t scan_wk = get_scan_workspace(scan_size); |
|
|
|
|
|
uint64_t *scan_inp = static_cast<uint64_t*>(workspace) + |
|
|
|
|
|
scan_wk / sizeof(uint64_t), |
|
|
|
|
|
*scan_out = scan_inp + scan_size; |
|
|
|
|
|
|
|
|
|
|
|
// reduce to scan_inp |
|
|
|
|
|
kptr_reduce_block_cnt<<< |
|
|
|
|
|
dim3(DIVUP(length_split, REDUCE_SHARD), batch_real), |
|
|
|
|
|
dim3(REDUCE_WARP_SIZE, REDUCE_SHARD), 0, stream>>>( |
|
|
|
|
|
input + batch_idx * lda, thresh + batch_idx, length, lda, |
|
|
|
|
|
scan_inp, length_split); |
|
|
|
|
|
|
|
|
|
|
|
// scan to scan_out |
|
|
|
|
|
scan_out += 1; // set scan[-1] to 0 |
|
|
|
|
|
cudaError_t err = invoke_cub_scan(scan_inp, scan_out, workspace, |
|
|
|
|
|
scan_wk, scan_size, stream); |
|
|
|
|
|
if (err != cudaSuccess) { |
|
|
|
|
|
return err; |
|
|
|
|
|
} |
|
|
|
|
|
kern_init_zero<<<1, 1, 0, stream>>>(scan_out - 1); |
|
|
|
|
|
|
|
|
|
|
|
// copy result |
|
|
|
|
|
kptr_copy<<<dim3(DIVUP(length_split, kern_copy_shard), batch_real), |
|
|
|
|
|
dim3(WARP_SIZE, kern_copy_shard), 0, stream>>>( |
|
|
|
|
|
input + batch_idx * lda, thresh + batch_idx, scan_out, |
|
|
|
|
|
length_split, output_value + std::abs(k) * batch_idx, |
|
|
|
|
|
output_idx + std::abs(k) * batch_idx, length, k, lda); |
|
|
|
|
|
|
|
|
|
|
|
batch_idx += batch_real; |
|
|
|
|
|
} |
|
|
return cudaGetLastError(); |
|
|
return cudaGetLastError(); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|