diff --git a/dnn/src/common/local/opr_impl.cpp b/dnn/src/common/local/opr_impl.cpp index 355cb465..ab44be3c 100644 --- a/dnn/src/common/local/opr_impl.cpp +++ b/dnn/src/common/local/opr_impl.cpp @@ -29,7 +29,11 @@ void LocalBase::deduce_layout_fwd(const TensorLayout &src, auto errmsg_c = errmsg.c_str(); MEGDNN_MARK_USED_VAR(errmsg_c); - megdnn_assert_contiguous(src); + //! in batch dim we don't need contiguous + TensorLayout src_contig = src; + src_contig.init_contiguous_stride(); + src_contig.stride[0] = src.stride[0]; + megdnn_assert_eq_layout(src_contig, src); megdnn_assert_contiguous(filter); megdnn_assert(src.ndim == 4_z, "%s", errmsg_c); megdnn_assert(filter.ndim == 6_z, "%s", errmsg_c); @@ -67,6 +71,8 @@ void LocalBase::check_layout_fwd(const TensorLayout &src, megdnn_assert_eq_dtype(src, filter); megdnn_assert_eq_dtype(src, dst); deduce_layout_fwd(src, filter, dst_expected); + //! in batch dim we don't need contiguous + dst_expected.stride[0] = dst.stride[0]; megdnn_assert_eq_layout(dst_expected, dst); megdnn_assert(src.dtype == filter.dtype && src.dtype == dst.dtype); diff --git a/dnn/src/cuda/group_local/forward/opr_impl.cpp b/dnn/src/cuda/group_local/forward/opr_impl.cpp index 9bde4f76..23d10e04 100644 --- a/dnn/src/cuda/group_local/forward/opr_impl.cpp +++ b/dnn/src/cuda/group_local/forward/opr_impl.cpp @@ -6,141 +6,121 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/cuda/group_local/opr_impl.h" +#include +#include "megdnn/opr_param_defs.h" #include "src/common/utils.h" -#include "src/cuda/local/local.cuh" #include "src/cuda/utils.h" #include "src/cuda/group_local/forward/kern.cuh" +#include "src/cuda/local/opr_impl.h" + +#include "src/cuda/local/local.cuh" + +using namespace megdnn; +using namespace cuda; + +namespace { + +std::unique_ptr get_opr(Handle* handle, + param::Convolution param) { + auto&& opr = handle->create_operator(); + opr->param() = param; + return std::move(opr); +} + +template +void incr_ptr(T*& dst, ptrdiff_t delta) { + dst = reinterpret_cast(reinterpret_cast(dst) + delta); +} + +TensorLayout prepare_src_dst(const TensorLayout& input, size_t g) { + TensorLayout ret = input; + megdnn_assert(ret[1] % g == 0); + ret[1] = ret[1] / g; + ret.init_contiguous_stride(); + //! change stride of batch + ret.stride[0] = input.stride[0]; + return ret; +} + +TensorLayout prepare_filter(const TensorLayout& filter) { + //! group, OH, OW, ICg, FH, FW, OCg -> OH, OW, IcCg, FH, FW, OCg + return {{filter[1], filter[2], filter[3], filter[4], filter[5], filter[6]}, + filter.dtype}; +} -namespace megdnn { -namespace cuda { +} // namespace void GroupLocalForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) -{ + _megdnn_tensor_in filter, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) { megdnn_assert(src.layout.dtype == dtype::Float32(), "cuda do not support fp16 group local operator"); check_exec(src.layout, filter.layout, dst.layout, workspace.size); + auto handle = concrete_handle(this->handle()); auto G = filter.layout[0]; - auto N = src.layout.shape[0], IC = src.layout.shape[1]/G, - IH = src.layout.shape[2], IW = src.layout.shape[3], - OC = dst.layout.shape[1]/G, + auto IH = src.layout.shape[2], IW = src.layout.shape[3], OH = dst.layout.shape[2], OW = dst.layout.shape[3]; - auto FH = filter.layout.shape[4], FW = filter.layout.shape[5]; - auto PH = param().pad_h, PW = param().pad_w; - auto SH = param().stride_h, SW = param().stride_w; - const float *sptr = src.ptr(); - const float *fptr = filter.ptr(); - float *dptr = dst.ptr(); - float *wptr = workspace.ptr(); - auto handle = concrete_handle(this->handle()); - auto stream = cuda_stream(this->handle()); - auto cublas = cublas_handle(this->handle()); - auto one = handle->one_device(); - auto zero = handle->zero_device(); if (prefer_inference_kernel(src.layout, filter.layout, dst.layout)) { - group_local::exec(sptr, fptr, dptr, wptr, - N, IC, IH, IW, - OC, OH, OW, - FH, FW, - G, - PH, PW, - SH, SW, - stream - ); - } else if (local::can_forward_proxy_convnet(N, IC, IH, IW, - OC, OH, OW, - FH, FW, - G*IC*IH*IW, G*OC*OH*OW, - PH, PW, - SH, SW)) - { - // use convnet - for (size_t g = 0; g < G; ++g) { - local::forward_proxy_convnet(sptr + g*IC*IH*IW, - fptr + g*OH*OW*IC*FH*FW*OC, - dptr + g*OC*OH*OW, - wptr, - N, IC, IH, IW, - OC, OH, OW, - FH, FW, - G*IC*IH*IW, G*OC*OH*OW, - PH, PW, - SH, SW, - cublas, stream, one, zero); - } + auto N = src.layout.shape[0], ICg = src.layout.shape[1] / G, + OCg = dst.layout.shape[1] / G; + auto FH = filter.layout.shape[4], FW = filter.layout.shape[5]; + auto PH = param().pad_h, PW = param().pad_w; + auto SH = param().stride_h, SW = param().stride_w; + const float* sptr = src.ptr(); + const float* fptr = filter.ptr(); + float* dptr = dst.ptr(); + float* wptr = workspace.ptr(); + auto stream = cuda_stream(this->handle()); + + group_local::exec(sptr, fptr, dptr, wptr, N, ICg, IH, IW, OCg, OH, OW, + FH, FW, G, PH, PW, SH, SW, stream); } else { - local::check_input(N, IC, IH, IW, OC, OH, OW, FH, FW, - G*IC*IH*IW, G*OC*OH*OW, - PH, PW, - SH, SW, - true); - // do not use convnet + auto&& opr = get_opr(handle, param()); + TensorND src_g = {src.raw_ptr, prepare_src_dst(src.layout, G)}; + TensorND dst_g = {dst.raw_ptr, prepare_src_dst(dst.layout, G)}; + TensorND filter_g = {filter.raw_ptr, prepare_filter(filter.layout)}; for (size_t g = 0; g < G; ++g) { - local::forward_proxy_weiming(sptr + g*IC*IH*IW, - fptr + g*OH*OW*IC*FH*FW*OC, - dptr + g*OC*OH*OW, - N, IC, IH, IW, - OC, OH, OW, - FH, FW, - G*IC*IH*IW, G*OC*OH*OW, - PH, PW, - SH, SW, - true, stream); + opr->exec(src_g, filter_g, dst_g, workspace); + incr_ptr(src_g.raw_ptr, src_g.layout.stride[1] * + src_g.layout.shape[1] * + src_g.layout.dtype.size()); + incr_ptr(dst_g.raw_ptr, dst_g.layout.stride[1] * + dst_g.layout.shape[1] * + dst_g.layout.dtype.size()); + incr_ptr(filter_g.raw_ptr, filter_g.layout.span().dist_byte()); } } } -GroupLocalForwardImpl::GroupLocalForwardImpl(Handle *handle): - GroupLocalForward(handle) -{ -} +GroupLocalForwardImpl::GroupLocalForwardImpl(Handle* handle) + : GroupLocalForward(handle) {} -size_t GroupLocalForwardImpl::get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst) -{ - auto G = filter[0]; - auto N = src.shape[0], IC = src.shape[1]/G, - IH = src.shape[2], IW = src.shape[3], - OC = dst.shape[1]/G, - OH = dst.shape[2], OW = dst.shape[3]; - auto FH = filter.shape[4], FW = filter.shape[5]; - auto PH = param().pad_h, PW = param().pad_w; - auto SH = param().stride_h, SW = param().stride_w; +size_t GroupLocalForwardImpl::get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst) { if (prefer_inference_kernel(src, filter, dst)) { return 0; - } else if (local::can_forward_proxy_convnet(N, IC, IH, IW, - OC, OH, OW, - FH, FW, - G*IC*IH*IW, G*OC*OH*OW, - PH, PW, - SH, SW)) - { - auto res = local::get_workspace_in_floats_forward_proxy_convnet(N, - IC, IH, IW, - OC, OH, OW, - FH, FW, - G*IC*IH*IW, G*OC*OH*OW, - PH, PW, - SH, SW) * sizeof(float); - return res; } else { - return 0; + auto G = filter[0]; + TensorLayout src_g = prepare_src_dst(src, G); + TensorLayout dst_g = prepare_src_dst(dst, G); + TensorLayout filter_g = prepare_filter(filter); + return get_opr(handle(), param()) + ->get_workspace_in_bytes(src_g, filter_g, dst_g); } } -bool GroupLocalForwardImpl::prefer_inference_kernel(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst) -{ +bool GroupLocalForwardImpl::prefer_inference_kernel(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst) { MEGDNN_MARK_USED_VAR(filter); MEGDNN_MARK_USED_VAR(dst); auto handle = concrete_handle(this->handle()); @@ -149,6 +129,4 @@ bool GroupLocalForwardImpl::prefer_inference_kernel(const TensorLayout &src, group_local::get_share_mem_in_bytes(IH, IW); } -} // namespace cuda -} // namespace megdnn - // vim: syntax=cpp.doxygen +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/local/forward.cpp b/dnn/src/cuda/local/forward.cpp index 24b2af6c..c8c4f38c 100644 --- a/dnn/src/cuda/local/forward.cpp +++ b/dnn/src/cuda/local/forward.cpp @@ -78,6 +78,8 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src, auto cublas = cublas_handle(this->handle()); auto one = handle->one_device(); auto zero = handle->zero_device(); + size_t src_batch_strd = src.layout.stride[0]; + size_t dst_batch_strd = dst.layout.stride[0]; if (use_cuda_convnet(src.layout, filter.layout, dst.layout)) { local::forward_proxy_convnet(src.ptr(), filter.ptr(), @@ -87,7 +89,7 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src, IC, IH, IW, OC, OH, OW, FH, FW, - IC*IH*IW, OC*OH*OW, + src_batch_strd, dst_batch_strd, param().pad_h, param().pad_w, param().stride_h, param().stride_w, cublas, stream, @@ -105,7 +107,7 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src, IC, IH, IW, OC, OH, OW, FH, FW, - IC*IH*IW, OC*OH*OW, + src_batch_strd, dst_batch_strd, param().pad_h, param().pad_w, param().stride_h, param().stride_w, is_xcorr, @@ -124,12 +126,14 @@ size_t LocalForwardImpl::get_workspace_in_bytes(const TensorLayout &src, FH = filter.shape[3], FW = filter.shape[4]; auto PH = param().pad_h, PW = param().pad_w, SH = param().stride_h, SW = param().stride_w; + size_t src_batch_strd = src.stride[0]; + size_t dst_batch_strd = dst.stride[0]; if (use_cuda_convnet(src, filter, dst)) { res = local::get_workspace_in_floats_forward_proxy_convnet(N, IC, IH, IW, OC, OH, OW, FH, FW, - IC*IH*IW, OC*OH*OW, + src_batch_strd, dst_batch_strd, PH, PW, SH, SW) * sizeof(dt_float32); } else {