GitOrigin-RevId: 5f66d51de4
tags/v0.5.0
@@ -32,9 +32,11 @@ void IndexingRemapBase::check_layout_fwd(const TensorLayout &src, | |||
} | |||
megdnn_assert(map.shape[dst.ndim] == src.ndim, "%s", errmsg_c); | |||
megdnn_assert(src.dtype == dtype::Float32()); | |||
megdnn_assert(dst.dtype == src.dtype); | |||
megdnn_assert(src.dtype == dtype::Float32() || src.dtype == dtype::Int32(), | |||
"indexing remap only support float32/int32, got %s", | |||
src.dtype.name()); | |||
megdnn_assert(map.dtype == dtype::Int32()); | |||
megdnn_assert(dst.dtype == dtype::Float32()); | |||
} | |||
void IndexingRemapForward::deduce_layout(const TensorLayout &src, | |||
@@ -36,13 +36,23 @@ void IndexingRemapForwardImpl::exec(_megdnn_tensor_in src, | |||
for (size_t i = 0_z; i < dst.layout.ndim; ++i) { | |||
dshape.data[i] = dst.layout.shape[i]; | |||
} | |||
// Invoke kernel | |||
tensor_remap::forward(src.ptr<dt_float32>(), | |||
map.ptr<dt_int32>(), | |||
dst.ptr<dt_float32>(), | |||
src.layout.ndim, dst.layout.ndim, | |||
sstride, dstride, dshape, | |||
cuda_stream(handle())); | |||
// Invoke kernel | |||
#define cb(dt) \ | |||
if (src.layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \ | |||
using ctype = DTypeTrait<dt>::ctype; \ | |||
tensor_remap::forward<ctype>(src.ptr<ctype>(), map.ptr<dt_int32>(), \ | |||
dst.ptr<ctype>(), src.layout.ndim, \ | |||
dst.layout.ndim, sstride, dstride, \ | |||
dshape, cuda_stream(handle())); \ | |||
return; \ | |||
} | |||
cb(dtype::Float32) | |||
cb(dtype::Int32) | |||
#undef cb | |||
megdnn_throw( | |||
ssprintf("cuda indexing remap forward only support " | |||
"float32/int32 dtype, got %s", | |||
src.layout.dtype.name())); | |||
} | |||
void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, | |||
@@ -69,18 +79,27 @@ void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, | |||
for (size_t i = 0_z; i < diff.layout.ndim; ++i) { | |||
dshape.data[i] = diff.layout.shape[i]; | |||
} | |||
// Invoke kernel | |||
tensor_remap::backward(diff.ptr<dt_float32>(), | |||
map.ptr<dt_int32>(), | |||
grad.ptr<dt_float32>(), | |||
grad.layout.ndim, diff.layout.ndim, | |||
sstride, dstride, sshape, dshape, | |||
param().is_non_overlapping, | |||
cuda_stream(handle())); | |||
// Invoke kernel | |||
#define cb(dt) \ | |||
if (diff.layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \ | |||
using ctype = DTypeTrait<dt>::ctype; \ | |||
tensor_remap::backward<ctype>( \ | |||
diff.ptr<ctype>(), map.ptr<dt_int32>(), grad.ptr<ctype>(), \ | |||
grad.layout.ndim, diff.layout.ndim, sstride, dstride, sshape, \ | |||
dshape, param().is_non_overlapping, cuda_stream(handle())); \ | |||
return; \ | |||
} | |||
cb(dtype::Float32) | |||
cb(dtype::Int32) | |||
megdnn_throw( | |||
ssprintf("cuda indexing remap forward only support " | |||
"float32/int32 dtype, got %s", | |||
diff.layout.dtype.name())); | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||
@@ -6,28 +6,29 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/cuda/tensor_remap/tensor_remap.cuh" | |||
#include "src/cuda/query_blocksize.cuh" | |||
#include "src/cuda/tensor_remap/tensor_remap.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace tensor_remap { | |||
namespace { | |||
__global__ void forward_kernel(const float *src, const int *map, float *dst, | |||
uint32_t sdim, uint32_t ddim, | |||
array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||
array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | |||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, | |||
uint32_t total) | |||
{ | |||
template <typename ctype> | |||
__global__ void forward_kernel(const ctype* src, const int* map, ctype* dst, | |||
uint32_t sdim, uint32_t ddim, | |||
array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||
array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | |||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, | |||
uint32_t total) { | |||
uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; | |||
if (didx_cont < total) { | |||
uint32_t midx = didx_cont * sdim; | |||
uint32_t didx = 0u; | |||
for (uint32_t j = ddim; j > 0u; --j) { | |||
uint32_t i = j-1u; | |||
uint32_t i = j - 1u; | |||
uint32_t didx_cur = didx_cont % dshape.data[i]; | |||
didx_cont /= dshape.data[i]; | |||
didx += didx_cur * dstride.data[i]; | |||
@@ -41,34 +42,16 @@ __global__ void forward_kernel(const float *src, const int *map, float *dst, | |||
} | |||
} | |||
void forward(const float *src, const int *map, float *dst, | |||
uint32_t sdim, uint32_t ddim, | |||
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, | |||
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, | |||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, | |||
cudaStream_t stream) | |||
{ | |||
uint32_t total = 1u; | |||
for (uint32_t i = 0u; i < ddim; ++i) total *= dshape.data[i]; | |||
uint32_t threads = query_blocksize_for_kernel((void *)&forward_kernel); | |||
uint32_t blocks = DIVUP(total, threads); | |||
forward_kernel<<<blocks, threads, 0, stream>>>(src, map, dst, | |||
sdim, ddim, | |||
sstride, dstride, dshape, | |||
total); | |||
after_kernel_launch(); | |||
} | |||
__global__ void fill_zero_kernel(float *a, uint32_t dim, | |||
array_wrapper<int, MEGDNN_MAX_NDIM> stride, | |||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> shape, | |||
uint32_t total) | |||
{ | |||
template <typename ctype> | |||
__global__ void fill_zero_kernel(ctype* a, uint32_t dim, | |||
array_wrapper<int, MEGDNN_MAX_NDIM> stride, | |||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> shape, | |||
uint32_t total) { | |||
uint32_t idx_cont = threadIdx.x + blockIdx.x * blockDim.x; | |||
if (idx_cont < total) { | |||
uint32_t idx = 0u; | |||
for (uint32_t j = dim; j > 0u; --j) { | |||
uint32_t i = j-1u; | |||
uint32_t i = j - 1u; | |||
uint32_t idx_cur = idx_cont % shape.data[i]; | |||
idx_cont /= shape.data[i]; | |||
idx += idx_cur * stride.data[i]; | |||
@@ -77,19 +60,19 @@ __global__ void fill_zero_kernel(float *a, uint32_t dim, | |||
} | |||
} | |||
__global__ void backward_kernel(const float *diff, const int *map, float *grad, | |||
uint32_t sdim, uint32_t ddim, | |||
array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||
array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | |||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, | |||
uint32_t total) | |||
{ | |||
template <typename ctype> | |||
__global__ void backward_kernel(const ctype* diff, const int* map, ctype* grad, | |||
uint32_t sdim, uint32_t ddim, | |||
array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||
array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | |||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, | |||
uint32_t total) { | |||
uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; | |||
if (didx_cont < total) { | |||
uint32_t midx = didx_cont * sdim; | |||
uint32_t didx = 0u; | |||
for (uint32_t j = ddim; j > 0u; --j) { | |||
uint32_t i = j-1u; | |||
uint32_t i = j - 1u; | |||
uint32_t didx_cur = didx_cont % dshape.data[i]; | |||
didx_cont /= dshape.data[i]; | |||
didx += didx_cur * dstride.data[i]; | |||
@@ -103,20 +86,18 @@ __global__ void backward_kernel(const float *diff, const int *map, float *grad, | |||
} | |||
} | |||
template <typename ctype> | |||
__global__ void backward_kernel_non_overlapping( | |||
const float *diff, const int *map, float *grad, | |||
uint32_t sdim, uint32_t ddim, | |||
array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||
const ctype* diff, const int* map, ctype* grad, uint32_t sdim, | |||
uint32_t ddim, array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||
array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | |||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, | |||
uint32_t total) | |||
{ | |||
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, uint32_t total) { | |||
uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; | |||
if (didx_cont < total) { | |||
uint32_t midx = didx_cont * sdim; | |||
uint32_t didx = 0u; | |||
for (uint32_t j = ddim; j > 0u; --j) { | |||
uint32_t i = j-1u; | |||
uint32_t i = j - 1u; | |||
uint32_t didx_cur = didx_cont % dshape.data[i]; | |||
didx_cont /= dshape.data[i]; | |||
didx += didx_cur * dstride.data[i]; | |||
@@ -130,55 +111,91 @@ __global__ void backward_kernel_non_overlapping( | |||
} | |||
} | |||
void backward(const float *diff, const int *map, float *grad, | |||
uint32_t sdim, uint32_t ddim, | |||
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, | |||
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, | |||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &sshape, | |||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, | |||
bool is_non_overlapping, | |||
cudaStream_t stream) | |||
{ | |||
} // anonymous namespace | |||
namespace tensor_remap { | |||
template <typename ctype> | |||
void forward(const ctype* src, const int* map, ctype* dst, uint32_t sdim, | |||
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, | |||
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, | |||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, | |||
cudaStream_t stream) { | |||
uint32_t total = 1u; | |||
for (uint32_t i = 0u; i < ddim; ++i) | |||
total *= dshape.data[i]; | |||
uint32_t threads = | |||
query_blocksize_for_kernel((void*)&forward_kernel<ctype>); | |||
uint32_t blocks = DIVUP(total, threads); | |||
forward_kernel<ctype><<<blocks, threads, 0, stream>>>( | |||
src, map, dst, sdim, ddim, sstride, dstride, dshape, total); | |||
after_kernel_launch(); | |||
} | |||
template <typename ctype> | |||
void backward(const ctype* diff, const int* map, ctype* grad, uint32_t sdim, | |||
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, | |||
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, | |||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape, | |||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, | |||
bool is_non_overlapping, cudaStream_t stream) { | |||
{ | |||
// Fill grad with zeros. | |||
uint32_t total = 1u; | |||
for (uint32_t i = 0u; i < sdim; ++i) total *= sshape.data[i]; | |||
uint32_t threads = query_blocksize_for_kernel((void *)&fill_zero_kernel); | |||
for (uint32_t i = 0u; i < sdim; ++i) | |||
total *= sshape.data[i]; | |||
uint32_t threads = | |||
query_blocksize_for_kernel((void*)&fill_zero_kernel<ctype>); | |||
uint32_t blocks = DIVUP(total, threads); | |||
fill_zero_kernel<<<blocks, threads, 0, stream>>>( | |||
fill_zero_kernel<ctype><<<blocks, threads, 0, stream>>>( | |||
grad, sdim, sstride, sshape, total); | |||
after_kernel_launch(); | |||
} | |||
{ | |||
// Update grad. | |||
uint32_t total = 1u; | |||
for (uint32_t i = 0u; i < ddim; ++i) total *= dshape.data[i]; | |||
for (uint32_t i = 0u; i < ddim; ++i) | |||
total *= dshape.data[i]; | |||
if (is_non_overlapping) { | |||
uint32_t threads = query_blocksize_for_kernel( | |||
(void *)&backward_kernel_non_overlapping); | |||
(void*)&backward_kernel_non_overlapping<ctype>); | |||
uint32_t blocks = DIVUP(total, threads); | |||
backward_kernel_non_overlapping<<<blocks, threads, 0, stream>>>( | |||
diff, map, grad, | |||
sdim, ddim, | |||
sstride, dstride, dshape, | |||
total); | |||
backward_kernel_non_overlapping<ctype> | |||
<<<blocks, threads, 0, stream>>>(diff, map, grad, sdim, | |||
ddim, sstride, dstride, | |||
dshape, total); | |||
} else { | |||
uint32_t threads = query_blocksize_for_kernel( | |||
(void *)&backward_kernel); | |||
uint32_t threads = | |||
query_blocksize_for_kernel((void*)&backward_kernel<ctype>); | |||
uint32_t blocks = DIVUP(total, threads); | |||
backward_kernel<<<blocks, threads, 0, stream>>>(diff, map, grad, | |||
sdim, ddim, | |||
sstride, dstride, dshape, | |||
backward_kernel<ctype><<<blocks, threads, 0, stream>>>( | |||
diff, map, grad, sdim, ddim, sstride, dstride, dshape, | |||
total); | |||
} | |||
after_kernel_launch(); | |||
} | |||
} | |||
} // namespace tensor_remap | |||
} // namespace cuda | |||
} // namespace megdnn | |||
#define INST(T) \ | |||
template void forward<T>( \ | |||
const T* src, const int* map, T* dst, uint32_t sdim, \ | |||
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, \ | |||
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, \ | |||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, \ | |||
cudaStream_t stream); \ | |||
template void backward<T>( \ | |||
const T* diff, const int* map, T* grad, uint32_t sdim, \ | |||
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, \ | |||
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, \ | |||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape, \ | |||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, \ | |||
bool is_non_overlapping, cudaStream_t stream); | |||
INST(dt_float32) | |||
INST(dt_int32) | |||
// vim: syntax=cpp.doxygen | |||
#undef INST | |||
} // namespace tensor_remap | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -17,25 +17,23 @@ namespace megdnn { | |||
namespace cuda { | |||
namespace tensor_remap { | |||
void forward(const float *src, const int *map, float *dst, | |||
uint32_t sdim, uint32_t ddim, | |||
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, | |||
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, | |||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, | |||
cudaStream_t stream); | |||
template <typename ctype> | |||
void forward(const ctype* src, const int* map, ctype* dst, uint32_t sdim, | |||
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, | |||
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, | |||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, | |||
cudaStream_t stream); | |||
void backward(const float *diff, const int *map, float *grad, | |||
uint32_t sdim, uint32_t ddim, | |||
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, | |||
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, | |||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &sshape, | |||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, | |||
bool is_non_overlapping, | |||
cudaStream_t stream); | |||
template <typename ctype> | |||
void backward(const ctype* diff, const int* map, ctype* grad, uint32_t sdim, | |||
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, | |||
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, | |||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape, | |||
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, | |||
bool is_non_overlapping, cudaStream_t stream); | |||
} // namespace tensor_remap | |||
} // namespace tensor_remap | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||
@@ -6,75 +6,107 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/naive/tensor_remap/opr_impl.h" | |||
#include "src/common/utils.h" | |||
#include "src/naive/handle.h" | |||
namespace megdnn { | |||
namespace naive { | |||
using namespace megdnn; | |||
using namespace naive; | |||
namespace { | |||
template <typename ctype> | |||
void forward(const TensorND& src, const TensorND& map, const TensorND& dst) { | |||
auto&& sshape = src.layout; | |||
auto&& mshape = map.layout; | |||
auto&& dshape = dst.layout; | |||
// Last element is zero to facilitate maddr calculation. | |||
std::vector<size_t> didx(dshape.ndim + 1, 0_z); | |||
do { | |||
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim); | |||
std::vector<size_t> sidx(sshape.ndim); | |||
for (size_t i = 0_z; i < sshape.ndim; ++i) { | |||
sidx[i] = map.ptr<dt_int32>()[maddr + i]; | |||
} | |||
auto saddr = get_linear_addr_noncont(sidx.data(), src.layout); | |||
auto daddr = get_linear_addr_noncont(didx.data(), dst.layout); | |||
dst.ptr<ctype>()[daddr] = src.ptr<ctype>()[saddr]; | |||
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim)); | |||
} | |||
template <typename ctype> | |||
void backward(const TensorND& diff, const TensorND& map, const TensorND& grad) { | |||
auto&& sshape = grad.layout; | |||
auto&& mshape = map.layout; | |||
auto&& dshape = diff.layout; | |||
std::vector<size_t> sidx(sshape.ndim, 0_z); | |||
{ | |||
// Set grad to zero. | |||
do { | |||
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout); | |||
grad.ptr<ctype>()[saddr] = 0.0f; | |||
} while (get_next_addr(sidx.data(), sshape.shape, sshape.ndim)); | |||
} | |||
std::vector<size_t> didx(dshape.ndim + 1, 0_z); | |||
do { | |||
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim); | |||
std::vector<size_t> sidx(sshape.ndim); | |||
for (size_t i = 0_z; i < sshape.ndim; ++i) { | |||
sidx[i] = map.ptr<dt_int32>()[maddr + i]; | |||
} | |||
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout); | |||
auto daddr = get_linear_addr_noncont(didx.data(), diff.layout); | |||
grad.ptr<ctype>()[saddr] += diff.ptr<ctype>()[daddr]; | |||
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim)); | |||
} | |||
} // anonymous namespace | |||
void IndexingRemapForwardImpl::exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in map, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) | |||
{ | |||
_megdnn_tensor_in map, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
check_exec(src.layout, map.layout, dst.layout, workspace.size); | |||
auto kern = [=]() { | |||
auto &&sshape = src.layout; | |||
auto &&mshape = map.layout; | |||
auto &&dshape = dst.layout; | |||
// Last element is zero to facilitate maddr calculation. | |||
std::vector<size_t> didx(dshape.ndim+1, 0_z); | |||
do { | |||
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim); | |||
std::vector<size_t> sidx(sshape.ndim); | |||
for (size_t i = 0_z; i < sshape.ndim; ++i) { | |||
sidx[i] = map.ptr<dt_int32>()[maddr+i]; | |||
} | |||
auto saddr = get_linear_addr_noncont(sidx.data(), src.layout); | |||
auto daddr = get_linear_addr_noncont(didx.data(), dst.layout); | |||
dst.ptr<dt_float32>()[daddr] = src.ptr<dt_float32>()[saddr]; | |||
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim)); | |||
}; | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(kern()); | |||
switch (src.layout.dtype.enumv()) { | |||
#define cb(dt) \ | |||
case DTypeTrait<dt>::enumv: \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
forward<DTypeTrait<dt>::ctype>(src, map, dst)); \ | |||
return; | |||
cb(dtype::Float32) | |||
cb(dtype::Int32) | |||
#undef cb | |||
default: | |||
megdnn_throw( | |||
ssprintf("unsupported dtype %s in indexing " | |||
"remap forward naive\n", | |||
src.layout.dtype.name())); | |||
} | |||
} | |||
void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, | |||
_megdnn_tensor_in map, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) | |||
{ | |||
_megdnn_tensor_in map, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) { | |||
check_exec(diff.layout, map.layout, grad.layout, workspace.size); | |||
auto kern = [=]() { | |||
auto &&sshape = grad.layout; | |||
auto &&mshape = map.layout; | |||
auto &&dshape = diff.layout; | |||
std::vector<size_t> sidx(sshape.ndim, 0_z); | |||
{ | |||
// Set grad to zero. | |||
do { | |||
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout); | |||
grad.ptr<dt_float32>()[saddr] = 0.0f; | |||
} while (get_next_addr(sidx.data(), sshape.shape, sshape.ndim)); | |||
} | |||
std::vector<size_t> didx(dshape.ndim+1, 0_z); | |||
do { | |||
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim); | |||
std::vector<size_t> sidx(sshape.ndim); | |||
for (size_t i = 0_z; i < sshape.ndim; ++i) { | |||
sidx[i] = map.ptr<dt_int32>()[maddr+i]; | |||
} | |||
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout); | |||
auto daddr = get_linear_addr_noncont(didx.data(), diff.layout); | |||
grad.ptr<dt_float32>()[saddr] += diff.ptr<dt_float32>()[daddr]; | |||
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim)); | |||
}; | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(kern()); | |||
switch (diff.layout.dtype.enumv()) { | |||
#define cb(dt) \ | |||
case DTypeTrait<dt>::enumv: \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
backward<DTypeTrait<dt>::ctype>(diff, map, grad)); \ | |||
return; | |||
cb(dtype::Float32) | |||
cb(dtype::Int32) | |||
#undef cb | |||
default: | |||
megdnn_throw(ssprintf( | |||
"unsupported dtype %s in indexing remap backward naive\n", | |||
diff.layout.dtype.name())); | |||
} | |||
} | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -16,39 +16,42 @@ | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(CUDA, TENSOR_REMAP_FORWARD) | |||
{ | |||
TEST_F(CUDA, TENSOR_REMAP_FORWARD) { | |||
Checker<IndexingRemapForward> checker(handle_cuda()); | |||
TensorShape src{11, 13, 17}, map{3, 5, 7, 3}, dst{3, 5, 7}; | |||
checker.set_dtype(1, dtype::Int32()); | |||
TensorShape src{11, 13, 17}, | |||
map{3, 5, 7, 3}, | |||
dst{3, 5, 7}; | |||
using namespace tensor_remap; | |||
{ | |||
MapRNG rng(src); | |||
checker.set_rng(1, &rng).execs({src, map, {}}); | |||
} | |||
{ | |||
NonoverlappingMapRNG rng(src); | |||
checker.set_rng(1, &rng).execs({src, map, {}}); | |||
for (auto dt : std::vector<DType>{dtype::Float32(), dtype::Int32()}) { | |||
checker.set_dtype(0, dt); | |||
checker.set_dtype(2, dt); | |||
using namespace tensor_remap; | |||
{ | |||
MapRNG rng(src); | |||
checker.set_rng(1, &rng).execs({src, map, {}}); | |||
} | |||
{ | |||
NonoverlappingMapRNG rng(src); | |||
checker.set_rng(1, &rng).execs({src, map, {}}); | |||
} | |||
} | |||
} | |||
TEST_F(CUDA, TENSOR_REMAP_BACKWARD) | |||
{ | |||
TEST_F(CUDA, TENSOR_REMAP_BACKWARD) { | |||
Checker<IndexingRemapBackward> checker(handle_cuda()); | |||
checker.set_dtype(1, dtype::Int32()); | |||
TensorShape src{11, 13, 17}, | |||
map{3, 5, 7, 3}, | |||
dst{3, 5, 7}; | |||
using namespace tensor_remap; | |||
{ | |||
MapRNG rng(src); | |||
checker.set_rng(1, &rng).execs({dst, map, src}); | |||
} | |||
{ | |||
NonoverlappingMapRNG rng(src); | |||
checker.set_rng(1, &rng).execs({dst, map, src}); | |||
TensorShape src{11, 13, 17}, map{3, 5, 7, 3}, dst{3, 5, 7}; | |||
checker.set_dtype(1, dtype::Int32()); | |||
for (auto dt : std::vector<DType>{dtype::Float32(), dtype::Int32()}) { | |||
checker.set_dtype(0, dt); | |||
checker.set_dtype(2, dt); | |||
using namespace tensor_remap; | |||
{ | |||
MapRNG rng(src); | |||
checker.set_rng(1, &rng).execs({dst, map, src}); | |||
} | |||
{ | |||
NonoverlappingMapRNG rng(src); | |||
checker.set_rng(1, &rng).execs({dst, map, src}); | |||
} | |||
} | |||
} | |||
@@ -56,5 +59,3 @@ TEST_F(CUDA, TENSOR_REMAP_BACKWARD) | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||