GitOrigin-RevId: 6aa35928c8
release-1.2
@@ -22,20 +22,25 @@ template <typename ctype> | |||||
void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda, | void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda, | ||||
const ctype* data, ctype* values, | const ctype* data, ctype* values, | ||||
int* indices, void* workspace) { | int* indices, void* workspace) { | ||||
auto stream = concrete_handle(handle())->stream(); | |||||
auto _handle = concrete_handle(handle()); | |||||
auto stream = _handle->stream(); | |||||
size_t grid_dim_y_limit = _handle->device_prop().maxGridSize[1]; | |||||
switch (param().mode) { | switch (param().mode) { | ||||
case Param::Mode::KTH_ONLY: | case Param::Mode::KTH_ONLY: | ||||
cuda_check(topk::find_kth_radix<ctype>(data, values, workspace, m, | cuda_check(topk::find_kth_radix<ctype>(data, values, workspace, m, | ||||
n, lda, k, stream)); | |||||
n, lda, k, grid_dim_y_limit, | |||||
stream)); | |||||
return; | return; | ||||
case Param::Mode::VALUE_IDX_NOSORT: { | case Param::Mode::VALUE_IDX_NOSORT: { | ||||
WorkspaceBundle wk_bundle{workspace, {m * sizeof(ctype), 1}}; | WorkspaceBundle wk_bundle{workspace, {m * sizeof(ctype), 1}}; | ||||
auto thresh = static_cast<ctype*>(wk_bundle.get(0)); | auto thresh = static_cast<ctype*>(wk_bundle.get(0)); | ||||
auto real_wk = wk_bundle.get(1); | auto real_wk = wk_bundle.get(1); | ||||
cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n, | cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n, | ||||
lda, k, stream)); | |||||
lda, k, grid_dim_y_limit, | |||||
stream)); | |||||
cuda_check(topk::topk_select<ctype>(data, thresh, values, indices, | cuda_check(topk::topk_select<ctype>(data, thresh, values, indices, | ||||
real_wk, m, n, lda, k, stream)); | |||||
real_wk, m, n, lda, k, | |||||
grid_dim_y_limit, stream)); | |||||
return; | return; | ||||
} | } | ||||
case Param::Mode::VALUE_IDX_SORTED: { | case Param::Mode::VALUE_IDX_SORTED: { | ||||
@@ -48,10 +53,11 @@ void TopKImpl::dispatch_with_ctype(int k, size_t m, size_t n, ptrdiff_t lda, | |||||
auto nosort_idx = static_cast<int32_t*>(wk_bundle.get(2)); | auto nosort_idx = static_cast<int32_t*>(wk_bundle.get(2)); | ||||
auto real_wk = wk_bundle.get(3); | auto real_wk = wk_bundle.get(3); | ||||
cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n, | cuda_check(topk::find_kth_radix<ctype>(data, thresh, real_wk, m, n, | ||||
lda, k, stream)); | |||||
lda, k, grid_dim_y_limit, | |||||
stream)); | |||||
cuda_check(topk::topk_select<ctype>(data, thresh, nosort_values, | cuda_check(topk::topk_select<ctype>(data, thresh, nosort_values, | ||||
nosort_idx, real_wk, m, n, lda, | nosort_idx, real_wk, m, n, lda, | ||||
k, stream)); | |||||
k, grid_dim_y_limit, stream)); | |||||
argsort::forward(nosort_values, values, indices, real_wk, m, | argsort::forward(nosort_values, values, indices, real_wk, m, | ||||
std::abs(k), k > 0, stream, nosort_idx); | std::abs(k), k > 0, stream, nosort_idx); | ||||
return; | return; | ||||
@@ -89,9 +95,11 @@ size_t TopKImpl::get_workspace_in_bytes(int k, const TensorLayout& data, | |||||
MEGDNN_MARK_USED_VAR(indices); | MEGDNN_MARK_USED_VAR(indices); | ||||
size_t m = data[0], n = data[1]; | size_t m = data[0], n = data[1]; | ||||
size_t kabs = std::abs(k); | size_t kabs = std::abs(k); | ||||
size_t grid_dim_y_limit = | |||||
concrete_handle(handle())->device_prop().maxGridSize[1]; | |||||
megdnn_assert(std::max(m, n) <= | megdnn_assert(std::max(m, n) <= | ||||
static_cast<size_t>(std::numeric_limits<int>::max())); | static_cast<size_t>(std::numeric_limits<int>::max())); | ||||
size_t kth = topk::find_kth_radix_workspace(m, n), | |||||
size_t kth = topk::find_kth_radix_workspace(m, n, grid_dim_y_limit), | |||||
sel = topk::topk_select_workspace(m, n); | sel = topk::topk_select_workspace(m, n); | ||||
auto ctsize = data.dtype.size(); | auto ctsize = data.dtype.size(); | ||||
switch (param().mode) { | switch (param().mode) { | ||||
@@ -468,17 +468,9 @@ static size_t get_scan_workspace(uint32_t size) { | |||||
} // namespace select | } // namespace select | ||||
} // namespace cuda_topk_impl | } // namespace cuda_topk_impl | ||||
uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length) { | |||||
uint32_t topk::find_kth_radix_workspace(uint32_t batch, uint32_t length, | |||||
uint32_t grid_dim_y_limit) { | |||||
using namespace cuda_topk_impl::kth; | using namespace cuda_topk_impl::kth; | ||||
int device_id; | |||||
if (cudaGetDevice(&device_id) != cudaSuccess) { | |||||
megdnn_trap(); | |||||
} | |||||
cudaDeviceProp prop; | |||||
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) { | |||||
megdnn_trap(); | |||||
} | |||||
uint32_t grid_dim_y_limit = prop.maxGridSize[1]; | |||||
uint32_t limit = batch > grid_dim_y_limit ? grid_dim_y_limit : batch; | uint32_t limit = batch > grid_dim_y_limit ? grid_dim_y_limit : batch; | ||||
return (limit * get_grid_dim_x(length) * NR_BUCKET + limit * 2) * | return (limit * get_grid_dim_x(length) * NR_BUCKET + limit * 2) * | ||||
sizeof(uint32_t); | sizeof(uint32_t); | ||||
@@ -488,6 +480,7 @@ template <typename ctype> | |||||
cudaError_t topk::find_kth_radix(const ctype* input, ctype* output, | cudaError_t topk::find_kth_radix(const ctype* input, ctype* output, | ||||
void* workspace, uint32_t batch, | void* workspace, uint32_t batch, | ||||
uint32_t length, int32_t lda, int32_t k, | uint32_t length, int32_t lda, int32_t k, | ||||
uint32_t grid_dim_y_limit, | |||||
cudaStream_t stream) { | cudaStream_t stream) { | ||||
using namespace cuda_topk_impl::kth; | using namespace cuda_topk_impl::kth; | ||||
if (!k) { | if (!k) { | ||||
@@ -502,16 +495,6 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output, | |||||
megdnn_trap(); | megdnn_trap(); | ||||
} | } | ||||
int device_id; | |||||
if (cudaGetDevice(&device_id) != cudaSuccess) { | |||||
megdnn_trap(); | |||||
} | |||||
cudaDeviceProp prop; | |||||
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) { | |||||
megdnn_trap(); | |||||
} | |||||
uint32_t grid_dim_y_limit = prop.maxGridSize[1]; | |||||
uint32_t batch_idx = 0; | uint32_t batch_idx = 0; | ||||
uint32_t grid_dim_x = get_grid_dim_x(length); | uint32_t grid_dim_x = get_grid_dim_x(length); | ||||
uint32_t grid_dim_y = 1; | uint32_t grid_dim_y = 1; | ||||
@@ -567,20 +550,11 @@ template <typename ctype> | |||||
cudaError_t topk::topk_select(const ctype* input, const ctype* thresh, | cudaError_t topk::topk_select(const ctype* input, const ctype* thresh, | ||||
ctype* output_value, int32_t* output_idx, | ctype* output_value, int32_t* output_idx, | ||||
void* workspace, uint32_t batch, uint32_t length, | void* workspace, uint32_t batch, uint32_t length, | ||||
int32_t lda, int32_t k, cudaStream_t stream) { | |||||
int32_t lda, int32_t k, | |||||
uint32_t batch_upper_limit, cudaStream_t stream) { | |||||
using namespace cuda_topk_impl; | using namespace cuda_topk_impl; | ||||
using namespace cuda_topk_impl::select; | using namespace cuda_topk_impl::select; | ||||
int device_id; | |||||
if (cudaGetDevice(&device_id) != cudaSuccess) { | |||||
megdnn_trap(); | |||||
} | |||||
cudaDeviceProp prop; | |||||
if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) { | |||||
megdnn_trap(); | |||||
} | |||||
uint32_t batch_upper_limit = prop.maxGridSize[1]; | |||||
uint32_t length_split = DIVUP(length, REDUCE_SIZE); | uint32_t length_split = DIVUP(length, REDUCE_SIZE); | ||||
void (*kptr_reduce_block_cnt)(const ctype*, const ctype*, uint32_t, int32_t, | void (*kptr_reduce_block_cnt)(const ctype*, const ctype*, uint32_t, int32_t, | ||||
@@ -688,10 +662,10 @@ namespace topk { | |||||
#define INST(t) \ | #define INST(t) \ | ||||
template cudaError_t find_kth_radix<t>(const t*, t*, void*, uint32_t, \ | template cudaError_t find_kth_radix<t>(const t*, t*, void*, uint32_t, \ | ||||
uint32_t, int32_t, int32_t, \ | uint32_t, int32_t, int32_t, \ | ||||
cudaStream_t); \ | |||||
uint32_t, cudaStream_t); \ | |||||
template cudaError_t topk_select<t>(const t*, const t*, t*, int32_t*, \ | template cudaError_t topk_select<t>(const t*, const t*, t*, int32_t*, \ | ||||
void*, uint32_t, uint32_t, int32_t, \ | void*, uint32_t, uint32_t, int32_t, \ | ||||
int32_t, cudaStream_t) | |||||
int32_t, uint32_t, cudaStream_t) | |||||
INST(float); | INST(float); | ||||
INST(int32_t); | INST(int32_t); | ||||
#undef INST | #undef INST | ||||
@@ -76,10 +76,12 @@ struct RadixConverter<int32_t> { | |||||
template <typename ctype> | template <typename ctype> | ||||
cudaError_t find_kth_radix(const ctype* input, ctype* output, void* workspace, | cudaError_t find_kth_radix(const ctype* input, ctype* output, void* workspace, | ||||
uint32_t batch, uint32_t length, int32_t lda, | uint32_t batch, uint32_t length, int32_t lda, | ||||
int32_t k, cudaStream_t stream); | |||||
int32_t k, uint32_t grid_dim_y_limit, | |||||
cudaStream_t stream); | |||||
//! get workspace in bytes | //! get workspace in bytes | ||||
uint32_t find_kth_radix_workspace(uint32_t batch, uint32_t length); | |||||
uint32_t find_kth_radix_workspace(uint32_t batch, uint32_t length, | |||||
uint32_t grid_dim_y_limit); | |||||
/*! | /*! | ||||
* \brief select values from rows of input that compare to thresh as specified | * \brief select values from rows of input that compare to thresh as specified | ||||
@@ -90,7 +92,8 @@ template <typename ctype> | |||||
cudaError_t topk_select(const ctype* input, const ctype* thresh, | cudaError_t topk_select(const ctype* input, const ctype* thresh, | ||||
ctype* output_value, int32_t* output_idx, | ctype* output_value, int32_t* output_idx, | ||||
void* workspace, uint32_t batch, uint32_t length, | void* workspace, uint32_t batch, uint32_t length, | ||||
int32_t lda, int32_t k, cudaStream_t stream); | |||||
int32_t lda, int32_t k, uint32_t batch_upper_limit, | |||||
cudaStream_t stream); | |||||
uint32_t topk_select_workspace(uint32_t batch, uint32_t length); | uint32_t topk_select_workspace(uint32_t batch, uint32_t length); | ||||