GitOrigin-RevId: 6aa35928c8
release-1.2
@@ -22,20 +22,25 @@ template <typename ctype> | |||
void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda, | |||
const ctype* data, ctype* values, | |||
int* indices, void* workspace) { | |||
auto stream = concrete_handle(handle())->stream(); | |||
auto _handle = concrete_handle(handle()); | |||
auto stream = _handle->stream(); | |||
size_t grid_dim_y_limit = _handle->device_prop().maxGridSize[1]; | |||
switch (param().mode) { | |||
case Param::Mode::KTH_ONLY: | |||
cuda_check(topk::find_kth_radix<ctype>(data, values, workspace, m, | |||
n, lda, k, stream)); | |||
n, lda, k, grid_dim_y_limit, | |||
stream)); | |||
return; | |||
case Param::Mode::VALUE_IDX_NOSORT: { | |||
WorkspaceBundle wk_bundle{workspace, {m * sizeof(ctype), 1}}; | |||
auto thresh = static_cast<ctype*>(wk_bundle.get(0)); | |||
auto real_wk = wk_bundle.get(1); | |||
cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n, | |||
lda, k, stream)); | |||
lda, k, grid_dim_y_limit, | |||
stream)); | |||
cuda_check(topk::topk_select<ctype>(data, thresh, values, indices, | |||
real_wk, m, n, lda, k, stream)); | |||
real_wk, m, n, lda, k, | |||
grid_dim_y_limit, stream)); | |||
return; | |||
} | |||
case Param::Mode::VALUE_IDX_SORTED: { | |||
@@ -48,10 +53,11 @@ void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda, | |||
auto nosort_idx = static_cast<int32_t*>(wk_bundle.get(2)); | |||
auto real_wk = wk_bundle.get(3); | |||
cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n, | |||
lda, k, stream)); | |||
lda, k, grid_dim_y_limit, | |||
stream)); | |||
cuda_check(topk::topk_select<ctype>(data, thresh, nosort_values, | |||
nosort_idx, real_wk, m, n, lda, | |||
k, stream)); | |||
k, grid_dim_y_limit, stream)); | |||
argsort::forward(nosort_values, values, indices, real_wk, m, | |||
std::abs(k), k > 0, stream, nosort_idx); | |||
return; | |||
@@ -89,9 +95,11 @@ size_t TopKImpl::get_workspace_in_bytes(int k, const TensorLayout& data, | |||
MEGDNN_MARK_USED_VAR(indices); | |||
size_t m = data[0], n = data[1]; | |||
size_t kabs = std::abs(k); | |||
size_t grid_dim_y_limit = | |||
concrete_handle(handle())->device_prop().maxGridSize[1]; | |||
megdnn_assert(std::max(m, n) <= | |||
static_cast<size_t>(std::numeric_limits<int>::max())); | |||
size_t kth = topk::find_kth_radix_workspace(m, n), | |||
size_t kth = topk::find_kth_radix_workspace(m, n, grid_dim_y_limit), | |||
sel = topk::topk_select_workspace(m, n); | |||
auto ctsize = data.dtype.size(); | |||
switch (param().mode) { | |||
@@ -468,17 +468,9 @@ static size_t get_scan_workspace(uint32_t size) { | |||
} // namespace select | |||
} // namespace cuda_topk_impl | |||
uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length) { | |||
uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length, | |||
uint32_t grid_dim_y_limit) { | |||
using namespace cuda_topk_impl::kth; | |||
int device_id; | |||
if (cudaGetDevice(&device_id) != cudaSuccess) { | |||
megdnn_trap(); | |||
} | |||
cudaDeviceProp prop; | |||
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) { | |||
megdnn_trap(); | |||
} | |||
uint32_t grid_dim_y_limit = prop.maxGridSize[1]; | |||
uint32_t limit = batch > grid_dim_y_limit ? grid_dim_y_limit : batch; | |||
return (limit * get_grid_dim_x(length) * NR_BUCKET + limit * 2) * | |||
sizeof(uint32_t); | |||
@@ -488,6 +480,7 @@ template <typename ctype> | |||
cudaError_t topk::find_kth_radix(const ctype* input, ctype* output, | |||
void* workspace, uint32_t batch, | |||
uint32_t length, int32_t lda, int32_t k, | |||
uint32_t grid_dim_y_limit, | |||
cudaStream_t stream) { | |||
using namespace cuda_topk_impl::kth; | |||
if (!k) { | |||
@@ -502,16 +495,6 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output, | |||
megdnn_trap(); | |||
} | |||
int device_id; | |||
if (cudaGetDevice(&device_id) != cudaSuccess) { | |||
megdnn_trap(); | |||
} | |||
cudaDeviceProp prop; | |||
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) { | |||
megdnn_trap(); | |||
} | |||
uint32_t grid_dim_y_limit = prop.maxGridSize[1]; | |||
uint32_t batch_idx = 0; | |||
uint32_t grid_dim_x = get_grid_dim_x(length); | |||
uint32_t grid_dim_y = 1; | |||
@@ -567,20 +550,11 @@ template <typename ctype> | |||
cudaError_t topk::topk_select(const ctype* input, const ctype* thresh, | |||
ctype* output_value, int32_t* output_idx, | |||
void* workspace, uint32_t batch, uint32_t length, | |||
int32_t lda, int32_t k, cudaStream_t stream) { | |||
int32_t lda, int32_t k, | |||
uint32_t batch_upper_limit, cudaStream_t stream) { | |||
using namespace cuda_topk_impl; | |||
using namespace cuda_topk_impl::select; | |||
int device_id; | |||
if (cudaGetDevice(&device_id) != cudaSuccess) { | |||
megdnn_trap(); | |||
} | |||
cudaDeviceProp prop; | |||
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) { | |||
megdnn_trap(); | |||
} | |||
uint32_t batch_upper_limit = prop.maxGridSize[1]; | |||
uint32_t length_split = DIVUP(length, REDUCE_SIZE); | |||
void (*kptr_reduce_block_cnt)(const ctype*, const ctype*, uint32_t, int32_t, | |||
@@ -688,10 +662,10 @@ namespace topk { | |||
#define INST(t) \ | |||
template cudaError_t find_kth_radix<t>(const t*, t*, void*, uint32_t, \ | |||
uint32_t, int32_t, int32_t, \ | |||
cudaStream_t); \ | |||
uint32_t, cudaStream_t); \ | |||
template cudaError_t topk_select<t>(const t*, const t*, t*, int32_t*, \ | |||
void*, uint32_t, uint32_t, int32_t, \ | |||
int32_t, cudaStream_t) | |||
int32_t, uint32_t, cudaStream_t) | |||
INST(float); | |||
INST(int32_t); | |||
#undef INST | |||
@@ -76,10 +76,12 @@ struct RadixConverter<int32_t> { | |||
template <typename ctype> | |||
cudaError_t find_kth_radix(const ctype* input, ctype* output, void* workspace, | |||
uint32_t batch, uint32_t length, int32_t lda, | |||
int32_t k, cudaStream_t stream); | |||
int32_t k, uint32_t grid_dim_y_limit, | |||
cudaStream_t stream); | |||
//! get workspace in bytes | |||
uint32_t find_kth_radix_workspace(uint32_t batch, uint32_t length); | |||
uint32_t find_kth_radix_workspace(uint32_t batch, uint32_t length, | |||
uint32_t grid_dim_y_limit); | |||
/*! | |||
* \brief select values from rows of input that compare to thresh as specified | |||
@@ -90,7 +92,8 @@ template <typename ctype> | |||
cudaError_t topk_select(const ctype* input, const ctype* thresh, | |||
ctype* output_value, int32_t* output_idx, | |||
void* workspace, uint32_t batch, uint32_t length, | |||
int32_t lda, int32_t k, cudaStream_t stream); | |||
int32_t lda, int32_t k, uint32_t batch_upper_limit, | |||
cudaStream_t stream); | |||
uint32_t topk_select_workspace(uint32_t batch, uint32_t length); | |||