GitOrigin-RevId: 8a4789852e
release-1.11
@@ -16,10 +16,18 @@ protected: | |||||
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { | const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { | ||||
check_layout_fwd(src, mat, {}, dst); | check_layout_fwd(src, mat, {}, dst); | ||||
} | } | ||||
void check_layout_fwd( | |||||
const TensorLayoutArray& srcs, const TensorLayout& mat, | |||||
const TensorLayout& dst) { | |||||
check_layout_fwd(srcs, mat, {}, dst); | |||||
} | |||||
void check_layout_fwd( | void check_layout_fwd( | ||||
const TensorLayout& src, const TensorLayout& mat, | const TensorLayout& src, const TensorLayout& mat, | ||||
const TensorLayout& mat_idx, const TensorLayout& dst); | const TensorLayout& mat_idx, const TensorLayout& dst); | ||||
void check_layout_fwd( | |||||
const TensorLayoutArray& srcs, const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, const TensorLayout& dst); | |||||
std::string param_msg() const; | std::string param_msg() const; | ||||
int get_real_coord(int p, int len); | int get_real_coord(int p, int len); | ||||
}; | }; | ||||
@@ -49,6 +57,12 @@ public: | |||||
exec(src, mat, {}, dst, workspace); | exec(src, mat, {}, dst, workspace); | ||||
} | } | ||||
void exec( | |||||
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
exec(srcs, mat, {}, dst, workspace); | |||||
} | |||||
/** | /** | ||||
* \p src should have batch size m, and \p mat and \p mat_idx should | * \p src should have batch size m, and \p mat and \p mat_idx should | ||||
* both have batch size n. Each item in \p mat_idx must be in the range | * both have batch size n. Each item in \p mat_idx must be in the range | ||||
@@ -62,15 +76,30 @@ public: | |||||
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | ||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | ||||
virtual void exec( | |||||
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat, | |||||
_megdnn_tensor_in mat_idx, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
size_t get_workspace_in_bytes( | size_t get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { | const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { | ||||
return get_workspace_in_bytes(src, mat, {}, dst); | return get_workspace_in_bytes(src, mat, {}, dst); | ||||
} | } | ||||
size_t get_workspace_in_bytes( | |||||
const TensorLayoutArray& srcs, const TensorLayout& mat, | |||||
const TensorLayout& dst) { | |||||
return get_workspace_in_bytes(srcs, mat, {}, dst); | |||||
} | |||||
virtual size_t get_workspace_in_bytes( | virtual size_t get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& mat, | const TensorLayout& src, const TensorLayout& mat, | ||||
const TensorLayout& mat_idx, const TensorLayout& dst) = 0; | const TensorLayout& mat_idx, const TensorLayout& dst) = 0; | ||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayoutArray& srcs, const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, const TensorLayout& dst) = 0; | |||||
protected: | protected: | ||||
void check_exec( | void check_exec( | ||||
const TensorLayout& src, const TensorLayout& mat, | const TensorLayout& src, const TensorLayout& mat, | ||||
@@ -81,6 +110,10 @@ protected: | |||||
const TensorLayout& src, const TensorLayout& mat, | const TensorLayout& src, const TensorLayout& mat, | ||||
const TensorLayout& mat_idx, const TensorLayout& dst, | const TensorLayout& mat_idx, const TensorLayout& dst, | ||||
size_t workspace_in_bytes); | size_t workspace_in_bytes); | ||||
void check_exec_allow_nhwc_mat_idx( | |||||
const TensorLayoutArray& srcs, const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
using WarpPerspective = WarpPerspectiveForward; | using WarpPerspective = WarpPerspectiveForward; | ||||
@@ -22,4 +22,11 @@ bool warp::is_dnn_available( | |||||
return imode == param::WarpAffine::InterpolationMode::LINEAR; | return imode == param::WarpAffine::InterpolationMode::LINEAR; | ||||
} | } | ||||
bool warp::is_dnn_available( | |||||
const TensorLayoutArray& /*src*/, const TensorLayout& /*mat*/, | |||||
const TensorLayout& /*dst*/, param::WarpAffine::InterpolationMode imode, | |||||
param::WarpAffine::Format /*format*/) { | |||||
return imode == param::WarpAffine::InterpolationMode::LINEAR; | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -90,6 +90,10 @@ bool is_dnn_available( | |||||
const TensorLayout&, const TensorLayout&, const TensorLayout&, | const TensorLayout&, const TensorLayout&, const TensorLayout&, | ||||
param::WarpAffine::InterpolationMode imode, param::WarpAffine::Format format); | param::WarpAffine::InterpolationMode imode, param::WarpAffine::Format format); | ||||
bool is_dnn_available( | |||||
const TensorLayoutArray&, const TensorLayout&, const TensorLayout&, | |||||
param::WarpAffine::InterpolationMode imode, param::WarpAffine::Format format); | |||||
using namespace megcv; | using namespace megcv; | ||||
using IMode = InterpolationMode; | using IMode = InterpolationMode; | ||||
using BMode = BorderMode; | using BMode = BorderMode; | ||||
@@ -3,7 +3,97 @@ | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
void WarpPerspectiveBase::check_layout_fwd( | |||||
const TensorLayoutArray& srcs, const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, const TensorLayout& dst) { | |||||
megdnn_assert(srcs.size() > 0); | |||||
auto s = srcs.front(); | |||||
for (auto&& src : srcs) { | |||||
megdnn_assert_contiguous(src); | |||||
megdnn_assert(src.dtype == s.dtype); | |||||
megdnn_assert(src.ndim == s.ndim); | |||||
megdnn_assert(src.shape[0] == 1); | |||||
for (size_t i = 0; i < s.ndim; i++) { | |||||
megdnn_assert(src.shape[i] == s.shape[i]); | |||||
} | |||||
megdnn_assert(src.format == s.format); | |||||
} | |||||
megdnn_assert_contiguous(mat); | |||||
megdnn_assert_contiguous(dst); | |||||
auto errmsg = [&]() { | |||||
std::string msg = "{"; | |||||
for (auto&& src : srcs) { | |||||
msg.append(megdnn_layout_msg(src) + ", "); | |||||
} | |||||
return msg + "} " + megdnn_layout_msg(mat) + ", " + megdnn_layout_msg(mat_idx) + | |||||
", " + megdnn_layout_msg(dst) + ", " + param_msg(); | |||||
}; | |||||
MEGDNN_MARK_USED_VAR(errmsg); | |||||
megdnn_assert( | |||||
param().format == param::WarpPerspective::Format::NHWC || | |||||
param().format == param::WarpPerspective::Format::NCHW); | |||||
megdnn_assert(s.ndim == 4_z, "%s", errmsg().c_str()); | |||||
megdnn_assert(dst.ndim == 4_z, "%s", errmsg().c_str()); | |||||
megdnn_assert(mat.ndim == 3_z, "%s", errmsg().c_str()); | |||||
megdnn_assert(dst.shape[0] == mat.shape[0], "%s", errmsg().c_str()); | |||||
if (mat_idx.ndim) { | |||||
megdnn_assert( | |||||
mat_idx.dtype == dtype::Int32() && mat_idx.ndim == 1, "%s", | |||||
errmsg().c_str()); | |||||
megdnn_assert(mat.shape[0] == mat_idx.shape[0], "%s", errmsg().c_str()); | |||||
megdnn_assert_contiguous(mat_idx); | |||||
} else { | |||||
megdnn_assert(s.shape[0] * srcs.size() == dst.shape[0], "%s", errmsg().c_str()); | |||||
} | |||||
megdnn_assert(mat.shape[1] == 3_z, "%s", errmsg().c_str()); | |||||
megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str()); | |||||
if (s.format == dst.format && dst.dtype == s.dtype) { | |||||
if (param().format == param::WarpPerspective::Format::NCHW) { | |||||
megdnn_assert( | |||||
s.dtype.enumv() == DTypeEnum::Float32 || | |||||
DNN_FLOAT16_SELECT( | |||||
(s.dtype.enumv() == DTypeEnum::Float16 || | |||||
s.dtype.enumv() == DTypeEnum::BFloat16), | |||||
false), | |||||
"WarpPerspective multi src NCHW input dtype should be " | |||||
"Float32" DNN_FLOAT16_SELECT("/Float16/BFloat16", "") "."); | |||||
megdnn_assert( | |||||
(s.dtype.category() == DTypeCategory::FLOAT && | |||||
(s.dtype == mat.dtype || mat.dtype.enumv() == DTypeEnum::Float32)), | |||||
"The input to WarpPerspective multi src is in NCHW format, in this " | |||||
"case, if the input dtype is floating point, the " | |||||
"transformation matrix should have same dtype as the " | |||||
"input, otherwise, it should be in Float32, %s given.", | |||||
mat.dtype.name()); | |||||
megdnn_assert(s.shape[1] == dst.shape[1], "%s", errmsg().c_str()); | |||||
megdnn_assert( | |||||
param().imode == param::WarpPerspective::InterpolationMode::LINEAR); | |||||
megdnn_assert( | |||||
param().bmode != param::WarpPerspective::BorderMode::TRANSPARENT); | |||||
megdnn_assert( | |||||
param().bmode != param::WarpPerspective::BorderMode::ISOLATED); | |||||
} else { | |||||
megdnn_assert(param().format == param::WarpPerspective::Format::NHWC); | |||||
megdnn_assert( | |||||
s.dtype.enumv() == DTypeEnum::Float32 || | |||||
DNN_FLOAT16_SELECT( | |||||
(s.dtype.enumv() == DTypeEnum::Float16 || | |||||
s.dtype.enumv() == DTypeEnum::BFloat16), | |||||
false), | |||||
"WarpPerspective multi src NHWC input dtype should be " | |||||
"Float32" DNN_FLOAT16_SELECT("/Float16/BFloat16", "") "."); | |||||
megdnn_assert(s.shape[3] == dst.shape[3], "%s", errmsg().c_str()); | |||||
} | |||||
} else { | |||||
megdnn_assert( | |||||
0, | |||||
"WarpPerspective multi src only support format NHWC/NCHW, dtype " | |||||
"Float32" DNN_FLOAT16_SELECT("/Float16/BFloat16", "") "."); | |||||
} | |||||
} | |||||
void WarpPerspectiveBase::check_layout_fwd( | void WarpPerspectiveBase::check_layout_fwd( | ||||
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& mat_idx, | const TensorLayout& src, const TensorLayout& mat, const TensorLayout& mat_idx, | ||||
const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
@@ -295,6 +385,19 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx( | |||||
} | } | ||||
} | } | ||||
void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx( | |||||
const TensorLayoutArray& srcs, const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, const TensorLayout& dst, | |||||
size_t workspace_in_bytes) { | |||||
check_layout_fwd(srcs, mat, mat_idx, dst); | |||||
auto required_workspace_in_bytes = get_workspace_in_bytes(srcs, mat, mat_idx, dst); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
if (param().format != Param::Format::NHWC && | |||||
param().format != Param::Format::NCHW) { | |||||
megdnn_assert(!mat_idx.ndim, "mat_idx not supported for current format"); | |||||
} | |||||
} | |||||
void WarpPerspectiveBackwardData::check_exec( | void WarpPerspectiveBackwardData::check_exec( | ||||
const TensorLayout& mat, const TensorLayout& mat_idx, const TensorLayout& diff, | const TensorLayout& mat, const TensorLayout& mat_idx, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_in_bytes) { | const TensorLayout& grad, size_t workspace_in_bytes) { | ||||
@@ -17,6 +17,13 @@ void forward_proxy( | |||||
ctype bval, BorderMode bmode, megcore::AsyncErrorInfo* error_info, | ctype bval, BorderMode bmode, megcore::AsyncErrorInfo* error_info, | ||||
void* error_tracker, cudaStream_t stream); | void* error_tracker, cudaStream_t stream); | ||||
template <typename ctype> | |||||
void forward_proxy_multi_src( | |||||
bool is_nhwc, const ctype** srcs, const float* mat, const int* mat_idx, | |||||
ctype* dst, int N_SRC, int N_MAT, int C, int IH, int IW, int OH, int OW, | |||||
ctype bval, BorderMode bmode, megcore::AsyncErrorInfo* error_info, | |||||
void* error_tracker, cudaStream_t stream); | |||||
template <typename ctype, int pack_c> | template <typename ctype, int pack_c> | ||||
void forward_proxy_nhwc_bit4( | void forward_proxy_nhwc_bit4( | ||||
const ctype* src, const float* mat, const int* mat_idx, ctype* dst, int N_SRC, | const ctype* src, const float* mat, const int* mat_idx, ctype* dst, int N_SRC, | ||||
@@ -143,6 +143,34 @@ WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle( | |||||
return {ptr, std::move(sizes)}; | return {ptr, std::move(sizes)}; | ||||
} | } | ||||
WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle( | |||||
void* ptr, const TensorLayoutArray& srcs, const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, const TensorLayout& dst) const { | |||||
MEGDNN_MARK_USED_VAR(mat_idx); | |||||
SmallVector<size_t> sizes; | |||||
TensorLayoutArray fsrcs = srcs; | |||||
TensorLayout fmat = mat; | |||||
TensorLayout fdst = dst; | |||||
auto get_workspace = [&sizes](TensorLayout& layout) { | |||||
if (layout.dtype == dtype::BFloat16()) { | |||||
layout.dtype = dtype::Float32(); | |||||
sizes.push_back(layout.span().dist_byte()); | |||||
} | |||||
}; | |||||
for (auto&& fsrc : fsrcs) { | |||||
get_workspace(fsrc); | |||||
} | |||||
get_workspace(fmat); | |||||
get_workspace(fdst); | |||||
sizes.push_back(sizeof(dt_float32*) * srcs.size()); | |||||
if (param().format == param::WarpPerspective::Format::NHWC) { | |||||
//! use double for the workspace dtype as float may cause | |||||
//! accuracy problems | |||||
sizes.push_back(mat.total_nr_elems() * sizeof(double)); | |||||
} | |||||
return {ptr, std::move(sizes)}; | |||||
} | |||||
void WarpPerspectiveForwardImpl::exec( | void WarpPerspectiveForwardImpl::exec( | ||||
_megdnn_tensor_in ssrc, _megdnn_tensor_in smat, _megdnn_tensor_in smat_idx, | _megdnn_tensor_in ssrc, _megdnn_tensor_in smat, _megdnn_tensor_in smat_idx, | ||||
_megdnn_tensor_out sdst, _megdnn_workspace sworkspace) { | _megdnn_tensor_out sdst, _megdnn_workspace sworkspace) { | ||||
@@ -453,6 +481,124 @@ void WarpPerspectiveForwardImpl::exec( | |||||
} | } | ||||
} | } | ||||
void WarpPerspectiveForwardImpl::exec( | |||||
_megdnn_in const TensorNDArray& ssrcs, _megdnn_tensor_in smat, | |||||
_megdnn_tensor_in smat_idx, _megdnn_tensor_out sdst, | |||||
_megdnn_workspace sworkspace) { | |||||
TensorLayoutArray ssrcs_layout; | |||||
for (auto&& s : ssrcs) { | |||||
ssrcs_layout.push_back(s.layout); | |||||
} | |||||
check_exec_allow_nhwc_mat_idx( | |||||
ssrcs_layout, smat.layout, smat_idx.layout, sdst.layout, sworkspace.size); | |||||
TensorNDArray srcs = ssrcs; | |||||
TensorND mat = smat; | |||||
TensorND mat_idx = smat_idx; | |||||
TensorND dst = sdst; | |||||
Param::Format inner_format = param().format; | |||||
auto bundle = get_workspace_bundle( | |||||
sworkspace.raw_ptr, ssrcs_layout, smat.layout, smat_idx.layout, | |||||
sdst.layout); | |||||
auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>( | |||||
concrete_handle(this->handle()), &bundle); | |||||
if (ssrcs.front().layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) { | |||||
for (size_t i = 0; i < ssrcs.size(); i++) { | |||||
ctypecvt.src_to_comp_type(ssrcs[i], srcs[i]); | |||||
} | |||||
ctypecvt.src_to_comp_type(smat, mat).src_to_comp_type(sdst, dst); | |||||
} | |||||
{ | |||||
auto stream = cuda_stream(this->handle()); | |||||
bool is_nhwc = inner_format == param::WarpPerspective::Format::NHWC; | |||||
TensorND src = srcs.front(); | |||||
megdnn_assert(warp::is_dnn_available( | |||||
ssrcs_layout, mat.layout, dst.layout, param().imode, inner_format)); | |||||
size_t C, IH, IW, OH, OW; | |||||
if (is_nhwc) { | |||||
C = src.layout.shape[3]; | |||||
IH = src.layout.shape[1]; | |||||
IW = src.layout.shape[2]; | |||||
OH = dst.layout.shape[1]; | |||||
OW = dst.layout.shape[2]; | |||||
} else { | |||||
megdnn_assert( | |||||
inner_format == param::WarpPerspective::Format::NCHW, | |||||
"invalid warp_perspective format"); | |||||
C = src.layout.shape[1]; | |||||
IH = src.layout.shape[2]; | |||||
IW = src.layout.shape[3]; | |||||
OH = dst.layout.shape[2]; | |||||
OW = dst.layout.shape[3]; | |||||
} | |||||
megdnn_assert( | |||||
param().imode == Param::InterpolationMode::LINEAR, | |||||
"unsupported interpolation mode form NCHW format"); | |||||
auto bval = param().border_val; | |||||
auto bmode = warp_perspective::get_bmode(param().bmode); | |||||
if (src.layout.dtype == dst.layout.dtype) { | |||||
if (src.layout.dtype == dtype::Float32{}) { | |||||
SmallVector<size_t> workspace_sizes{sizeof(dt_float32*) * srcs.size()}; | |||||
WorkspaceBundle workspace_cpu(nullptr, workspace_sizes); | |||||
auto total_workspace_size = workspace_cpu.total_size_in_bytes(); | |||||
void* workspace_cpu_raw = malloc(total_workspace_size); | |||||
workspace_cpu = WorkspaceBundle(workspace_cpu_raw, workspace_sizes); | |||||
auto srcs_cpu = static_cast<const dt_float32**>(workspace_cpu.get(0)); | |||||
size_t i = | |||||
is_nhwc ? bundle.nr_workspace() - 2 : bundle.nr_workspace() - 1; | |||||
auto srcs_gpu = static_cast<const dt_float32**>(bundle.get(i)); | |||||
for (size_t i = 0; i < srcs.size(); ++i) { | |||||
srcs_cpu[i] = srcs[i].ptr<dt_float32>(); | |||||
} | |||||
cuda_check(cudaMemcpyAsync( | |||||
bundle.get(i), workspace_cpu.get(0), workspace_cpu.get_size(0), | |||||
cudaMemcpyHostToDevice, stream)); | |||||
cuda_check(cudaStreamAddCallback( | |||||
stream, callback_free, static_cast<void*>(workspace_cpu_raw), | |||||
0)); | |||||
warp_perspective::forward_proxy_multi_src( | |||||
is_nhwc, srcs_gpu, mat.ptr<dt_float32>(), | |||||
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr, | |||||
dst.ptr<dt_float32>(), srcs.size(), mat.layout[0], C, IH, IW, | |||||
OH, OW, bval, bmode, async_error_info(handle()), | |||||
m_error_tracker, stream); | |||||
} else if (DNN_FLOAT16_SELECT( | |||||
src.layout.dtype == dtype::Float16(), false)) { | |||||
#ifndef MEGDNN_DISABLE_FLOAT16 | |||||
SmallVector<size_t> workspace_sizes{sizeof(dt_float16*) * srcs.size()}; | |||||
WorkspaceBundle workspace_cpu(nullptr, workspace_sizes); | |||||
auto total_workspace_size = workspace_cpu.total_size_in_bytes(); | |||||
void* workspace_cpu_raw = malloc(total_workspace_size); | |||||
workspace_cpu = WorkspaceBundle(workspace_cpu_raw, workspace_sizes); | |||||
auto srcs_cpu = static_cast<const dt_float16**>(workspace_cpu.get(0)); | |||||
auto srcs_gpu = static_cast<const dt_float16**>(bundle.get(0)); | |||||
for (size_t i = 0; i < srcs.size(); ++i) { | |||||
srcs_cpu[i] = srcs[i].ptr<dt_float16>(); | |||||
} | |||||
cuda_check(cudaMemcpyAsync( | |||||
bundle.get(0), workspace_cpu.get(0), workspace_cpu.get_size(0), | |||||
cudaMemcpyHostToDevice, stream)); | |||||
cuda_check(cudaStreamAddCallback( | |||||
stream, callback_free, static_cast<void*>(workspace_cpu_raw), | |||||
0)); | |||||
warp_perspective::forward_proxy_multi_src( | |||||
is_nhwc, srcs_gpu, mat.ptr<dt_float32>(), | |||||
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr, | |||||
dst.ptr<dt_float16>(), srcs.size(), mat.layout[0], C, IH, IW, | |||||
OH, OW, static_cast<dt_float16>(bval), bmode, | |||||
async_error_info(handle()), m_error_tracker, stream); | |||||
#endif | |||||
} | |||||
} else { | |||||
megdnn_throw(ssprintf("unsupported dtype: %s", src.layout.dtype.name())); | |||||
} | |||||
} | |||||
if (ssrcs.front().layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) { | |||||
ctypecvt.comp_to_dst_type(dst, sdst); | |||||
} | |||||
} | |||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -47,11 +47,16 @@ struct CtypeHelper<dt_quint4> { | |||||
template <typename ctype> | template <typename ctype> | ||||
struct DirectSrcVisitor { | struct DirectSrcVisitor { | ||||
const void* ptr; | const void* ptr; | ||||
const void** ptrs; | |||||
__device__ __forceinline__ const ctype* get(int batch, int im_size) { | __device__ __forceinline__ const ctype* get(int batch, int im_size) { | ||||
return (ctype*)((char*)ptr + static_cast<int64_t>(batch) * static_cast<int64_t>(im_size) * CtypeHelper<ctype>::bit_width / 8); | return (ctype*)((char*)ptr + static_cast<int64_t>(batch) * static_cast<int64_t>(im_size) * CtypeHelper<ctype>::bit_width / 8); | ||||
} | } | ||||
__device__ __forceinline__ const ctype* get(int batch) { | |||||
return (ctype*)(ptrs[batch]); | |||||
} | |||||
void move_batch(size_t batch, size_t im_size) { | void move_batch(size_t batch, size_t im_size) { | ||||
ptr = (char*)ptr + batch * im_size * CtypeHelper<ctype>::bit_width / 8; | ptr = (char*)ptr + batch * im_size * CtypeHelper<ctype>::bit_width / 8; | ||||
} | } | ||||
@@ -60,6 +65,7 @@ struct DirectSrcVisitor { | |||||
template <typename ctype> | template <typename ctype> | ||||
struct IndexedSrcVisitor { | struct IndexedSrcVisitor { | ||||
const void* ptr; | const void* ptr; | ||||
const void** ptrs; | |||||
const int* idx; | const int* idx; | ||||
int N_SRC; | int N_SRC; | ||||
@@ -79,11 +85,60 @@ struct IndexedSrcVisitor { | |||||
return (ctype*)((char*)ptr + static_cast<int64_t>(batch) * static_cast<int64_t>(im_size) * CtypeHelper<ctype>::bit_width / 8); | return (ctype*)((char*)ptr + static_cast<int64_t>(batch) * static_cast<int64_t>(im_size) * CtypeHelper<ctype>::bit_width / 8); | ||||
} | } | ||||
__device__ __forceinline__ const ctype* get(int batch) { | |||||
int orig_batch = batch; | |||||
batch = idx[batch]; | |||||
if (batch < 0 || batch >= N_SRC) { | |||||
set_async_error_info( | |||||
error_info, error_tracker, | |||||
"mat_idx out of bound: mat_idx[%d]=%d src_batch=%d", orig_batch, | |||||
batch, N_SRC); | |||||
batch = 0; | |||||
} | |||||
return (ctype*)(ptrs[batch]); | |||||
} | |||||
void move_batch(size_t batch, size_t) { idx += batch; } | void move_batch(size_t batch, size_t) { idx += batch; } | ||||
}; | }; | ||||
template < | template < | ||||
typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter> | typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter> | ||||
__global__ void kern_general_multi_src( | |||||
SrcVisitor srcs, const float* __restrict mat, ctype* __restrict dst, int C, | |||||
int IH, int IW, int OH, int OW) { | |||||
Getter getter; | |||||
OutputConverter output_converter; | |||||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||||
const ctype* __restrict sptr = srcs.get(blockIdx.z); | |||||
dst += blockIdx.z * C * OH * OW; | |||||
mat += blockIdx.z * 3 * 3; | |||||
if (ow < OW && oh < OH) { | |||||
float denominator = mat[6] * ow + mat[7] * oh + mat[8]; | |||||
float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; | |||||
float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; | |||||
int iw0 = getter(floor(iw) + 0, IW); | |||||
int iw1 = getter(floor(iw) + 1, IW); | |||||
int ih0 = getter(floor(ih) + 0, IH); | |||||
int ih1 = getter(floor(ih) + 1, IH); | |||||
float palpha = ih - floor(ih); | |||||
float pbeta = iw - floor(iw); | |||||
float nalpha = 1.0f - palpha; | |||||
float nbeta = 1.0f - pbeta; | |||||
for (int c = 0; c < C; ++c) { | |||||
dst[oh * OW + ow] = output_converter( | |||||
sptr[ih0 * IW + iw0] * nalpha * nbeta + | |||||
sptr[ih0 * IW + iw1] * nalpha * pbeta + | |||||
sptr[ih1 * IW + iw0] * palpha * nbeta + | |||||
sptr[ih1 * IW + iw1] * palpha * pbeta); | |||||
sptr += IH * IW; | |||||
dst += OH * OW; | |||||
} | |||||
} | |||||
} | |||||
template < | |||||
typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter> | |||||
__global__ void kern_general( | __global__ void kern_general( | ||||
SrcVisitor src, const float* __restrict mat, ctype* __restrict dst, int C, | SrcVisitor src, const float* __restrict mat, ctype* __restrict dst, int C, | ||||
int IH, int IW, int OH, int OW) { | int IH, int IW, int OH, int OW) { | ||||
@@ -262,6 +317,47 @@ __global__ void kern_general_nchw64( | |||||
} | } | ||||
template <typename ctype, typename SrcVisitor, typename OutputConverter> | template <typename ctype, typename SrcVisitor, typename OutputConverter> | ||||
__global__ void kern_const_border_multi_src( | |||||
SrcVisitor srcs, const float* __restrict mat, ctype* __restrict dst, int C, | |||||
int IH, int IW, int OH, int OW, ctype bval) { | |||||
OutputConverter output_converter; | |||||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||||
const ctype* __restrict sptr = srcs.get(blockIdx.z); | |||||
dst += blockIdx.z * C * OH * OW; | |||||
mat += blockIdx.z * 3 * 3; | |||||
if (ow < OW && oh < OH) { | |||||
float denominator = mat[6] * ow + mat[7] * oh + mat[8]; | |||||
float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; | |||||
float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; | |||||
int iw0 = floor(iw) + 0; | |||||
int iw1 = floor(iw) + 1; | |||||
int ih0 = floor(ih) + 0; | |||||
int ih1 = floor(ih) + 1; | |||||
bool okw0 = (iw0 >= 0 && iw0 < IW); | |||||
bool okw1 = (iw1 >= 0 && iw1 < IW); | |||||
bool okh0 = (ih0 >= 0 && ih0 < IH); | |||||
bool okh1 = (ih1 >= 0 && ih1 < IH); | |||||
float palpha = ih - floor(ih); | |||||
float pbeta = iw - floor(iw); | |||||
float nalpha = 1.0f - palpha; | |||||
float nbeta = 1.0f - pbeta; | |||||
for (int c = 0; c < C; ++c) { | |||||
ctype v00 = (okh0 && okw0 ? sptr[ih0 * IW + iw0] : bval); | |||||
ctype v01 = (okh0 && okw1 ? sptr[ih0 * IW + iw1] : bval); | |||||
ctype v10 = (okh1 && okw0 ? sptr[ih1 * IW + iw0] : bval); | |||||
ctype v11 = (okh1 && okw1 ? sptr[ih1 * IW + iw1] : bval); | |||||
ctype val = output_converter( | |||||
v00 * nalpha * nbeta + v01 * nalpha * pbeta + v10 * palpha * nbeta + | |||||
v11 * palpha * pbeta); | |||||
dst[oh * OW + ow] = val; | |||||
sptr += IH * IW; | |||||
dst += OH * OW; | |||||
} | |||||
} | |||||
} | |||||
template <typename ctype, typename SrcVisitor, typename OutputConverter> | |||||
__global__ void kern_const_border( | __global__ void kern_const_border( | ||||
SrcVisitor src, const float* __restrict mat, ctype* __restrict dst, int C, | SrcVisitor src, const float* __restrict mat, ctype* __restrict dst, int C, | ||||
int IH, int IW, int OH, int OW, ctype bval) { | int IH, int IW, int OH, int OW, ctype bval) { | ||||
@@ -556,6 +652,51 @@ struct KernCoreNHWC<ctype, OutputConverter, 16> { | |||||
template < | template < | ||||
typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter, | typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter, | ||||
int pack_c> | int pack_c> | ||||
__global__ void kern_general_nhwc_multi_src( | |||||
SrcVisitor srcs, const float* __restrict mat, ctype* __restrict dst, int C, | |||||
int IH, int IW, int OH, int OW) { | |||||
Getter getter; | |||||
OutputConverter output_converter; | |||||
constexpr int bit_width = CtypeHelper<ctype>::bit_width; | |||||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||||
const ctype* __restrict sptr = srcs.get(blockIdx.z); | |||||
dst = (ctype*)((char*)dst + blockIdx.z * C * OH * OW * bit_width / 8); | |||||
mat += blockIdx.z * 3 * 3; | |||||
if (ow < OW && oh < OH) { | |||||
float denominator = mat[6] * ow + mat[7] * oh + mat[8]; | |||||
float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; | |||||
float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; | |||||
int iw0 = getter(floor(iw) + 0, IW); | |||||
int iw1 = getter(floor(iw) + 1, IW); | |||||
int ih0 = getter(floor(ih) + 0, IH); | |||||
int ih1 = getter(floor(ih) + 1, IH); | |||||
float palpha = ih - floor(ih); | |||||
float pbeta = iw - floor(iw); | |||||
float nalpha = 1.0f - palpha; | |||||
float nbeta = 1.0f - pbeta; | |||||
float w00 = nalpha * nbeta; | |||||
float w01 = nalpha * pbeta; | |||||
float w10 = palpha * nbeta; | |||||
float w11 = palpha * pbeta; | |||||
const char* src_ptr0 = (char*)sptr + (ih0 * IW + iw0) * C * bit_width / 8; | |||||
const char* src_ptr1 = (char*)sptr + (ih0 * IW + iw1) * C * bit_width / 8; | |||||
const char* src_ptr2 = (char*)sptr + (ih1 * IW + iw0) * C * bit_width / 8; | |||||
const char* src_ptr3 = (char*)sptr + (ih1 * IW + iw1) * C * bit_width / 8; | |||||
char* dst_ptr = (char*)dst + (oh * OW + ow) * C * bit_width / 8; | |||||
for (int c = 0; c < C; c += pack_c) { | |||||
KernCoreNHWC<ctype, OutputConverter, pack_c>::func( | |||||
dst_ptr, src_ptr0, src_ptr1, src_ptr2, src_ptr3, c * bit_width / 8, | |||||
w00, w01, w10, w11, output_converter, true, true, true, true, | |||||
(ctype)0); | |||||
} | |||||
} | |||||
} | |||||
template < | |||||
typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter, | |||||
int pack_c> | |||||
__global__ void kern_general_nhwc( | __global__ void kern_general_nhwc( | ||||
SrcVisitor src, const float* __restrict mat, ctype* __restrict dst, int C, | SrcVisitor src, const float* __restrict mat, ctype* __restrict dst, int C, | ||||
int IH, int IW, int OH, int OW) { | int IH, int IW, int OH, int OW) { | ||||
@@ -601,6 +742,58 @@ __global__ void kern_general_nhwc( | |||||
template < | template < | ||||
typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter, | typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter, | ||||
int pack_c> | int pack_c> | ||||
__global__ void kern_general_nhwc_const_multi_src( | |||||
SrcVisitor srcs, const float* __restrict mat, ctype* __restrict dst, int C, | |||||
int IH, int IW, int OH, int OW, ctype bval) { | |||||
Getter getter; | |||||
OutputConverter output_converter; | |||||
constexpr int bit_width = CtypeHelper<ctype>::bit_width; | |||||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||||
const ctype* __restrict sptr = srcs.get(blockIdx.z); | |||||
dst = (ctype*)((char*)dst + blockIdx.z * C * OH * OW * bit_width / 8); | |||||
mat += blockIdx.z * 3 * 3; | |||||
if (ow < OW && oh < OH) { | |||||
float denominator = mat[6] * ow + mat[7] * oh + mat[8]; | |||||
float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; | |||||
float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; | |||||
int iw0 = getter(floor(iw) + 0, IW); | |||||
int iw1 = getter(floor(iw) + 1, IW); | |||||
int ih0 = getter(floor(ih) + 0, IH); | |||||
int ih1 = getter(floor(ih) + 1, IH); | |||||
float palpha = ih - floor(ih); | |||||
float pbeta = iw - floor(iw); | |||||
float nalpha = 1.0f - palpha; | |||||
float nbeta = 1.0f - pbeta; | |||||
float w00 = nalpha * nbeta; | |||||
float w01 = nalpha * pbeta; | |||||
float w10 = palpha * nbeta; | |||||
float w11 = palpha * pbeta; | |||||
const char* src_ptr0 = (char*)sptr + (ih0 * IW + iw0) * C * bit_width / 8; | |||||
const char* src_ptr1 = (char*)sptr + (ih0 * IW + iw1) * C * bit_width / 8; | |||||
const char* src_ptr2 = (char*)sptr + (ih1 * IW + iw0) * C * bit_width / 8; | |||||
const char* src_ptr3 = (char*)sptr + (ih1 * IW + iw1) * C * bit_width / 8; | |||||
char* dst_ptr = (char*)dst + (oh * OW + ow) * C * bit_width / 8; | |||||
bool okw0 = (iw0 >= 0 && iw0 < IW); | |||||
bool okw1 = (iw1 >= 0 && iw1 < IW); | |||||
bool okh0 = (ih0 >= 0 && ih0 < IH); | |||||
bool okh1 = (ih1 >= 0 && ih1 < IH); | |||||
bool src0_ok = okh0 && okw0; | |||||
bool src1_ok = okh0 && okw1; | |||||
bool src2_ok = okh1 && okw0; | |||||
bool src3_ok = okh1 && okw1; | |||||
for (int c = 0; c < C; c += pack_c) { | |||||
KernCoreNHWC<ctype, OutputConverter, pack_c>::func( | |||||
dst_ptr, src_ptr0, src_ptr1, src_ptr2, src_ptr3, c * bit_width / 8, | |||||
w00, w01, w10, w11, output_converter, src0_ok, src1_ok, src2_ok, | |||||
src3_ok, bval); | |||||
} | |||||
} | |||||
} | |||||
template < | |||||
typename ctype, typename Getter, typename SrcVisitor, typename OutputConverter, | |||||
int pack_c> | |||||
__global__ void kern_general_nhwc_const( | __global__ void kern_general_nhwc_const( | ||||
SrcVisitor src, const float* __restrict mat, ctype* __restrict dst, int C, | SrcVisitor src, const float* __restrict mat, ctype* __restrict dst, int C, | ||||
int IH, int IW, int OH, int OW, ctype bval) { | int IH, int IW, int OH, int OW, ctype bval) { | ||||
@@ -651,6 +844,73 @@ __global__ void kern_general_nhwc_const( | |||||
} | } | ||||
template <typename ctype, typename SrcVisitor> | template <typename ctype, typename SrcVisitor> | ||||
void dispatch_with_visitor_multi_src( | |||||
bool is_nhwc, SrcVisitor srcs, const float* mat, ctype* dst, int N, int C, | |||||
int IH, int IW, int OH, int OW, ctype bval, BorderMode bmode, | |||||
cudaStream_t stream) { | |||||
constexpr int pack_c = 1; | |||||
const int BY = 16, BX = 32; | |||||
#define DISPATCH(Getter) \ | |||||
do { \ | |||||
if (is_nhwc) { \ | |||||
kern_general_nhwc_multi_src< \ | |||||
ctype, Getter, SrcVisitor, rounding::RoundingConverter<ctype>, \ | |||||
pack_c><<<blocks, threads, 0, stream>>>( \ | |||||
srcs, mat, dst, C, IH, IW, OH, OW); \ | |||||
} else { \ | |||||
kern_general_multi_src< \ | |||||
ctype, Getter, SrcVisitor, rounding::RoundingConverter<ctype>> \ | |||||
<<<blocks, threads, 0, stream>>>( \ | |||||
srcs, mat, dst, C, IH, IW, OH, OW); \ | |||||
} \ | |||||
} while (0) | |||||
const int max_batch_size = 65535; | |||||
while (N) { | |||||
size_t curr_batch_size = N < max_batch_size ? N : max_batch_size; | |||||
dim3 threads(BX, BY); | |||||
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); | |||||
switch (bmode) { | |||||
case BORDER_REPLICATE: | |||||
DISPATCH(ReplicateGetter); | |||||
break; | |||||
case BORDER_REFLECT: | |||||
DISPATCH(ReflectGetter); | |||||
break; | |||||
case BORDER_REFLECT_101: | |||||
DISPATCH(Reflect101Getter); | |||||
break; | |||||
case BORDER_WRAP: | |||||
DISPATCH(WrapGetter); | |||||
break; | |||||
#undef DISPATCH | |||||
case BORDER_CONSTANT: | |||||
if (is_nhwc) { | |||||
kern_general_nhwc_const_multi_src< | |||||
ctype, ConstGetter, SrcVisitor, | |||||
rounding::RoundingConverter<ctype>, pack_c> | |||||
<<<blocks, threads, 0, stream>>>( | |||||
srcs, mat, dst, C, IH, IW, OH, OW, bval); | |||||
} else { | |||||
kern_const_border_multi_src< | |||||
ctype, SrcVisitor, rounding::RoundingConverter<ctype>> | |||||
<<<blocks, threads, 0, stream>>>( | |||||
srcs, mat, dst, C, IH, IW, OH, OW, bval); | |||||
} | |||||
break; | |||||
default: | |||||
break; | |||||
} | |||||
N -= curr_batch_size; | |||||
srcs.move_batch(curr_batch_size, C * IH * IW); | |||||
mat += curr_batch_size * 3 * 3; | |||||
dst += curr_batch_size * C * OH * OW; | |||||
} | |||||
} | |||||
template <typename ctype, typename SrcVisitor> | |||||
void dispatch_with_visitor( | void dispatch_with_visitor( | ||||
bool is_nhwc, SrcVisitor src, const float* mat, ctype* dst, int N, int C, | bool is_nhwc, SrcVisitor src, const float* mat, ctype* dst, int N, int C, | ||||
int IH, int IW, int OH, int OW, ctype bval, BorderMode bmode, | int IH, int IW, int OH, int OW, ctype bval, BorderMode bmode, | ||||
@@ -1534,6 +1794,33 @@ void dispatch_with_visitor_quint8_dimshuffle_typecvt_nchw( | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace cuda { | namespace cuda { | ||||
namespace warp_perspective { | namespace warp_perspective { | ||||
template <typename ctype> | |||||
void forward_proxy_multi_src( | |||||
bool is_nhwc, const ctype** srcs, const float* mat, const int* mat_idx, | |||||
ctype* dst, int N_SRC, int N_MAT, int C, int IH, int IW, int OH, int OW, | |||||
ctype bval, BorderMode bmode, megcore::AsyncErrorInfo* error_info, | |||||
void* error_tracker, cudaStream_t stream) { | |||||
if (mat_idx) { | |||||
IndexedSrcVisitor<ctype> visitor; | |||||
visitor.ptrs = reinterpret_cast<const void**>(srcs); | |||||
visitor.ptr = srcs; | |||||
visitor.idx = mat_idx; | |||||
visitor.N_SRC = N_SRC; | |||||
visitor.error_info = error_info; | |||||
visitor.error_tracker = error_tracker; | |||||
dispatch_with_visitor_multi_src( | |||||
is_nhwc, visitor, mat, dst, N_MAT, C, IH, IW, OH, OW, bval, bmode, | |||||
stream); | |||||
} else { | |||||
DirectSrcVisitor<ctype> visitor; | |||||
visitor.ptrs = reinterpret_cast<const void**>(srcs); | |||||
visitor.ptr = srcs; | |||||
dispatch_with_visitor_multi_src( | |||||
is_nhwc, visitor, mat, dst, N_MAT, C, IH, IW, OH, OW, bval, bmode, | |||||
stream); | |||||
} | |||||
after_kernel_launch(); | |||||
} | |||||
template <typename ctype> | template <typename ctype> | ||||
void forward_proxy( | void forward_proxy( | ||||
@@ -1643,6 +1930,17 @@ INST(dt_float16) | |||||
INST(int8_t) | INST(int8_t) | ||||
#undef INST | #undef INST | ||||
#define INST(ctype) \ | |||||
template void forward_proxy_multi_src( \ | |||||
bool, const ctype**, const float*, const int*, ctype*, int, int, int, int, \ | |||||
int, int, int, ctype, BorderMode, megcore::AsyncErrorInfo*, void*, \ | |||||
cudaStream_t); | |||||
INST(float) | |||||
#ifndef MEGDNN_DISABLE_FLOAT16 | |||||
INST(dt_float16) | |||||
#endif | |||||
#undef INST | |||||
#define INST(ctype) \ | #define INST(ctype) \ | ||||
template void forward_proxy_nchw4( \ | template void forward_proxy_nchw4( \ | ||||
const ctype*, const float*, const int*, ctype*, int, int, int, int, int, \ | const ctype*, const float*, const int*, ctype*, int, int, int, int, int, \ | ||||
@@ -15,12 +15,22 @@ public: | |||||
void exec( | void exec( | ||||
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | ||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | _megdnn_tensor_out dst, _megdnn_workspace workspace) override; | ||||
void exec( | |||||
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat, | |||||
_megdnn_tensor_in mat_idx, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | size_t get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& mat, | const TensorLayout& src, const TensorLayout& mat, | ||||
const TensorLayout& mat_idx, const TensorLayout& dst) override { | const TensorLayout& mat_idx, const TensorLayout& dst) override { | ||||
return get_workspace_bundle(nullptr, src, mat, mat_idx, dst) | return get_workspace_bundle(nullptr, src, mat, mat_idx, dst) | ||||
.total_size_in_bytes(); | .total_size_in_bytes(); | ||||
} | } | ||||
size_t get_workspace_in_bytes( | |||||
const TensorLayoutArray& srcs, const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, const TensorLayout& dst) override { | |||||
return get_workspace_bundle(nullptr, srcs, mat, mat_idx, dst) | |||||
.total_size_in_bytes(); | |||||
} | |||||
void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } | void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } | ||||
@@ -28,6 +38,9 @@ private: | |||||
WorkspaceBundle get_workspace_bundle( | WorkspaceBundle get_workspace_bundle( | ||||
void* ptr, const TensorLayout& src, const TensorLayout& mat, | void* ptr, const TensorLayout& src, const TensorLayout& mat, | ||||
const TensorLayout& mat_idx, const TensorLayout& dst) const; | const TensorLayout& mat_idx, const TensorLayout& dst) const; | ||||
WorkspaceBundle get_workspace_bundle( | |||||
void* ptr, const TensorLayoutArray& srcs, const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, const TensorLayout& dst) const; | |||||
}; | }; | ||||
class WarpPerspectiveBackwardDataImpl final : public WarpPerspectiveBackwardData { | class WarpPerspectiveBackwardDataImpl final : public WarpPerspectiveBackwardData { | ||||
@@ -51,6 +51,56 @@ size_t WarpPerspectiveImpl::get_workspace_in_bytes( | |||||
} | } | ||||
} | } | ||||
size_t WarpPerspectiveImpl::get_workspace_in_bytes( | |||||
const TensorLayoutArray&, const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout& dst) { | |||||
if (param().format == param::WarpPerspective::Format::NCHW) { | |||||
size_t OH = dst.shape[2], OW = dst.shape[3]; | |||||
return get_bundle(OH, OW).total_size_in_bytes(); | |||||
} else { | |||||
return 0; | |||||
} | |||||
} | |||||
void WarpPerspectiveImpl::exec( | |||||
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat, | |||||
_megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, _megdnn_workspace workspace) { | |||||
TensorLayoutArray srcs_layout; | |||||
for (auto&& src : srcs) { | |||||
srcs_layout.push_back(src.layout); | |||||
} | |||||
check_exec_allow_nhwc_mat_idx( | |||||
srcs_layout, mat.layout, mat_idx.layout, dst.layout, workspace.size); | |||||
size_t nr_threads = static_cast<naive::HandleImpl*>(handle()) | |||||
->megcore_dispatcher() | |||||
->nr_threads(); | |||||
if (param().format == Format::NCHW && nr_threads == 1_z) { | |||||
#define cb(dt, ct, mct) \ | |||||
case DTypeTrait<dt>::enumv: { \ | |||||
auto kparam = KernParam<ct, mct>::from_tensors( \ | |||||
param().format, param().bmode, param().border_val, srcs, mat, mat_idx, \ | |||||
dst, workspace); \ | |||||
MIDOUT_BEGIN(megdnn_fallback_warpperspective, midout_iv(0), dt, ct, mct) { \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_fallback_multi_src(kparam)); \ | |||||
return; \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
} | |||||
switch (srcs.front().layout.dtype.enumv()) { | |||||
cb(dtype::Float32, float, float); | |||||
DNN_INC_FLOAT16(cb(dtype::Float16, dt_float16, float)); | |||||
default: | |||||
megdnn_throw(ssprintf( | |||||
"Unsupported input DType in " | |||||
"WarpPerspective: %s", | |||||
srcs.front().layout.dtype.name()) | |||||
.c_str()); | |||||
} | |||||
#undef cb | |||||
} | |||||
naive::WarpPerspectiveForwardImpl::exec(srcs, mat, mat_idx, dst, workspace); | |||||
} | |||||
void WarpPerspectiveImpl::exec( | void WarpPerspectiveImpl::exec( | ||||
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | ||||
_megdnn_tensor_in dst, _megdnn_workspace workspace) { | _megdnn_tensor_in dst, _megdnn_workspace workspace) { | ||||
@@ -96,6 +146,69 @@ void WarpPerspectiveImpl::exec( | |||||
} | } | ||||
template <typename ctype, typename mtype> | template <typename ctype, typename mtype> | ||||
void WarpPerspectiveImpl::kern_fallback_multi_src( | |||||
const KernParam<ctype, mtype>& kern_param) { | |||||
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param); | |||||
// cause error if accidentally used | |||||
sptr = nullptr; | |||||
mptr = nullptr; | |||||
dptr = nullptr; | |||||
MEGDNN_MARK_USED_VAR(sptr); | |||||
MEGDNN_MARK_USED_VAR(mptr); | |||||
MEGDNN_MARK_USED_VAR(dptr); | |||||
MEGDNN_MARK_USED_VAR(border_val); | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
MEGDNN_MARK_USED_VAR(IW); | |||||
KernParam<ctype, mtype> sub_param = kern_param; | |||||
sub_param.n_src = 1; | |||||
sub_param.n_mat = 1; | |||||
sub_param.midx_ptr = RefPtr(); | |||||
sub_param.src_ptr = RefPtr(kern_param.srcs_ptr.front().get_ptr()); | |||||
sub_param.mat_ptr = RefPtr(kern_param.mat_ptr.get_ptr()); | |||||
sub_param.dst_ptr = RefPtr(kern_param.dst_ptr.get_ptr()); | |||||
sub_param.srcs_ptr = kern_param.srcs_ptr; | |||||
rep(n, N_MAT) { | |||||
if (midx_ptr) { | |||||
size_t idx = midx_ptr[n]; | |||||
megdnn_assert( | |||||
idx < N_SRC, "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", | |||||
n, idx, N_SRC); | |||||
sub_param.src_ptr.reset( | |||||
static_cast<ctype*>(kern_param.srcs_ptr[idx].get_ptr())); | |||||
} else if (n) { | |||||
sub_param.src_ptr.reset( | |||||
static_cast<ctype*>(kern_param.srcs_ptr[n].get_ptr())); | |||||
} | |||||
if (is_resize_optimizable(static_cast<mtype*>(sub_param.mat_ptr.get_ptr()))) { | |||||
if (bmode == BorderMode::CONSTANT) { | |||||
MIDOUT_BEGIN( | |||||
megdnn_fallback_warpperspective, midout_iv(1), midout_iv(true), | |||||
ctype, mtype) { | |||||
kern_resize<true, ctype, mtype>(sub_param); | |||||
} | |||||
MIDOUT_END(); | |||||
} else { | |||||
MIDOUT_BEGIN( | |||||
megdnn_fallback_warpperspective, midout_iv(1), midout_iv(false), | |||||
ctype, mtype) { | |||||
kern_resize<false, ctype, mtype>(sub_param); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} else { | |||||
MIDOUT_BEGIN(megdnn_fallback_warpperspective, midout_iv(2), ctype, mtype) { | |||||
rep(oh, OH) kern_naive<ctype, mtype>(sub_param, oh); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
sub_param.mat_ptr += 3 * 3 * sizeof(mtype); | |||||
sub_param.dst_ptr += C * OH * OW * sizeof(ctype); | |||||
} | |||||
} | |||||
template <typename ctype, typename mtype> | |||||
void WarpPerspectiveImpl::kern_fallback(const KernParam<ctype, mtype>& kern_param) { | void WarpPerspectiveImpl::kern_fallback(const KernParam<ctype, mtype>& kern_param) { | ||||
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param); | UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param); | ||||
@@ -9,14 +9,24 @@ protected: | |||||
template <typename ctype, typename mtype> | template <typename ctype, typename mtype> | ||||
void kern_fallback(const KernParam<ctype, mtype>& kern_param); | void kern_fallback(const KernParam<ctype, mtype>& kern_param); | ||||
template <typename ctype, typename mtype> | |||||
void kern_fallback_multi_src(const KernParam<ctype, mtype>& kern_param); | |||||
public: | public: | ||||
using naive::WarpPerspectiveForwardImpl::WarpPerspectiveForwardImpl; | using naive::WarpPerspectiveForwardImpl::WarpPerspectiveForwardImpl; | ||||
size_t get_workspace_in_bytes( | size_t get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& mat, | const TensorLayout& src, const TensorLayout& mat, | ||||
const TensorLayout& mat_idx, const TensorLayout& dst) override; | const TensorLayout& mat_idx, const TensorLayout& dst) override; | ||||
size_t get_workspace_in_bytes( | |||||
const TensorLayoutArray& srcs, const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, const TensorLayout& dst) override; | |||||
void exec( | void exec( | ||||
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | ||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | _megdnn_tensor_out dst, _megdnn_workspace workspace) override; | ||||
void exec( | |||||
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat, | |||||
_megdnn_tensor_in mat_idx, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) override; | |||||
private: | private: | ||||
template <typename ctype> | template <typename ctype> | ||||
@@ -15,6 +15,119 @@ using namespace megdnn; | |||||
using namespace naive; | using namespace naive; | ||||
template <typename ctype, typename mtype> | template <typename ctype, typename mtype> | ||||
void WarpPerspectiveForwardImpl::kern_naive_multi_src( | |||||
const KernParam<ctype, mtype>& kern_param, size_t task_id) { | |||||
MEGDNN_MARK_USED_VAR(kern_param); | |||||
MIDOUT_BEGIN(megdnn_naive_warpperspective, ctype, mtype, midout_iv(0)) { | |||||
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param); | |||||
MEGDNN_MARK_USED_VAR(N_MAT); | |||||
//! strides of C, H, W on src and dst | |||||
size_t sstrd[3], dstrd[3]; | |||||
auto set_sstrd = [&](size_t s0, size_t s1, size_t s2) { | |||||
sstrd[0] = s0; | |||||
sstrd[1] = s1; | |||||
sstrd[2] = s2; | |||||
}; | |||||
auto set_dstrd = [&](size_t s0, size_t s1, size_t s2) { | |||||
dstrd[0] = s0; | |||||
dstrd[1] = s1; | |||||
dstrd[2] = s2; | |||||
}; | |||||
switch (kern_param.format) { | |||||
case Format::NCHW: | |||||
set_sstrd(IH * IW, IW, 1); | |||||
set_dstrd(OH * OW, OW, 1); | |||||
break; | |||||
case Format::NHWC: | |||||
set_sstrd(1, IW * C, C); | |||||
set_dstrd(1, OW * C, C); | |||||
break; | |||||
default: | |||||
megdnn_throw("bad format"); | |||||
} | |||||
auto visit_src = [&sptr, sstrd](size_t c, int h, int w) -> float { | |||||
return sptr[sstrd[0] * c + sstrd[1] * h + sstrd[2] * w]; | |||||
}; | |||||
auto visit_src_bd = [&sptr, sstrd, border_val]( | |||||
size_t c, int h, int w) -> float { | |||||
if (h != -1 && w != -1) { | |||||
return sptr[sstrd[0] * c + sstrd[1] * h + sstrd[2] * w]; | |||||
} else | |||||
return border_val; | |||||
}; | |||||
auto visit_dst = [&dptr, dstrd](size_t c, int h, int w) -> ctype& { | |||||
return dptr[dstrd[0] * c + dstrd[1] * h + dstrd[2] * w]; | |||||
}; | |||||
rounding::RoundingConverter<ctype> output_converter; | |||||
sptr = static_cast<const ctype*>(kern_param.srcs_ptr.front().get_ptr()); | |||||
size_t n = task_id / OH; | |||||
size_t oh = task_id % OH; | |||||
mptr = mptr + n * 3 * 3; | |||||
dptr = dptr + n * C * OH * OW; | |||||
if (midx_ptr) { | |||||
size_t idx = midx_ptr[n]; | |||||
megdnn_assert( | |||||
idx < N_SRC, "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", | |||||
n, idx, N_SRC); | |||||
sptr = sptrs[idx]; | |||||
} else if (n) { | |||||
sptr = sptrs[n]; | |||||
} | |||||
rep(ow, OW) { | |||||
float numeratorw = mptr[0] * ow + mptr[1] * oh + mptr[2]; | |||||
float numeratorh = mptr[3] * ow + mptr[4] * oh + mptr[5]; | |||||
float denominator = mptr[6] * ow + mptr[7] * oh + mptr[8]; | |||||
float alphaw = numeratorw / denominator; | |||||
float alphah = numeratorh / denominator; | |||||
int iw0 = get_real_coord(std::floor(alphaw) + 0, IW); | |||||
int iw1 = get_real_coord(std::floor(alphaw) + 1, IW); | |||||
int ih0 = get_real_coord(std::floor(alphah) + 0, IH); | |||||
int ih1 = get_real_coord(std::floor(alphah) + 1, IH); | |||||
alphaw -= floor(alphaw); | |||||
alphah -= floor(alphah); | |||||
if (bmode != BorderMode::CONSTANT) { | |||||
rep(c, C) { | |||||
visit_dst(c, oh, ow) = output_converter( | |||||
visit_src(c, ih0, iw0) * (1.0f - alphaw) * (1.0f - alphah) + | |||||
visit_src(c, ih0, iw1) * alphaw * (1.0f - alphah) + | |||||
visit_src(c, ih1, iw0) * (1.0f - alphaw) * alphah + | |||||
visit_src(c, ih1, iw1) * alphaw * alphah); | |||||
} | |||||
} else { | |||||
rep(c, C) { | |||||
auto val = visit_src_bd(c, ih0, iw0) * (1.0f - alphaw) * | |||||
(1.0f - alphah) + | |||||
visit_src_bd(c, ih0, iw1) * alphaw * (1.0f - alphah) + | |||||
visit_src_bd(c, ih1, iw0) * (1.0f - alphaw) * alphah + | |||||
visit_src_bd(c, ih1, iw1) * alphaw * alphah; | |||||
visit_dst(c, oh, ow) = | |||||
output_converter(std::isfinite(val) ? val : border_val); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
#define INST(ctype, mtype) \ | |||||
template void WarpPerspectiveForwardImpl::kern_naive_multi_src<ctype, mtype>( \ | |||||
const KernParam<ctype, mtype>&, size_t); | |||||
INST(float, float); | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
INST(dt_float16, float); | |||||
INST(dt_float16, dt_float16); | |||||
INST(dt_bfloat16, float); | |||||
INST(dt_bfloat16, dt_bfloat16); | |||||
#endif | |||||
#undef INST | |||||
template <typename ctype, typename mtype> | |||||
void WarpPerspectiveForwardImpl::kern_naive( | void WarpPerspectiveForwardImpl::kern_naive( | ||||
const KernParam<ctype, mtype>& kern_param, size_t task_id) { | const KernParam<ctype, mtype>& kern_param, size_t task_id) { | ||||
MEGDNN_MARK_USED_VAR(kern_param); | MEGDNN_MARK_USED_VAR(kern_param); | ||||
@@ -505,6 +618,71 @@ INST(uint8_t, float, float); | |||||
#undef INST | #undef INST | ||||
void WarpPerspectiveForwardImpl::exec( | void WarpPerspectiveForwardImpl::exec( | ||||
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat, | |||||
_megdnn_tensor_in mat_idx, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
TensorLayoutArray srcs_layout; | |||||
for (auto&& src : srcs) { | |||||
srcs_layout.push_back(src.layout); | |||||
} | |||||
check_exec_allow_nhwc_mat_idx( | |||||
srcs_layout, mat.layout, mat_idx.layout, dst.layout, workspace.size); | |||||
size_t batch = dst.layout[0]; | |||||
#define KERN_NAIVE_MULTI_SRC(ct, mct) \ | |||||
auto kparam = KernParam<ct, mct>::from_tensors( \ | |||||
param().format, param().bmode, param().border_val, srcs, mat, mat_idx, \ | |||||
dst, workspace); \ | |||||
auto run = [kparam, this](size_t index, size_t) { \ | |||||
kern_naive_multi_src(kparam, index); \ | |||||
}; \ | |||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run, kparam.oh* batch); | |||||
#define DISPATCH_ST_MULTI_SRC(dt, ct, mct, kern) \ | |||||
if (srcs.front().layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \ | |||||
kern(ct, mct); \ | |||||
return; \ | |||||
} | |||||
#define DISPATCH_ST_MT_MULTI_SRC(dt, ct, kern) \ | |||||
if (srcs.front().layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \ | |||||
if (mat.layout.dtype.enumv() == DTypeTrait<dtype::Float32>::enumv) { \ | |||||
kern(ct, float); \ | |||||
return; \ | |||||
} else { \ | |||||
kern(ct, ct); \ | |||||
return; \ | |||||
} \ | |||||
} | |||||
megdnn_assert(warp::is_dnn_available( | |||||
srcs_layout, mat.layout, dst.layout, param().imode, param().format)); | |||||
/*! | |||||
* We currently use floating point for all WarpPerspective | |||||
* computation, so even if the input ctype is one of the integer | |||||
* type, mtype should always be float32. | |||||
* | |||||
* \warning It's different with \c WarpAffine, with mtype be float16 | |||||
* if input type is float16. | |||||
*/ | |||||
DISPATCH_ST_MULTI_SRC(dtype::Float32, float, float, KERN_NAIVE_MULTI_SRC); | |||||
DNN_INC_FLOAT16( | |||||
DISPATCH_ST_MT_MULTI_SRC(dtype::Float16, dt_float16, KERN_NAIVE_MULTI_SRC)); | |||||
DNN_INC_FLOAT16(DISPATCH_ST_MT_MULTI_SRC( | |||||
dtype::BFloat16, dt_bfloat16, KERN_NAIVE_MULTI_SRC)); | |||||
megdnn_throw(ssprintf( | |||||
"Unsupported input DType in " | |||||
"WarpPerspective: %s", | |||||
srcs.front().layout.dtype.name()) | |||||
.c_str()); | |||||
#undef KERN_NAIVE_MULTI_SRC | |||||
#undef DISPATCH_ST_MT_MULTI_SRC | |||||
#undef DISPATCH_ST_MULTI_SRC | |||||
} | |||||
void WarpPerspectiveForwardImpl::exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | ||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | _megdnn_tensor_out dst, _megdnn_workspace workspace) { | ||||
check_exec_allow_nhwc_mat_idx( | check_exec_allow_nhwc_mat_idx( | ||||
@@ -17,10 +17,72 @@ protected: | |||||
DType src_dtype, dst_dtype; | DType src_dtype, dst_dtype; | ||||
RefPtr src_ptr, mat_ptr, dst_ptr; | RefPtr src_ptr, mat_ptr, dst_ptr; | ||||
RefPtr midx_ptr; //!< can be null | RefPtr midx_ptr; //!< can be null | ||||
SmallVector<RefPtr> srcs_ptr; | |||||
Workspace workspace; | Workspace workspace; | ||||
static KernParam from_tensors( | static KernParam from_tensors( | ||||
Format format, BorderMode bmode, float border_val, | Format format, BorderMode bmode, float border_val, | ||||
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat, | |||||
_megdnn_tensor_in mat_idx, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
auto src = srcs.front(); | |||||
KernParam ret; | |||||
ret.format = format; | |||||
ret.bmode = bmode; | |||||
ret.border_val = border_val; | |||||
ret.n_src = srcs.size(); | |||||
ret.src_dtype = src.layout.dtype; | |||||
ret.dst_dtype = dst.layout.dtype; | |||||
if (mat_idx.raw_ptr()) { | |||||
megdnn_assert(mat_idx.layout.ndim == 1); | |||||
ret.n_mat = mat_idx.layout.shape[0]; | |||||
ret.midx_ptr = mat_idx.get_ref_ptr(); | |||||
} else { | |||||
megdnn_assert(mat_idx.layout.ndim == 0); | |||||
ret.n_mat = ret.n_src; | |||||
ret.midx_ptr = nullptr; | |||||
} | |||||
if (format == Format::NCHW) { | |||||
ret.c = src.layout.shape[1]; | |||||
ret.ih = src.layout.shape[2]; | |||||
ret.iw = src.layout.shape[3]; | |||||
ret.oh = dst.layout.shape[2]; | |||||
ret.ow = dst.layout.shape[3]; | |||||
} else { | |||||
megdnn_assert(format == Format::NHWC); | |||||
ret.c = src.layout.shape[3]; | |||||
ret.ih = src.layout.shape[1]; | |||||
ret.iw = src.layout.shape[2]; | |||||
ret.oh = dst.layout.shape[1]; | |||||
ret.ow = dst.layout.shape[2]; | |||||
} | |||||
if ((src.layout.dtype.enumv() == DTypeEnum::Float32 || | |||||
DNN_FLOAT16_SELECT( | |||||
(src.layout.dtype.enumv() == DTypeEnum::Float16 || | |||||
src.layout.dtype.enumv() == DTypeEnum::BFloat16), | |||||
false)) && | |||||
(src.layout.dtype == dst.layout.dtype)) { | |||||
for (auto&& s : srcs) { | |||||
ret.srcs_ptr.push_back(s.get_ref_ptr()); | |||||
} | |||||
ret.mat_ptr = mat.get_ref_ptr(); | |||||
ret.dst_ptr = dst.get_ref_ptr(); | |||||
} else { | |||||
for (size_t i = 0; i < srcs.size(); i++) { | |||||
ret.srcs_ptr.push_back(nullptr); | |||||
} | |||||
ret.mat_ptr = nullptr; | |||||
ret.dst_ptr = nullptr; | |||||
} | |||||
ret.src_ptr = nullptr; | |||||
ret.workspace = workspace; | |||||
return ret; | |||||
} | |||||
static KernParam from_tensors( | |||||
Format format, BorderMode bmode, float border_val, | |||||
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | ||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | _megdnn_tensor_out dst, _megdnn_workspace workspace) { | ||||
KernParam ret; | KernParam ret; | ||||
@@ -124,16 +186,29 @@ protected: | |||||
template <typename ctype, typename mtype> | template <typename ctype, typename mtype> | ||||
void kern_naive(const KernParam<ctype, mtype>& kern_param, size_t task_id); | void kern_naive(const KernParam<ctype, mtype>& kern_param, size_t task_id); | ||||
template <typename ctype, typename mtype> | |||||
void kern_naive_multi_src( | |||||
const KernParam<ctype, mtype>& kern_param, size_t task_id); | |||||
public: | public: | ||||
using WarpPerspectiveForward::WarpPerspectiveForward; | using WarpPerspectiveForward::WarpPerspectiveForward; | ||||
void exec( | void exec( | ||||
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | ||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | _megdnn_tensor_out dst, _megdnn_workspace workspace) override; | ||||
void exec( | |||||
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_in mat, | |||||
_megdnn_tensor_in mat_idx, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | size_t get_workspace_in_bytes( | ||||
const TensorLayout&, const TensorLayout&, const TensorLayout&, | const TensorLayout&, const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&) override { | const TensorLayout&) override { | ||||
return 0; | return 0; | ||||
} | } | ||||
size_t get_workspace_in_bytes( | |||||
const TensorLayoutArray&, const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
private: | private: | ||||
template <typename ctype, typename mtype> | template <typename ctype, typename mtype> | ||||
@@ -253,6 +328,10 @@ private: | |||||
auto mptr = static_cast<const mtype*>(p.mat_ptr.get_ptr()); \ | auto mptr = static_cast<const mtype*>(p.mat_ptr.get_ptr()); \ | ||||
auto dptr = static_cast<ctype*>(p.dst_ptr.get_ptr()); \ | auto dptr = static_cast<ctype*>(p.dst_ptr.get_ptr()); \ | ||||
auto midx_ptr = static_cast<int*>(p.midx_ptr.get_ptr()); \ | auto midx_ptr = static_cast<int*>(p.midx_ptr.get_ptr()); \ | ||||
SmallVector<const ctype*> sptrs; \ | |||||
for (auto&& s_ptr : p.srcs_ptr) { \ | |||||
sptrs.push_back(static_cast<const ctype*>(s_ptr.get_ptr())); \ | |||||
} \ | |||||
auto bmode = p.bmode; \ | auto bmode = p.bmode; \ | ||||
float border_val = p.border_val | float border_val = p.border_val | ||||
@@ -50,6 +50,54 @@ void WarpPerspectiveMatIdxProxy::exec( | |||||
tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], W.workspace()); | tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], W.workspace()); | ||||
} | } | ||||
void WarpPerspectiveMultiSrcProxy::deduce_layout( | |||||
WarpPerspectiveForward*, TensorLayoutArray&) {} | |||||
void WarpPerspectiveMultiSrcProxy::exec( | |||||
WarpPerspectiveForward* opr, const TensorNDArray& tensors) { | |||||
if (!W.valid()) { | |||||
W = WorkspaceWrapper(opr->handle(), 0); | |||||
} | |||||
megdnn_assert(tensors.size() >= 3); | |||||
bool has_mat_idx = false; | |||||
TensorLayout mat_idx_layout; | |||||
TensorND mat_idx_tensor; | |||||
TensorLayoutArray layouts(tensors.size()); | |||||
std::transform( | |||||
tensors.begin(), tensors.end(), layouts.begin(), | |||||
[](const TensorND& tensor) { return tensor.layout; }); | |||||
auto srcs_layouts = layouts; | |||||
srcs_layouts.pop_back(); // dst | |||||
if (srcs_layouts.back().ndim == 1) { | |||||
has_mat_idx = true; | |||||
mat_idx_layout = srcs_layouts.back(); | |||||
srcs_layouts.pop_back(); // mat_idx; | |||||
} | |||||
auto mat_layout = srcs_layouts.back(); | |||||
srcs_layouts.pop_back(); // mat | |||||
if (has_mat_idx) | |||||
W.update(opr->get_workspace_in_bytes( | |||||
srcs_layouts, mat_layout, mat_idx_layout, layouts.back())); | |||||
else | |||||
W.update(opr->get_workspace_in_bytes(srcs_layouts, mat_layout, layouts.back())); | |||||
auto srcs_tensors = tensors; | |||||
srcs_tensors.pop_back(); // dst | |||||
if (has_mat_idx) { | |||||
mat_idx_tensor = srcs_tensors.back(); | |||||
srcs_tensors.pop_back(); // mat_idx; | |||||
} | |||||
auto mat_tensor = srcs_tensors.back(); | |||||
srcs_tensors.pop_back(); // mat | |||||
if (has_mat_idx) | |||||
opr->exec( | |||||
srcs_tensors, mat_tensor, mat_idx_tensor, tensors.back(), | |||||
W.workspace()); | |||||
else | |||||
opr->exec(srcs_tensors, mat_tensor, tensors.back(), W.workspace()); | |||||
} | |||||
std::vector<TestArg> warp_perspective::get_cv_args() { | std::vector<TestArg> warp_perspective::get_cv_args() { | ||||
std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
@@ -19,6 +19,12 @@ struct WarpPerspectiveMatIdxProxy { | |||||
void exec(WarpPerspectiveBackwardMat* opr, const TensorNDArray& tensors); | void exec(WarpPerspectiveBackwardMat* opr, const TensorNDArray& tensors); | ||||
}; | }; | ||||
struct WarpPerspectiveMultiSrcProxy { | |||||
WorkspaceWrapper W; | |||||
static void deduce_layout(WarpPerspectiveForward*, TensorLayoutArray&); | |||||
void exec(WarpPerspectiveForward* opr, const TensorNDArray& tensors); | |||||
}; | |||||
class WarpPerspectiveMatRNG final : public IIDRNG { | class WarpPerspectiveMatRNG final : public IIDRNG { | ||||
public: | public: | ||||
WarpPerspectiveMatRNG() : idx(0) {} | WarpPerspectiveMatRNG() : idx(0) {} | ||||
@@ -887,6 +887,194 @@ TEST_F(CUDA, WARP_PERSPECTIVE_NCHW64_QUINT4) { | |||||
} | } | ||||
} | } | ||||
TEST_F(CUDA, WARP_PERSPECTIVE_MULTI_SRC_NCHW) { | |||||
using Param = WarpPerspective::Param; | |||||
Param param; | |||||
WarpPerspectiveMatRNG rng; | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NCHW; | |||||
auto run = [¶m, &rng, this]( | |||||
size_t bs, size_t ih, size_t iw, size_t c, size_t oh, | |||||
size_t ow, DType dtype) { | |||||
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker( | |||||
handle_cuda()); | |||||
checker.set_param(param); | |||||
TensorShapeArray shapes; | |||||
// src | |||||
for (size_t i = 0; i < bs; i++) { | |||||
shapes.emplace_back(TensorShape{{1, c, ih, iw}}); | |||||
checker.set_dtype(i, dtype); | |||||
} | |||||
// mat | |||||
shapes.emplace_back(TensorShape{{bs, 3, 3}}); | |||||
checker.set_rng(bs, &rng); | |||||
// dst | |||||
shapes.emplace_back(TensorShape{{bs, c, oh, ow}}); | |||||
checker.set_dtype(bs + 1, dtype); | |||||
checker.execs(shapes); | |||||
}; | |||||
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) { | |||||
run(1, 20, 18, 4, 6, 6, dtype); | |||||
run(2, 100, 110, 10, 50, 50, dtype); | |||||
run(20, 10, 11, 123, 15, 16, dtype); | |||||
run(2200, 10, 11, 3, 11, 12, dtype); | |||||
} | |||||
} | |||||
} | |||||
TEST_F(CUDA, WARP_PERSPECTIVE_MULTI_SRC_NHWC) { | |||||
using Param = WarpPerspective::Param; | |||||
Param param; | |||||
WarpPerspectiveMatRNG rng; | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NHWC; | |||||
auto run = [¶m, &rng, this]( | |||||
size_t bs, size_t ih, size_t iw, size_t c, size_t oh, | |||||
size_t ow, DType dtype) { | |||||
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker( | |||||
handle_cuda()); | |||||
checker.set_param(param); | |||||
TensorShapeArray shapes; | |||||
// src | |||||
for (size_t i = 0; i < bs; i++) { | |||||
shapes.emplace_back(TensorShape{{1, ih, iw, c}}); | |||||
checker.set_dtype(i, dtype); | |||||
} | |||||
// mat | |||||
shapes.emplace_back(TensorShape{{bs, 3, 3}}); | |||||
checker.set_rng(bs, &rng); | |||||
// dst | |||||
shapes.emplace_back(TensorShape{{bs, oh, ow, c}}); | |||||
checker.set_dtype(bs + 1, dtype); | |||||
checker.execs(shapes); | |||||
}; | |||||
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) { | |||||
run(1, 20, 18, 4, 6, 6, dtype); | |||||
run(2, 100, 110, 10, 50, 50, dtype); | |||||
run(20, 10, 11, 123, 15, 16, dtype); | |||||
run(2200, 10, 11, 3, 11, 12, dtype); | |||||
} | |||||
} | |||||
} | |||||
TEST_F(CUDA, WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NCHW) { | |||||
using Param = WarpPerspective::Param; | |||||
Param param; | |||||
WarpPerspectiveMatRNG rng; | |||||
UniformIntRNG idx_rng{0, 0}; | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NCHW; | |||||
auto run = [¶m, &rng, &idx_rng, this]( | |||||
size_t bs, size_t ih, size_t iw, size_t c, size_t oh, | |||||
size_t ow, size_t idx, DType dtype) { | |||||
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker( | |||||
handle_cuda()); | |||||
checker.set_param(param); | |||||
TensorShapeArray shapes; | |||||
// src | |||||
for (size_t i = 0; i < bs; i++) { | |||||
shapes.emplace_back(TensorShape{{1, c, ih, iw}}); | |||||
checker.set_dtype(i, dtype); | |||||
} | |||||
// mat | |||||
shapes.emplace_back(TensorShape{{idx, 3, 3}}); | |||||
checker.set_rng(bs, &rng); | |||||
// mat_idx | |||||
shapes.emplace_back(TensorShape{{idx}}); | |||||
checker.set_dtype(bs + 1, dtype::Int32()); | |||||
idx_rng = UniformIntRNG{0, (int)bs - 1}; | |||||
checker.set_rng(bs + 1, &idx_rng); | |||||
// dst | |||||
shapes.emplace_back(TensorShape{{idx, c, oh, ow}}); | |||||
checker.set_dtype(bs + 2, dtype); | |||||
checker.execs(shapes); | |||||
}; | |||||
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) { | |||||
run(1, 20, 18, 4, 6, 6, 1, dtype); | |||||
run(2, 100, 110, 10, 50, 50, 1, dtype); | |||||
run(20, 10, 11, 123, 15, 16, 10, dtype); | |||||
run(2200, 10, 11, 3, 11, 12, 100, dtype); | |||||
} | |||||
} | |||||
} | |||||
TEST_F(CUDA, WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NHWC) { | |||||
using Param = WarpPerspective::Param; | |||||
Param param; | |||||
WarpPerspectiveMatRNG rng; | |||||
UniformIntRNG idx_rng{0, 0}; | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NHWC; | |||||
auto run = [¶m, &rng, &idx_rng, this]( | |||||
size_t bs, size_t ih, size_t iw, size_t c, size_t oh, | |||||
size_t ow, size_t idx, DType dtype) { | |||||
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker( | |||||
handle_cuda()); | |||||
checker.set_param(param); | |||||
TensorShapeArray shapes; | |||||
// src | |||||
for (size_t i = 0; i < bs; i++) { | |||||
shapes.emplace_back(TensorShape{{1, ih, iw, c}}); | |||||
checker.set_dtype(i, dtype); | |||||
} | |||||
// mat | |||||
shapes.emplace_back(TensorShape{{idx, 3, 3}}); | |||||
checker.set_rng(bs, &rng); | |||||
// mat_idx | |||||
shapes.emplace_back(TensorShape{{idx}}); | |||||
checker.set_dtype(bs + 1, dtype::Int32()); | |||||
idx_rng = UniformIntRNG{0, (int)bs - 1}; | |||||
checker.set_rng(bs + 1, &idx_rng); | |||||
// dst | |||||
shapes.emplace_back(TensorShape{{idx, oh, ow, c}}); | |||||
checker.set_dtype(bs + 2, dtype); | |||||
checker.execs(shapes); | |||||
}; | |||||
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) { | |||||
run(1, 20, 18, 4, 6, 6, 1, dtype); | |||||
run(2, 100, 110, 10, 50, 50, 1, dtype); | |||||
run(20, 10, 11, 123, 15, 16, 10, dtype); | |||||
run(2200, 10, 11, 3, 11, 12, 100, dtype); | |||||
} | |||||
} | |||||
} | |||||
#if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
TEST_F(CUDA, BENCHMARK_WARP_PERSPECTIVE_NCHW4) { | TEST_F(CUDA, BENCHMARK_WARP_PERSPECTIVE_NCHW4) { | ||||
@@ -172,6 +172,190 @@ TEST_F(FALLBACK, WARP_PERSPECTIFVE_NCHW_QUINT8) { | |||||
warp_perspective::run_quint8_test(handle()); | warp_perspective::run_quint8_test(handle()); | ||||
} | } | ||||
TEST_F(FALLBACK, WARP_PERSPECTIVE_MULTI_SRC_NCHW) { | |||||
using Param = WarpPerspective::Param; | |||||
Param param; | |||||
WarpPerspectiveMatRNG rng; | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NCHW; | |||||
auto run = [¶m, &rng, this]( | |||||
size_t bs, size_t ih, size_t iw, size_t c, size_t oh, | |||||
size_t ow, DType dtype) { | |||||
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker( | |||||
handle()); | |||||
checker.set_param(param); | |||||
TensorShapeArray shapes; | |||||
// src | |||||
for (size_t i = 0; i < bs; i++) { | |||||
shapes.emplace_back(TensorShape{{1, c, ih, iw}}); | |||||
checker.set_dtype(i, dtype); | |||||
} | |||||
// mat | |||||
shapes.emplace_back(TensorShape{{bs, 3, 3}}); | |||||
checker.set_rng(bs, &rng); | |||||
// dst | |||||
shapes.emplace_back(TensorShape{{bs, c, oh, ow}}); | |||||
checker.set_dtype(bs + 1, dtype); | |||||
checker.execs(shapes); | |||||
}; | |||||
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) { | |||||
run(1, 20, 18, 4, 6, 6, dtype); | |||||
run(20, 10, 11, 123, 15, 16, dtype); | |||||
run(100, 10, 11, 3, 11, 12, dtype); | |||||
} | |||||
} | |||||
} | |||||
TEST_F(FALLBACK, WARP_PERSPECTIVE_MULTI_SRC_NHWC) { | |||||
using Param = WarpPerspective::Param; | |||||
Param param; | |||||
WarpPerspectiveMatRNG rng; | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NHWC; | |||||
auto run = [¶m, &rng, this]( | |||||
size_t bs, size_t ih, size_t iw, size_t c, size_t oh, | |||||
size_t ow, DType dtype) { | |||||
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker( | |||||
handle()); | |||||
checker.set_param(param); | |||||
TensorShapeArray shapes; | |||||
// src | |||||
for (size_t i = 0; i < bs; i++) { | |||||
shapes.emplace_back(TensorShape{{1, ih, iw, c}}); | |||||
checker.set_dtype(i, dtype); | |||||
} | |||||
// mat | |||||
shapes.emplace_back(TensorShape{{bs, 3, 3}}); | |||||
checker.set_rng(bs, &rng); | |||||
// dst | |||||
shapes.emplace_back(TensorShape{{bs, oh, ow, c}}); | |||||
checker.set_dtype(bs + 1, dtype); | |||||
checker.execs(shapes); | |||||
}; | |||||
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) { | |||||
run(1, 20, 18, 4, 6, 6, dtype); | |||||
run(20, 10, 11, 123, 15, 16, dtype); | |||||
run(100, 10, 11, 3, 11, 12, dtype); | |||||
} | |||||
} | |||||
} | |||||
TEST_F(FALLBACK, WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NCHW) { | |||||
using Param = WarpPerspective::Param; | |||||
Param param; | |||||
WarpPerspectiveMatRNG rng; | |||||
UniformIntRNG idx_rng{0, 0}; | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NCHW; | |||||
auto run = [¶m, &rng, &idx_rng, this]( | |||||
size_t bs, size_t ih, size_t iw, size_t c, size_t oh, | |||||
size_t ow, size_t idx, DType dtype) { | |||||
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker( | |||||
handle()); | |||||
checker.set_param(param); | |||||
TensorShapeArray shapes; | |||||
// src | |||||
for (size_t i = 0; i < bs; i++) { | |||||
shapes.emplace_back(TensorShape{{1, c, ih, iw}}); | |||||
checker.set_dtype(i, dtype); | |||||
} | |||||
// mat | |||||
shapes.emplace_back(TensorShape{{idx, 3, 3}}); | |||||
checker.set_rng(bs, &rng); | |||||
// mat_idx | |||||
shapes.emplace_back(TensorShape{{idx}}); | |||||
checker.set_dtype(bs + 1, dtype::Int32()); | |||||
idx_rng = UniformIntRNG{0, (int)bs - 1}; | |||||
checker.set_rng(bs + 1, &idx_rng); | |||||
// dst | |||||
shapes.emplace_back(TensorShape{{idx, c, oh, ow}}); | |||||
checker.set_dtype(bs + 2, dtype); | |||||
checker.execs(shapes); | |||||
}; | |||||
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) { | |||||
run(1, 20, 18, 4, 6, 6, 1, dtype); | |||||
run(20, 10, 11, 123, 15, 16, 10, dtype); | |||||
run(100, 10, 11, 3, 11, 12, 100, dtype); | |||||
} | |||||
} | |||||
} | |||||
TEST_F(FALLBACK, WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX_NHWC) { | |||||
using Param = WarpPerspective::Param; | |||||
Param param; | |||||
WarpPerspectiveMatRNG rng; | |||||
UniformIntRNG idx_rng{0, 0}; | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NHWC; | |||||
auto run = [¶m, &rng, &idx_rng, this]( | |||||
size_t bs, size_t ih, size_t iw, size_t c, size_t oh, | |||||
size_t ow, size_t idx, DType dtype) { | |||||
Checker<WarpPerspectiveForward, WarpPerspectiveMultiSrcProxy> checker( | |||||
handle()); | |||||
checker.set_param(param); | |||||
TensorShapeArray shapes; | |||||
// src | |||||
for (size_t i = 0; i < bs; i++) { | |||||
shapes.emplace_back(TensorShape{{1, ih, iw, c}}); | |||||
checker.set_dtype(i, dtype); | |||||
} | |||||
// mat | |||||
shapes.emplace_back(TensorShape{{idx, 3, 3}}); | |||||
checker.set_rng(bs, &rng); | |||||
// mat_idx | |||||
shapes.emplace_back(TensorShape{{idx}}); | |||||
checker.set_dtype(bs + 1, dtype::Int32()); | |||||
idx_rng = UniformIntRNG{0, (int)bs - 1}; | |||||
checker.set_rng(bs + 1, &idx_rng); | |||||
// dst | |||||
shapes.emplace_back(TensorShape{{idx, oh, ow, c}}); | |||||
checker.set_dtype(bs + 2, dtype); | |||||
checker.execs(shapes); | |||||
}; | |||||
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) { | |||||
run(1, 20, 18, 4, 6, 6, 1, dtype); | |||||
run(20, 10, 11, 123, 15, 16, 10, dtype); | |||||
run(100, 10, 11, 3, 11, 12, 100, dtype); | |||||
} | |||||
} | |||||
} | |||||
} // namespace test | } // namespace test | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -55,6 +55,282 @@ class NanMatRNG : public RNG { | |||||
}; | }; | ||||
} // namespace | } // namespace | ||||
TEST_F(NAIVE, WARP_PERSPECTIVE_MULTI_SRC) { | |||||
using Param = WarpPerspective::Param; | |||||
WarpPerspective::Param param; | |||||
auto extra_impl = [¶m, this](const TensorNDArray& tensors) { | |||||
//! split src | |||||
TensorND src = tensors[0]; // n h w c | |||||
size_t n = src.layout[0]; | |||||
TensorNDArray srcs; // n 个 1 h w c | |||||
TensorLayoutArray srcs_layouts; | |||||
for (size_t i = 0; i < n; i++) { | |||||
TensorLayout ly; | |||||
ly = TensorLayout{ | |||||
{1, src.layout[1], src.layout[2], src.layout[3]}, src.layout.dtype}; | |||||
srcs.emplace_back(malloc(ly.span().dist_byte()), ly); | |||||
srcs_layouts.emplace_back(ly); | |||||
} | |||||
auto split = handle()->create_operator<SplitForward>(); | |||||
split->param().axis = 0; | |||||
auto split_ws_size = split->get_workspace_in_bytes(src.layout, srcs_layouts); | |||||
dt_byte* split_ws_ptr = static_cast<dt_byte*>(malloc(split_ws_size)); | |||||
Workspace split_ws{split_ws_ptr, split_ws_size}; | |||||
split->exec(src, srcs, split_ws); | |||||
auto warp_perspective = handle()->create_operator<WarpPerspective>(); | |||||
warp_perspective->param() = param; | |||||
auto warp_ws_size = warp_perspective->get_workspace_in_bytes( | |||||
srcs_layouts, tensors[1].layout, tensors[2].layout); | |||||
dt_byte* warp_ws_ptr = static_cast<dt_byte*>(malloc(warp_ws_size)); | |||||
Workspace warp_ws{warp_ws_ptr, warp_ws_size}; | |||||
warp_perspective->exec(srcs, tensors[1], tensors[2], warp_ws); | |||||
free(split_ws_ptr); | |||||
free(warp_ws_ptr); | |||||
for (auto&& s : srcs) { | |||||
free(s.raw_ptr()); | |||||
} | |||||
}; | |||||
{ | |||||
// Float32 | |||||
Checker<WarpPerspectiveForward> checker(handle()); | |||||
WarpPerspectiveMatRNG rng; | |||||
checker.set_rng(1, &rng); | |||||
checker.set_extra_opr_impl(extra_impl); | |||||
// NHWC | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NHWC; | |||||
checker.set_param(param); | |||||
checker.execs({{1, 2, 2, 4}, {1, 3, 3}, {1, 2, 2, 4}}); | |||||
checker.execs({{2, 10, 10, 4}, {2, 3, 3}, {2, 10, 12, 4}}); | |||||
checker.execs({{3, 25, 24, 8}, {3, 3, 3}, {3, 12, 10, 8}}); | |||||
checker.execs({{4, 33, 22, 16}, {4, 3, 3}, {4, 9, 12, 16}}); | |||||
} | |||||
// NCHW | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NCHW; | |||||
checker.set_param(param); | |||||
checker.execs({{1, 4, 2, 2}, {1, 3, 3}, {1, 4, 2, 2}}); | |||||
checker.execs({{2, 4, 10, 10}, {2, 3, 3}, {2, 4, 10, 12}}); | |||||
checker.execs({{3, 8, 25, 24}, {3, 3, 3}, {3, 8, 12, 10}}); | |||||
checker.execs({{4, 16, 33, 22}, {4, 3, 3}, {4, 16, 9, 12}}); | |||||
} | |||||
} | |||||
{ | |||||
// Float16 | |||||
Checker<WarpPerspectiveForward> checker(handle()); | |||||
WarpPerspectiveMatRNG rng; | |||||
checker.set_rng(1, &rng); | |||||
checker.set_dtype(0, dtype::Float16()); | |||||
checker.set_dtype(2, dtype::Float16()); | |||||
checker.set_extra_opr_impl(extra_impl); | |||||
// NHWC | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NHWC; | |||||
checker.set_param(param); | |||||
checker.execs({{1, 2, 2, 4}, {1, 3, 3}, {1, 2, 2, 4}}); | |||||
checker.execs({{2, 10, 10, 4}, {2, 3, 3}, {2, 10, 12, 4}}); | |||||
checker.execs({{3, 25, 24, 8}, {3, 3, 3}, {3, 12, 10, 8}}); | |||||
checker.execs({{4, 33, 22, 16}, {4, 3, 3}, {4, 9, 12, 16}}); | |||||
} | |||||
// NCHW | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NCHW; | |||||
checker.set_param(param); | |||||
checker.execs({{1, 4, 2, 2}, {1, 3, 3}, {1, 4, 2, 2}}); | |||||
checker.execs({{2, 4, 10, 10}, {2, 3, 3}, {2, 4, 10, 12}}); | |||||
checker.execs({{3, 8, 25, 24}, {3, 3, 3}, {3, 8, 12, 10}}); | |||||
checker.execs({{4, 16, 33, 22}, {4, 3, 3}, {4, 16, 9, 12}}); | |||||
} | |||||
} | |||||
} | |||||
TEST_F(NAIVE, WARP_PERSPECTIVE_MULTI_SRC_WITH_IDX) { | |||||
using Param = WarpPerspective::Param; | |||||
WarpPerspective::Param param; | |||||
auto extra_impl = [¶m, this](const TensorNDArray& tensors) { | |||||
//! split src | |||||
TensorND src = tensors[0]; // n h w c | |||||
size_t n = src.layout[0]; | |||||
TensorNDArray srcs; // n 个 1 h w c | |||||
TensorLayoutArray srcs_layouts; | |||||
for (size_t i = 0; i < n; i++) { | |||||
TensorLayout ly; | |||||
ly = TensorLayout{ | |||||
{1, src.layout[1], src.layout[2], src.layout[3]}, src.layout.dtype}; | |||||
srcs.emplace_back(malloc(ly.span().dist_byte()), ly); | |||||
srcs_layouts.emplace_back(ly); | |||||
} | |||||
auto split = handle()->create_operator<SplitForward>(); | |||||
split->param().axis = 0; | |||||
auto split_ws_size = split->get_workspace_in_bytes(src.layout, srcs_layouts); | |||||
dt_byte* split_ws_ptr = static_cast<dt_byte*>(malloc(split_ws_size)); | |||||
Workspace split_ws{split_ws_ptr, split_ws_size}; | |||||
split->exec(src, srcs, split_ws); | |||||
auto warp_perspective = handle()->create_operator<WarpPerspective>(); | |||||
warp_perspective->param() = param; | |||||
auto warp_ws_size = warp_perspective->get_workspace_in_bytes( | |||||
srcs_layouts, tensors[1].layout, tensors[2].layout, tensors[3].layout); | |||||
dt_byte* warp_ws_ptr = static_cast<dt_byte*>(malloc(warp_ws_size)); | |||||
Workspace warp_ws{warp_ws_ptr, warp_ws_size}; | |||||
warp_perspective->exec(srcs, tensors[1], tensors[2], tensors[3], warp_ws); | |||||
free(split_ws_ptr); | |||||
free(warp_ws_ptr); | |||||
for (auto&& s : srcs) { | |||||
free(s.raw_ptr()); | |||||
} | |||||
}; | |||||
{ | |||||
// Float32 | |||||
Checker<WarpPerspectiveForward, WarpPerspectiveMatIdxProxy> checker(handle()); | |||||
WarpPerspectiveMatRNG rng; | |||||
checker.set_rng(1, &rng); | |||||
checker.set_dtype(0, dtype::Float32()); | |||||
checker.set_dtype(1, dtype::Float32()); | |||||
checker.set_dtype(2, dtype::Int32()); | |||||
checker.set_dtype(3, dtype::Float32()); | |||||
checker.set_extra_opr_impl(extra_impl); | |||||
// NHWC | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NHWC; | |||||
checker.set_param(param); | |||||
UniformIntRNG idx_rng{0, 0}; | |||||
checker.set_rng(2, &idx_rng); | |||||
checker.execs({{1, 2, 2, 4}, {1, 3, 3}, {1}, {1, 2, 2, 4}}); | |||||
idx_rng = UniformIntRNG{0, 1}; | |||||
checker.set_rng(2, &idx_rng); | |||||
checker.execs({{2, 10, 10, 4}, {1, 3, 3}, {1}, {1, 10, 12, 4}}); | |||||
idx_rng = UniformIntRNG{0, 2}; | |||||
checker.set_rng(2, &idx_rng); | |||||
checker.execs({{3, 25, 24, 8}, {2, 3, 3}, {2}, {2, 12, 10, 8}}); | |||||
checker.execs({{4, 33, 22, 16}, {2, 3, 3}, {2}, {2, 9, 12, 16}}); | |||||
} | |||||
// NCHW | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NCHW; | |||||
checker.set_param(param); | |||||
UniformIntRNG idx_rng{0, 0}; | |||||
checker.set_rng(2, &idx_rng); | |||||
checker.execs({{1, 4, 2, 2}, {1, 3, 3}, {1}, {1, 4, 2, 2}}); | |||||
idx_rng = UniformIntRNG{0, 1}; | |||||
checker.set_rng(2, &idx_rng); | |||||
checker.execs({{2, 4, 10, 10}, {1, 3, 3}, {1}, {1, 4, 10, 12}}); | |||||
idx_rng = UniformIntRNG{0, 2}; | |||||
checker.set_rng(2, &idx_rng); | |||||
checker.execs({{3, 8, 25, 24}, {2, 3, 3}, {2}, {2, 8, 12, 10}}); | |||||
checker.execs({{4, 16, 33, 22}, {2, 3, 3}, {2}, {2, 16, 9, 12}}); | |||||
} | |||||
} | |||||
{ | |||||
// Float16 | |||||
Checker<WarpPerspectiveForward, WarpPerspectiveMatIdxProxy> checker(handle()); | |||||
WarpPerspectiveMatRNG rng; | |||||
checker.set_rng(1, &rng); | |||||
checker.set_dtype(0, dtype::Float16()); | |||||
checker.set_dtype(1, dtype::Float32()); | |||||
checker.set_dtype(2, dtype::Int32()); | |||||
checker.set_dtype(3, dtype::Float16()); | |||||
checker.set_extra_opr_impl(extra_impl); | |||||
// NHWC | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NHWC; | |||||
checker.set_param(param); | |||||
UniformIntRNG idx_rng{0, 0}; | |||||
checker.set_rng(2, &idx_rng); | |||||
checker.execs({{1, 2, 2, 4}, {1, 3, 3}, {1}, {1, 2, 2, 4}}); | |||||
idx_rng = UniformIntRNG{0, 1}; | |||||
checker.set_rng(2, &idx_rng); | |||||
checker.execs({{2, 10, 10, 4}, {1, 3, 3}, {1}, {1, 10, 12, 4}}); | |||||
idx_rng = UniformIntRNG{0, 2}; | |||||
checker.set_rng(2, &idx_rng); | |||||
checker.execs({{3, 25, 24, 8}, {2, 3, 3}, {2}, {2, 12, 10, 8}}); | |||||
checker.execs({{4, 33, 22, 16}, {2, 3, 3}, {2}, {2, 9, 12, 16}}); | |||||
} | |||||
// NCHW | |||||
for (auto bmode : | |||||
{WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, | |||||
WarpPerspective::BorderMode::REPLICATE, | |||||
WarpPerspective::BorderMode::CONSTANT}) { | |||||
param.border_val = 0.3f; | |||||
param.bmode = bmode; | |||||
param.imode = Param::InterpolationMode::LINEAR; | |||||
param.format = Param::Format::NCHW; | |||||
checker.set_param(param); | |||||
UniformIntRNG idx_rng{0, 0}; | |||||
checker.set_rng(2, &idx_rng); | |||||
checker.execs({{1, 4, 2, 2}, {1, 3, 3}, {1}, {1, 4, 2, 2}}); | |||||
idx_rng = UniformIntRNG{0, 1}; | |||||
checker.set_rng(2, &idx_rng); | |||||
checker.execs({{2, 4, 10, 10}, {1, 3, 3}, {1}, {1, 4, 10, 12}}); | |||||
idx_rng = UniformIntRNG{0, 2}; | |||||
checker.set_rng(2, &idx_rng); | |||||
checker.execs({{3, 8, 25, 24}, {2, 3, 3}, {2}, {2, 8, 12, 10}}); | |||||
checker.execs({{4, 16, 33, 22}, {2, 3, 3}, {2}, {2, 16, 9, 12}}); | |||||
} | |||||
} | |||||
} | |||||
TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW4) { | TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW4) { | ||||
using Param = WarpPerspective::Param; | using Param = WarpPerspective::Param; | ||||