GitOrigin-RevId: 07b1248bdc
release-1.7
@@ -1115,7 +1115,7 @@ public: | |||
* access *data*; stride of layout on that axis would be zero, and | |||
* strides on other axes correspond to the strides in *data* | |||
*/ | |||
static std::pair<TensorLayout, size_t> get_value_iter_optimized_layout( | |||
static std::tuple<TensorLayout, size_t, TensorShape> get_value_iter_optimized_layout( | |||
const TensorLayout& data, const TensorLayout& value, const IndexDesc& index, | |||
size_t idx_axis); | |||
@@ -1159,7 +1159,8 @@ public: | |||
* \brief get workspace size based on output shape and indexing axes | |||
*/ | |||
size_t get_workspace_in_bytes( | |||
const TensorShape& dst, const size_t* axes, size_t nr_axes); | |||
const TensorShape& dst, const size_t* axes, size_t nr_axes, | |||
size_t idx_ndim); | |||
static void deduce_layout( | |||
const TensorLayout& data, const IndexDescLayoutOnly& index, | |||
@@ -1193,7 +1194,8 @@ public: | |||
* axes | |||
*/ | |||
size_t get_workspace_in_bytes( | |||
const TensorShape& value, const size_t* axes, size_t nr_axes); | |||
const TensorShape& value, const size_t* axes, size_t nr_axes, | |||
size_t idx_ndim); | |||
protected: | |||
ExecInfo check_exec( | |||
@@ -1223,7 +1225,7 @@ public: | |||
using AxisIndexerLayoutOnly = IndexingMultiAxisVecBase::AxisIndexerLayoutOnly; | |||
using IndexDescLayoutOnly = IndexingMultiAxisVecBase::IndexDescLayoutOnly; | |||
size_t get_workspace_in_bytes(const TensorShape&, const size_t*, size_t) { | |||
size_t get_workspace_in_bytes(const TensorShape&, const size_t*, size_t, size_t) { | |||
return 0; | |||
} | |||
@@ -15,8 +15,10 @@ | |||
using namespace megdnn; | |||
namespace { | |||
// we need a workspace to store offset base table, which has same size with index | |||
size_t get_index_size_for_workspace( | |||
const TensorShape& shp, const size_t* axes, size_t nr_axes) { | |||
const TensorShape& shp, const size_t* axes, size_t nr_axes, size_t idx_ndim) { | |||
size_t idx_axis = axes[0]; | |||
megdnn_assert(shp.ndim && nr_axes); | |||
for (size_t i = 1; i < nr_axes; ++i) { | |||
@@ -29,7 +31,11 @@ size_t get_index_size_for_workspace( | |||
megdnn_assert( | |||
shp.ndim > idx_axis, "index on the %zuth axis; but shape is %s", idx_axis, | |||
shp.to_string().c_str()); | |||
return shp.shape[idx_axis]; | |||
size_t idx_size = 1; | |||
for (size_t i = 0; i < idx_ndim; ++i) { | |||
idx_size *= shp.shape[idx_axis + i]; | |||
} | |||
return idx_size; | |||
} | |||
} // anonymous namespace | |||
@@ -47,23 +53,17 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd( | |||
const TensorLayout& data, const IndexDescLayoutOnly& index, TensorLayout& dst) { | |||
megdnn_assert(!index.empty()); | |||
megdnn_assert(data.ndim >= index.size()); | |||
dst.ndim = data.ndim - index.size() + 1; | |||
dst.shape[0] = 1; | |||
dst.ndim = data.ndim - index.size(); | |||
dst.dtype = data.dtype; | |||
TensorShapeArray index_shapes; | |||
auto brdcast = [&](const TensorLayout& ly) { | |||
if (ly.ndim != 1) | |||
return false; | |||
if (dst.shape[0] == ly.shape[0]) | |||
return true; | |||
if (dst.shape[0] == 1) { | |||
dst.shape[0] = ly.shape[0]; | |||
return true; | |||
} | |||
return ly.shape[0] == 1; | |||
megdnn_assert(ly.dtype == dtype::Int32{}); | |||
index_shapes.push_back(ly); | |||
}; | |||
size_t dst_axis = 1; | |||
size_t dst_axis = 0; | |||
ptrdiff_t prev_axis = -1; | |||
for (size_t axis = 0; axis < index.size(); ++axis) { | |||
auto&& idx = index[axis]; | |||
@@ -73,10 +73,7 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd( | |||
megdnn_assert( | |||
idx.axis<data.ndim&& static_cast<ptrdiff_t>(idx.axis)> prev_axis, | |||
"index %zu requests invalid axis %zu", axis, idx.axis); | |||
auto brd_succ = brdcast(idx.layout); | |||
megdnn_assert( | |||
brd_succ, "invalid layout at index %zu: %s", axis, | |||
idx.layout.to_string().c_str()); | |||
brdcast(idx.layout); | |||
for (size_t i = prev_axis + 1; i < idx.axis; ++i) { | |||
dst.shape[dst_axis++] = data.shape[i]; | |||
@@ -99,15 +96,18 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd( | |||
} | |||
} | |||
if (contig_idx) { | |||
auto shp0 = dst.shape[0]; | |||
idx_axis = index[0].axis; | |||
for (size_t i = 0; i < idx_axis; ++i) { | |||
dst.shape[i] = dst.shape[i + 1]; | |||
} | |||
dst.shape[idx_axis] = shp0; | |||
} | |||
} | |||
TensorShape index_shape; | |||
Elemwise::deduce_shape(index_shapes, index_shape); | |||
for (size_t i = 0; i < index_shape.ndim; ++i) { | |||
dst.add_axis_inplace(idx_axis + i, 1, 0); | |||
dst.shape[idx_axis + i] = index_shape.shape[i]; | |||
} | |||
dst.init_contiguous_stride(); | |||
return idx_axis; | |||
} | |||
@@ -145,15 +145,26 @@ IndexingMultiAxisVecBase::ExecInfo IndexingMultiAxisVecBase::check_exec_noworksp | |||
return ret; | |||
} | |||
std::pair<TensorLayout, size_t> IndexingMultiAxisVecBase:: | |||
std::tuple<TensorLayout, size_t, TensorShape> IndexingMultiAxisVecBase:: | |||
get_value_iter_optimized_layout( | |||
const TensorLayout& data, const TensorLayout& value, | |||
const IndexDesc& index, size_t idx_axis) { | |||
size_t data_axes[TensorLayout::MAX_NDIM], | |||
nr_axes = get_nonindex_axes(data.ndim, index, data_axes); | |||
// broadcast index shapes | |||
TensorLayout index_shape; | |||
{ | |||
TensorShapeArray index_shapes; | |||
for (auto& idx : index) { | |||
megdnn_assert(idx.vec.layout.dtype == dtype::Int32{}); | |||
index_shapes.push_back(idx.vec.layout); | |||
} | |||
Elemwise::deduce_shape(index_shapes, index_shape); | |||
} | |||
megdnn_assert( | |||
nr_axes == value.ndim - 1 && idx_axis < value.ndim && | |||
nr_axes == value.ndim - index_shape.ndim && idx_axis < value.ndim && | |||
nr_axes + index.size() == data.ndim); | |||
TensorLayout ret; | |||
@@ -165,10 +176,13 @@ std::pair<TensorLayout, size_t> IndexingMultiAxisVecBase:: | |||
} | |||
ret = ret.collapse_contiguous(); | |||
} | |||
ret.shape[ret.ndim] = value.shape[idx_axis]; | |||
ret.stride[ret.ndim] = 0; | |||
size_t ret_idx_axis = ret.ndim; | |||
++ret.ndim; | |||
for (size_t i = 0; i < index_shape.ndim; ++i) { | |||
ret.shape[ret.ndim] = value.shape[idx_axis + i]; | |||
ret.stride[ret.ndim] = 0; | |||
++ret.ndim; | |||
} | |||
if (idx_axis < nr_axes) { | |||
TensorLayout tail; | |||
@@ -185,12 +199,13 @@ std::pair<TensorLayout, size_t> IndexingMultiAxisVecBase:: | |||
} | |||
} | |||
return {ret, ret_idx_axis}; | |||
return std::make_tuple(ret, ret_idx_axis, index_shape); | |||
} | |||
size_t IndexingMultiAxisVec::get_workspace_in_bytes( | |||
const TensorShape& dst, const size_t* axes, size_t nr_axes) { | |||
return get_workspace_in_bytes(get_index_size_for_workspace(dst, axes, nr_axes)); | |||
const TensorShape& dst, const size_t* axes, size_t nr_axes, size_t idx_ndim) { | |||
return get_workspace_in_bytes( | |||
get_index_size_for_workspace(dst, axes, nr_axes, idx_ndim)); | |||
} | |||
IndexingMultiAxisVec::ExecInfo IndexingMultiAxisVec::check_exec( | |||
@@ -205,8 +220,9 @@ IndexingMultiAxisVec::ExecInfo IndexingMultiAxisVec::check_exec( | |||
} | |||
size_t IndexingModifyMultiAxisVecBase::get_workspace_in_bytes( | |||
const TensorShape& value, const size_t* axes, size_t nr_axes) { | |||
return get_workspace_in_bytes(get_index_size_for_workspace(value, axes, nr_axes)); | |||
const TensorShape& value, const size_t* axes, size_t nr_axes, size_t idx_ndim) { | |||
return get_workspace_in_bytes( | |||
get_index_size_for_workspace(value, axes, nr_axes, idx_ndim)); | |||
} | |||
IndexingModifyMultiAxisVecBase::ExecInfo IndexingModifyMultiAxisVecBase::check_exec( | |||
@@ -21,17 +21,24 @@ namespace cuda { | |||
namespace indexing_multi_axis_vec { | |||
//! AxisIndexer equiv in kernel | |||
template <int idx_ndim> | |||
struct KAxisIndexer { | |||
int stride; | |||
int stride[idx_ndim]; | |||
#ifdef WIN32 | |||
Uint32Fastdiv shape[idx_ndim]; | |||
#else | |||
// original shape[0] not storaged | |||
Uint32Fastdiv shape[idx_ndim - 1]; | |||
#endif | |||
const int* ptr; | |||
}; | |||
//! param for gen_offset_base | |||
template <int nidx> | |||
template <int nidx, int idx_ndim> | |||
struct GenOffsetBaseParam { | |||
uint32_t size; //!< number of outputs; also size of each index | |||
int* output; //!< output ptr | |||
KAxisIndexer indexer[nidx]; | |||
KAxisIndexer<idx_ndim> indexer[nidx]; | |||
uint32_t data_shape[nidx]; | |||
int data_stride[nidx]; | |||
@@ -59,7 +66,12 @@ struct ApplyOprParam { | |||
const int* offset_base; | |||
ctype *data, *value; | |||
// first idx axis | |||
int idx_axis; | |||
// last idx axis + 1 | |||
int idx_axis_end; | |||
// number of elements for idx shape | |||
int idx_nelems; | |||
int value_stride; | |||
@@ -68,8 +80,9 @@ struct ApplyOprParam { | |||
}; | |||
//! generate offset bases for first axis in the output | |||
template <int nidx> | |||
void gen_offset_base(const GenOffsetBaseParam<nidx>& param, cudaStream_t stream); | |||
template <int nidx, int idx_ndim> | |||
void gen_offset_base( | |||
const GenOffsetBaseParam<nidx, idx_ndim>& param, cudaStream_t stream); | |||
struct OprAtomicIncr { | |||
#if MEGDNN_CC_CUDA | |||
@@ -29,11 +29,23 @@ namespace { | |||
uint32_t oidx = threadIdx.x + blockDim.x * blockIdx.x; | |||
if (oidx < param.tot_size) { | |||
int offset = 0, coidx = oidx; | |||
int all_ax_idx[ndim]; | |||
// offset in index | |||
int idx_flat = 0; | |||
// for non-indexed axes get offset | |||
#pragma unroll | |||
for (int i = ndim - 1; i >= 0; -- i) { | |||
int next_coidx, ax_idx; | |||
// [..., indexed_axes... |, ...] | |||
if (i + 1 == param.idx_axis_end) { | |||
idx_flat = coidx; | |||
} | |||
// [... |, indexed_axes..., ...] | |||
if (i + 1 == param.idx_axis) { | |||
idx_flat -= coidx * param.idx_nelems; | |||
} | |||
// shape[i] was storaged at shape[i-1] | |||
if (i) { | |||
// fast divide | |||
next_coidx = coidx / param.value_ly_on_data.shape[i - 1]; | |||
ax_idx = | |||
coidx - | |||
@@ -44,9 +56,9 @@ namespace { | |||
ax_idx = coidx; | |||
} | |||
offset += param.value_ly_on_data.stride[i] * ax_idx; | |||
all_ax_idx[i] = ax_idx; | |||
} | |||
offset += param.offset_base[all_ax_idx[param.idx_axis]]; | |||
// offset from index, which was generated before | |||
offset += param.offset_base[idx_flat]; | |||
Opr::apply( | |||
param.data[offset], | |||
param.value[oidx * param.value_stride]); | |||
@@ -18,14 +18,29 @@ using namespace cuda; | |||
using namespace indexing_multi_axis_vec; | |||
namespace { | |||
template <int nidx> | |||
__global__ void kgen_offset_base(GenOffsetBaseParam<nidx> param) { | |||
template <int nidx, int idx_ndim> | |||
__global__ void kgen_offset_base(GenOffsetBaseParam<nidx, idx_ndim> param) { | |||
int oidx = threadIdx.x + blockDim.x * blockIdx.x; | |||
if (oidx < param.size) { | |||
int offset = 0; | |||
#pragma unroll | |||
for (int i = 0; i < nidx; ++i) { | |||
int data_idx = param.indexer[i].ptr[param.indexer[i].stride * oidx]; | |||
auto& indexer = param.indexer[i]; | |||
// index in index | |||
int idx_flat = 0, coidx = oidx; | |||
#pragma unroll | |||
for (int j = idx_ndim - 1; j >= 0; --j) { | |||
int ax_idx; | |||
if (j) { | |||
int next_coidx = coidx / indexer.shape[j - 1]; | |||
ax_idx = coidx - (next_coidx * indexer.shape[j - 1].divisor()); | |||
coidx = next_coidx; | |||
} else { | |||
ax_idx = coidx; | |||
} | |||
idx_flat += indexer.stride[j] * ax_idx; | |||
} | |||
int data_idx = indexer.ptr[idx_flat]; | |||
data_idx += (data_idx < 0 ? param.data_shape[i] : 0); | |||
if (static_cast<uint32_t>(data_idx) >= param.data_shape[i]) { | |||
// cast to uint32 to handle both negative and overflow | |||
@@ -36,17 +51,19 @@ __global__ void kgen_offset_base(GenOffsetBaseParam<nidx> param) { | |||
i, data_idx, param.data_shape[i]); | |||
data_idx = 0; | |||
} | |||
// calculate offset from current index | |||
offset += data_idx * param.data_stride[i]; | |||
} | |||
// sum offsets and store at offset table | |||
param.output[oidx] = offset; | |||
} | |||
} | |||
} // namespace | |||
template <int nidx> | |||
template <int nidx, int idx_ndim> | |||
void indexing_multi_axis_vec::gen_offset_base( | |||
const GenOffsetBaseParam<nidx>& param, cudaStream_t stream) { | |||
void (*kptr)(GenOffsetBaseParam<nidx>) = kgen_offset_base<nidx>; | |||
const GenOffsetBaseParam<nidx, idx_ndim>& param, cudaStream_t stream) { | |||
void (*kptr)(GenOffsetBaseParam<nidx, idx_ndim>) = kgen_offset_base<nidx, idx_ndim>; | |||
int bsize = query_blocksize_for_kernel(kptr); | |||
(*kptr)<<<DIVUP(param.size, bsize), bsize, 0, stream>>>(param); | |||
} | |||
@@ -55,9 +72,17 @@ namespace megdnn { | |||
namespace cuda { | |||
namespace indexing_multi_axis_vec { | |||
#define INST(_n) \ | |||
template void gen_offset_base(const GenOffsetBaseParam<_n>&, cudaStream_t); | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST) | |||
#define INST(_m, _n) \ | |||
template void gen_offset_base(const GenOffsetBaseParam<_m, _n>&, cudaStream_t); | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST, 1) | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST, 2) | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST, 3) | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST, 4) | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST, 5) | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST, 6) | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST, 7) | |||
#undef INST | |||
} // namespace indexing_multi_axis_vec | |||
@@ -21,9 +21,10 @@ using namespace indexing_multi_axis_vec; | |||
namespace { | |||
class ExecImplHelper { | |||
template <int nidx, int idx_ndim> | |||
void dispatch_gen_offset_base_nidx_ndim(); | |||
template <int nidx> | |||
void dispatch_gen_offset_base_nidx(); | |||
void dispatch_gen_offset_base(); | |||
protected: | |||
@@ -38,6 +39,7 @@ protected: | |||
int* const m_offset_base; | |||
TensorLayout m_value_layout_on_data; | |||
size_t m_idx_axis; | |||
TensorShape m_idx_shape; | |||
int m_value_stride; | |||
public: | |||
@@ -76,28 +78,30 @@ ExecImplHelper::ExecImplHelper( | |||
m_exec_info{&exec_info}, | |||
m_offset_base{workspace.ptr<int>()} { | |||
safe_size_in_kern(data.layout.total_nr_elems()); | |||
dispatch_gen_offset_base(); | |||
std::tie(m_value_layout_on_data, m_idx_axis) = | |||
std::tie(m_value_layout_on_data, m_idx_axis, m_idx_shape) = | |||
IndexingMultiAxisVec::get_value_iter_optimized_layout( | |||
data.layout, value.layout, index, exec_info.idx_axis); | |||
dispatch_gen_offset_base(); | |||
m_value_stride = exec_info.value_stride; | |||
} | |||
template <int nidx> | |||
void ExecImplHelper::dispatch_gen_offset_base_nidx() { | |||
GenOffsetBaseParam<nidx> param; | |||
param.size = m_value->layout.shape[m_exec_info->idx_axis]; | |||
template <int nidx, int idx_ndim> | |||
void ExecImplHelper::dispatch_gen_offset_base_nidx_ndim() { | |||
GenOffsetBaseParam<nidx, idx_ndim> param; | |||
param.size = m_idx_shape.total_nr_elems(); | |||
param.output = m_offset_base; | |||
param.error_tracker = m_exec_info->error_tracker; | |||
param.error_info = m_exec_info->error_info; | |||
megdnn_assert(m_idx_shape.ndim == idx_ndim); | |||
for (int i = 0; i < nidx; ++i) { | |||
auto&& dst = param.indexer[i]; | |||
auto&& src = m_index->operator[](i); | |||
megdnn_assert(src.vec.layout.ndim == 1); | |||
dst.stride = src.vec.layout.stride[0]; | |||
if (src.vec.layout.shape[0] == 1) { | |||
dst.stride = 0; | |||
auto&& src = m_index->at(i); | |||
auto src_layout = src.vec.layout.broadcast(m_idx_shape); | |||
for (size_t i = 0; i < idx_ndim; ++i) { | |||
if (i) { | |||
dst.shape[i - 1] = src_layout.shape[i]; | |||
} | |||
dst.stride[i] = src_layout.stride[i]; | |||
} | |||
dst.ptr = src.vec.ptr<int>(); | |||
param.data_shape[i] = m_data->layout.shape[src.axis]; | |||
@@ -106,6 +110,18 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() { | |||
gen_offset_base(param, m_stream); | |||
} | |||
template <int nidx> | |||
void ExecImplHelper::dispatch_gen_offset_base_nidx() { | |||
switch (m_idx_shape.ndim) { | |||
#define cb(_n) \ | |||
case _n: \ | |||
return dispatch_gen_offset_base_nidx_ndim<nidx, _n>(); | |||
MEGDNN_FOREACH_TENSOR_NDIM(cb) | |||
#undef cb | |||
} | |||
megdnn_throw("bad index ndim"); | |||
} | |||
void ExecImplHelper::dispatch_gen_offset_base() { | |||
switch (m_index->size()) { | |||
#define cb(_n) \ | |||
@@ -153,6 +169,8 @@ void ExecImpl<Opr>::dispatch_exec_ctype_ndim() { | |||
param.data = m_data->ptr<ctype>(); | |||
param.value = m_value->ptr<ctype>(); | |||
param.idx_axis = m_idx_axis; | |||
param.idx_axis_end = m_idx_axis + m_idx_shape.ndim; | |||
param.idx_nelems = m_idx_shape.total_nr_elems(); | |||
param.value_stride = m_value_stride; | |||
for (int i = 0; i < ndim; ++i) { | |||
param.value_ly_on_data.stride[i] = m_value_layout_on_data.stride[i]; | |||
@@ -33,37 +33,46 @@ void do_exec( | |||
auto data_layout = data.layout; | |||
auto data_ptr = data.ptr<data_type>(); | |||
std::tuple<size_t, const idx_type*, ptrdiff_t> index_raw[TensorLayout::MAX_NDIM]; | |||
std::tuple<size_t, const idx_type*, TensorLayout> index_raw[TensorLayout::MAX_NDIM]; | |||
size_t nr_index = index.size(); | |||
TensorShape idx_shape; | |||
{ | |||
TensorShapeArray idx_shapes; | |||
for (size_t i = 0; i < nr_index; ++i) { | |||
idx_shapes.push_back(index[i].vec.layout); | |||
} | |||
Elemwise::deduce_shape(idx_shapes, idx_shape); | |||
} | |||
for (size_t i = 0; i < nr_index; ++i) { | |||
auto&& s = index[i]; | |||
index_raw[i] = | |||
std::make_tuple(s.axis, s.vec.ptr<idx_type>(), s.vec.layout.stride[0]); | |||
if (s.vec.layout.shape[0] == 1) | |||
std::get<2>(index_raw[i]) = 0; | |||
index_raw[i] = std::make_tuple( | |||
s.axis, s.vec.ptr<idx_type>(), s.vec.layout.broadcast(idx_shape)); | |||
} | |||
auto value_iter = tensor_iter<data_type>(value).begin(); | |||
for (size_t _ = 0, _t = value.layout.total_nr_elems(); _ < _t; ++_) { | |||
ptrdiff_t offset = 0; | |||
auto index_idx = value_iter.idx()[exec_info.idx_axis]; | |||
auto* index_idx = value_iter.idx() + exec_info.idx_axis; | |||
for (size_t i = 0; i < nr_index; ++i) { | |||
size_t axis = std::get<0>(index_raw[i]), | |||
data_shape = data_layout.shape[axis]; | |||
ptrdiff_t data_stride = data_layout.stride[axis]; | |||
idx_type data_idx = | |||
std::get<1>(index_raw[i])[std::get<2>(index_raw[i]) * index_idx]; | |||
size_t index_offset = 0; | |||
TensorLayout& index_layout = std::get<2>(index_raw[i]); | |||
for (size_t i = 0; i < index_layout.ndim; ++i) { | |||
index_offset += index_idx[i] * index_layout.stride[i]; | |||
} | |||
idx_type data_idx = std::get<1>(index_raw[i])[index_offset]; | |||
if (data_idx < 0) | |||
data_idx += data_shape; | |||
megdnn_assert( | |||
data_idx >= 0 && static_cast<size_t>(data_idx) < data_shape, | |||
"bad index value for index %zu at output %zu", i, index_idx); | |||
"bad index value for index %zu at output %zu", i, *index_idx); | |||
offset += data_stride * data_idx; | |||
} | |||
for (size_t i = 0; i < nr_nonidx_axes; ++i) { | |||
auto stride = data_layout.stride[nonidx_axes[i]]; | |||
auto idx = value_iter.idx()[i + (i >= exec_info.idx_axis)]; | |||
auto idx = value_iter.idx()[i + (i >= exec_info.idx_axis) * idx_shape.ndim]; | |||
offset += stride * idx; | |||
} | |||
Opr::apply(data_ptr[offset], *value_iter); | |||
@@ -21,17 +21,23 @@ namespace rocm { | |||
namespace indexing_multi_axis_vec { | |||
//! AxisIndexer equiv in kernel | |||
template <int idx_ndim> | |||
struct KAxisIndexer { | |||
int stride; | |||
int stride[idx_ndim]; | |||
#ifdef WIN32 | |||
Uint32Fastdiv shape[idx_ndim]; | |||
#else | |||
Uint32Fastdiv shape[idx_ndim - 1]; | |||
#endif | |||
const int *ptr; | |||
}; | |||
//! param for gen_offset_base | |||
template<int nidx> | |||
template<int nidx, int idx_ndim> | |||
struct GenOffsetBaseParam { | |||
uint32_t size; //!< number of outputs; also size of each index | |||
int *output; //!< output ptr | |||
KAxisIndexer indexer[nidx]; | |||
KAxisIndexer<idx_ndim> indexer[nidx]; | |||
uint32_t data_shape[nidx]; | |||
int data_stride[nidx]; | |||
@@ -60,6 +66,8 @@ namespace indexing_multi_axis_vec { | |||
ctype *data, *value; | |||
int idx_axis; | |||
int idx_axis_end; | |||
int idx_nelems; | |||
int value_stride; | |||
@@ -68,8 +76,8 @@ namespace indexing_multi_axis_vec { | |||
}; | |||
//! generate offset bases for first axis in the output | |||
template<int nidx> | |||
void gen_offset_base(const GenOffsetBaseParam<nidx> ¶m, | |||
template<int nidx, int idx_ndim> | |||
void gen_offset_base(const GenOffsetBaseParam<nidx, idx_ndim> ¶m, | |||
hipStream_t stream); | |||
struct OprAtomicIncr { | |||
@@ -30,10 +30,17 @@ namespace { | |||
uint32_t oidx = threadIdx.x + blockDim.x * blockIdx.x; | |||
if (oidx < param.tot_size) { | |||
int offset = 0, coidx = oidx; | |||
int all_ax_idx[ndim]; | |||
int idx_flat = 0; | |||
#pragma unroll | |||
for (int i = ndim - 1; i >= 0; -- i) { | |||
int next_coidx, ax_idx; | |||
if (i + 1 == param.idx_axis_end) { | |||
idx_flat = coidx; | |||
} | |||
// may not trigger | |||
if (i + 1 == param.idx_axis) { | |||
idx_flat -= coidx * param.idx_nelems; | |||
} | |||
if (i) { | |||
next_coidx = coidx / param.value_ly_on_data.shape[i - 1]; | |||
ax_idx = | |||
@@ -45,9 +52,8 @@ namespace { | |||
ax_idx = coidx; | |||
} | |||
offset += param.value_ly_on_data.stride[i] * ax_idx; | |||
all_ax_idx[i] = ax_idx; | |||
} | |||
offset += param.offset_base[all_ax_idx[param.idx_axis]]; | |||
offset += param.offset_base[idx_flat]; | |||
Opr::apply( | |||
param.data[offset], | |||
param.value[oidx * param.value_stride]); | |||
@@ -21,15 +21,28 @@ using namespace rocm; | |||
using namespace indexing_multi_axis_vec; | |||
namespace { | |||
template<int nidx> | |||
__global__ void kgen_offset_base(GenOffsetBaseParam<nidx> param) { | |||
template<int nidx, int idx_ndim> | |||
__global__ void kgen_offset_base(GenOffsetBaseParam<nidx, idx_ndim> param) { | |||
int oidx = threadIdx.x + blockDim.x * blockIdx.x; | |||
if (oidx < param.size) { | |||
int offset = 0; | |||
#pragma unroll | |||
for (int i = 0; i < nidx; ++ i) { | |||
int data_idx = param.indexer[i].ptr[ | |||
param.indexer[i].stride * oidx]; | |||
auto& indexer = param.indexer[i]; | |||
int offset2 = 0, coidx = oidx; | |||
#pragma unroll | |||
for (int j = idx_ndim-1; j >= 0; --j) { | |||
int ax_idx; | |||
if (j) { | |||
int next_coidx = coidx / indexer.shape[j-1]; | |||
ax_idx = coidx - (next_coidx * indexer.shape[j-1].divisor()); | |||
coidx = next_coidx; | |||
} else { | |||
ax_idx = coidx; | |||
} | |||
offset2 += indexer.stride[j] * ax_idx; | |||
} | |||
int data_idx = indexer.ptr[offset2]; | |||
data_idx += (data_idx < 0 ? param.data_shape[i] : 0); | |||
if (static_cast<uint32_t>(data_idx) >= param.data_shape[i]) { | |||
// cast to uint32 to handle both negative and overflow | |||
@@ -50,20 +63,28 @@ namespace megdnn { | |||
namespace rocm { | |||
namespace indexing_multi_axis_vec { | |||
#define INST(_n) \ | |||
#define INST(_m, _n) \ | |||
template void gen_offset_base( \ | |||
const GenOffsetBaseParam<_n> &, hipStream_t); | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST) | |||
const GenOffsetBaseParam<_m, _n> &, hipStream_t); | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST, 1) | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST, 2) | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST, 3) | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST, 4) | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST, 5) | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST, 6) | |||
MEGDNN_FOREACH_TENSOR_NDIM(INST, 7) | |||
#undef INST | |||
} // namespace indexing_multi_axis_vec | |||
} // namespace rocm | |||
} // namespace megdnn | |||
template<int nidx> | |||
template<int nidx, int idx_ndim> | |||
void indexing_multi_axis_vec::gen_offset_base( | |||
const GenOffsetBaseParam<nidx> ¶m, hipStream_t stream) { | |||
void (*kptr)(GenOffsetBaseParam<nidx>) = kgen_offset_base<nidx>; | |||
const GenOffsetBaseParam<nidx, idx_ndim> ¶m, hipStream_t stream) { | |||
void (*kptr)(GenOffsetBaseParam<nidx, idx_ndim>) = kgen_offset_base<nidx, idx_ndim>; | |||
int bsize = 256; | |||
hipLaunchKernelGGL(kptr, | |||
DIVUP(param.size, bsize), bsize, 0, stream, | |||
@@ -22,9 +22,10 @@ using namespace indexing_multi_axis_vec; | |||
namespace { | |||
class ExecImplHelper { | |||
template <int nidx, int idx_ndim> | |||
void dispatch_gen_offset_base_nidx_ndim(); | |||
template <int nidx> | |||
void dispatch_gen_offset_base_nidx(); | |||
void dispatch_gen_offset_base(); | |||
protected: | |||
@@ -39,6 +40,7 @@ protected: | |||
int* const m_offset_base; | |||
TensorLayout m_value_layout_on_data; | |||
size_t m_idx_axis; | |||
TensorShape m_idx_shape; | |||
int m_value_stride; | |||
public: | |||
@@ -77,18 +79,17 @@ ExecImplHelper::ExecImplHelper( | |||
m_exec_info{&exec_info}, | |||
m_offset_base{workspace.ptr<int>()} { | |||
safe_size_in_kern(data.layout.total_nr_elems()); | |||
dispatch_gen_offset_base(); | |||
std::tie(m_value_layout_on_data, m_idx_axis) = | |||
std::tie(m_value_layout_on_data, m_idx_axis, m_idx_shape) = | |||
IndexingMultiAxisVec::get_value_iter_optimized_layout( | |||
data.layout, value.layout, index, exec_info.idx_axis); | |||
dispatch_gen_offset_base(); | |||
m_value_stride = exec_info.value_stride; | |||
} | |||
template <int nidx> | |||
void ExecImplHelper::dispatch_gen_offset_base_nidx() { | |||
GenOffsetBaseParam<nidx> param; | |||
param.size = m_value->layout.shape[m_exec_info->idx_axis]; | |||
template <int nidx, int idx_ndim> | |||
void ExecImplHelper::dispatch_gen_offset_base_nidx_ndim() { | |||
GenOffsetBaseParam<nidx, idx_ndim> param; | |||
param.size = m_idx_shape.total_nr_elems(); | |||
param.output = m_offset_base; | |||
param.error_tracker = m_exec_info->error_tracker; | |||
param.error_info = m_exec_info->error_info; | |||
@@ -96,9 +97,12 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() { | |||
auto&& dst = param.indexer[i]; | |||
auto&& src = m_index->operator[](i); | |||
megdnn_assert(src.vec.layout.ndim == 1); | |||
dst.stride = src.vec.layout.stride[0]; | |||
if (src.vec.layout.shape[0] == 1) { | |||
dst.stride = 0; | |||
auto src_layout = src.vec.layout.broadcast(m_idx_shape); | |||
for (size_t i = 0; i < idx_ndim; ++i) { | |||
if (i) { | |||
dst.shape[i - 1] = src_layout.shape[i]; | |||
} | |||
dst.stride[i] = src_layout.stride[i]; | |||
} | |||
dst.ptr = src.vec.ptr<int>(); | |||
param.data_shape[i] = m_data->layout.shape[src.axis]; | |||
@@ -107,6 +111,18 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() { | |||
gen_offset_base(param, m_stream); | |||
} | |||
template <int nidx> | |||
void ExecImplHelper::dispatch_gen_offset_base_nidx() { | |||
switch (m_idx_shape.ndim) { | |||
#define cb(_n) \ | |||
case _n: \ | |||
return dispatch_gen_offset_base_nidx_ndim<nidx, _n>(); | |||
MEGDNN_FOREACH_TENSOR_NDIM(cb) | |||
#undef cb | |||
} | |||
megdnn_throw("bad index ndim"); | |||
} | |||
void ExecImplHelper::dispatch_gen_offset_base() { | |||
switch (m_index->size()) { | |||
#define cb(_n) \ | |||
@@ -154,6 +170,8 @@ void ExecImpl<Opr>::dispatch_exec_ctype_ndim() { | |||
param.data = m_data->ptr<ctype>(); | |||
param.value = m_value->ptr<ctype>(); | |||
param.idx_axis = m_idx_axis; | |||
param.idx_axis_end = m_idx_axis + m_idx_shape.ndim; | |||
param.idx_nelems = m_idx_shape.total_nr_elems(); | |||
param.value_stride = m_value_stride; | |||
for (int i = 0; i < ndim; ++i) { | |||
param.value_ly_on_data.stride[i] = m_value_layout_on_data.stride[i]; | |||
@@ -46,6 +46,15 @@ struct OprProxyIndexingMultiAxisVecHelper { | |||
return ret; | |||
} | |||
size_t get_index_ndim(const TensorNDArray& tensors) const { | |||
megdnn_assert(tensors.size() >= 3); | |||
size_t ndim = 0; | |||
for (size_t i = 2; i < tensors.size(); ++i) { | |||
ndim = std::max(tensors[i].layout.ndim, ndim); | |||
} | |||
return ndim; | |||
} | |||
IndexingMultiAxisVec::IndexDescLayoutOnly make_index_layout( | |||
const TensorLayoutArray& layouts) const { | |||
megdnn_assert(layouts.size() >= 3); | |||
@@ -65,7 +74,8 @@ struct OprProxy<IndexingMultiAxisVec> : public OprProxyIndexingMultiAxisVecHelpe | |||
void exec(IndexingMultiAxisVec* opr, const TensorNDArray& tensors) const { | |||
WorkspaceWrapper W( | |||
opr->handle(), opr->get_workspace_in_bytes( | |||
tensors[1].layout, axes, tensors.size() - 2)); | |||
tensors[1].layout, axes, tensors.size() - 2, | |||
get_index_ndim(tensors))); | |||
opr->exec(tensors[0], make_index_desc(tensors), tensors[1], W.workspace()); | |||
} | |||
@@ -81,7 +91,8 @@ struct OprProxy<IndexingIncrMultiAxisVec> : public OprProxyIndexingMultiAxisVecH | |||
void exec(IndexingIncrMultiAxisVec* opr, const TensorNDArray& tensors) const { | |||
WorkspaceWrapper W( | |||
opr->handle(), opr->get_workspace_in_bytes( | |||
tensors[1].layout, axes, tensors.size() - 2)); | |||
tensors[1].layout, axes, tensors.size() - 2, | |||
get_index_ndim(tensors))); | |||
opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); | |||
} | |||
@@ -95,7 +106,8 @@ struct OprProxy<IndexingSetMultiAxisVec> : public OprProxyIndexingMultiAxisVecHe | |||
void exec(IndexingSetMultiAxisVec* opr, const TensorNDArray& tensors) const { | |||
WorkspaceWrapper W( | |||
opr->handle(), opr->get_workspace_in_bytes( | |||
tensors[1].layout, axes, tensors.size() - 2)); | |||
tensors[1].layout, axes, tensors.size() - 2, | |||
get_index_ndim(tensors))); | |||
opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); | |||
} | |||
@@ -27,7 +27,7 @@ namespace test { | |||
WorkspaceWrapper W( \ | |||
opr->handle(), \ | |||
opr->get_workspace_in_bytes( \ | |||
tensors[1].layout, axes, tensors.size() - 2)); \ | |||
tensors[1].layout, axes, tensors.size() - 2, 1)); \ | |||
opr->exec( \ | |||
tensors[0], make_index_desc(tensors), tensors[1], W.workspace()); \ | |||
} \ | |||
@@ -46,7 +46,7 @@ namespace test { | |||
WorkspaceWrapper W( \ | |||
opr->handle(), \ | |||
opr->get_workspace_in_bytes( \ | |||
tensors[1].layout, axes, tensors.size() - 2)); \ | |||
tensors[1].layout, axes, tensors.size() - 2, 1)); \ | |||
opr->exec( \ | |||
tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); \ | |||
} \ | |||
@@ -132,6 +132,25 @@ TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) { | |||
TensorLayout{TensorShape{9}, {-1}, dtype::Int32()}}); | |||
} | |||
TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC_ND_INDEX) { | |||
run_check<IndexingMultiAxisVec>(handle_cuda()); | |||
Checker<IndexingMultiAxisVec> checker(handle_cuda()); | |||
OrderedRNG rng; | |||
checker.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()) | |||
.set_dtype(2, dtype::Int32()) | |||
.set_dtype(3, dtype::Int32()) | |||
.set_dtype(4, dtype::Int32()) | |||
.set_rng(0, &rng) | |||
.set_rng(1, &rng) | |||
.set_rng(2, &rng) | |||
.set_rng(3, &rng) | |||
.set_rng(4, &rng); | |||
checker.set_proxy({{1, 2, 3}}) | |||
.execs({{5, 5, 6, 7, 3}, {5, 2, 3, 4, 3}, {3, 1}, {2, 1, 1}, {1, 4}}); | |||
} | |||
TEST_F(CUDA, INDEXING_INCR_MULTI_AXIS_VEC) { | |||
run_check<IndexingIncrMultiAxisVec>(handle_cuda()); | |||
Checker<IndexingIncrMultiAxisVec> checker(handle_cuda()); | |||
@@ -708,3 +708,19 @@ def test_indexingSetMultiAxisVec_on_empty_tensor(symbolic): | |||
run_test((10, 10, 0), test4) | |||
run_test((10, 10, 10), test3) | |||
run_test((10, 10, 10), test4) | |||
@pytest.mark.parametrize("symbolic", [True, False, None]) | |||
def test_nd_int_indexing(symbolic): | |||
inp = np.arange(11) | |||
idx = np.random.randint(11, size=(5, 7)) | |||
def run_test(args, fn): | |||
npy_out = fn(*args) | |||
if symbolic: | |||
fn = jit.trace(symbolic=symbolic)(fn) | |||
for _ in range(3): | |||
out = fn(*[Tensor(arg) for arg in args]) | |||
np.testing.assert_equal(out.numpy(), npy_out) | |||
run_test([inp, idx], lambda inp, idx: inp[idx]) |
@@ -197,9 +197,15 @@ Opr& mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::megdnn_opr( | |||
template <class Opr> | |||
void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::register_workspace_infer( | |||
const indexing::IndexDesc& index_desc, cg::SingleCNOperatorNodeBase& opr, | |||
VarNode* data, VarNode* value) { | |||
VarNode* data, VarNode* value, VarNodeArray idx_arr) { | |||
using namespace cg::static_infer; | |||
auto infer_shape = [this, &index_desc, &opr](TensorShape& dest, const InpVal& inp) { | |||
DepVal deps = {{data, DepType::SHAPE}, {value, DepType::SHAPE}}; | |||
for (auto&& idx : idx_arr) { | |||
deps.push_back({idx, DepType::SHAPE}); | |||
} | |||
auto infer_shape = [this, &index_desc, &opr, nr_idx = idx_arr.size()]( | |||
TensorShape& dest, const InpVal& inp) { | |||
size_t axes[TensorShape::MAX_NDIM], nr_axes = 0; | |||
auto ndim = inp.val[0].shape().ndim; | |||
for (auto&& i : reverse_adaptor(index_desc)) { | |||
@@ -207,18 +213,22 @@ void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::register_workspace_infer( | |||
axes[nr_axes++] = i.axis.get(ndim); | |||
} | |||
} | |||
mgb_assert(nr_axes == nr_idx); | |||
if (!nr_axes) { | |||
dest = {0}; | |||
} else { | |||
size_t idx_ndim = 0; | |||
for (size_t i = 0; i < nr_idx; ++i) { | |||
idx_ndim = std::max(idx_ndim, inp.val[2 + i].shape().ndim); | |||
} | |||
mgb_assert(idx_ndim > 0); | |||
dest = {megdnn_opr(opr).get_workspace_in_bytes( | |||
inp.val[1].shape(), axes, nr_axes)}; | |||
inp.val[1].shape(), axes, nr_axes, idx_ndim)}; | |||
} | |||
return true; | |||
}; | |||
opr.owner_graph()->static_infer_manager().register_shape_infer( | |||
opr.output(1), {SourceType::DEP, | |||
{{data, DepType::SHAPE}, {value, DepType::SHAPE}}, | |||
infer_shape}); | |||
opr.output(1), {SourceType::DEP, deps, infer_shape}); | |||
} | |||
template <class Opr> | |||
@@ -342,8 +352,13 @@ void IndexingMultiAxisVecBase<Opr>::init_output_static_infer_desc() { | |||
}; | |||
owner_graph()->static_infer_manager().register_shape_infer( | |||
output(0), {SourceType::DEP, deps, infer_shape}); | |||
this->register_workspace_infer(index_desc(), *this, input(0), output(0)); | |||
VarNodeArray idx_arr; | |||
for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++i) { | |||
if (m_input2idxonly_axis_indexer[i]) { | |||
idx_arr.push_back(input(i)); | |||
} | |||
} | |||
this->register_workspace_infer(index_desc(), *this, input(0), output(0), idx_arr); | |||
} | |||
template <class Opr> | |||
@@ -401,7 +416,13 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::init_output_static_infer_desc( | |||
this->owner_graph()->static_infer_manager().register_shape_infer( | |||
this->output(0), ShapeInferDesc::make_identity(this->input(0))); | |||
this->register_workspace_infer(index_desc(), *this, input(0), input(1)); | |||
VarNodeArray idx_arr; | |||
for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++i) { | |||
if (m_input2idxonly_axis_indexer[i]) { | |||
idx_arr.push_back(input(i)); | |||
} | |||
} | |||
this->register_workspace_infer(index_desc(), *this, input(0), input(1), idx_arr); | |||
} | |||
template <class Opr> | |||
@@ -96,7 +96,7 @@ protected: | |||
void register_workspace_infer( | |||
const indexing::IndexDesc& index_desc, cg::SingleCNOperatorNodeBase& opr, | |||
VarNode* data, VarNode* value); | |||
VarNode* data, VarNode* value, VarNodeArray idx_arr); | |||
void record_megdnn_opr(mgb::cg::GraphExecutable::ExecDependencyArray& deps); | |||
}; | |||