@@ -0,0 +1,183 @@ | |||
/** | |||
* \file dnn/src/rocm/argsort/argsort.cpp.hip | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "hcc_detail/hcc_defs_prologue.h" | |||
#include "src/rocm/utils.h.hip" | |||
#include "./argsort.h.hip" | |||
#include "./bitonic_sort.h.hip" | |||
#include "megdnn/basic_types.h" | |||
#include "hipcub/device/device_radix_sort.hpp" | |||
#include "hipcub/device/device_segmented_radix_sort.hpp" | |||
using namespace megdnn; | |||
using namespace rocm; | |||
namespace { | |||
struct StridedOffsetIterator { | |||
int bias, stride; | |||
StridedOffsetIterator(int bias_, int stride_) | |||
: bias(bias_), stride(stride_) {} | |||
__device__ __forceinline__ int operator[](int i) const { | |||
return stride * i + bias; | |||
} | |||
}; | |||
bool use_bitonic(uint32_t /*M*/, uint32_t N) { | |||
// bitonic sort is preferred when N is small (alwyas faster than radix sort) | |||
return N <= BITONIC_SORT_MAX_LENGTH; | |||
} | |||
bool use_segmented(uint32_t M, uint32_t /*N*/) { | |||
// an empirical value: | |||
// sort(1, 1e6): 0.574ms | |||
// segsort({1,2,8,16}, 1e6): 7-8ms | |||
// sort(1, 1e7): 3.425ms | |||
// segsort({1,2,8,16}, 1e7): 71-84ms | |||
// | |||
// segsort is about 7x-10x slower than sort on small batches, so we can | |||
// expect it to be faster than sort when batch is large enough. | |||
return M >= 8; | |||
} | |||
__global__ void kern_arange(int* dst, uint32_t n, uint32_t mod) { | |||
uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; | |||
if (i < n) { | |||
dst[i] = i % mod; | |||
} | |||
} | |||
template <typename ctype> | |||
size_t get_sort_workspace(uint32_t M, uint32_t N, bool is_ascending) { | |||
if (use_bitonic(M, N)) { | |||
return 0; | |||
} | |||
return argsort::cub_sort_pairs<ctype, int>(is_ascending, NULL, 0, NULL, NULL, NULL, NULL, | |||
M, N, 0, sizeof(float)*8, NULL); | |||
} | |||
} // anonymous namespace | |||
template <typename KeyType, typename ValueType> | |||
MEGDNN_NOINLINE size_t argsort::cub_sort_pairs( | |||
bool is_ascending, void* workspace, size_t workspace_size, | |||
const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in, | |||
ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,hipStream_t stream){ | |||
hipError_t err; | |||
if (use_segmented(M, N)) { | |||
if (is_ascending) { | |||
err = hipcub::DeviceSegmentedRadixSort::SortPairs( | |||
workspace, workspace_size, keys_in, keys_out, values_in, | |||
values_out, N * M, M, StridedOffsetIterator(0, N), | |||
StridedOffsetIterator(N, N), begin_bit, end_bit, stream); | |||
hip_check(err); | |||
} else { | |||
err = hipcub::DeviceSegmentedRadixSort::SortPairsDescending( | |||
workspace, workspace_size, keys_in, keys_out, values_in, | |||
values_out, N * M, M, StridedOffsetIterator(0, N), | |||
StridedOffsetIterator(N, N), begin_bit, end_bit, stream); | |||
hip_check(err); | |||
} | |||
} else { | |||
if (is_ascending) { | |||
for (size_t i = 0; i < M; ++i) { | |||
err = hipcub::DeviceRadixSort::SortPairs( | |||
workspace, workspace_size, keys_in + N * i, | |||
keys_out + N * i, values_in + N * i, values_out + N * i, | |||
N, begin_bit, end_bit, stream); | |||
hip_check(err); | |||
if (!keys_in) { | |||
return workspace_size; | |||
} | |||
} | |||
} else { | |||
for (size_t i = 0; i < M; ++i) { | |||
err = hipcub::DeviceRadixSort::SortPairsDescending( | |||
workspace, workspace_size, keys_in + N * i, | |||
keys_out + N * i, values_in + N * i, values_out + N * i, | |||
N, begin_bit, end_bit, stream); | |||
hip_check(err); | |||
if (!keys_in) { | |||
return workspace_size; | |||
} | |||
} | |||
} | |||
} | |||
return workspace_size; | |||
} | |||
size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, | |||
bool is_ascending, | |||
bool iptr_src_given) { | |||
size_t size = 0; | |||
switch (dtype.enumv().ev) { | |||
#define cb(ctype) \ | |||
case DTypeTrait<ctype>::enumv: \ | |||
size = get_sort_workspace<ctype>(M, N, is_ascending); \ | |||
break; | |||
ARGSORT_FOREACH_CTYPE(cb) | |||
#undef cb | |||
default: | |||
megdnn_throw("argsort only supports float, int32 and float16"); | |||
} | |||
if (!iptr_src_given) { | |||
size = DIVUP(size, sizeof(float)) * sizeof(float) + M * N * sizeof(int); | |||
} | |||
return size; | |||
} | |||
template <typename dtype> | |||
void argsort::forward(const dtype* sptr, dtype* dptr, int* iptr, | |||
void* workspace, uint32_t M, uint32_t N, | |||
bool is_ascending, hipStream_t stream, | |||
const int* iptr_src) { | |||
size_t wk_size = get_sort_workspace<dtype>(M, N, is_ascending); | |||
if (!iptr_src) { | |||
int* ptr = reinterpret_cast<int*>(static_cast<uint8_t*>(workspace) + | |||
DIVUP(wk_size, sizeof(float)) * | |||
sizeof(float)); | |||
kern_arange<<<DIVUP(N * M, 512), 512, 0, stream>>>(ptr, M * N, N); | |||
iptr_src = ptr; | |||
} | |||
if (use_bitonic(M, N)) { | |||
hip_check(bitonic_sort(M, N, sptr, iptr_src, dptr, iptr, is_ascending, | |||
stream)); | |||
} else { | |||
cub_sort_pairs(is_ascending, workspace, wk_size, sptr, dptr, iptr_src, | |||
iptr, M, N, 0, sizeof(float)*8, stream); | |||
} | |||
} | |||
namespace megdnn { | |||
namespace rocm { | |||
#define INST_CUB_SORT(dtype) \ | |||
template MEGDNN_NOINLINE size_t argsort::cub_sort_pairs<dtype, dtype>(bool, \ | |||
void*, size_t, const dtype*, dtype*, \ | |||
const dtype*, dtype*, uint32_t, uint32_t,\ | |||
int, int, hipStream_t); | |||
#define INST_FORWARD(dtype) \ | |||
template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*, \ | |||
uint32_t, uint32_t, bool, hipStream_t, \ | |||
const int*); | |||
ARGSORT_FOREACH_CTYPE(INST_FORWARD) | |||
INST_CUB_SORT(uint32_t) | |||
// INST_CUB_SORT(uint64_t) | |||
#undef INST_CUB_SORT | |||
#undef INST_FORWARD | |||
} | |||
} // namespace megdnn | |||
// vim: ft=rocm syntax=rocm.doxygen | |||
@@ -0,0 +1,50 @@ | |||
/** | |||
* \file dnn/src/rocm/argsort/argsort.h.hip | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "hcc_detail/hcc_defs_prologue.h" | |||
#include "hip_header.h" | |||
#include <stddef.h> | |||
#include <stdint.h> | |||
#include "megdnn/dtype.h" | |||
namespace megdnn { | |||
namespace rocm { | |||
namespace argsort { | |||
size_t get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, | |||
bool is_ascending, | |||
bool iptr_src_given = false); | |||
template <typename KeyType, typename ValueType> | |||
size_t cub_sort_pairs( | |||
bool is_ascending, void* workspace, size_t workspace_size, | |||
const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in, | |||
ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,hipStream_t stream); | |||
/*! | |||
* \param iptr_src pointer to indices; a range would be generated if it is null | |||
*/ | |||
template <typename dtype> | |||
void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace, | |||
uint32_t M, uint32_t N, bool is_ascending, hipStream_t stream, | |||
const int* iptr_src = NULL); | |||
//! iterate over all supported data types | |||
#define ARGSORT_FOREACH_CTYPE(cb) \ | |||
cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16)) | |||
} // namespace argsort | |||
} // namespace rocm | |||
} // namespace megdnn | |||
// vim: ft=cpp syntax=cpp.doxygen |
@@ -0,0 +1,67 @@ | |||
/** | |||
* \file dnn/src/rocm/argsort/backward.cpp.hip | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "hcc_detail/hcc_defs_prologue.h" | |||
#include "src/rocm/utils.h.hip" | |||
#include "./argsort.h.hip" | |||
#include "./backward.h.hip" | |||
// #include "src/rocm/utils.h" | |||
using namespace megdnn; | |||
using namespace rocm; | |||
using namespace argsort; | |||
namespace { | |||
template <typename T> | |||
__global__ void backward_kernel(uint32_t dst_w, uint32_t src_w, | |||
uint32_t src_size, T* dst, const T* src_data, | |||
const int* src_idx) { | |||
uint32_t idx = threadIdx.x + blockIdx.x * blockDim.x; | |||
if (idx < src_size) { | |||
uint32_t r = idx / src_w; | |||
dst[r * dst_w + src_idx[idx]] = src_data[idx]; | |||
} | |||
} | |||
} // namespace | |||
template <typename T> | |||
void argsort::backward_proxy(uint32_t dst_h, uint32_t dst_w, uint32_t src_w, | |||
T* dst, const T* src_data, const int* src_idx, | |||
hipStream_t stream) { | |||
if (dst_w != src_w) { | |||
hipMemsetAsync(dst, 0, dst_h * dst_w * sizeof(T), stream); | |||
} | |||
uint32_t src_size = dst_h * src_w; | |||
backward_kernel<<<DIVUP(src_size, 512), 512, 0, stream>>>( | |||
dst_w, src_w, src_size, dst, src_data, src_idx); | |||
after_kernel_launch(); | |||
} | |||
namespace megdnn { | |||
namespace rocm { | |||
namespace argsort { | |||
#define INST(T) \ | |||
template void backward_proxy(uint32_t dst_h, uint32_t dst_w, \ | |||
uint32_t src_w, T* dst, const T* src_data, \ | |||
const int* src_idx, hipStream_t stream); | |||
ARGSORT_FOREACH_CTYPE(INST) | |||
#undef INST | |||
} // namespace argsort | |||
} // namespace rocm | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,29 @@ | |||
/** | |||
* \file dnn/src/rocm/argsort/backward.h.hip | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "hip_header.h" | |||
#include <stdint.h> | |||
namespace megdnn { | |||
namespace rocm { | |||
namespace argsort { | |||
template <typename T> | |||
void backward_proxy(uint32_t dst_h, uint32_t dst_w, uint32_t src_w, T* dst, | |||
const T* src_data, const int* src_idx, hipStream_t stream); | |||
} // namespace argsort | |||
} // namespace rocm | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||
@@ -0,0 +1,320 @@ | |||
/** | |||
* \file dnn/src/rocm/argsort/bitonic_sort.cpp.hip | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "hcc_detail/hcc_defs_prologue.h" | |||
#include "./bitonic_sort.h.hip" | |||
// #include "src/cuda/query_blocksize.cuh" | |||
// #include "megdnn/dtype.h" | |||
// #if __CUDACC_VER_MAJOR__ < 9 | |||
// #pragma message "warp sync disabled due to insufficient cuda version" | |||
#define __syncwarp __syncthreads | |||
// #endif | |||
#include <algorithm> | |||
#include <cmath> | |||
using namespace megdnn; | |||
using namespace rocm; | |||
namespace bitonic_sort_impl { | |||
//! load keys and init idx | |||
template <class CompareLess, typename T> | |||
__device__ __forceinline__ void safe_load0(T* dst, uint16_t* idx, const T* src, | |||
uint32_t id, uint32_t size) { | |||
dst[id] = id < size ? src[id] : CompareLess::template max<T>(); | |||
idx[id] = id; | |||
} | |||
//! load values | |||
template <typename T> | |||
__device__ __forceinline__ void safe_load1(T* dst, const T* src, uint32_t id, | |||
uint32_t size) { | |||
// broadcast last value to avoid out-of-bound values (for example, when | |||
// input contains NaN) | |||
dst[id] = src[min(id, size - 1)]; | |||
} | |||
//! write keys | |||
template <typename T> | |||
__device__ __forceinline__ void safe_write0(T* dst, const T* src, uint32_t id, | |||
uint32_t size) { | |||
if (id < size) { | |||
dst[id] = src[id]; | |||
} | |||
} | |||
//! write values | |||
template <typename T> | |||
__device__ __forceinline__ void safe_write1(T* dst, const T* src, | |||
const uint16_t* remap, uint32_t id, | |||
uint32_t size) { | |||
if (id < size) { | |||
dst[id] = src[remap[id]]; | |||
} | |||
} | |||
struct SyncWarp { | |||
static __device__ __forceinline__ void s() { __syncwarp(); } | |||
}; | |||
struct SyncBlock { | |||
static __device__ __forceinline__ void s() { __syncthreads(); } | |||
}; | |||
template <typename T> | |||
struct NumTrait; | |||
template <> | |||
struct NumTrait<float> { | |||
static __device__ __forceinline__ float max() { return INFINITY; } | |||
static __device__ __forceinline__ float min() { return -INFINITY; } | |||
}; | |||
template <> | |||
struct NumTrait<int32_t> { | |||
static __device__ __forceinline__ int32_t max() { return INT_MAX; } | |||
static __device__ __forceinline__ int32_t min() { return INT_MIN; } | |||
}; | |||
// #if !MEGDNN_DISABLE_FLOAT16 | |||
// template <> | |||
// struct NumTrait<dt_float16> { | |||
// static __device__ __forceinline__ dt_float16 max() { | |||
// return std::numeric_limits<dt_float16>::max(); | |||
// } | |||
// static __device__ __forceinline__ dt_float16 min() { | |||
// return std::numeric_limits<dt_float16>::lowest(); | |||
// } | |||
// }; | |||
// #endif | |||
struct LessThan { | |||
template <typename Key, typename Value> | |||
static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1, | |||
Value v1) { | |||
return (k0 < k1) | ((k0 == k1) & (v0 < v1)); | |||
} | |||
template <typename T> | |||
static __device__ __forceinline__ T max() { | |||
return NumTrait<T>::max(); | |||
} | |||
}; | |||
struct GreaterThan { | |||
template <typename Key, typename Value> | |||
static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1, | |||
Value v1) { | |||
return (k0 > k1) | ((k0 == k1) & (v0 < v1)); | |||
} | |||
template <typename T> | |||
static __device__ __forceinline__ T max() { | |||
return NumTrait<T>::min(); | |||
} | |||
}; | |||
template <typename Key, typename Value> | |||
union KVUnion { | |||
Key key; | |||
Value value; | |||
}; | |||
template <typename Key, typename Value> | |||
static int get_shmem(int block_size, void* = NULL) { | |||
return (sizeof(KVUnion<Key, Value>) + sizeof(uint16_t)) * block_size * 4; | |||
} | |||
/*! | |||
* \brief batched bitonic sort (M, N) for small N | |||
* | |||
* launch configuration: | |||
* grid(X) | |||
* block(N/4, Y) | |||
* | |||
* where N / 4 == 1 << nr_th_log2 | |||
*/ | |||
template <class Sync, typename Key, typename Value, class CompareLess, | |||
uint32_t nr_th_log2> | |||
static __global__ void kern(uint32_t batch, uint32_t length, const Key* key_inp, | |||
const Value* value_inp, Key* key_out, | |||
Value* value_out) { | |||
const uint32_t nr_th = 1 << nr_th_log2; | |||
// 24KiB shared memory for 4-byte keys for 1024 threads | |||
extern __shared__ uint8_t smem_storage[]; | |||
uint16_t* idx_storage = reinterpret_cast<uint16_t*>(smem_storage); | |||
KVUnion<Key, Value>* keys_storage = reinterpret_cast<KVUnion<Key, Value>*>( | |||
idx_storage + blockDim.y * (nr_th * 4)); | |||
uint32_t cur_batch = blockIdx.x * blockDim.y + threadIdx.y, | |||
off = cur_batch * length; | |||
key_inp += off; | |||
key_out += off; | |||
value_inp += off; | |||
value_out += off; | |||
uint32_t storage_offset = threadIdx.y * (nr_th * 4); | |||
uint16_t* values = idx_storage + storage_offset; | |||
Key* keys = reinterpret_cast<Key*>(keys_storage + storage_offset); | |||
uint32_t tid0 = threadIdx.x, tid1 = tid0 + nr_th, | |||
cur_length = cur_batch < batch ? length : 0; | |||
safe_load0<CompareLess>(keys, values, key_inp, tid0, cur_length); | |||
safe_load0<CompareLess>(keys, values, key_inp, tid0 + nr_th, cur_length); | |||
safe_load0<CompareLess>(keys, values, key_inp, tid0 + nr_th * 2, | |||
cur_length); | |||
safe_load0<CompareLess>(keys, values, key_inp, tid0 + nr_th * 3, | |||
cur_length); | |||
Sync::s(); | |||
#define WORK(_idx, _asc) \ | |||
do { \ | |||
uint32_t _id0 = (_idx), _id1 = _id0 + step; \ | |||
Key _k0 = keys[_id0], _k1 = keys[_id1]; \ | |||
uint16_t _v0 = values[_id0], _v1 = values[_id1]; \ | |||
if (CompareLess::cmp(_k0, _v0, _k1, _v1) != _asc) { \ | |||
keys[_id0] = _k1; \ | |||
keys[_id1] = _k0; \ | |||
values[_id0] = _v1; \ | |||
values[_id1] = _v0; \ | |||
} \ | |||
} while (0) | |||
#pragma unroll | |||
for (uint32_t slen_log = 0; slen_log <= (nr_th_log2 + 1); ++slen_log) { | |||
// log2 of half of current bitonic sequence (i.e. length of its | |||
// monotonic part) | |||
uint32_t asc0 = !((tid0 >> slen_log) & 1), | |||
asc1 = !((tid1 >> slen_log) & 1); | |||
#pragma unroll | |||
for (uint32_t j = 0; j <= slen_log; ++j) { | |||
uint32_t step = 1 << (slen_log - j), xmask = step - 1, | |||
ymask = ~xmask; | |||
WORK((tid0 & xmask) + ((tid0 & ymask) << 1), asc0); | |||
WORK((tid1 & xmask) + ((tid1 & ymask) << 1), asc1); | |||
Sync::s(); | |||
} | |||
} | |||
#undef WORK | |||
if (cur_batch < batch) { | |||
safe_write0(key_out, keys, tid0, length); | |||
safe_write0(key_out, keys, tid0 + nr_th, length); | |||
safe_write0(key_out, keys, tid0 + nr_th * 2, length); | |||
safe_write0(key_out, keys, tid0 + nr_th * 3, length); | |||
// permute values according to sorted indices | |||
Value* copied_values = reinterpret_cast<Value*>(keys); | |||
safe_load1(copied_values, value_inp, tid0, cur_length); | |||
safe_load1(copied_values, value_inp, tid0 + nr_th, cur_length); | |||
safe_load1(copied_values, value_inp, tid0 + nr_th * 2, cur_length); | |||
safe_load1(copied_values, value_inp, tid0 + nr_th * 3, cur_length); | |||
Sync::s(); | |||
safe_write1(value_out, copied_values, values, tid0, length); | |||
safe_write1(value_out, copied_values, values, tid0 + nr_th, length); | |||
safe_write1(value_out, copied_values, values, tid0 + nr_th * 2, length); | |||
safe_write1(value_out, copied_values, values, tid0 + nr_th * 3, length); | |||
} | |||
} | |||
} // namespace bitonic_sort_impl | |||
template <typename Key, typename Value> | |||
hipError_t rocm::bitonic_sort(uint32_t batch, uint32_t length, | |||
const Key* key_inp, const Value* value_inp, | |||
Key* key_out, Value* value_out, bool ascending, | |||
hipStream_t stream) { | |||
using namespace bitonic_sort_impl; | |||
if (length == 1) { | |||
if (key_inp != key_out) { | |||
hipMemcpyAsync(key_out, key_inp, sizeof(Key) * batch, | |||
hipMemcpyDeviceToDevice, stream); | |||
} | |||
if (value_inp != value_out) { | |||
hipMemcpyAsync(value_out, value_inp, sizeof(Value) * batch, | |||
hipMemcpyDeviceToDevice, stream); | |||
} | |||
return hipGetLastError(); | |||
} | |||
void (*kptr)(uint32_t, uint32_t, const Key*, const Value*, Key*, Value*) = | |||
NULL; | |||
uint32_t l4 = (length + 3) / 4; | |||
dim3 block; | |||
#define chk(s) \ | |||
do { \ | |||
if (!kptr && l4 <= (1 << s)) { \ | |||
block.x = 1 << s; \ | |||
if ((1 << s) <= 32) { \ | |||
if (ascending) { \ | |||
kptr = kern<SyncWarp, Key, Value, LessThan, s>; \ | |||
} else { \ | |||
kptr = kern<SyncWarp, Key, Value, GreaterThan, s>; \ | |||
} \ | |||
} else { \ | |||
if (ascending) { \ | |||
kptr = kern<SyncBlock, Key, Value, LessThan, s>; \ | |||
} else { \ | |||
kptr = kern<SyncBlock, Key, Value, GreaterThan, s>; \ | |||
} \ | |||
} \ | |||
} \ | |||
} while (0) | |||
chk(0); | |||
chk(1); | |||
chk(2); | |||
chk(3); | |||
chk(4); | |||
chk(5); | |||
chk(6); | |||
chk(7); | |||
chk(8); | |||
chk(9); | |||
if (!kptr) { | |||
return hipErrorInvalidConfiguration; | |||
} | |||
// TODO: this is randomly choosed | |||
int suggested_block_size = 128; | |||
// query_launch_config_for_kernel(reinterpret_cast<void*>(kptr), | |||
// get_shmem<Key, Value>) | |||
// .block_size; | |||
block.y = std::max<int>(suggested_block_size / block.x, 1); | |||
int shmem = get_shmem<Key, Value>(block.y * block.x); | |||
kptr<<<(batch - 1) / block.y + 1, block, shmem, stream>>>( | |||
batch, length, key_inp, value_inp, key_out, value_out); | |||
return hipGetLastError(); | |||
} | |||
namespace megdnn { | |||
namespace rocm { | |||
#define INST(k, v) \ | |||
template hipError_t bitonic_sort<k, v>(uint32_t, uint32_t, const k*, \ | |||
const v*, k*, v*, bool, \ | |||
hipStream_t) | |||
INST(float, int); | |||
INST(int32_t, int); | |||
// DNN_INC_FLOAT16(INST(dt_float16, int)); | |||
#undef INST | |||
} // namespace megdnn | |||
} // namespace megdnn | |||
// vim: ft=rocm syntax=rocm.doxygen | |||
@@ -0,0 +1,38 @@ | |||
/** | |||
* \file dnn/src/rocm/argsort/bitonic_sort.h.hip | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "hip_header.h" | |||
#include <stdint.h> | |||
namespace megdnn { | |||
namespace rocm { | |||
const uint32_t BITONIC_SORT_MAX_LENGTH = 1024; | |||
// cub radix sort seems to be faster with lengths > 1024 | |||
/*! | |||
* \brief bitonic sort for k/v pairs | |||
* | |||
* Requires \p length no larger than 4 times of cuda thread num. \p key_inp | |||
* and \p key_out can be identical, and so are \p value_inp and \p value_out. | |||
*/ | |||
template <typename Key, typename Value> | |||
hipError_t bitonic_sort(uint32_t batch, uint32_t length, const Key* key_inp, | |||
const Value* value_inp, Key* key_out, Value* value_out, | |||
bool ascending, hipStream_t stream); | |||
} // namespace rocm | |||
} // namespace megdnn | |||
// vim: ft=cpp syntax=cpp.doxygen | |||
@@ -0,0 +1,79 @@ | |||
/** | |||
* \file dnn/src/rocm/argsort/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./opr_impl.h" | |||
#include "./argsort.h.hip" | |||
#include "./backward.h.hip" | |||
#include "src/common/utils.h" | |||
#include "src/rocm/utils.h" | |||
using namespace megdnn; | |||
using namespace rocm; | |||
void ArgsortForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_tensor_out indices, | |||
_megdnn_workspace workspace) { | |||
check_exec(src.layout, dst.layout, indices.layout, workspace.size); | |||
auto M = src.layout.shape[0], N = src.layout.shape[1]; | |||
auto iptr = indices.ptr<dt_int32>(); | |||
auto wptr = static_cast<void*>(workspace.raw_ptr); | |||
bool is_ascending = (param().order == Order::ASCENDING); | |||
auto stream = hip_stream(handle()); | |||
switch (src.layout.dtype.enumv()) { | |||
#define cb(t) \ | |||
case DTypeTrait<t>::enumv: \ | |||
argsort::forward(src.ptr<t>(), dst.ptr<t>(), iptr, wptr, M, N, \ | |||
is_ascending, stream); \ | |||
break; | |||
ARGSORT_FOREACH_CTYPE(cb); | |||
#undef cb | |||
default: | |||
megdnn_throw(ssprintf("unsupported argsort dtype on cuda: %s", | |||
src.layout.dtype.name())); | |||
} | |||
} | |||
size_t ArgsortForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout&, | |||
const TensorLayout&) { | |||
megdnn_assert(src.ndim == 2, "invalid src layout: %s", | |||
src.to_string().c_str()); | |||
auto M = src.shape[0], N = src.shape[1]; | |||
auto&& dtype = src.dtype; | |||
megdnn_assert(std::max(M, N) <= | |||
static_cast<size_t>(std::numeric_limits<int>::max())); | |||
return argsort::get_fwd_workspace_in_bytes( | |||
M, N, dtype, param().order == Param::Order::ASCENDING); | |||
} | |||
void ArgsortBackwardImpl::exec(_megdnn_tensor_in diff, | |||
_megdnn_tensor_in indices, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) { | |||
check_exec(diff.layout, indices.layout, grad.layout, workspace.size); | |||
auto stream = hip_stream(handle()); | |||
switch (diff.layout.dtype.enumv()) { | |||
#define cb(t) \ | |||
case DTypeTrait<t>::enumv: \ | |||
argsort::backward_proxy(grad.layout[0], grad.layout[1], \ | |||
diff.layout[1], grad.ptr<t>(), diff.ptr<t>(), \ | |||
indices.ptr<int>(), stream); \ | |||
break; | |||
ARGSORT_FOREACH_CTYPE(cb); | |||
#undef cb | |||
default: | |||
megdnn_throw(ssprintf("unsupported argsort dtype on cuda: %s", | |||
diff.layout.dtype.name())); | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,47 @@ | |||
/** | |||
* \file dnn/src/rocm/argsort/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace rocm { | |||
class ArgsortForwardImpl final: public ArgsortForward { | |||
public: | |||
using ArgsortForward::ArgsortForward; | |||
void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_out dst, | |||
_megdnn_tensor_out indices, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &dst, | |||
const TensorLayout &indices) override; | |||
}; | |||
class ArgsortBackwardImpl final: public ArgsortBackward { | |||
public: | |||
using ArgsortBackward::ArgsortBackward; | |||
void exec(_megdnn_tensor_in diff, | |||
_megdnn_tensor_in indices, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout &, | |||
const TensorLayout &, | |||
const TensorLayout &) override { | |||
return 0; | |||
} | |||
}; | |||
} // namespace rocm | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||
@@ -33,6 +33,7 @@ | |||
#include "src/rocm/powc/opr_impl.h" | |||
#include "src/rocm/indexing_multi_axis_vec/opr_impl.h" | |||
#include "src/rocm/linspace/opr_impl.h" | |||
#include "src/rocm/argsort/opr_impl.h" | |||
#include "src/rocm/argmxx/opr_impl.h" | |||
#include "src/rocm/sleep/opr_impl.h" | |||
#include "src/rocm/batch_normalization/opr_impl.h" | |||
@@ -148,6 +149,8 @@ bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) { | |||
return src.is_contiguous() || src.stride[src.ndim - 1] == 1; | |||
} | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortBackward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardFilter); | |||
@@ -0,0 +1,124 @@ | |||
/** | |||
* \file dnn/test/rocm/argsort.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "test/rocm/fixture.h" | |||
#include "test/common/checker.h" | |||
#include "test/common/rng.h" | |||
#include "test/common/tensor.h" | |||
#include "../src/rocm/argsort/opr_impl.h" | |||
using namespace megdnn; | |||
using namespace test; | |||
namespace { | |||
class ArgsortRNG final : public RNG { | |||
bool m_rev_order = false; | |||
DType m_dtype; | |||
template <typename T> | |||
void fill(T* ptr, int n) { | |||
if (m_rev_order) { | |||
for (int i = 0; i < n; ++i) | |||
ptr[i] = static_cast<T>(n / 2 - i); | |||
} else { | |||
for (int i = 0; i < n; ++i) | |||
ptr[i] = static_cast<T>(i - n / 2); | |||
COMPAT_RANDOM(ptr, ptr + n); | |||
} | |||
} | |||
void gen(const TensorND& tensor) override { | |||
auto n = tensor.layout.total_nr_elems(); | |||
if (m_dtype == dtype::Float32{}) { | |||
fill(tensor.ptr<dt_float32>(), n); | |||
} else { | |||
megdnn_assert(m_dtype == dtype::Int32{}); | |||
fill(tensor.ptr<dt_int32>(), n); | |||
} | |||
} | |||
public: | |||
ArgsortRNG(DType dt) : m_dtype{dt} {} | |||
void set_rev_order(bool flag) { m_rev_order = flag; } | |||
}; | |||
void run_forward_test(Handle* handle, DType dtype) { | |||
Checker<ArgsortForward> checker(handle); | |||
using Param = Argsort::Param; | |||
using Order = Param::Order; | |||
ArgsortRNG rng{dtype}; | |||
checker.set_dtype(2, dtype::Int32()); | |||
checker.set_dtype(0, dtype).set_rng(0, &rng); | |||
for (size_t i = 3; i < 10240; i *= 2) { | |||
Param param; | |||
param.order = Order::ASCENDING; | |||
checker.set_param(param).execs({{3, i + 1}, {}, {}}); | |||
param.order = Order::DESCENDING; | |||
checker.set_param(param).execs({{3, i - 1}, {}, {}}); | |||
checker.set_param(param).execs({{13, i + 3}, {}, {}}); | |||
} | |||
{ | |||
// reverse sort large array | |||
constexpr size_t N = 200003; | |||
rng.set_rev_order(true); | |||
Param param; | |||
param.order = Order::ASCENDING; | |||
checker.set_param(param).execs({{1, N}, {}, {}}); | |||
} | |||
} | |||
void run_backward_test(Handle* handle, DType dtype) { | |||
class IdxRng final : public RNG { | |||
void gen(const TensorND& tensor) override { | |||
auto ptr = tensor.ptr<dt_int32>(); | |||
auto m = tensor.layout[0], n = tensor.layout[1]; | |||
for (size_t i = 0; i < m; ++i) { | |||
for (size_t j = 0; j < n; ++j) { | |||
ptr[j] = j; | |||
} | |||
COMPAT_RANDOM(ptr, ptr + n); | |||
ptr += n; | |||
} | |||
} | |||
} rng; | |||
Checker<ArgsortBackward> checker(handle); | |||
checker.set_dtype(1, dtype::Int32()).set_rng(1, &rng); | |||
checker.set_dtype(0, dtype); | |||
checker.set_dtype(2, dtype); | |||
for (size_t i = 16; i < 4096; i *= 2) { | |||
checker.execs({{3, i}, {3, i}, {3, i}}); | |||
checker.execs({{3, i + 3}, {3, i + 3}, {3, i + 3}}); | |||
checker.execs({{3, i + 3}, {3, i + 3}, {3, i + 7}}); | |||
} | |||
} | |||
} // anonymous namespace | |||
TEST_F(ROCM, ARGSORT_FORWARD_F32) { | |||
run_forward_test(handle_rocm(), dtype::Float32{}); | |||
} | |||
TEST_F(ROCM, ARGSORT_FORWARD_I32) { | |||
run_forward_test(handle_rocm(), dtype::Int32{}); | |||
} | |||
TEST_F(ROCM, ARGSORT_BACKWARD_F32) { | |||
run_backward_test(handle_rocm(), dtype::Float32{}); | |||
} | |||
TEST_F(ROCM, ARGSORT_BACKWARD_I32) { | |||
run_backward_test(handle_rocm(), dtype::Int32{}); | |||
} | |||
// vim: syntax=cpp.doxygen | |||