GitOrigin-RevId: 5f66d51de4
tags/v0.5.0
@@ -32,9 +32,11 @@ void IndexingRemapBase::check_layout_fwd(const TensorLayout &src, | |||||
} | } | ||||
megdnn_assert(map.shape[dst.ndim] == src.ndim, "%s", errmsg_c); | megdnn_assert(map.shape[dst.ndim] == src.ndim, "%s", errmsg_c); | ||||
megdnn_assert(src.dtype == dtype::Float32()); | |||||
megdnn_assert(dst.dtype == src.dtype); | |||||
megdnn_assert(src.dtype == dtype::Float32() || src.dtype == dtype::Int32(), | |||||
"indexing remap only support float32/int32, got %s", | |||||
src.dtype.name()); | |||||
megdnn_assert(map.dtype == dtype::Int32()); | megdnn_assert(map.dtype == dtype::Int32()); | ||||
megdnn_assert(dst.dtype == dtype::Float32()); | |||||
} | } | ||||
void IndexingRemapForward::deduce_layout(const TensorLayout &src, | void IndexingRemapForward::deduce_layout(const TensorLayout &src, | ||||
@@ -36,13 +36,23 @@ void IndexingRemapForwardImpl::exec(_megdnn_tensor_in src, | |||||
for (size_t i = 0_z; i < dst.layout.ndim; ++i) { | for (size_t i = 0_z; i < dst.layout.ndim; ++i) { | ||||
dshape.data[i] = dst.layout.shape[i]; | dshape.data[i] = dst.layout.shape[i]; | ||||
} | } | ||||
// Invoke kernel | |||||
tensor_remap::forward(src.ptr<dt_float32>(), | |||||
map.ptr<dt_int32>(), | |||||
dst.ptr<dt_float32>(), | |||||
src.layout.ndim, dst.layout.ndim, | |||||
sstride, dstride, dshape, | |||||
cuda_stream(handle())); | |||||
// Invoke kernel | |||||
#define cb(dt) \ | |||||
if (src.layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \ | |||||
using ctype = DTypeTrait<dt>::ctype; \ | |||||
tensor_remap::forward<ctype>(src.ptr<ctype>(), map.ptr<dt_int32>(), \ | |||||
dst.ptr<ctype>(), src.layout.ndim, \ | |||||
dst.layout.ndim, sstride, dstride, \ | |||||
dshape, cuda_stream(handle())); \ | |||||
return; \ | |||||
} | |||||
cb(dtype::Float32) | |||||
cb(dtype::Int32) | |||||
#undef cb | |||||
megdnn_throw( | |||||
ssprintf("cuda indexing remap forward only support " | |||||
"float32/int32 dtype, got %s", | |||||
src.layout.dtype.name())); | |||||
} | } | ||||
void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, | void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, | ||||
@@ -69,18 +79,27 @@ void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, | |||||
for (size_t i = 0_z; i < diff.layout.ndim; ++i) { | for (size_t i = 0_z; i < diff.layout.ndim; ++i) { | ||||
dshape.data[i] = diff.layout.shape[i]; | dshape.data[i] = diff.layout.shape[i]; | ||||
} | } | ||||
// Invoke kernel | |||||
tensor_remap::backward(diff.ptr<dt_float32>(), | |||||
map.ptr<dt_int32>(), | |||||
grad.ptr<dt_float32>(), | |||||
grad.layout.ndim, diff.layout.ndim, | |||||
sstride, dstride, sshape, dshape, | |||||
param().is_non_overlapping, | |||||
cuda_stream(handle())); | |||||
// Invoke kernel | |||||
#define cb(dt) \ | |||||
if (diff.layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \ | |||||
using ctype = DTypeTrait<dt>::ctype; \ | |||||
tensor_remap::backward<ctype>( \ | |||||
diff.ptr<ctype>(), map.ptr<dt_int32>(), grad.ptr<ctype>(), \ | |||||
grad.layout.ndim, diff.layout.ndim, sstride, dstride, sshape, \ | |||||
dshape, param().is_non_overlapping, cuda_stream(handle())); \ | |||||
return; \ | |||||
} | |||||
cb(dtype::Float32) | |||||
cb(dtype::Int32) | |||||
megdnn_throw( | |||||
ssprintf("cuda indexing remap forward only support " | |||||
"float32/int32 dtype, got %s", | |||||
diff.layout.dtype.name())); | |||||
} | } | ||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
@@ -6,28 +6,29 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/cuda/tensor_remap/tensor_remap.cuh" | |||||
#include "src/cuda/query_blocksize.cuh" | #include "src/cuda/query_blocksize.cuh" | ||||
#include "src/cuda/tensor_remap/tensor_remap.cuh" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace cuda { | namespace cuda { | ||||
namespace tensor_remap { | |||||
namespace { | |||||
__global__ void forward_kernel(const float *src, const int *map, float *dst, | |||||
uint32_t sdim, uint32_t ddim, | |||||
array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||||
array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | |||||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, | |||||
uint32_t total) | |||||
{ | |||||
template <typename ctype> | |||||
__global__ void forward_kernel(const ctype* src, const int* map, ctype* dst, | |||||
uint32_t sdim, uint32_t ddim, | |||||
array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||||
array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | |||||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, | |||||
uint32_t total) { | |||||
uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; | uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; | ||||
if (didx_cont < total) { | if (didx_cont < total) { | ||||
uint32_t midx = didx_cont * sdim; | uint32_t midx = didx_cont * sdim; | ||||
uint32_t didx = 0u; | uint32_t didx = 0u; | ||||
for (uint32_t j = ddim; j > 0u; --j) { | for (uint32_t j = ddim; j > 0u; --j) { | ||||
uint32_t i = j-1u; | |||||
uint32_t i = j - 1u; | |||||
uint32_t didx_cur = didx_cont % dshape.data[i]; | uint32_t didx_cur = didx_cont % dshape.data[i]; | ||||
didx_cont /= dshape.data[i]; | didx_cont /= dshape.data[i]; | ||||
didx += didx_cur * dstride.data[i]; | didx += didx_cur * dstride.data[i]; | ||||
@@ -41,34 +42,16 @@ __global__ void forward_kernel(const float *src, const int *map, float *dst, | |||||
} | } | ||||
} | } | ||||
void forward(const float *src, const int *map, float *dst, | |||||
uint32_t sdim, uint32_t ddim, | |||||
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, | |||||
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, | |||||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, | |||||
cudaStream_t stream) | |||||
{ | |||||
uint32_t total = 1u; | |||||
for (uint32_t i = 0u; i < ddim; ++i) total *= dshape.data[i]; | |||||
uint32_t threads = query_blocksize_for_kernel((void *)&forward_kernel); | |||||
uint32_t blocks = DIVUP(total, threads); | |||||
forward_kernel<<<blocks, threads, 0, stream>>>(src, map, dst, | |||||
sdim, ddim, | |||||
sstride, dstride, dshape, | |||||
total); | |||||
after_kernel_launch(); | |||||
} | |||||
__global__ void fill_zero_kernel(float *a, uint32_t dim, | |||||
array_wrapper<int, MEGDNN_MAX_NDIM> stride, | |||||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> shape, | |||||
uint32_t total) | |||||
{ | |||||
template <typename ctype> | |||||
__global__ void fill_zero_kernel(ctype* a, uint32_t dim, | |||||
array_wrapper<int, MEGDNN_MAX_NDIM> stride, | |||||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> shape, | |||||
uint32_t total) { | |||||
uint32_t idx_cont = threadIdx.x + blockIdx.x * blockDim.x; | uint32_t idx_cont = threadIdx.x + blockIdx.x * blockDim.x; | ||||
if (idx_cont < total) { | if (idx_cont < total) { | ||||
uint32_t idx = 0u; | uint32_t idx = 0u; | ||||
for (uint32_t j = dim; j > 0u; --j) { | for (uint32_t j = dim; j > 0u; --j) { | ||||
uint32_t i = j-1u; | |||||
uint32_t i = j - 1u; | |||||
uint32_t idx_cur = idx_cont % shape.data[i]; | uint32_t idx_cur = idx_cont % shape.data[i]; | ||||
idx_cont /= shape.data[i]; | idx_cont /= shape.data[i]; | ||||
idx += idx_cur * stride.data[i]; | idx += idx_cur * stride.data[i]; | ||||
@@ -77,19 +60,19 @@ __global__ void fill_zero_kernel(float *a, uint32_t dim, | |||||
} | } | ||||
} | } | ||||
__global__ void backward_kernel(const float *diff, const int *map, float *grad, | |||||
uint32_t sdim, uint32_t ddim, | |||||
array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||||
array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | |||||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, | |||||
uint32_t total) | |||||
{ | |||||
template <typename ctype> | |||||
__global__ void backward_kernel(const ctype* diff, const int* map, ctype* grad, | |||||
uint32_t sdim, uint32_t ddim, | |||||
array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||||
array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | |||||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, | |||||
uint32_t total) { | |||||
uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; | uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; | ||||
if (didx_cont < total) { | if (didx_cont < total) { | ||||
uint32_t midx = didx_cont * sdim; | uint32_t midx = didx_cont * sdim; | ||||
uint32_t didx = 0u; | uint32_t didx = 0u; | ||||
for (uint32_t j = ddim; j > 0u; --j) { | for (uint32_t j = ddim; j > 0u; --j) { | ||||
uint32_t i = j-1u; | |||||
uint32_t i = j - 1u; | |||||
uint32_t didx_cur = didx_cont % dshape.data[i]; | uint32_t didx_cur = didx_cont % dshape.data[i]; | ||||
didx_cont /= dshape.data[i]; | didx_cont /= dshape.data[i]; | ||||
didx += didx_cur * dstride.data[i]; | didx += didx_cur * dstride.data[i]; | ||||
@@ -103,20 +86,18 @@ __global__ void backward_kernel(const float *diff, const int *map, float *grad, | |||||
} | } | ||||
} | } | ||||
template <typename ctype> | |||||
__global__ void backward_kernel_non_overlapping( | __global__ void backward_kernel_non_overlapping( | ||||
const float *diff, const int *map, float *grad, | |||||
uint32_t sdim, uint32_t ddim, | |||||
array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||||
const ctype* diff, const int* map, ctype* grad, uint32_t sdim, | |||||
uint32_t ddim, array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||||
array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | ||||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, | |||||
uint32_t total) | |||||
{ | |||||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, uint32_t total) { | |||||
uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; | uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; | ||||
if (didx_cont < total) { | if (didx_cont < total) { | ||||
uint32_t midx = didx_cont * sdim; | uint32_t midx = didx_cont * sdim; | ||||
uint32_t didx = 0u; | uint32_t didx = 0u; | ||||
for (uint32_t j = ddim; j > 0u; --j) { | for (uint32_t j = ddim; j > 0u; --j) { | ||||
uint32_t i = j-1u; | |||||
uint32_t i = j - 1u; | |||||
uint32_t didx_cur = didx_cont % dshape.data[i]; | uint32_t didx_cur = didx_cont % dshape.data[i]; | ||||
didx_cont /= dshape.data[i]; | didx_cont /= dshape.data[i]; | ||||
didx += didx_cur * dstride.data[i]; | didx += didx_cur * dstride.data[i]; | ||||
@@ -130,55 +111,91 @@ __global__ void backward_kernel_non_overlapping( | |||||
} | } | ||||
} | } | ||||
void backward(const float *diff, const int *map, float *grad, | |||||
uint32_t sdim, uint32_t ddim, | |||||
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, | |||||
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, | |||||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &sshape, | |||||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, | |||||
bool is_non_overlapping, | |||||
cudaStream_t stream) | |||||
{ | |||||
} // anonymous namespace | |||||
namespace tensor_remap { | |||||
template <typename ctype> | |||||
void forward(const ctype* src, const int* map, ctype* dst, uint32_t sdim, | |||||
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, | |||||
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, | |||||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, | |||||
cudaStream_t stream) { | |||||
uint32_t total = 1u; | |||||
for (uint32_t i = 0u; i < ddim; ++i) | |||||
total *= dshape.data[i]; | |||||
uint32_t threads = | |||||
query_blocksize_for_kernel((void*)&forward_kernel<ctype>); | |||||
uint32_t blocks = DIVUP(total, threads); | |||||
forward_kernel<ctype><<<blocks, threads, 0, stream>>>( | |||||
src, map, dst, sdim, ddim, sstride, dstride, dshape, total); | |||||
after_kernel_launch(); | |||||
} | |||||
template <typename ctype> | |||||
void backward(const ctype* diff, const int* map, ctype* grad, uint32_t sdim, | |||||
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, | |||||
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, | |||||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape, | |||||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, | |||||
bool is_non_overlapping, cudaStream_t stream) { | |||||
{ | { | ||||
// Fill grad with zeros. | // Fill grad with zeros. | ||||
uint32_t total = 1u; | uint32_t total = 1u; | ||||
for (uint32_t i = 0u; i < sdim; ++i) total *= sshape.data[i]; | |||||
uint32_t threads = query_blocksize_for_kernel((void *)&fill_zero_kernel); | |||||
for (uint32_t i = 0u; i < sdim; ++i) | |||||
total *= sshape.data[i]; | |||||
uint32_t threads = | |||||
query_blocksize_for_kernel((void*)&fill_zero_kernel<ctype>); | |||||
uint32_t blocks = DIVUP(total, threads); | uint32_t blocks = DIVUP(total, threads); | ||||
fill_zero_kernel<<<blocks, threads, 0, stream>>>( | |||||
fill_zero_kernel<ctype><<<blocks, threads, 0, stream>>>( | |||||
grad, sdim, sstride, sshape, total); | grad, sdim, sstride, sshape, total); | ||||
after_kernel_launch(); | after_kernel_launch(); | ||||
} | } | ||||
{ | { | ||||
// Update grad. | // Update grad. | ||||
uint32_t total = 1u; | uint32_t total = 1u; | ||||
for (uint32_t i = 0u; i < ddim; ++i) total *= dshape.data[i]; | |||||
for (uint32_t i = 0u; i < ddim; ++i) | |||||
total *= dshape.data[i]; | |||||
if (is_non_overlapping) { | if (is_non_overlapping) { | ||||
uint32_t threads = query_blocksize_for_kernel( | uint32_t threads = query_blocksize_for_kernel( | ||||
(void *)&backward_kernel_non_overlapping); | |||||
(void*)&backward_kernel_non_overlapping<ctype>); | |||||
uint32_t blocks = DIVUP(total, threads); | uint32_t blocks = DIVUP(total, threads); | ||||
backward_kernel_non_overlapping<<<blocks, threads, 0, stream>>>( | |||||
diff, map, grad, | |||||
sdim, ddim, | |||||
sstride, dstride, dshape, | |||||
total); | |||||
backward_kernel_non_overlapping<ctype> | |||||
<<<blocks, threads, 0, stream>>>(diff, map, grad, sdim, | |||||
ddim, sstride, dstride, | |||||
dshape, total); | |||||
} else { | } else { | ||||
uint32_t threads = query_blocksize_for_kernel( | |||||
(void *)&backward_kernel); | |||||
uint32_t threads = | |||||
query_blocksize_for_kernel((void*)&backward_kernel<ctype>); | |||||
uint32_t blocks = DIVUP(total, threads); | uint32_t blocks = DIVUP(total, threads); | ||||
backward_kernel<<<blocks, threads, 0, stream>>>(diff, map, grad, | |||||
sdim, ddim, | |||||
sstride, dstride, dshape, | |||||
backward_kernel<ctype><<<blocks, threads, 0, stream>>>( | |||||
diff, map, grad, sdim, ddim, sstride, dstride, dshape, | |||||
total); | total); | ||||
} | } | ||||
after_kernel_launch(); | after_kernel_launch(); | ||||
} | } | ||||
} | } | ||||
} // namespace tensor_remap | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
#define INST(T) \ | |||||
template void forward<T>( \ | |||||
const T* src, const int* map, T* dst, uint32_t sdim, \ | |||||
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, \ | |||||
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, \ | |||||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, \ | |||||
cudaStream_t stream); \ | |||||
template void backward<T>( \ | |||||
const T* diff, const int* map, T* grad, uint32_t sdim, \ | |||||
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, \ | |||||
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, \ | |||||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape, \ | |||||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, \ | |||||
bool is_non_overlapping, cudaStream_t stream); | |||||
INST(dt_float32) | |||||
INST(dt_int32) | |||||
// vim: syntax=cpp.doxygen | |||||
#undef INST | |||||
} // namespace tensor_remap | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -17,25 +17,23 @@ namespace megdnn { | |||||
namespace cuda { | namespace cuda { | ||||
namespace tensor_remap { | namespace tensor_remap { | ||||
void forward(const float *src, const int *map, float *dst, | |||||
uint32_t sdim, uint32_t ddim, | |||||
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, | |||||
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, | |||||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, | |||||
cudaStream_t stream); | |||||
template <typename ctype> | |||||
void forward(const ctype* src, const int* map, ctype* dst, uint32_t sdim, | |||||
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, | |||||
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, | |||||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, | |||||
cudaStream_t stream); | |||||
void backward(const float *diff, const int *map, float *grad, | |||||
uint32_t sdim, uint32_t ddim, | |||||
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, | |||||
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, | |||||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &sshape, | |||||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, | |||||
bool is_non_overlapping, | |||||
cudaStream_t stream); | |||||
template <typename ctype> | |||||
void backward(const ctype* diff, const int* map, ctype* grad, uint32_t sdim, | |||||
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, | |||||
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, | |||||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape, | |||||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, | |||||
bool is_non_overlapping, cudaStream_t stream); | |||||
} // namespace tensor_remap | |||||
} // namespace tensor_remap | |||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
@@ -6,75 +6,107 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/naive/tensor_remap/opr_impl.h" | #include "src/naive/tensor_remap/opr_impl.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
namespace megdnn { | |||||
namespace naive { | |||||
using namespace megdnn; | |||||
using namespace naive; | |||||
namespace { | |||||
template <typename ctype> | |||||
void forward(const TensorND& src, const TensorND& map, const TensorND& dst) { | |||||
auto&& sshape = src.layout; | |||||
auto&& mshape = map.layout; | |||||
auto&& dshape = dst.layout; | |||||
// Last element is zero to facilitate maddr calculation. | |||||
std::vector<size_t> didx(dshape.ndim + 1, 0_z); | |||||
do { | |||||
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim); | |||||
std::vector<size_t> sidx(sshape.ndim); | |||||
for (size_t i = 0_z; i < sshape.ndim; ++i) { | |||||
sidx[i] = map.ptr<dt_int32>()[maddr + i]; | |||||
} | |||||
auto saddr = get_linear_addr_noncont(sidx.data(), src.layout); | |||||
auto daddr = get_linear_addr_noncont(didx.data(), dst.layout); | |||||
dst.ptr<ctype>()[daddr] = src.ptr<ctype>()[saddr]; | |||||
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim)); | |||||
} | |||||
template <typename ctype> | |||||
void backward(const TensorND& diff, const TensorND& map, const TensorND& grad) { | |||||
auto&& sshape = grad.layout; | |||||
auto&& mshape = map.layout; | |||||
auto&& dshape = diff.layout; | |||||
std::vector<size_t> sidx(sshape.ndim, 0_z); | |||||
{ | |||||
// Set grad to zero. | |||||
do { | |||||
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout); | |||||
grad.ptr<ctype>()[saddr] = 0.0f; | |||||
} while (get_next_addr(sidx.data(), sshape.shape, sshape.ndim)); | |||||
} | |||||
std::vector<size_t> didx(dshape.ndim + 1, 0_z); | |||||
do { | |||||
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim); | |||||
std::vector<size_t> sidx(sshape.ndim); | |||||
for (size_t i = 0_z; i < sshape.ndim; ++i) { | |||||
sidx[i] = map.ptr<dt_int32>()[maddr + i]; | |||||
} | |||||
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout); | |||||
auto daddr = get_linear_addr_noncont(didx.data(), diff.layout); | |||||
grad.ptr<ctype>()[saddr] += diff.ptr<ctype>()[daddr]; | |||||
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim)); | |||||
} | |||||
} // anonymous namespace | |||||
void IndexingRemapForwardImpl::exec(_megdnn_tensor_in src, | void IndexingRemapForwardImpl::exec(_megdnn_tensor_in src, | ||||
_megdnn_tensor_in map, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) | |||||
{ | |||||
_megdnn_tensor_in map, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(src.layout, map.layout, dst.layout, workspace.size); | check_exec(src.layout, map.layout, dst.layout, workspace.size); | ||||
auto kern = [=]() { | |||||
auto &&sshape = src.layout; | |||||
auto &&mshape = map.layout; | |||||
auto &&dshape = dst.layout; | |||||
// Last element is zero to facilitate maddr calculation. | |||||
std::vector<size_t> didx(dshape.ndim+1, 0_z); | |||||
do { | |||||
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim); | |||||
std::vector<size_t> sidx(sshape.ndim); | |||||
for (size_t i = 0_z; i < sshape.ndim; ++i) { | |||||
sidx[i] = map.ptr<dt_int32>()[maddr+i]; | |||||
} | |||||
auto saddr = get_linear_addr_noncont(sidx.data(), src.layout); | |||||
auto daddr = get_linear_addr_noncont(didx.data(), dst.layout); | |||||
dst.ptr<dt_float32>()[daddr] = src.ptr<dt_float32>()[saddr]; | |||||
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim)); | |||||
}; | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(kern()); | |||||
switch (src.layout.dtype.enumv()) { | |||||
#define cb(dt) \ | |||||
case DTypeTrait<dt>::enumv: \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
forward<DTypeTrait<dt>::ctype>(src, map, dst)); \ | |||||
return; | |||||
cb(dtype::Float32) | |||||
cb(dtype::Int32) | |||||
#undef cb | |||||
default: | |||||
megdnn_throw( | |||||
ssprintf("unsupported dtype %s in indexing " | |||||
"remap forward naive\n", | |||||
src.layout.dtype.name())); | |||||
} | |||||
} | } | ||||
void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, | void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, | ||||
_megdnn_tensor_in map, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) | |||||
{ | |||||
_megdnn_tensor_in map, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(diff.layout, map.layout, grad.layout, workspace.size); | check_exec(diff.layout, map.layout, grad.layout, workspace.size); | ||||
auto kern = [=]() { | |||||
auto &&sshape = grad.layout; | |||||
auto &&mshape = map.layout; | |||||
auto &&dshape = diff.layout; | |||||
std::vector<size_t> sidx(sshape.ndim, 0_z); | |||||
{ | |||||
// Set grad to zero. | |||||
do { | |||||
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout); | |||||
grad.ptr<dt_float32>()[saddr] = 0.0f; | |||||
} while (get_next_addr(sidx.data(), sshape.shape, sshape.ndim)); | |||||
} | |||||
std::vector<size_t> didx(dshape.ndim+1, 0_z); | |||||
do { | |||||
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim); | |||||
std::vector<size_t> sidx(sshape.ndim); | |||||
for (size_t i = 0_z; i < sshape.ndim; ++i) { | |||||
sidx[i] = map.ptr<dt_int32>()[maddr+i]; | |||||
} | |||||
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout); | |||||
auto daddr = get_linear_addr_noncont(didx.data(), diff.layout); | |||||
grad.ptr<dt_float32>()[saddr] += diff.ptr<dt_float32>()[daddr]; | |||||
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim)); | |||||
}; | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(kern()); | |||||
switch (diff.layout.dtype.enumv()) { | |||||
#define cb(dt) \ | |||||
case DTypeTrait<dt>::enumv: \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
backward<DTypeTrait<dt>::ctype>(diff, map, grad)); \ | |||||
return; | |||||
cb(dtype::Float32) | |||||
cb(dtype::Int32) | |||||
#undef cb | |||||
default: | |||||
megdnn_throw(ssprintf( | |||||
"unsupported dtype %s in indexing remap backward naive\n", | |||||
diff.layout.dtype.name())); | |||||
} | |||||
} | } | ||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -16,39 +16,42 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace test { | namespace test { | ||||
TEST_F(CUDA, TENSOR_REMAP_FORWARD) | |||||
{ | |||||
TEST_F(CUDA, TENSOR_REMAP_FORWARD) { | |||||
Checker<IndexingRemapForward> checker(handle_cuda()); | Checker<IndexingRemapForward> checker(handle_cuda()); | ||||
TensorShape src{11, 13, 17}, map{3, 5, 7, 3}, dst{3, 5, 7}; | |||||
checker.set_dtype(1, dtype::Int32()); | checker.set_dtype(1, dtype::Int32()); | ||||
TensorShape src{11, 13, 17}, | |||||
map{3, 5, 7, 3}, | |||||
dst{3, 5, 7}; | |||||
using namespace tensor_remap; | |||||
{ | |||||
MapRNG rng(src); | |||||
checker.set_rng(1, &rng).execs({src, map, {}}); | |||||
} | |||||
{ | |||||
NonoverlappingMapRNG rng(src); | |||||
checker.set_rng(1, &rng).execs({src, map, {}}); | |||||
for (auto dt : std::vector<DType>{dtype::Float32(), dtype::Int32()}) { | |||||
checker.set_dtype(0, dt); | |||||
checker.set_dtype(2, dt); | |||||
using namespace tensor_remap; | |||||
{ | |||||
MapRNG rng(src); | |||||
checker.set_rng(1, &rng).execs({src, map, {}}); | |||||
} | |||||
{ | |||||
NonoverlappingMapRNG rng(src); | |||||
checker.set_rng(1, &rng).execs({src, map, {}}); | |||||
} | |||||
} | } | ||||
} | } | ||||
TEST_F(CUDA, TENSOR_REMAP_BACKWARD) | |||||
{ | |||||
TEST_F(CUDA, TENSOR_REMAP_BACKWARD) { | |||||
Checker<IndexingRemapBackward> checker(handle_cuda()); | Checker<IndexingRemapBackward> checker(handle_cuda()); | ||||
checker.set_dtype(1, dtype::Int32()); | checker.set_dtype(1, dtype::Int32()); | ||||
TensorShape src{11, 13, 17}, | |||||
map{3, 5, 7, 3}, | |||||
dst{3, 5, 7}; | |||||
using namespace tensor_remap; | |||||
{ | |||||
MapRNG rng(src); | |||||
checker.set_rng(1, &rng).execs({dst, map, src}); | |||||
} | |||||
{ | |||||
NonoverlappingMapRNG rng(src); | |||||
checker.set_rng(1, &rng).execs({dst, map, src}); | |||||
TensorShape src{11, 13, 17}, map{3, 5, 7, 3}, dst{3, 5, 7}; | |||||
checker.set_dtype(1, dtype::Int32()); | |||||
for (auto dt : std::vector<DType>{dtype::Float32(), dtype::Int32()}) { | |||||
checker.set_dtype(0, dt); | |||||
checker.set_dtype(2, dt); | |||||
using namespace tensor_remap; | |||||
{ | |||||
MapRNG rng(src); | |||||
checker.set_rng(1, &rng).execs({dst, map, src}); | |||||
} | |||||
{ | |||||
NonoverlappingMapRNG rng(src); | |||||
checker.set_rng(1, &rng).execs({dst, map, src}); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -56,5 +59,3 @@ TEST_F(CUDA, TENSOR_REMAP_BACKWARD) | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||