|
@@ -6,141 +6,121 @@ |
|
|
* |
|
|
* |
|
|
* Unless required by applicable law or agreed to in writing, |
|
|
* Unless required by applicable law or agreed to in writing, |
|
|
* software distributed under the License is distributed on an |
|
|
* software distributed under the License is distributed on an |
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
|
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
|
|
|
|
* implied. |
|
|
*/ |
|
|
*/ |
|
|
#include "src/cuda/group_local/opr_impl.h" |
|
|
#include "src/cuda/group_local/opr_impl.h" |
|
|
|
|
|
#include <memory> |
|
|
|
|
|
|
|
|
|
|
|
#include "megdnn/opr_param_defs.h" |
|
|
#include "src/common/utils.h" |
|
|
#include "src/common/utils.h" |
|
|
#include "src/cuda/local/local.cuh" |
|
|
|
|
|
#include "src/cuda/utils.h" |
|
|
#include "src/cuda/utils.h" |
|
|
|
|
|
|
|
|
#include "src/cuda/group_local/forward/kern.cuh" |
|
|
#include "src/cuda/group_local/forward/kern.cuh" |
|
|
|
|
|
#include "src/cuda/local/opr_impl.h" |
|
|
|
|
|
|
|
|
|
|
|
#include "src/cuda/local/local.cuh" |
|
|
|
|
|
|
|
|
|
|
|
using namespace megdnn; |
|
|
|
|
|
using namespace cuda; |
|
|
|
|
|
|
|
|
|
|
|
namespace { |
|
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<LocalForward> get_opr(Handle* handle, |
|
|
|
|
|
param::Convolution param) { |
|
|
|
|
|
auto&& opr = handle->create_operator<LocalForward>(); |
|
|
|
|
|
opr->param() = param; |
|
|
|
|
|
return std::move(opr); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
|
|
|
void incr_ptr(T*& dst, ptrdiff_t delta) { |
|
|
|
|
|
dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
TensorLayout prepare_src_dst(const TensorLayout& input, size_t g) { |
|
|
|
|
|
TensorLayout ret = input; |
|
|
|
|
|
megdnn_assert(ret[1] % g == 0); |
|
|
|
|
|
ret[1] = ret[1] / g; |
|
|
|
|
|
ret.init_contiguous_stride(); |
|
|
|
|
|
//! change stride of batch |
|
|
|
|
|
ret.stride[0] = input.stride[0]; |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
TensorLayout prepare_filter(const TensorLayout& filter) { |
|
|
|
|
|
//! group, OH, OW, ICg, FH, FW, OCg -> OH, OW, IcCg, FH, FW, OCg |
|
|
|
|
|
return {{filter[1], filter[2], filter[3], filter[4], filter[5], filter[6]}, |
|
|
|
|
|
filter.dtype}; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
namespace megdnn { |
|
|
|
|
|
namespace cuda { |
|
|
|
|
|
|
|
|
} // namespace |
|
|
|
|
|
|
|
|
void GroupLocalForwardImpl::exec(_megdnn_tensor_in src, |
|
|
void GroupLocalForwardImpl::exec(_megdnn_tensor_in src, |
|
|
_megdnn_tensor_in filter, |
|
|
|
|
|
_megdnn_tensor_out dst, |
|
|
|
|
|
_megdnn_workspace workspace) |
|
|
|
|
|
{ |
|
|
|
|
|
|
|
|
_megdnn_tensor_in filter, |
|
|
|
|
|
_megdnn_tensor_out dst, |
|
|
|
|
|
_megdnn_workspace workspace) { |
|
|
megdnn_assert(src.layout.dtype == dtype::Float32(), |
|
|
megdnn_assert(src.layout.dtype == dtype::Float32(), |
|
|
"cuda do not support fp16 group local operator"); |
|
|
"cuda do not support fp16 group local operator"); |
|
|
check_exec(src.layout, filter.layout, dst.layout, workspace.size); |
|
|
check_exec(src.layout, filter.layout, dst.layout, workspace.size); |
|
|
|
|
|
|
|
|
|
|
|
auto handle = concrete_handle(this->handle()); |
|
|
auto G = filter.layout[0]; |
|
|
auto G = filter.layout[0]; |
|
|
auto N = src.layout.shape[0], IC = src.layout.shape[1]/G, |
|
|
|
|
|
IH = src.layout.shape[2], IW = src.layout.shape[3], |
|
|
|
|
|
OC = dst.layout.shape[1]/G, |
|
|
|
|
|
|
|
|
auto IH = src.layout.shape[2], IW = src.layout.shape[3], |
|
|
OH = dst.layout.shape[2], OW = dst.layout.shape[3]; |
|
|
OH = dst.layout.shape[2], OW = dst.layout.shape[3]; |
|
|
auto FH = filter.layout.shape[4], FW = filter.layout.shape[5]; |
|
|
|
|
|
auto PH = param().pad_h, PW = param().pad_w; |
|
|
|
|
|
auto SH = param().stride_h, SW = param().stride_w; |
|
|
|
|
|
const float *sptr = src.ptr<dt_float32>(); |
|
|
|
|
|
const float *fptr = filter.ptr<dt_float32>(); |
|
|
|
|
|
float *dptr = dst.ptr<dt_float32>(); |
|
|
|
|
|
float *wptr = workspace.ptr<dt_float32>(); |
|
|
|
|
|
auto handle = concrete_handle(this->handle()); |
|
|
|
|
|
auto stream = cuda_stream(this->handle()); |
|
|
|
|
|
auto cublas = cublas_handle(this->handle()); |
|
|
|
|
|
auto one = handle->one_device(); |
|
|
|
|
|
auto zero = handle->zero_device(); |
|
|
|
|
|
if (prefer_inference_kernel(src.layout, filter.layout, dst.layout)) { |
|
|
if (prefer_inference_kernel(src.layout, filter.layout, dst.layout)) { |
|
|
group_local::exec(sptr, fptr, dptr, wptr, |
|
|
|
|
|
N, IC, IH, IW, |
|
|
|
|
|
OC, OH, OW, |
|
|
|
|
|
FH, FW, |
|
|
|
|
|
G, |
|
|
|
|
|
PH, PW, |
|
|
|
|
|
SH, SW, |
|
|
|
|
|
stream |
|
|
|
|
|
); |
|
|
|
|
|
} else if (local::can_forward_proxy_convnet(N, IC, IH, IW, |
|
|
|
|
|
OC, OH, OW, |
|
|
|
|
|
FH, FW, |
|
|
|
|
|
G*IC*IH*IW, G*OC*OH*OW, |
|
|
|
|
|
PH, PW, |
|
|
|
|
|
SH, SW)) |
|
|
|
|
|
{ |
|
|
|
|
|
// use convnet |
|
|
|
|
|
for (size_t g = 0; g < G; ++g) { |
|
|
|
|
|
local::forward_proxy_convnet(sptr + g*IC*IH*IW, |
|
|
|
|
|
fptr + g*OH*OW*IC*FH*FW*OC, |
|
|
|
|
|
dptr + g*OC*OH*OW, |
|
|
|
|
|
wptr, |
|
|
|
|
|
N, IC, IH, IW, |
|
|
|
|
|
OC, OH, OW, |
|
|
|
|
|
FH, FW, |
|
|
|
|
|
G*IC*IH*IW, G*OC*OH*OW, |
|
|
|
|
|
PH, PW, |
|
|
|
|
|
SH, SW, |
|
|
|
|
|
cublas, stream, one, zero); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
auto N = src.layout.shape[0], ICg = src.layout.shape[1] / G, |
|
|
|
|
|
OCg = dst.layout.shape[1] / G; |
|
|
|
|
|
auto FH = filter.layout.shape[4], FW = filter.layout.shape[5]; |
|
|
|
|
|
auto PH = param().pad_h, PW = param().pad_w; |
|
|
|
|
|
auto SH = param().stride_h, SW = param().stride_w; |
|
|
|
|
|
const float* sptr = src.ptr<dt_float32>(); |
|
|
|
|
|
const float* fptr = filter.ptr<dt_float32>(); |
|
|
|
|
|
float* dptr = dst.ptr<dt_float32>(); |
|
|
|
|
|
float* wptr = workspace.ptr<dt_float32>(); |
|
|
|
|
|
auto stream = cuda_stream(this->handle()); |
|
|
|
|
|
|
|
|
|
|
|
group_local::exec(sptr, fptr, dptr, wptr, N, ICg, IH, IW, OCg, OH, OW, |
|
|
|
|
|
FH, FW, G, PH, PW, SH, SW, stream); |
|
|
} else { |
|
|
} else { |
|
|
local::check_input(N, IC, IH, IW, OC, OH, OW, FH, FW, |
|
|
|
|
|
G*IC*IH*IW, G*OC*OH*OW, |
|
|
|
|
|
PH, PW, |
|
|
|
|
|
SH, SW, |
|
|
|
|
|
true); |
|
|
|
|
|
// do not use convnet |
|
|
|
|
|
|
|
|
auto&& opr = get_opr(handle, param()); |
|
|
|
|
|
TensorND src_g = {src.raw_ptr, prepare_src_dst(src.layout, G)}; |
|
|
|
|
|
TensorND dst_g = {dst.raw_ptr, prepare_src_dst(dst.layout, G)}; |
|
|
|
|
|
TensorND filter_g = {filter.raw_ptr, prepare_filter(filter.layout)}; |
|
|
for (size_t g = 0; g < G; ++g) { |
|
|
for (size_t g = 0; g < G; ++g) { |
|
|
local::forward_proxy_weiming(sptr + g*IC*IH*IW, |
|
|
|
|
|
fptr + g*OH*OW*IC*FH*FW*OC, |
|
|
|
|
|
dptr + g*OC*OH*OW, |
|
|
|
|
|
N, IC, IH, IW, |
|
|
|
|
|
OC, OH, OW, |
|
|
|
|
|
FH, FW, |
|
|
|
|
|
G*IC*IH*IW, G*OC*OH*OW, |
|
|
|
|
|
PH, PW, |
|
|
|
|
|
SH, SW, |
|
|
|
|
|
true, stream); |
|
|
|
|
|
|
|
|
opr->exec(src_g, filter_g, dst_g, workspace); |
|
|
|
|
|
incr_ptr(src_g.raw_ptr, src_g.layout.stride[1] * |
|
|
|
|
|
src_g.layout.shape[1] * |
|
|
|
|
|
src_g.layout.dtype.size()); |
|
|
|
|
|
incr_ptr(dst_g.raw_ptr, dst_g.layout.stride[1] * |
|
|
|
|
|
dst_g.layout.shape[1] * |
|
|
|
|
|
dst_g.layout.dtype.size()); |
|
|
|
|
|
incr_ptr(filter_g.raw_ptr, filter_g.layout.span().dist_byte()); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
GroupLocalForwardImpl::GroupLocalForwardImpl(Handle *handle): |
|
|
|
|
|
GroupLocalForward(handle) |
|
|
|
|
|
{ |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
GroupLocalForwardImpl::GroupLocalForwardImpl(Handle* handle) |
|
|
|
|
|
: GroupLocalForward(handle) {} |
|
|
|
|
|
|
|
|
size_t GroupLocalForwardImpl::get_workspace_in_bytes(const TensorLayout &src, |
|
|
|
|
|
const TensorLayout &filter, |
|
|
|
|
|
const TensorLayout &dst) |
|
|
|
|
|
{ |
|
|
|
|
|
auto G = filter[0]; |
|
|
|
|
|
auto N = src.shape[0], IC = src.shape[1]/G, |
|
|
|
|
|
IH = src.shape[2], IW = src.shape[3], |
|
|
|
|
|
OC = dst.shape[1]/G, |
|
|
|
|
|
OH = dst.shape[2], OW = dst.shape[3]; |
|
|
|
|
|
auto FH = filter.shape[4], FW = filter.shape[5]; |
|
|
|
|
|
auto PH = param().pad_h, PW = param().pad_w; |
|
|
|
|
|
auto SH = param().stride_h, SW = param().stride_w; |
|
|
|
|
|
|
|
|
size_t GroupLocalForwardImpl::get_workspace_in_bytes(const TensorLayout& src, |
|
|
|
|
|
const TensorLayout& filter, |
|
|
|
|
|
const TensorLayout& dst) { |
|
|
if (prefer_inference_kernel(src, filter, dst)) { |
|
|
if (prefer_inference_kernel(src, filter, dst)) { |
|
|
return 0; |
|
|
return 0; |
|
|
} else if (local::can_forward_proxy_convnet(N, IC, IH, IW, |
|
|
|
|
|
OC, OH, OW, |
|
|
|
|
|
FH, FW, |
|
|
|
|
|
G*IC*IH*IW, G*OC*OH*OW, |
|
|
|
|
|
PH, PW, |
|
|
|
|
|
SH, SW)) |
|
|
|
|
|
{ |
|
|
|
|
|
auto res = local::get_workspace_in_floats_forward_proxy_convnet(N, |
|
|
|
|
|
IC, IH, IW, |
|
|
|
|
|
OC, OH, OW, |
|
|
|
|
|
FH, FW, |
|
|
|
|
|
G*IC*IH*IW, G*OC*OH*OW, |
|
|
|
|
|
PH, PW, |
|
|
|
|
|
SH, SW) * sizeof(float); |
|
|
|
|
|
return res; |
|
|
|
|
|
} else { |
|
|
} else { |
|
|
return 0; |
|
|
|
|
|
|
|
|
auto G = filter[0]; |
|
|
|
|
|
TensorLayout src_g = prepare_src_dst(src, G); |
|
|
|
|
|
TensorLayout dst_g = prepare_src_dst(dst, G); |
|
|
|
|
|
TensorLayout filter_g = prepare_filter(filter); |
|
|
|
|
|
return get_opr(handle(), param()) |
|
|
|
|
|
->get_workspace_in_bytes(src_g, filter_g, dst_g); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool GroupLocalForwardImpl::prefer_inference_kernel(const TensorLayout &src, |
|
|
|
|
|
const TensorLayout &filter, |
|
|
|
|
|
const TensorLayout &dst) |
|
|
|
|
|
{ |
|
|
|
|
|
|
|
|
bool GroupLocalForwardImpl::prefer_inference_kernel(const TensorLayout& src, |
|
|
|
|
|
const TensorLayout& filter, |
|
|
|
|
|
const TensorLayout& dst) { |
|
|
MEGDNN_MARK_USED_VAR(filter); |
|
|
MEGDNN_MARK_USED_VAR(filter); |
|
|
MEGDNN_MARK_USED_VAR(dst); |
|
|
MEGDNN_MARK_USED_VAR(dst); |
|
|
auto handle = concrete_handle(this->handle()); |
|
|
auto handle = concrete_handle(this->handle()); |
|
@@ -149,6 +129,4 @@ bool GroupLocalForwardImpl::prefer_inference_kernel(const TensorLayout &src, |
|
|
group_local::get_share_mem_in_bytes(IH, IW); |
|
|
group_local::get_share_mem_in_bytes(IH, IW); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
} // namespace cuda |
|
|
|
|
|
} // namespace megdnn |
|
|
|
|
|
// vim: syntax=cpp.doxygen |
|
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen |