GitOrigin-RevId: f65feae5cc
release-1.7
@@ -588,7 +588,7 @@ if(MGE_WITH_CUDA) | |||||
set(CMAKE_CUDA_FLAGS_MINSIZEREL "-Os") | set(CMAKE_CUDA_FLAGS_MINSIZEREL "-Os") | ||||
if(MSVC OR WIN32) | if(MSVC OR WIN32) | ||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xfatbin -compress-all") | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xfatbin -compress-all") | ||||
set(CCBIN_FLAG "${CCBIN_FLAG} /wd4819 /wd4334 /wd4267 /wd4002 /wd4244 /wd4068 /std:c++14") | |||||
set(CCBIN_FLAG "${CCBIN_FLAG} /wd4819 /wd4334 /wd4267 /wd4002 /wd4244 /wd4068 /std:c++14 /bigobj") | |||||
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") | if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") | ||||
set(CCBIN_FLAG "${CCBIN_FLAG} -D_ITERATOR_DEBUG_LEVEL=2 -MTd") | set(CCBIN_FLAG "${CCBIN_FLAG} -D_ITERATOR_DEBUG_LEVEL=2 -MTd") | ||||
endif() | endif() | ||||
@@ -365,27 +365,22 @@ void aarch64::RelayoutForwardImpl::exec( | |||||
relayout::TransposeParam trans_param; | relayout::TransposeParam trans_param; | ||||
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param, true); | bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param, true); | ||||
if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { | if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { | ||||
auto sptr = static_cast<TransposeByte*>(src.raw_ptr), | |||||
dptr = static_cast<TransposeByte*>(dst.raw_ptr); | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<TransposeByte>( | MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<TransposeByte>( | ||||
trans_param.batch, trans_param.m, trans_param.n, sptr, dptr, | |||||
trans_param.stride_m)); | |||||
trans_param.batch, trans_param.m, trans_param.n, | |||||
static_cast<TransposeByte*>(src.raw_ptr()), | |||||
static_cast<TransposeByte*>(dst.raw_ptr()), trans_param.stride_m)); | |||||
return; | return; | ||||
} else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 2) { | } else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 2) { | ||||
auto sptr = static_cast<Transpose2Byte*>(src.raw_ptr), | |||||
dptr = static_cast<Transpose2Byte*>(dst.raw_ptr); | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<Transpose2Byte>( | MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<Transpose2Byte>( | ||||
trans_param.batch, trans_param.m, trans_param.n, sptr, dptr, | |||||
trans_param.stride_m)); | |||||
trans_param.batch, trans_param.m, trans_param.n, | |||||
static_cast<Transpose2Byte*>(src.raw_ptr()), | |||||
static_cast<Transpose2Byte*>(dst.raw_ptr()), trans_param.stride_m)); | |||||
return; | return; | ||||
} else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 4) { | } else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 4) { | ||||
auto sptr = static_cast<Transpose4Byte*>(src.raw_ptr), | |||||
dptr = static_cast<Transpose4Byte*>(dst.raw_ptr); | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<Transpose4Byte>( | MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<Transpose4Byte>( | ||||
trans_param.batch, trans_param.m, trans_param.n, sptr, dptr, | |||||
trans_param.stride_m)); | |||||
trans_param.batch, trans_param.m, trans_param.n, | |||||
static_cast<Transpose4Byte*>(src.raw_ptr()), | |||||
static_cast<Transpose4Byte*>(dst.raw_ptr()), trans_param.stride_m)); | |||||
return; | return; | ||||
} | } | ||||
@@ -358,11 +358,13 @@ void RotateImpl::exec( | |||||
return fallback::RotateImpl::exec(src, dst, workspace); | return fallback::RotateImpl::exec(src, dst, workspace); | ||||
} | } | ||||
auto clockwise = param().clockwise; | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR({ | MEGDNN_DISPATCH_CPU_KERN_OPR({ | ||||
for (size_t i = 0; i < src.layout.shape[0]; ++i) { | for (size_t i = 0; i < src.layout.shape[0]; ++i) { | ||||
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i); | Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i); | ||||
Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i); | Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i); | ||||
rotate(src_mat, dst_mat, param().clockwise); | |||||
rotate(src_mat, dst_mat, clockwise); | |||||
} | } | ||||
}); | }); | ||||
} | } | ||||
@@ -205,16 +205,16 @@ void megdnn::aarch64::warp_perspective_cv_exec( | |||||
megdnn_assert( | megdnn_assert( | ||||
ch == 1 || ch == 3 || ch == 2, | ch == 1 || ch == 3 || ch == 2, | ||||
"unsupported src channel: %zu, avaiable channel size: 1/2/3", ch); | "unsupported src channel: %zu, avaiable channel size: 1/2/3", ch); | ||||
const float* trans_ptr = trans.ptr<dt_float32>(); | |||||
const int* midx_ptr = nullptr; | |||||
if (mat_idx.raw_ptr) { | |||||
megdnn_assert(mat_idx.layout.ndim == 1); | |||||
midx_ptr = mat_idx.ptr<int>(); | |||||
} | |||||
if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { | if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { | ||||
#define cb(_imode, _bmode, _ch) \ | #define cb(_imode, _bmode, _ch) \ | ||||
auto task = [src, trans_ptr, midx_ptr, dst, border_value, parallelism_batch]( \ | |||||
auto task = [src, trans, mat_idx, dst, border_value, parallelism_batch]( \ | |||||
size_t index, size_t) { \ | size_t index, size_t) { \ | ||||
const float* trans_ptr = trans.ptr<dt_float32>(); \ | |||||
const int* midx_ptr = nullptr; \ | |||||
if (mat_idx.raw_ptr()) { \ | |||||
megdnn_assert(mat_idx.layout.ndim == 1); \ | |||||
midx_ptr = mat_idx.ptr<int>(); \ | |||||
} \ | |||||
size_t batch_id = index / parallelism_batch; \ | size_t batch_id = index / parallelism_batch; \ | ||||
size_t task_id = index % parallelism_batch; \ | size_t task_id = index % parallelism_batch; \ | ||||
size_t src_id = batch_id; \ | size_t src_id = batch_id; \ | ||||
@@ -240,8 +240,14 @@ void megdnn::aarch64::warp_perspective_cv_exec( | |||||
#undef cb | #undef cb | ||||
} else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { | } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { | ||||
#define cb(_imode, _bmode, _ch) \ | #define cb(_imode, _bmode, _ch) \ | ||||
auto task = [src, trans_ptr, midx_ptr, dst, border_value, parallelism_batch]( \ | |||||
auto task = [src, trans, mat_idx, dst, border_value, parallelism_batch]( \ | |||||
size_t index, size_t) { \ | size_t index, size_t) { \ | ||||
const float* trans_ptr = trans.ptr<dt_float32>(); \ | |||||
const int* midx_ptr = nullptr; \ | |||||
if (mat_idx.raw_ptr()) { \ | |||||
megdnn_assert(mat_idx.layout.ndim == 1); \ | |||||
midx_ptr = mat_idx.ptr<int>(); \ | |||||
} \ | |||||
size_t batch_id = index / parallelism_batch; \ | size_t batch_id = index / parallelism_batch; \ | ||||
size_t task_id = index % parallelism_batch; \ | size_t task_id = index % parallelism_batch; \ | ||||
size_t src_id = batch_id; \ | size_t src_id = batch_id; \ | ||||
@@ -531,10 +531,10 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Stride2Filter2:: | |||||
megdnn_arm_common_conv_bias_int8816_kimpl, | megdnn_arm_common_conv_bias_int8816_kimpl, | ||||
midout_iv("AlgoI8x8x16Stride2Filter2::dispatch_kerns"_hash)) { | midout_iv("AlgoI8x8x16Stride2Filter2::dispatch_kerns"_hash)) { | ||||
auto ncb_param = param; | auto ncb_param = param; | ||||
ncb_param.src_ptr = param.src<void>(0, ncb_index.ndrange_id[0]); | |||||
ncb_param.dst_ptr = param.dst<void>(0, ncb_index.ndrange_id[0]); | |||||
ncb_param.filter_ptr = param.filter<void>(ncb_index.ndrange_id[0]); | |||||
ncb_param.bias_ptr = param.bias<void>(0, ncb_index.ndrange_id[0]); | |||||
ncb_param.src_ptr += param.src_offset(0, ncb_index.ndrange_id[0]); | |||||
ncb_param.dst_ptr += param.dst_offset(0, ncb_index.ndrange_id[0]); | |||||
ncb_param.filter_ptr += param.filter_offset(ncb_index.ndrange_id[0]); | |||||
ncb_param.bias_ptr += param.bias_offset(0, ncb_index.ndrange_id[0]); | |||||
conv_bias::conv_int8x8x16_stride2_flt2(ncb_param); | conv_bias::conv_int8x8x16_stride2_flt2(ncb_param); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -133,7 +133,8 @@ static void pack_weight( | |||||
constexpr int pack_oc = 8; | constexpr int pack_oc = 8; | ||||
if (kern_param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS && oc % pack_oc != 0) { | if (kern_param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS && oc % pack_oc != 0) { | ||||
auto packed_bias = reinterpret_cast<int16_t*>(bundle.get(2)); | auto packed_bias = reinterpret_cast<int16_t*>(bundle.get(2)); | ||||
memcpy(packed_bias, kern_param.bias_ptr, round_up(oc, 8) * sizeof(int16_t)); | |||||
memcpy(packed_bias, kern_param.bias_ptr.get_ptr(), | |||||
round_up(oc, 8) * sizeof(int16_t)); | |||||
} | } | ||||
} | } | ||||
@@ -1657,4 +1657,4 @@ void CvtColorImpl::exec( | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | |||||
// vim: syntax=cpp.doxygen |
@@ -220,9 +220,9 @@ void ElemwiseImpl::AlgoBinaryVecVec::exec(const KernParam& kern_param) const { | |||||
run = OpCallerBinary<_op<_type, _type>, BcastType::VEC_VEC>::run; \ | run = OpCallerBinary<_op<_type, _type>, BcastType::VEC_VEC>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, dst.layout.dtype, \ | ||||
src0.layout.total_nr_elems())); \ | src0.layout.total_nr_elems())); \ | ||||
} \ | } \ | ||||
@@ -254,9 +254,9 @@ void ElemwiseImpl::AlgoBinaryVecScalar::exec(const KernParam& kern_param) const | |||||
_op<_type, _type>, BcastType::VEC_SCALAR>::run; \ | _op<_type, _type>, BcastType::VEC_SCALAR>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr)[0], \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr())[0], \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, dst.layout.dtype, \ | ||||
src0.layout.total_nr_elems())); \ | src0.layout.total_nr_elems())); \ | ||||
} \ | } \ | ||||
@@ -280,9 +280,9 @@ void ElemwiseImpl::AlgoBinaryVecScalar::exec(const KernParam& kern_param) const | |||||
_op<_type, _type>, BcastType::SCALAR_VEC>::run; \ | _op<_type, _type>, BcastType::SCALAR_VEC>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr)[0], \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr())[0], \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, dst.layout.dtype, \ | ||||
src1.layout.total_nr_elems())); \ | src1.layout.total_nr_elems())); \ | ||||
} \ | } \ | ||||
@@ -318,9 +318,9 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) cons | |||||
_op<_type, _type>, BcastType::VEC_BCAST101>::run; \ | _op<_type, _type>, BcastType::VEC_BCAST101>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | ||||
binfo.z)); \ | binfo.z)); \ | ||||
} \ | } \ | ||||
@@ -347,9 +347,9 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) cons | |||||
_op<_type, _type>, BcastType::BCAST101_VEC>::run; \ | _op<_type, _type>, BcastType::BCAST101_VEC>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | ||||
binfo.z)); \ | binfo.z)); \ | ||||
} \ | } \ | ||||
@@ -384,9 +384,9 @@ void ElemwiseImpl::AlgoBinaryVecBcastX0X::exec(const KernParam& kern_param) cons | |||||
_op<_type, _type>, BcastType::VEC_BCASTX0X>::run; \ | _op<_type, _type>, BcastType::VEC_BCASTX0X>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | ||||
binfo.z)); \ | binfo.z)); \ | ||||
} \ | } \ | ||||
@@ -413,9 +413,9 @@ void ElemwiseImpl::AlgoBinaryVecBcastX0X::exec(const KernParam& kern_param) cons | |||||
_op<_type, _type>, BcastType::BCASTX0X_VEC>::run; \ | _op<_type, _type>, BcastType::BCASTX0X_VEC>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | ||||
binfo.z)); \ | binfo.z)); \ | ||||
} \ | } \ | ||||
@@ -450,9 +450,9 @@ void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) con | |||||
_op<_type, _type>, BcastType::VEC_BCAST111C>::run; \ | _op<_type, _type>, BcastType::VEC_BCAST111C>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | ||||
binfo.z)); \ | binfo.z)); \ | ||||
} \ | } \ | ||||
@@ -479,9 +479,9 @@ void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) con | |||||
_op<_type, _type>, BcastType::BCAST111C_VEC>::run; \ | _op<_type, _type>, BcastType::BCAST111C_VEC>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | ||||
binfo.z)); \ | binfo.z)); \ | ||||
} \ | } \ | ||||
@@ -519,9 +519,9 @@ void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec(const KernParam& kern_param) co | |||||
_op<_type, _type>, BcastType::VEC_BCAST101xX>::run; \ | _op<_type, _type>, BcastType::VEC_BCAST101xX>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \ | src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \ | ||||
binfo.y, binfo.z)); \ | binfo.y, binfo.z)); \ | ||||
} \ | } \ | ||||
@@ -551,9 +551,9 @@ void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec(const KernParam& kern_param) co | |||||
_op<_type, _type>, BcastType::BCAST101xX_VEC>::run; \ | _op<_type, _type>, BcastType::BCAST101xX_VEC>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \ | src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \ | ||||
binfo.y, binfo.z)); \ | binfo.y, binfo.z)); \ | ||||
} \ | } \ | ||||
@@ -79,10 +79,10 @@ void ElemwiseImpl::AlgoTernaryFma3VecVecVec::exec(const KernParam& kern_param) c | |||||
_op<_type, _type>, BcastType::VEC_VEC_VEC>::run; \ | _op<_type, _type>, BcastType::VEC_VEC_VEC>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<const _type*>(src2.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<const _type*>(src2.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
src0.layout.total_nr_elems())); \ | src0.layout.total_nr_elems())); \ | ||||
} \ | } \ | ||||
@@ -113,10 +113,10 @@ void ElemwiseImpl::AlgoTernaryFma3VecVecScalar::exec( | |||||
_op<_type, _type>, BcastType::VEC_VEC_SCALAR>::run; \ | _op<_type, _type>, BcastType::VEC_VEC_SCALAR>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<const _type*>(src2.raw_ptr)[0], \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<const _type*>(src2.raw_ptr())[0], \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
src0.layout.total_nr_elems())); \ | src0.layout.total_nr_elems())); \ | ||||
} \ | } \ | ||||
@@ -149,10 +149,10 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( | |||||
_op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \ | _op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<const _type*>(src2.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<const _type*>(src2.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
binfo.x, binfo.y, binfo.z)); \ | binfo.x, binfo.y, binfo.z)); \ | ||||
} \ | } \ | ||||
@@ -187,11 +187,11 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec( | |||||
BcastType::BCAST111C_VEC_BCAST111C>::run; \ | BcastType::BCAST111C_VEC_BCAST111C>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \ | is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \ | ||||
static_cast<const _type*>(src2.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
static_cast<const _type*>(src2.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
binfo.x, binfo.y, binfo.z)); \ | binfo.x, binfo.y, binfo.z)); \ | ||||
} \ | } \ | ||||
@@ -228,10 +228,10 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec( | |||||
BcastType::BCAST101xX_VEC_BCAST101xX>::run; \ | BcastType::BCAST101xX_VEC_BCAST101xX>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<const _type*>(src2.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<const _type*>(src2.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
batch_size, binfo.x, binfo.y, binfo.z)); \ | batch_size, binfo.x, binfo.y, binfo.z)); \ | ||||
} \ | } \ | ||||
@@ -268,10 +268,10 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101xXVec::exec( | |||||
_op<_type, _type>, BcastType::VEC_BCAST101xX_VEC>::run; \ | _op<_type, _type>, BcastType::VEC_BCAST101xX_VEC>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<const _type*>(src2.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<const _type*>(src2.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
batch_size, binfo.x, binfo.y, binfo.z)); \ | batch_size, binfo.x, binfo.y, binfo.z)); \ | ||||
} \ | } \ | ||||
@@ -306,10 +306,10 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec( | |||||
_op<_type, _type>, BcastType::VEC_BCAST101_VEC>::run; \ | _op<_type, _type>, BcastType::VEC_BCAST101_VEC>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<const _type*>(src2.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<const _type*>(src2.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
binfo.x, binfo.y, binfo.z)); \ | binfo.x, binfo.y, binfo.z)); \ | ||||
} \ | } \ | ||||
@@ -343,12 +343,12 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast111CVec::exec( | |||||
_op<_type, _type>, BcastType::VEC_BCAST111C_VEC>::run; \ | _op<_type, _type>, BcastType::VEC_BCAST111C_VEC>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
is_vector(src0.layout) ? 0 : src0.layout.stride[0] - binfo.z, \ | is_vector(src0.layout) ? 0 : src0.layout.stride[0] - binfo.z, \ | ||||
static_cast<const _type*>(src1.raw_ptr), \ | |||||
static_cast<const _type*>(src2.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr()), \ | |||||
static_cast<const _type*>(src2.raw_ptr()), \ | |||||
is_vector(src2.layout) ? 0 : src2.layout.stride[0] - binfo.z, \ | is_vector(src2.layout) ? 0 : src2.layout.stride[0] - binfo.z, \ | ||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
binfo.x, binfo.y, binfo.z)); \ | binfo.x, binfo.y, binfo.z)); \ | ||||
} \ | } \ | ||||
@@ -380,10 +380,10 @@ void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec( | |||||
_op<_type, _type>, BcastType::VEC_SCALAR_VEC>::run; \ | _op<_type, _type>, BcastType::VEC_SCALAR_VEC>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr)[0], \ | |||||
static_cast<const _type*>(src2.raw_ptr), \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr())[0], \ | |||||
static_cast<const _type*>(src2.raw_ptr()), \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
src0.layout.total_nr_elems())); \ | src0.layout.total_nr_elems())); \ | ||||
} \ | } \ | ||||
@@ -414,10 +414,10 @@ void ElemwiseImpl::AlgoTernaryFma3VecScalarScalar::exec( | |||||
_op<_type, _type>, BcastType::VEC_SCALAR_SCALAR>::run; \ | _op<_type, _type>, BcastType::VEC_SCALAR_SCALAR>::run; \ | ||||
MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
static_cast<const _type*>(src1.raw_ptr)[0], \ | |||||
static_cast<const _type*>(src2.raw_ptr)[0], \ | |||||
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
static_cast<const _type*>(src1.raw_ptr())[0], \ | |||||
static_cast<const _type*>(src2.raw_ptr())[0], \ | |||||
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
src0.layout.total_nr_elems())); \ | src0.layout.total_nr_elems())); \ | ||||
} \ | } \ | ||||
@@ -76,8 +76,8 @@ void ElemwiseImpl::AlgoUnary::exec(const KernParam& kern_param) const { | |||||
size_t offset = task_id * nr_elems_per_thread; \ | size_t offset = task_id * nr_elems_per_thread; \ | ||||
size_t nr_elems_thread = \ | size_t nr_elems_thread = \ | ||||
std::min(nr_elems - offset, nr_elems_per_thread); \ | std::min(nr_elems - offset, nr_elems_per_thread); \ | ||||
run(static_cast<const _type*>(src0.raw_ptr) + offset, \ | |||||
static_cast<_type*>(dst_tensor.raw_ptr) + offset, \ | |||||
run(static_cast<const _type*>(src0.raw_ptr()) + offset, \ | |||||
static_cast<_type*>(dst_tensor.raw_ptr()) + offset, \ | |||||
src0.layout.dtype, dst_tensor.layout.dtype, nr_elems_thread); \ | src0.layout.dtype, dst_tensor.layout.dtype, nr_elems_thread); \ | ||||
}; \ | }; \ | ||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
@@ -148,17 +148,17 @@ void ElemwiseMultiTypeImpl::neon_round_shr_saturate_bcast_scalar<int32_t>( | |||||
template <typename ctype> | template <typename ctype> | ||||
void ElemwiseMultiTypeImpl::dispatch_round_shr_saturate_iXxi8xi8_bcast_scalar( | void ElemwiseMultiTypeImpl::dispatch_round_shr_saturate_iXxi8xi8_bcast_scalar( | ||||
const ElemwiseOpParamN<2>& param, megdnn::dt_int8* dst) { | |||||
auto a_ptr = param[0].ptr<ctype>(); | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||||
auto k = param[1].ptr<dt_int8>()[0]; | auto k = param[1].ptr<dt_int8>()[0]; | ||||
size_t size = param.size; | size_t size = param.size; | ||||
auto src = param[0]; | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
neon_round_shr_saturate_bcast_scalar(a_ptr, k, size, dst)); | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(neon_round_shr_saturate_bcast_scalar( | |||||
src.ptr<ctype>(), k, size, static_cast<dt_int8*>(dst.raw_ptr()))); | |||||
} | } | ||||
void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( | void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( | ||||
const ElemwiseOpParamN<2>& param, megdnn::dt_int8* dst) { | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||||
if (is_vector(param[0].layout) && is_broadcasted_scalar(param[1].layout)) { | if (is_vector(param[0].layout) && is_broadcasted_scalar(param[1].layout)) { | ||||
switch (param[0].layout.dtype.enumv()) { | switch (param[0].layout.dtype.enumv()) { | ||||
#define cb(t) \ | #define cb(t) \ | ||||
@@ -282,7 +282,7 @@ void neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_int32( | |||||
} | } | ||||
bool ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_rshr( | bool ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_rshr( | ||||
const ElemwiseOpParamN<6>& param, megdnn::dt_int8* dst) { | |||||
const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||||
BroadcastChannelInfo binfo; | BroadcastChannelInfo binfo; | ||||
if (is_vector(param[0].layout) && | if (is_vector(param[0].layout) && | ||||
is_broadcasted_channel_like(param[1].layout, binfo) && | is_broadcasted_channel_like(param[1].layout, binfo) && | ||||
@@ -294,16 +294,18 @@ bool ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_rshr( | |||||
auto minv = param[4].ptr<dt_int8>()[0]; | auto minv = param[4].ptr<dt_int8>()[0]; | ||||
auto maxv = param[5].ptr<dt_int8>()[0]; | auto maxv = param[5].ptr<dt_int8>()[0]; | ||||
switch (param[0].layout.dtype.enumv()) { | switch (param[0].layout.dtype.enumv()) { | ||||
#define DISPATCH(stype, suffix) \ | |||||
case DTypeTrait<stype>::enumv: { \ | |||||
auto x_ptr = param[0].ptr<DTypeTrait<stype>::ctype>(); \ | |||||
auto b_ptr = param[1].ptr<DTypeTrait<stype>::ctype>(); \ | |||||
auto M = param[2].ptr<DTypeTrait<stype>::ctype>()[0]; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_##suffix( \ | |||||
binfo.x, binfo.y, binfo.z, x_ptr, b_ptr, M, offset, minv, \ | |||||
maxv, param.size, dst)); \ | |||||
break; \ | |||||
#define DISPATCH(stype, suffix) \ | |||||
case DTypeTrait<stype>::enumv: { \ | |||||
auto M = param[2].ptr<DTypeTrait<stype>::ctype>()[0]; \ | |||||
auto src0 = param[0]; \ | |||||
auto src1 = param[1]; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_##suffix( \ | |||||
binfo.x, binfo.y, binfo.z, \ | |||||
src0.ptr<DTypeTrait<stype>::ctype>(), \ | |||||
src1.ptr<DTypeTrait<stype>::ctype>(), M, offset, minv, maxv, \ | |||||
param.size, static_cast<dt_int8*>(dst.raw_ptr()))); \ | |||||
break; \ | |||||
} | } | ||||
DISPATCH(dtype::Int16, int16) | DISPATCH(dtype::Int16, int16) | ||||
DISPATCH(dtype::Int32, int32) | DISPATCH(dtype::Int32, int32) | ||||
@@ -317,7 +319,7 @@ bool ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_rshr( | |||||
} | } | ||||
void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | ||||
const ElemwiseOpParamN<6>& param, megdnn::dt_int8* dst) { | |||||
const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||||
if (dispatch_fuse_add_rmulh_rshr(param, dst)) | if (dispatch_fuse_add_rmulh_rshr(param, dst)) | ||||
return; | return; | ||||
fallback::ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | fallback::ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | ||||
@@ -325,7 +327,7 @@ void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | |||||
} | } | ||||
void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | ||||
const ElemwiseOpParamN<6>& param, megdnn::dt_int8* dst) { | |||||
const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||||
if (dispatch_fuse_add_rmulh_rshr(param, dst)) | if (dispatch_fuse_add_rmulh_rshr(param, dst)) | ||||
return; | return; | ||||
fallback::ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | fallback::ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | ||||
@@ -23,18 +23,18 @@ class ElemwiseMultiTypeImpl : public fallback::ElemwiseMultiTypeImpl { | |||||
template <typename ctype> | template <typename ctype> | ||||
void dispatch_round_shr_saturate_iXxi8xi8_bcast_scalar( | void dispatch_round_shr_saturate_iXxi8xi8_bcast_scalar( | ||||
const ElemwiseOpParamN<2>& param, megdnn::dt_int8* dst); | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst); | |||||
bool dispatch_fuse_add_rmulh_rshr( | bool dispatch_fuse_add_rmulh_rshr( | ||||
const ElemwiseOpParamN<6>& param, megdnn::dt_int8* dst); | |||||
const ElemwiseOpParamN<6>& param, const TensorND& dst); | |||||
protected: | protected: | ||||
void on_round_shr_saturate_iXxi8xi8( | void on_round_shr_saturate_iXxi8xi8( | ||||
const ElemwiseOpParamN<2>& param, dt_int8* dst) override; | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst) override; | |||||
void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | ||||
const ElemwiseOpParamN<6>& param, dt_int8* dst) override; | |||||
const ElemwiseOpParamN<6>& param, const TensorND& dst) override; | |||||
void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | ||||
const ElemwiseOpParamN<6>& param, dt_int8* dst) override; | |||||
const ElemwiseOpParamN<6>& param, const TensorND& dst) override; | |||||
void on_quantized_mode( | void on_quantized_mode( | ||||
const ElemwiseOpParamN<1>& param, const TensorND& dst, | const ElemwiseOpParamN<1>& param, const TensorND& dst, | ||||
@@ -117,27 +117,27 @@ void PoolingImpl::AlgoFilterxModexStride1::exec(const PoolingKernParam& param) c | |||||
auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
auto FH = param.filter[0]; | auto FH = param.filter[0]; | ||||
void* src_ptr = param.src_ptr; | |||||
void* dst_ptr = param.dst_ptr; | |||||
#define DISPATCH_FUNC(Pooler, NeonPooler, window, midout_type_id) \ | |||||
MIDOUT_BEGIN( \ | |||||
megdnn_arm_common_pooling, midout_iv(0), midout_iv(midout_type_id), \ | |||||
Pooler::MIDOUT_CASE_NUM, NeonPooler::MIDOUT_CASE_NUM, window) { \ | |||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | |||||
src_dtype = param.src_type](size_t index, size_t) { \ | |||||
size_t n = index / C; \ | |||||
size_t c = index % C; \ | |||||
do_pooling_compact<Pooler MEGDNN_COMMA NeonPooler MEGDNN_COMMA window>( \ | |||||
static_cast<const typename Pooler::ctype*>(src_ptr) + \ | |||||
n * C * IH * IW + c * IH * IW, \ | |||||
static_cast<typename Pooler::ctype*>(dst_ptr) + n * C * OH * OW + \ | |||||
c * OH * OW, \ | |||||
src_dtype, IH, IW, OH, OW, PH, PW); \ | |||||
}; \ | |||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||||
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | |||||
} \ | |||||
auto src_ptr = param.src_ptr; | |||||
auto dst_ptr = param.dst_ptr; | |||||
#define DISPATCH_FUNC(Pooler, NeonPooler, window, midout_type_id) \ | |||||
MIDOUT_BEGIN( \ | |||||
megdnn_arm_common_pooling, midout_iv(0), midout_iv(midout_type_id), \ | |||||
Pooler::MIDOUT_CASE_NUM, NeonPooler::MIDOUT_CASE_NUM, window) { \ | |||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | |||||
src_dtype = param.src_type](size_t index, size_t) { \ | |||||
size_t n = index / C; \ | |||||
size_t c = index % C; \ | |||||
do_pooling_compact<Pooler MEGDNN_COMMA NeonPooler MEGDNN_COMMA window>( \ | |||||
static_cast<const typename Pooler::ctype*>(src_ptr.get_ptr()) + \ | |||||
n * C * IH * IW + c * IH * IW, \ | |||||
static_cast<typename Pooler::ctype*>(dst_ptr.get_ptr()) + \ | |||||
n * C * OH * OW + c * OH * OW, \ | |||||
src_dtype, IH, IW, OH, OW, PH, PW); \ | |||||
}; \ | |||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||||
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | |||||
} \ | |||||
MIDOUT_END() | MIDOUT_END() | ||||
#define DISPATCH_WINDOW(Pooler, NeonPooler, dtype, ctype, comp_type, midout_type_id) \ | #define DISPATCH_WINDOW(Pooler, NeonPooler, dtype, ctype, comp_type, midout_type_id) \ | ||||
@@ -213,26 +213,26 @@ void PoolingImpl::AlgoFilter2ModexStride2::exec(const PoolingKernParam& param) c | |||||
auto PH = param.padding[0]; | auto PH = param.padding[0]; | ||||
auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
void* src_ptr = param.src_ptr; | |||||
void* dst_ptr = param.dst_ptr; | |||||
#define DISPATCH_FUNC(Pooler, mode, midout_type_id) \ | |||||
MIDOUT_BEGIN( \ | |||||
megdnn_arm_common_pooling, midout_iv(1), midout_iv(midout_type_id), \ | |||||
Pooler::MIDOUT_CASE_NUM) { \ | |||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | |||||
src_dtype = param.src_type](size_t index, size_t) { \ | |||||
size_t n = index / C; \ | |||||
size_t c = index % C; \ | |||||
do_pooling_2x2<Pooler MEGDNN_COMMA mode>( \ | |||||
static_cast<const typename Pooler::ctype*>(src_ptr) + \ | |||||
n * C * IH * IW + c * IH * IW, \ | |||||
static_cast<typename Pooler::ctype*>(dst_ptr) + n * C * OH * OW + \ | |||||
c * OH * OW, \ | |||||
src_dtype, IH, IW, OH, OW, PH, PW); \ | |||||
}; \ | |||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||||
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | |||||
} \ | |||||
auto src_ptr = param.src_ptr; | |||||
auto dst_ptr = param.dst_ptr; | |||||
#define DISPATCH_FUNC(Pooler, mode, midout_type_id) \ | |||||
MIDOUT_BEGIN( \ | |||||
megdnn_arm_common_pooling, midout_iv(1), midout_iv(midout_type_id), \ | |||||
Pooler::MIDOUT_CASE_NUM) { \ | |||||
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | |||||
src_dtype = param.src_type](size_t index, size_t) { \ | |||||
size_t n = index / C; \ | |||||
size_t c = index % C; \ | |||||
do_pooling_2x2<Pooler MEGDNN_COMMA mode>( \ | |||||
static_cast<const typename Pooler::ctype*>(src_ptr.get_ptr()) + \ | |||||
n * C * IH * IW + c * IH * IW, \ | |||||
static_cast<typename Pooler::ctype*>(dst_ptr.get_ptr()) + \ | |||||
n * C * OH * OW + c * OH * OW, \ | |||||
src_dtype, IH, IW, OH, OW, PH, PW); \ | |||||
}; \ | |||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||||
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | |||||
} \ | |||||
MIDOUT_END() | MIDOUT_END() | ||||
#define DISPATCH_MODE(dtype, ctype, comp_type, midout_type_id) \ | #define DISPATCH_MODE(dtype, ctype, comp_type, midout_type_id) \ | ||||
@@ -286,8 +286,8 @@ void PoolingImpl::AlgoFilter3MaxStride2::exec(const PoolingKernParam& param) con | |||||
auto PH = param.padding[0]; | auto PH = param.padding[0]; | ||||
auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
void* src_ptr = param.src_ptr; | |||||
void* dst_ptr = param.dst_ptr; | |||||
auto src_ptr = param.src_ptr; | |||||
auto dst_ptr = param.dst_ptr; | |||||
#define DISPATCH_FUNC(type, func, midout_type_id) \ | #define DISPATCH_FUNC(type, func, midout_type_id) \ | ||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), midout_iv(midout_type_id)) { \ | MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), midout_iv(midout_type_id)) { \ | ||||
@@ -300,9 +300,11 @@ void PoolingImpl::AlgoFilter3MaxStride2::exec(const PoolingKernParam& param) con | |||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_max_pooling_3x3_s2x2_##func##_NEON( \ | do_max_pooling_3x3_s2x2_##func##_NEON( \ | ||||
static_cast<const type*>(src_ptr) + n * C * IH * IW + c * IH * IW, \ | |||||
static_cast<type*>(dst_ptr) + n * C * OH * OW + c * OH * OW, IH, \ | |||||
IW, OH, OW, PH, PW, ws); \ | |||||
static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW + \ | |||||
c * IH * IW, \ | |||||
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW + \ | |||||
c * OH * OW, \ | |||||
IH, IW, OH, OW, PH, PW, ws); \ | |||||
}; \ | }; \ | ||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | ||||
@@ -339,8 +341,8 @@ void PoolingImpl::AlgoFilter3AverageStride2::exec(const PoolingKernParam& param) | |||||
auto PH = param.padding[0]; | auto PH = param.padding[0]; | ||||
auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
void* src_ptr = param.src_ptr; | |||||
void* dst_ptr = param.dst_ptr; | |||||
auto src_ptr = param.src_ptr; | |||||
auto dst_ptr = param.dst_ptr; | |||||
#define DISPATCH_FUNC(type, MEGDNN_SIMD_WIDTH, midout_type_id) \ | #define DISPATCH_FUNC(type, MEGDNN_SIMD_WIDTH, midout_type_id) \ | ||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(3), midout_iv(midout_type_id)) { \ | MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(3), midout_iv(midout_type_id)) { \ | ||||
@@ -353,9 +355,11 @@ void PoolingImpl::AlgoFilter3AverageStride2::exec(const PoolingKernParam& param) | |||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_average_pooling_3x3_s2x2_NEON( \ | do_average_pooling_3x3_s2x2_NEON( \ | ||||
static_cast<const type*>(src_ptr) + n * C * IH * IW + c * IH * IW, \ | |||||
static_cast<type*>(dst_ptr) + n * C * OH * OW + c * OH * OW, IH, \ | |||||
IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \ | |||||
static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW + \ | |||||
c * IH * IW, \ | |||||
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW + \ | |||||
c * OH * OW, \ | |||||
IH, IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \ | |||||
}; \ | }; \ | ||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | ||||
@@ -392,8 +396,8 @@ void PoolingImpl::AlgoFilter4MaxStride2::exec(const PoolingKernParam& param) con | |||||
auto PH = param.padding[0]; | auto PH = param.padding[0]; | ||||
auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
void* src_ptr = param.src_ptr; | |||||
void* dst_ptr = param.dst_ptr; | |||||
auto src_ptr = param.src_ptr; | |||||
auto dst_ptr = param.dst_ptr; | |||||
#define DISPATCH_FUNC(type, func, midout_type_id) \ | #define DISPATCH_FUNC(type, func, midout_type_id) \ | ||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(4), midout_iv(midout_type_id)) { \ | MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(4), midout_iv(midout_type_id)) { \ | ||||
@@ -402,8 +406,10 @@ void PoolingImpl::AlgoFilter4MaxStride2::exec(const PoolingKernParam& param) con | |||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_max_pooling_w4x4_s2x2_##func##_NEON( \ | do_max_pooling_w4x4_s2x2_##func##_NEON( \ | ||||
static_cast<const type*>(src_ptr) + n * C * IH * IW + c * IH * IW, \ | |||||
static_cast<type*>(dst_ptr) + n * C * OH * OW + c * OH * OW, \ | |||||
static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW + \ | |||||
c * IH * IW, \ | |||||
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW + \ | |||||
c * OH * OW, \ | |||||
src_dtype, IH, IW, OH, OW, PH, PW); \ | src_dtype, IH, IW, OH, OW, PH, PW); \ | ||||
}; \ | }; \ | ||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
@@ -446,8 +452,8 @@ void PoolingImpl::AlgoFilter5MaxStride2::exec(const PoolingKernParam& param) con | |||||
auto PH = param.padding[0]; | auto PH = param.padding[0]; | ||||
auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
void* src_ptr = param.src_ptr; | |||||
void* dst_ptr = param.dst_ptr; | |||||
auto src_ptr = param.src_ptr; | |||||
auto dst_ptr = param.dst_ptr; | |||||
#define DISPATCH_FUNC(dtype, type, midout_type_id, MEGDNN_SIMD_WIDTH) \ | #define DISPATCH_FUNC(dtype, type, midout_type_id, MEGDNN_SIMD_WIDTH) \ | ||||
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(5), midout_iv(midout_type_id)) { \ | MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(5), midout_iv(midout_type_id)) { \ | ||||
@@ -460,9 +466,11 @@ void PoolingImpl::AlgoFilter5MaxStride2::exec(const PoolingKernParam& param) con | |||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_max_pooling_w5x5_s2x2_NEON<dtype>( \ | do_max_pooling_w5x5_s2x2_NEON<dtype>( \ | ||||
static_cast<const type*>(src_ptr) + n * C * IH * IW + c * IH * IW, \ | |||||
static_cast<type*>(dst_ptr) + n * C * OH * OW + c * OH * OW, IH, \ | |||||
IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \ | |||||
static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW + \ | |||||
c * IH * IW, \ | |||||
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW + \ | |||||
c * OH * OW, \ | |||||
IH, IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \ | |||||
}; \ | }; \ | ||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | ||||
@@ -593,8 +601,8 @@ void PoolingImpl::AlgoFilter3ModexStridexNCHW44::exec( | |||||
auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
auto SW = param.stride[0]; | auto SW = param.stride[0]; | ||||
void* src_ptr = param.src_ptr; | |||||
void* dst_ptr = param.dst_ptr; | |||||
auto src_ptr = param.src_ptr; | |||||
auto dst_ptr = param.dst_ptr; | |||||
#define DISPATCH_FUNC(type, func, i, mode) \ | #define DISPATCH_FUNC(type, func, i, mode) \ | ||||
MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
@@ -608,9 +616,9 @@ void PoolingImpl::AlgoFilter3ModexStridexNCHW44::exec( | |||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_##mode##_pooling_3x3_stride##i##_##func##_nchw44_NEON( \ | do_##mode##_pooling_3x3_stride##i##_##func##_nchw44_NEON( \ | ||||
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | |||||
c * IH * IW * 4, \ | |||||
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | |||||
static_cast<const type*>(src_ptr.get_ptr()) + \ | |||||
n * C * IH * IW * 4 + c * IH * IW * 4, \ | |||||
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 + \ | |||||
c * OH * OW * 4, \ | c * OH * OW * 4, \ | ||||
IH, IW, OH, OW, PH, PW, ws); \ | IH, IW, OH, OW, PH, PW, ws); \ | ||||
}; \ | }; \ | ||||
@@ -685,8 +693,8 @@ void PoolingImpl::AlgoFilter2ModexStridexNCHW44::exec( | |||||
auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
auto SW = param.stride[0]; | auto SW = param.stride[0]; | ||||
void* src_ptr = param.src_ptr; | |||||
void* dst_ptr = param.dst_ptr; | |||||
auto src_ptr = param.src_ptr; | |||||
auto dst_ptr = param.dst_ptr; | |||||
#define DISPATCH_FUNC(type, func, i, mode) \ | #define DISPATCH_FUNC(type, func, i, mode) \ | ||||
MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
@@ -700,9 +708,9 @@ void PoolingImpl::AlgoFilter2ModexStridexNCHW44::exec( | |||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_##mode##_pooling_2x2_stride##i##_##func##_nchw44_NEON( \ | do_##mode##_pooling_2x2_stride##i##_##func##_nchw44_NEON( \ | ||||
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | |||||
c * IH * IW * 4, \ | |||||
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | |||||
static_cast<const type*>(src_ptr.get_ptr()) + \ | |||||
n * C * IH * IW * 4 + c * IH * IW * 4, \ | |||||
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 + \ | |||||
c * OH * OW * 4, \ | c * OH * OW * 4, \ | ||||
IH, IW, OH, OW, PH, PW, ws); \ | IH, IW, OH, OW, PH, PW, ws); \ | ||||
}; \ | }; \ | ||||
@@ -778,8 +786,8 @@ void PoolingImpl::AlgoFilter4ModexStridexNCHW44::exec( | |||||
auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
auto SW = param.stride[0]; | auto SW = param.stride[0]; | ||||
void* src_ptr = param.src_ptr; | |||||
void* dst_ptr = param.dst_ptr; | |||||
auto src_ptr = param.src_ptr; | |||||
auto dst_ptr = param.dst_ptr; | |||||
#define DISPATCH_FUNC(type, func, i, mode) \ | #define DISPATCH_FUNC(type, func, i, mode) \ | ||||
MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
@@ -793,9 +801,9 @@ void PoolingImpl::AlgoFilter4ModexStridexNCHW44::exec( | |||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_##mode##_pooling_4x4_stride##i##_##func##_nchw44_NEON( \ | do_##mode##_pooling_4x4_stride##i##_##func##_nchw44_NEON( \ | ||||
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | |||||
c * IH * IW * 4, \ | |||||
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | |||||
static_cast<const type*>(src_ptr.get_ptr()) + \ | |||||
n * C * IH * IW * 4 + c * IH * IW * 4, \ | |||||
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 + \ | |||||
c * OH * OW * 4, \ | c * OH * OW * 4, \ | ||||
IH, IW, OH, OW, PH, PW, ws); \ | IH, IW, OH, OW, PH, PW, ws); \ | ||||
}; \ | }; \ | ||||
@@ -870,8 +878,8 @@ void PoolingImpl::AlgoFilter5ModexStridexNCHW44::exec( | |||||
auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
auto SW = param.stride[0]; | auto SW = param.stride[0]; | ||||
void* src_ptr = param.src_ptr; | |||||
void* dst_ptr = param.dst_ptr; | |||||
auto src_ptr = param.src_ptr; | |||||
auto dst_ptr = param.dst_ptr; | |||||
#define DISPATCH_FUNC(type, func, i, mode) \ | #define DISPATCH_FUNC(type, func, i, mode) \ | ||||
MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
@@ -885,9 +893,9 @@ void PoolingImpl::AlgoFilter5ModexStridexNCHW44::exec( | |||||
size_t n = index / C; \ | size_t n = index / C; \ | ||||
size_t c = index % C; \ | size_t c = index % C; \ | ||||
do_##mode##_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ | do_##mode##_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ | ||||
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | |||||
c * IH * IW * 4, \ | |||||
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | |||||
static_cast<const type*>(src_ptr.get_ptr()) + \ | |||||
n * C * IH * IW * 4 + c * IH * IW * 4, \ | |||||
static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 + \ | |||||
c * OH * OW * 4, \ | c * OH * OW * 4, \ | ||||
IH, IW, OH, OW, PH, PW, ws); \ | IH, IW, OH, OW, PH, PW, ws); \ | ||||
}; \ | }; \ | ||||
@@ -50,8 +50,8 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( | |||||
int sh = param.stride[0]; | int sh = param.stride[0]; | ||||
int fh = param.filter[0]; | int fh = param.filter[0]; | ||||
void* src_ptr = param.src_ptr; | |||||
void* dst_ptr = param.dst_ptr; | |||||
auto src_ptr = param.src_ptr; | |||||
auto dst_ptr = param.dst_ptr; | |||||
#define DISPATCH_FUNC(filter, stride, mode) \ | #define DISPATCH_FUNC(filter, stride, mode) \ | ||||
MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
@@ -60,9 +60,10 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( | |||||
auto run = [ih, iw, oh, ow, ph, pw, src_ptr, dst_ptr](size_t index, size_t) { \ | auto run = [ih, iw, oh, ow, ph, pw, src_ptr, dst_ptr](size_t index, size_t) { \ | ||||
const int c_idx = index; \ | const int c_idx = index; \ | ||||
pooling_fp32_nchw44<filter, stride, mode>( \ | pooling_fp32_nchw44<filter, stride, mode>( \ | ||||
static_cast<const float*>(src_ptr) + c_idx * ih * iw * 4, \ | |||||
static_cast<float*>(dst_ptr) + c_idx * oh * ow * 4, ih, iw, oh, \ | |||||
ow, ph, pw); \ | |||||
static_cast<const float*>(src_ptr.get_ptr()) + \ | |||||
c_idx * ih * iw * 4, \ | |||||
static_cast<float*>(dst_ptr.get_ptr()) + c_idx * oh * ow * 4, ih, \ | |||||
iw, oh, ow, ph, pw); \ | |||||
}; \ | }; \ | ||||
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
static_cast<::megdnn::naive::HandleImpl*>(param.handle), n* ic, run); \ | static_cast<::megdnn::naive::HandleImpl*>(param.handle), n* ic, run); \ | ||||
@@ -89,8 +89,8 @@ PoolingImpl::PoolingKernParam PoolingImpl::make_pooling_kern_param( | |||||
PoolingKernParam ret; | PoolingKernParam ret; | ||||
static_cast<PoolingKernSizeParam&>(ret) = | static_cast<PoolingKernSizeParam&>(ret) = | ||||
make_pooling_kern_szie_param(opr, src.layout, dst.layout); | make_pooling_kern_szie_param(opr, src.layout, dst.layout); | ||||
ret.src_ptr = src.raw_ptr; | |||||
ret.dst_ptr = dst.raw_ptr; | |||||
ret.src_ptr = src.get_ref_ptr(); | |||||
ret.dst_ptr = dst.get_ref_ptr(); | |||||
ret.workspace_ptr = workspace.raw_ptr; | ret.workspace_ptr = workspace.raw_ptr; | ||||
ret.workspace_size = workspace.size; | ret.workspace_size = workspace.size; | ||||
return ret; | return ret; | ||||
@@ -56,21 +56,21 @@ public: | |||||
}; | }; | ||||
struct PoolingKernParam : public PoolingKernSizeParam { | struct PoolingKernParam : public PoolingKernSizeParam { | ||||
void* src_ptr; | |||||
void* dst_ptr; | |||||
RefPtr src_ptr; | |||||
RefPtr dst_ptr; | |||||
void* workspace_ptr; | void* workspace_ptr; | ||||
size_t workspace_size; | size_t workspace_size; | ||||
template <typename T> | template <typename T> | ||||
const T* src() const { | const T* src() const { | ||||
src_type.assert_is_compatible_ctype<T>(); | src_type.assert_is_compatible_ctype<T>(); | ||||
return static_cast<const T*>(src_ptr); | |||||
return static_cast<const T*>(src_ptr.get_ptr()); | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
T* dst() const { | T* dst() const { | ||||
dst_type.assert_is_compatible_ctype<T>(); | dst_type.assert_is_compatible_ctype<T>(); | ||||
return static_cast<T*>(dst_ptr); | |||||
return static_cast<T*>(dst_ptr.get_ptr()); | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
@@ -816,8 +816,8 @@ void ReduceImpl::exec( | |||||
MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
megdnn_arm_common_reduce, ctype, dtype, comp_type, midout_iv(1)) { \ | megdnn_arm_common_reduce, ctype, dtype, comp_type, midout_iv(1)) { \ | ||||
MEGDNN_DISPATCH_CPU_KERN_OPR(do_reduce( \ | MEGDNN_DISPATCH_CPU_KERN_OPR(do_reduce( \ | ||||
reinterpret_cast<ctype*>(src.raw_ptr), \ | |||||
reinterpret_cast<ctype*>(dst.raw_ptr), src_type, A, B, C)); \ | |||||
reinterpret_cast<ctype*>(src.raw_ptr()), \ | |||||
reinterpret_cast<ctype*>(dst.raw_ptr()), src_type, A, B, C)); \ | |||||
execed = true; \ | execed = true; \ | ||||
} \ | } \ | ||||
MIDOUT_END(); \ | MIDOUT_END(); \ | ||||
@@ -828,8 +828,8 @@ void ReduceImpl::exec( | |||||
MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
megdnn_arm_common_reduce, ctype, dtype, comp_type, midout_iv(1)) { \ | megdnn_arm_common_reduce, ctype, dtype, comp_type, midout_iv(1)) { \ | ||||
MEGDNN_DISPATCH_CPU_KERN_OPR(do_reduce( \ | MEGDNN_DISPATCH_CPU_KERN_OPR(do_reduce( \ | ||||
reinterpret_cast<ctype*>(src.raw_ptr), \ | |||||
reinterpret_cast<ctype*>(dst.raw_ptr), src_type, A, B, C)); \ | |||||
reinterpret_cast<ctype*>(src.raw_ptr()), \ | |||||
reinterpret_cast<ctype*>(dst.raw_ptr()), src_type, A, B, C)); \ | |||||
execed = true; \ | execed = true; \ | ||||
} \ | } \ | ||||
MIDOUT_END(); \ | MIDOUT_END(); \ | ||||
@@ -72,14 +72,14 @@ void resize_direct_nchwxx( | |||||
void megdnn::arm_common::resize_direct_nearest_nchw44_fp32( | void megdnn::arm_common::resize_direct_nearest_nchw44_fp32( | ||||
const ResizeImpl::KernParam<float>& kern_param) { | const ResizeImpl::KernParam<float>& kern_param) { | ||||
resize_direct_nchwxx<float, InterpolationMode::INTER_NEAREST>( | resize_direct_nchwxx<float, InterpolationMode::INTER_NEAREST>( | ||||
kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c / 4, | |||||
kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, | |||||
kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow); | kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow); | ||||
} | } | ||||
void megdnn::arm_common::resize_direct_linear_nchw44_fp32( | void megdnn::arm_common::resize_direct_linear_nchw44_fp32( | ||||
const ResizeImpl::KernParam<float>& kern_param) { | const ResizeImpl::KernParam<float>& kern_param) { | ||||
resize_direct_nchwxx<float, InterpolationMode::INTER_LINEAR>( | resize_direct_nchwxx<float, InterpolationMode::INTER_LINEAR>( | ||||
kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c / 4, | |||||
kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, | |||||
kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow); | kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow); | ||||
} | } | ||||
@@ -87,8 +87,8 @@ void megdnn::arm_common::resize_direct_linear_nchw44_fp32( | |||||
void megdnn::arm_common::resize_direct_nearest_nchw88_fp16( | void megdnn::arm_common::resize_direct_nearest_nchw88_fp16( | ||||
const ResizeImpl::KernParam<dt_float16>& kern_param) { | const ResizeImpl::KernParam<dt_float16>& kern_param) { | ||||
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr); | |||||
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); | |||||
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr.get_ptr()); | |||||
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr()); | |||||
resize_direct_nchwxx<__fp16, InterpolationMode::INTER_NEAREST>( | resize_direct_nchwxx<__fp16, InterpolationMode::INTER_NEAREST>( | ||||
sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw, | sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw, | ||||
kern_param.oh, kern_param.ow); | kern_param.oh, kern_param.ow); | ||||
@@ -96,8 +96,8 @@ void megdnn::arm_common::resize_direct_nearest_nchw88_fp16( | |||||
void megdnn::arm_common::resize_direct_linear_nchw88_fp16( | void megdnn::arm_common::resize_direct_linear_nchw88_fp16( | ||||
const ResizeImpl::KernParam<dt_float16>& kern_param) { | const ResizeImpl::KernParam<dt_float16>& kern_param) { | ||||
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr); | |||||
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); | |||||
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr.get_ptr()); | |||||
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr()); | |||||
resize_direct_nchwxx<__fp16, InterpolationMode::INTER_LINEAR>( | resize_direct_nchwxx<__fp16, InterpolationMode::INTER_LINEAR>( | ||||
sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw, | sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw, | ||||
kern_param.oh, kern_param.ow); | kern_param.oh, kern_param.ow); | ||||
@@ -191,14 +191,14 @@ void nearest_upsample2_nchw( | |||||
void megdnn::arm_common::resize_linear_upsample2_nchw_fp32( | void megdnn::arm_common::resize_linear_upsample2_nchw_fp32( | ||||
const ResizeImpl::KernParam<float>& kern_param) { | const ResizeImpl::KernParam<float>& kern_param) { | ||||
linear_upsample2_nchw( | linear_upsample2_nchw( | ||||
kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c, | |||||
kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c, | |||||
kern_param.ih, kern_param.iw); | kern_param.ih, kern_param.iw); | ||||
} | } | ||||
void megdnn::arm_common::resize_nearest_upsample2_nchw_fp32( | void megdnn::arm_common::resize_nearest_upsample2_nchw_fp32( | ||||
const ResizeImpl::KernParam<float>& kern_param) { | const ResizeImpl::KernParam<float>& kern_param) { | ||||
nearest_upsample2_nchw( | nearest_upsample2_nchw( | ||||
kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c, | |||||
kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c, | |||||
kern_param.ih, kern_param.iw); | kern_param.ih, kern_param.iw); | ||||
} | } | ||||
@@ -206,16 +206,16 @@ void megdnn::arm_common::resize_nearest_upsample2_nchw_fp32( | |||||
void megdnn::arm_common::resize_linear_upsample2_nchw_fp16( | void megdnn::arm_common::resize_linear_upsample2_nchw_fp16( | ||||
const ResizeImpl::KernParam<dt_float16>& kern_param) { | const ResizeImpl::KernParam<dt_float16>& kern_param) { | ||||
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr); | |||||
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); | |||||
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr.get_ptr()); | |||||
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr()); | |||||
linear_upsample2_nchw( | linear_upsample2_nchw( | ||||
sptr, dptr, kern_param.n * kern_param.c, kern_param.ih, kern_param.iw); | sptr, dptr, kern_param.n * kern_param.c, kern_param.ih, kern_param.iw); | ||||
} | } | ||||
void megdnn::arm_common::resize_nearest_upsample2_nchw_fp16( | void megdnn::arm_common::resize_nearest_upsample2_nchw_fp16( | ||||
const ResizeImpl::KernParam<dt_float16>& kern_param) { | const ResizeImpl::KernParam<dt_float16>& kern_param) { | ||||
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr); | |||||
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); | |||||
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr.get_ptr()); | |||||
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr()); | |||||
nearest_upsample2_nchw( | nearest_upsample2_nchw( | ||||
sptr, dptr, kern_param.n * kern_param.c, kern_param.ih, kern_param.iw); | sptr, dptr, kern_param.n * kern_param.c, kern_param.ih, kern_param.iw); | ||||
} | } | ||||
@@ -158,14 +158,14 @@ void nearest_upsample2_nchwxx( | |||||
void megdnn::arm_common::resize_linear_upsample2_nchw44_fp32( | void megdnn::arm_common::resize_linear_upsample2_nchw44_fp32( | ||||
const ResizeImpl::KernParam<float>& kern_param) { | const ResizeImpl::KernParam<float>& kern_param) { | ||||
linear_upsample2_nchwxx( | linear_upsample2_nchwxx( | ||||
kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c / 4, | |||||
kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, | |||||
kern_param.ih, kern_param.iw); | kern_param.ih, kern_param.iw); | ||||
} | } | ||||
void megdnn::arm_common::resize_nearest_upsample2_nchw44_fp32( | void megdnn::arm_common::resize_nearest_upsample2_nchw44_fp32( | ||||
const ResizeImpl::KernParam<float>& kern_param) { | const ResizeImpl::KernParam<float>& kern_param) { | ||||
nearest_upsample2_nchwxx( | nearest_upsample2_nchwxx( | ||||
kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c / 4, | |||||
kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, | |||||
kern_param.ih, kern_param.iw); | kern_param.ih, kern_param.iw); | ||||
} | } | ||||
@@ -173,16 +173,16 @@ void megdnn::arm_common::resize_nearest_upsample2_nchw44_fp32( | |||||
void megdnn::arm_common::resize_linear_upsample2_nchw88_fp16( | void megdnn::arm_common::resize_linear_upsample2_nchw88_fp16( | ||||
const ResizeImpl::KernParam<dt_float16>& kern_param) { | const ResizeImpl::KernParam<dt_float16>& kern_param) { | ||||
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr); | |||||
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); | |||||
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr.get_ptr()); | |||||
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr()); | |||||
linear_upsample2_nchwxx( | linear_upsample2_nchwxx( | ||||
sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw); | sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw); | ||||
} | } | ||||
void megdnn::arm_common::resize_nearest_upsample2_nchw88_fp16( | void megdnn::arm_common::resize_nearest_upsample2_nchw88_fp16( | ||||
const ResizeImpl::KernParam<dt_float16>& kern_param) { | const ResizeImpl::KernParam<dt_float16>& kern_param) { | ||||
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr); | |||||
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); | |||||
auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr.get_ptr()); | |||||
auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr()); | |||||
nearest_upsample2_nchwxx( | nearest_upsample2_nchwxx( | ||||
sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw); | sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw); | ||||
} | } | ||||
@@ -78,9 +78,9 @@ void SeparableFilterImpl::separable_filter_exec_8u( | |||||
megdnn_assert(src.layout.dtype == dtype::Uint8()); | megdnn_assert(src.layout.dtype == dtype::Uint8()); | ||||
Mat<float> kernel_column( | Mat<float> kernel_column( | ||||
1, filter_y.layout.shape[3], 1, static_cast<float*>(filter_y.raw_ptr)); | |||||
1, filter_y.layout.shape[3], 1, static_cast<float*>(filter_y.raw_ptr())); | |||||
Mat<float> kernel_row( | Mat<float> kernel_row( | ||||
1, filter_x.layout.shape[3], 1, static_cast<float*>(filter_x.raw_ptr)); | |||||
1, filter_x.layout.shape[3], 1, static_cast<float*>(filter_x.raw_ptr())); | |||||
size_t src_channels = src.layout.shape[3]; | size_t src_channels = src.layout.shape[3]; | ||||
@@ -128,9 +128,9 @@ void SeparableFilterImpl::separable_filter_exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in filter_x, _megdnn_tensor_in filter_y, | _megdnn_tensor_in src, _megdnn_tensor_in filter_x, _megdnn_tensor_in filter_y, | ||||
_megdnn_tensor_out dst) { | _megdnn_tensor_out dst) { | ||||
Mat<T> kernel_column( | Mat<T> kernel_column( | ||||
1, filter_y.layout.shape[3], 1, static_cast<T*>(filter_y.raw_ptr)); | |||||
1, filter_y.layout.shape[3], 1, static_cast<T*>(filter_y.raw_ptr())); | |||||
Mat<T> kernel_row( | Mat<T> kernel_row( | ||||
1, filter_x.layout.shape[3], 1, static_cast<T*>(filter_x.raw_ptr)); | |||||
1, filter_x.layout.shape[3], 1, static_cast<T*>(filter_x.raw_ptr())); | |||||
size_t src_channels = src.layout.shape[3]; | size_t src_channels = src.layout.shape[3]; | ||||
T border_value[4] = {0, 0, 0, 0}; | T border_value[4] = {0, 0, 0, 0}; | ||||
@@ -483,18 +483,18 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
#undef DISPATCH_QUANTIZED | #undef DISPATCH_QUANTIZED | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
#define DISPATCH_FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ | |||||
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ | |||||
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ | |||||
MIDOUT_BEGIN(megdnn_arm_typecvt_float, midout_iv(_midout_iv)) { \ | |||||
using _TypeCvter = FloatTypeCvter<_stype, _dtype>; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ | |||||
reinterpret_cast<_stype*>(src.raw_ptr), \ | |||||
reinterpret_cast<_dtype*>(dst.raw_ptr), src_dtype, dst_dtype, \ | |||||
nr_elems)); \ | |||||
execed = true; \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
#define DISPATCH_FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ | |||||
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ | |||||
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ | |||||
MIDOUT_BEGIN(megdnn_arm_typecvt_float, midout_iv(_midout_iv)) { \ | |||||
using _TypeCvter = FloatTypeCvter<_stype, _dtype>; \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ | |||||
reinterpret_cast<_stype*>(src.raw_ptr()), \ | |||||
reinterpret_cast<_dtype*>(dst.raw_ptr()), src_dtype, dst_dtype, \ | |||||
nr_elems)); \ | |||||
execed = true; \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
} | } | ||||
DISPATCH_FLOAT(dt_float16, __fp16, float, float, 0); | DISPATCH_FLOAT(dt_float16, __fp16, float, float, 0); | ||||
DISPATCH_FLOAT(float, float, dt_float16, __fp16, 1); | DISPATCH_FLOAT(float, float, dt_float16, __fp16, 1); | ||||
@@ -167,21 +167,17 @@ void megdnn::arm_common::warp_perspective_cv_exec( | |||||
megdnn_assert( | megdnn_assert( | ||||
ch == 1 || ch == 3 || ch == 2, | ch == 1 || ch == 3 || ch == 2, | ||||
"unsupported src channel: %zu, avaiable channel size: 1/2/3", ch); | "unsupported src channel: %zu, avaiable channel size: 1/2/3", ch); | ||||
const float* trans_ptr = trans.ptr<dt_float32>(); | |||||
const int* midx_ptr = nullptr; | |||||
if (mat_idx.raw_ptr) { | |||||
megdnn_assert(mat_idx.layout.ndim == 1); | |||||
midx_ptr = mat_idx.ptr<int>(); | |||||
} | |||||
if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { | if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { | ||||
#define cb(_imode, _bmode, _ch) \ | #define cb(_imode, _bmode, _ch) \ | ||||
auto task = [src, trans_ptr, midx_ptr, dst, border_value, parallelism_batch]( \ | |||||
auto task = [src, trans, mat_idx, dst, border_value, parallelism_batch]( \ | |||||
size_t index, size_t) { \ | size_t index, size_t) { \ | ||||
size_t batch_id = index / parallelism_batch; \ | size_t batch_id = index / parallelism_batch; \ | ||||
size_t task_id = index % parallelism_batch; \ | size_t task_id = index % parallelism_batch; \ | ||||
size_t src_id = batch_id; \ | size_t src_id = batch_id; \ | ||||
if (midx_ptr) { \ | |||||
src_id = midx_ptr[batch_id]; \ | |||||
const float* trans_ptr = trans.ptr<dt_float32>(); \ | |||||
if (mat_idx.raw_ptr()) { \ | |||||
megdnn_assert(mat_idx.layout.ndim == 1); \ | |||||
src_id = mat_idx.ptr<int>()[batch_id]; \ | |||||
megdnn_assert( \ | megdnn_assert( \ | ||||
src_id < src.layout.shape[0], \ | src_id < src.layout.shape[0], \ | ||||
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", batch_id, \ | "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", batch_id, \ | ||||
@@ -202,13 +198,15 @@ void megdnn::arm_common::warp_perspective_cv_exec( | |||||
#undef cb | #undef cb | ||||
} else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { | } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { | ||||
#define cb(_imode, _bmode, _ch) \ | #define cb(_imode, _bmode, _ch) \ | ||||
auto task = [src, trans_ptr, midx_ptr, dst, border_value, parallelism_batch]( \ | |||||
auto task = [src, trans, mat_idx, dst, border_value, parallelism_batch]( \ | |||||
size_t index, size_t) { \ | size_t index, size_t) { \ | ||||
size_t batch_id = index / parallelism_batch; \ | size_t batch_id = index / parallelism_batch; \ | ||||
size_t task_id = index % parallelism_batch; \ | size_t task_id = index % parallelism_batch; \ | ||||
size_t src_id = batch_id; \ | size_t src_id = batch_id; \ | ||||
if (midx_ptr) { \ | |||||
src_id = midx_ptr[batch_id]; \ | |||||
const float* trans_ptr = trans.ptr<dt_float32>(); \ | |||||
if (mat_idx.raw_ptr()) { \ | |||||
megdnn_assert(mat_idx.layout.ndim == 1); \ | |||||
src_id = mat_idx.ptr<int>()[batch_id]; \ | |||||
megdnn_assert( \ | megdnn_assert( \ | ||||
src_id < src.layout.shape[0], \ | src_id < src.layout.shape[0], \ | ||||
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", batch_id, \ | "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", batch_id, \ | ||||
@@ -136,10 +136,10 @@ void armv7::RelayoutForwardImpl::exec( | |||||
relayout::TransposeParam trans_param; | relayout::TransposeParam trans_param; | ||||
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); | bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); | ||||
if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { | if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { | ||||
auto sptr = static_cast<TransposeByte*>(src.raw_ptr), | |||||
dptr = static_cast<TransposeByte*>(dst.raw_ptr); | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<TransposeByte>( | MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<TransposeByte>( | ||||
trans_param.batch, trans_param.m, trans_param.n, sptr, dptr)); | |||||
trans_param.batch, trans_param.m, trans_param.n, | |||||
static_cast<TransposeByte*>(src.raw_ptr()), | |||||
static_cast<TransposeByte*>(dst.raw_ptr()))); | |||||
return; | return; | ||||
} | } | ||||
exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); | exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); | ||||
@@ -288,11 +288,13 @@ void RotateImpl::exec( | |||||
return fallback::RotateImpl::exec(src, dst, workspace); | return fallback::RotateImpl::exec(src, dst, workspace); | ||||
} | } | ||||
auto clockwise = param().clockwise; | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR({ | MEGDNN_DISPATCH_CPU_KERN_OPR({ | ||||
for (size_t i = 0; i < src.layout.shape[0]; ++i) { | for (size_t i = 0; i < src.layout.shape[0]; ++i) { | ||||
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i); | Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i); | ||||
Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i); | Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i); | ||||
rotate(src_mat, dst_mat, param().clockwise); | |||||
rotate(src_mat, dst_mat, clockwise); | |||||
} | } | ||||
}); | }); | ||||
} | } | ||||
@@ -36,7 +36,7 @@ ChecksumForward::Result ChecksumForwardImpl::exec( | |||||
megcoreComputingHandle_t comp_handle = handle()->megcore_computing_handle(); | megcoreComputingHandle_t comp_handle = handle()->megcore_computing_handle(); | ||||
megcoreGetDeviceHandle(comp_handle, &dev_handle); | megcoreGetDeviceHandle(comp_handle, &dev_handle); | ||||
megcoreMemcpy( | megcoreMemcpy( | ||||
comp_handle, cpu_data.data(), data.raw_ptr, cpu_data.size(), | |||||
comp_handle, cpu_data.data(), data.raw_ptr(), cpu_data.size(), | |||||
megcoreMemcpyDeviceToHost); | megcoreMemcpyDeviceToHost); | ||||
megcoreSynchronize(comp_handle); | megcoreSynchronize(comp_handle); | ||||
@@ -62,7 +62,7 @@ ChecksumForward::Result ChecksumForwardImpl::exec( | |||||
check_exec(data.layout, workspace.size); | check_exec(data.layout, workspace.size); | ||||
auto queue = cnrt_queue(handle()); | auto queue = cnrt_queue(handle()); | ||||
auto ptr = static_cast<uint8_t*>(data.raw_ptr); | |||||
auto ptr = static_cast<uint8_t*>(data.raw_ptr()); | |||||
size_t size_all = data.layout.shape[0], size_ints = size_all / sizeof(uint32_t); | size_t size_all = data.layout.shape[0], size_ints = size_all / sizeof(uint32_t); | ||||
auto last_val_size = std::min<size_t>(size_all, 4); | auto last_val_size = std::min<size_t>(size_all, 4); | ||||
cnrt_check(cnrtMemcpyAsync( | cnrt_check(cnrtMemcpyAsync( | ||||
@@ -72,7 +72,7 @@ ChecksumForward::Result ChecksumForwardImpl::exec( | |||||
auto&& device_info = current_device_info(); | auto&& device_info = current_device_info(); | ||||
bang_c_wrapper( | bang_c_wrapper( | ||||
reinterpret_cast<uint32_t*>(workspace.raw_ptr), | reinterpret_cast<uint32_t*>(workspace.raw_ptr), | ||||
static_cast<uint32_t*>(data.raw_ptr), size_ints, queue, | |||||
static_cast<uint32_t*>(data.raw_ptr()), size_ints, queue, | |||||
device_info.core_version); | device_info.core_version); | ||||
cnrt_check(cnrtMemcpyAsync( | cnrt_check(cnrtMemcpyAsync( | ||||
&result.checksum, workspace.raw_ptr, sizeof(result.checksum), queue, | &result.checksum, workspace.raw_ptr, sizeof(result.checksum), queue, | ||||
@@ -38,10 +38,9 @@ void ConcatSplitBase::check_layout_common( | |||||
megdnn_assert_eq_size_t(src.ndim, ndim); | megdnn_assert_eq_size_t(src.ndim, ndim); | ||||
} | } | ||||
// ensure param().axis is correct | // ensure param().axis is correct | ||||
auto errmsg = "param().axis=" + std::to_string(param().axis) + | |||||
", ndim=" + std::to_string(ndim); | |||||
MEGDNN_MARK_USED_VAR(errmsg); | |||||
megdnn_assert(param().axis < static_cast<int32_t>(ndim), "%s", errmsg.c_str()); | |||||
megdnn_assert( | |||||
param().axis < static_cast<int32_t>(ndim), "param().axis=%u, ndim=%zu", | |||||
param().axis, ndim); | |||||
// ensure shape size for each axis is correct | // ensure shape size for each axis is correct | ||||
for (size_t i = 0; i < ndim; ++i) { | for (size_t i = 0; i < ndim; ++i) { | ||||
if (i == static_cast<size_t>(param().axis)) { | if (i == static_cast<size_t>(param().axis)) { | ||||
@@ -24,28 +24,24 @@ void ElemwiseMultiTypeImplHelper::exec( | |||||
_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) { | _megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) { | ||||
switch (m_param.mode) { | switch (m_param.mode) { | ||||
case Mode::FUSE_MUL_ADD3_INT16x32x32x32: | case Mode::FUSE_MUL_ADD3_INT16x32x32x32: | ||||
on_fuse_mul_add3_int16x32x32x32( | |||||
make_elemwise_op_param<3>(src, dst), dst.ptr<dt_int32>()); | |||||
on_fuse_mul_add3_int16x32x32x32(make_elemwise_op_param<3>(src, dst), dst); | |||||
break; | break; | ||||
case Mode::FUSE_MUL_ADD3_IXxF32xF32xI8: | case Mode::FUSE_MUL_ADD3_IXxF32xF32xI8: | ||||
on_fuse_mul_add3_iXxf32xf32xi8( | |||||
make_elemwise_op_param<3>(src, dst), dst.ptr<dt_int8>()); | |||||
on_fuse_mul_add3_iXxf32xf32xi8(make_elemwise_op_param<3>(src, dst), dst); | |||||
break; | break; | ||||
case Mode::ROUND_SHR_SATURATE_IXxI8xI8: | case Mode::ROUND_SHR_SATURATE_IXxI8xI8: | ||||
on_round_shr_saturate_iXxi8xi8( | |||||
make_elemwise_op_param<2>(src, dst), dst.ptr<dt_int8>()); | |||||
on_round_shr_saturate_iXxi8xi8(make_elemwise_op_param<2>(src, dst), dst); | |||||
break; | break; | ||||
case Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8: | case Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8: | ||||
on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | ||||
make_elemwise_op_param<6>(src, dst), dst.ptr<dt_int8>()); | |||||
make_elemwise_op_param<6>(src, dst), dst); | |||||
break; | break; | ||||
case Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8: | case Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8: | ||||
on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | ||||
make_elemwise_op_param<6>(src, dst), dst.ptr<dt_int8>()); | |||||
make_elemwise_op_param<6>(src, dst), dst); | |||||
break; | break; | ||||
case Mode::ROUND_SHR_SATURATE_IXxI8xI16: | case Mode::ROUND_SHR_SATURATE_IXxI8xI16: | ||||
on_round_shr_saturate_iXxi8xi16( | |||||
make_elemwise_op_param<2>(src, dst), dst.ptr<dt_int16>()); | |||||
on_round_shr_saturate_iXxi8xi16(make_elemwise_op_param<2>(src, dst), dst); | |||||
break; | break; | ||||
ON_QUANTIZED_MODE(RELU, 1); | ON_QUANTIZED_MODE(RELU, 1); | ||||
ON_QUANTIZED_MODE(ABS, 1); | ON_QUANTIZED_MODE(ABS, 1); | ||||
@@ -33,22 +33,22 @@ class ElemwiseMultiTypeImplHelper : public ElemwiseMultiType, | |||||
protected: | protected: | ||||
virtual void on_fuse_mul_add3_int16x32x32x32( | virtual void on_fuse_mul_add3_int16x32x32x32( | ||||
const ElemwiseOpParamN<3>& param, dt_int32* dst) = 0; | |||||
const ElemwiseOpParamN<3>& param, const TensorND& dst) = 0; | |||||
virtual void on_fuse_mul_add3_iXxf32xf32xi8( | virtual void on_fuse_mul_add3_iXxf32xf32xi8( | ||||
const ElemwiseOpParamN<3>& param, dt_int8* dst) = 0; | |||||
const ElemwiseOpParamN<3>& param, const TensorND& dst) = 0; | |||||
virtual void on_round_shr_saturate_iXxi8xi8( | virtual void on_round_shr_saturate_iXxi8xi8( | ||||
const ElemwiseOpParamN<2>& param, dt_int8* dst) = 0; | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst) = 0; | |||||
virtual void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | virtual void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | ||||
const ElemwiseOpParamN<6>& param, dt_int8* dst) = 0; | |||||
const ElemwiseOpParamN<6>& param, const TensorND& dst) = 0; | |||||
virtual void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | virtual void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | ||||
const ElemwiseOpParamN<6>& param, dt_int8* dst) = 0; | |||||
const ElemwiseOpParamN<6>& param, const TensorND& dst) = 0; | |||||
virtual void on_round_shr_saturate_iXxi8xi16( | virtual void on_round_shr_saturate_iXxi8xi16( | ||||
const ElemwiseOpParamN<2>& param, dt_int16* dst) = 0; | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst) = 0; | |||||
virtual void on_quantized_mode( | virtual void on_quantized_mode( | ||||
const ElemwiseOpParamN<1>& param, const TensorND& dst, | const ElemwiseOpParamN<1>& param, const TensorND& dst, | ||||
@@ -29,9 +29,9 @@ template <int N, int OC> | |||||
void local_xcorr_tpl(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; | void local_xcorr_tpl(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; | ||||
template <int N, int OC> | template <int N, int OC> | ||||
void local_xcorr_tpl(const LocalKParam& kparam) { | void local_xcorr_tpl(const LocalKParam& kparam) { | ||||
const float* src = static_cast<const float*>(kparam.src); | |||||
const float* filter = static_cast<const float*>(kparam.filter); | |||||
float* dst = static_cast<float*>(kparam.dst); | |||||
const float* src = static_cast<const float*>(kparam.src.get_ptr()); | |||||
const float* filter = static_cast<const float*>(kparam.filter.get_ptr()); | |||||
float* dst = static_cast<float*>(kparam.dst.get_ptr()); | |||||
float* workspace = static_cast<float*>(kparam.workspace); | float* workspace = static_cast<float*>(kparam.workspace); | ||||
const int IC = kparam.ic, IH = kparam.ih, IW = kparam.iw, OH = kparam.oh, | const int IC = kparam.ic, IH = kparam.ih, IW = kparam.iw, OH = kparam.oh, | ||||
OW = kparam.ow, FH = kparam.fh, FW = kparam.fw; | OW = kparam.ow, FH = kparam.fh, FW = kparam.fw; | ||||
@@ -191,9 +191,9 @@ template <int N, int OC> | |||||
void local_conv_tpl(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; | void local_conv_tpl(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; | ||||
template <int N, int OC> | template <int N, int OC> | ||||
void local_conv_tpl(const LocalKParam& kparam) { | void local_conv_tpl(const LocalKParam& kparam) { | ||||
const float* src = static_cast<const float*>(kparam.src); | |||||
const float* filter = static_cast<const float*>(kparam.filter); | |||||
float* dst = static_cast<float*>(kparam.dst); | |||||
const float* src = static_cast<const float*>(kparam.src.get_ptr()); | |||||
const float* filter = static_cast<const float*>(kparam.filter.get_ptr()); | |||||
float* dst = static_cast<float*>(kparam.dst.get_ptr()); | |||||
float* workspace = static_cast<float*>(kparam.workspace); | float* workspace = static_cast<float*>(kparam.workspace); | ||||
const int IC = kparam.ic, IH = kparam.ih, IW = kparam.iw, OH = kparam.oh, | const int IC = kparam.ic, IH = kparam.ih, IW = kparam.iw, OH = kparam.oh, | ||||
OW = kparam.ow, FH = kparam.fh, FW = kparam.fw; | OW = kparam.ow, FH = kparam.fh, FW = kparam.fw; | ||||
@@ -11,9 +11,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
#if MEGDNN_CC_HOST | |||||
#include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
#endif | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace reduce { | namespace reduce { | ||||
@@ -24,16 +22,14 @@ struct SumOp { | |||||
const wtype INIT; | const wtype INIT; | ||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
RefPtr src; | |||||
RefPtr dst; | |||||
const size_t B; | const size_t B; | ||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
return lhs + rhs; | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE SumOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; } | |||||
void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
static wtype apply(wtype lhs, wtype rhs) { return lhs + rhs; } | |||||
SumOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
: INIT(wtype(0)), src(src), dst(dst), B(B) {} | : INIT(wtype(0)), src(src), dst(dst), B(B) {} | ||||
}; | }; | ||||
@@ -43,18 +39,16 @@ struct MeanOp { | |||||
const wtype INIT; | const wtype INIT; | ||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
RefPtr src; | |||||
RefPtr dst; | |||||
const size_t B; | const size_t B; | ||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { | |||||
dst[idx] = val / static_cast<wtype>(B); | |||||
} | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
return lhs + rhs; | |||||
wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; } | |||||
void write(uint32_t idx, wtype val) { | |||||
dst.ptr<dst_ctype>()[idx] = val / static_cast<wtype>(B); | |||||
} | } | ||||
MEGDNN_HOST MEGDNN_DEVICE MeanOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
static wtype apply(wtype lhs, wtype rhs) { return lhs + rhs; } | |||||
MeanOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
: INIT(wtype(0)), src(src), dst(dst), B(B) {} | : INIT(wtype(0)), src(src), dst(dst), B(B) {} | ||||
}; | }; | ||||
@@ -64,18 +58,17 @@ struct SumSqrOp { | |||||
const wtype INIT; | const wtype INIT; | ||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
RefPtr src; | |||||
RefPtr dst; | |||||
const size_t B; | const size_t B; | ||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | |||||
return static_cast<wtype>(src[idx]) * static_cast<wtype>(src[idx]); | |||||
wtype read(uint32_t idx) { | |||||
return static_cast<wtype>(src.ptr<src_ctype>()[idx]) * | |||||
static_cast<wtype>(src.ptr<src_ctype>()[idx]); | |||||
} | } | ||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
return lhs + rhs; | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE SumSqrOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
static wtype apply(wtype lhs, wtype rhs) { return lhs + rhs; } | |||||
SumSqrOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
: INIT(wtype(0)), src(src), dst(dst), B(B) {} | : INIT(wtype(0)), src(src), dst(dst), B(B) {} | ||||
}; | }; | ||||
@@ -84,16 +77,14 @@ struct ProdOp { | |||||
typedef wtype_ wtype; | typedef wtype_ wtype; | ||||
const wtype INIT; | const wtype INIT; | ||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
RefPtr src; | |||||
RefPtr dst; | |||||
const size_t B; | const size_t B; | ||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
return lhs * rhs; | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE ProdOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; } | |||||
void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
static wtype apply(wtype lhs, wtype rhs) { return lhs * rhs; } | |||||
ProdOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
: INIT(wtype(1)), src(src), dst(dst), B(B) {} | : INIT(wtype(1)), src(src), dst(dst), B(B) {} | ||||
}; | }; | ||||
@@ -102,20 +93,14 @@ struct MinOp { | |||||
typedef wtype_ wtype; | typedef wtype_ wtype; | ||||
const wtype INIT; | const wtype INIT; | ||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
RefPtr src; | |||||
RefPtr dst; | |||||
const size_t B; | const size_t B; | ||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
#if defined(__CUDA_ARCH__) | |||||
return lhs < rhs ? lhs : rhs; | |||||
#else | |||||
return std::min(lhs, rhs); | |||||
#endif | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE MinOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; } | |||||
void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
static wtype apply(wtype lhs, wtype rhs) { return std::min(lhs, rhs); } | |||||
MinOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
: INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | : INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | ||||
}; | }; | ||||
@@ -124,20 +109,16 @@ struct MinOp<src_ctype, dst_ctype, dt_float32> { | |||||
typedef dt_float32 wtype; | typedef dt_float32 wtype; | ||||
const wtype INIT; | const wtype INIT; | ||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
RefPtr src; | |||||
RefPtr dst; | |||||
const size_t B; | const size_t B; | ||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
#if defined(__CUDA_ARCH__) | |||||
return (isnan(lhs) || lhs < rhs) ? lhs : rhs; | |||||
#else | |||||
wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; } | |||||
void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
static wtype apply(wtype lhs, wtype rhs) { | |||||
return (std::isnan(lhs) || lhs < rhs) ? lhs : rhs; | return (std::isnan(lhs) || lhs < rhs) ? lhs : rhs; | ||||
#endif | |||||
} | } | ||||
MEGDNN_HOST MEGDNN_DEVICE MinOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
MinOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
: INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | : INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | ||||
}; | }; | ||||
@@ -146,20 +127,14 @@ struct MaxOp { | |||||
typedef wtype_ wtype; | typedef wtype_ wtype; | ||||
const wtype INIT; | const wtype INIT; | ||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
RefPtr src; | |||||
RefPtr dst; | |||||
const size_t B; | const size_t B; | ||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
#if defined(__CUDA_ARCH__) | |||||
return lhs > rhs ? lhs : rhs; | |||||
#else | |||||
return std::max(lhs, rhs); | |||||
#endif | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE MaxOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; } | |||||
void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
static wtype apply(wtype lhs, wtype rhs) { return std::max(lhs, rhs); } | |||||
MaxOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | ||||
}; | }; | ||||
@@ -168,20 +143,16 @@ struct MaxOp<src_ctype, dst_ctype, dt_float32> { | |||||
typedef dt_float32 wtype; | typedef dt_float32 wtype; | ||||
const wtype INIT; | const wtype INIT; | ||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
RefPtr src; | |||||
RefPtr dst; | |||||
const size_t B; | const size_t B; | ||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
#if defined(__CUDA_ARCH__) | |||||
return (isnan(lhs) || lhs > rhs) ? lhs : rhs; | |||||
#else | |||||
wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; } | |||||
void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
static wtype apply(wtype lhs, wtype rhs) { | |||||
return (std::isnan(lhs) || lhs > rhs) ? lhs : rhs; | return (std::isnan(lhs) || lhs > rhs) ? lhs : rhs; | ||||
#endif | |||||
} | } | ||||
MEGDNN_HOST MEGDNN_DEVICE MaxOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
MaxOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | ||||
}; | }; | ||||
@@ -190,28 +161,19 @@ struct CheckNonFiniteOp { | |||||
typedef wtype_ wtype; | typedef wtype_ wtype; | ||||
const wtype INIT; | const wtype INIT; | ||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
RefPtr src; | |||||
RefPtr dst; | |||||
const size_t B; | const size_t B; | ||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | |||||
#if defined(__CUDA_ARCH__) | |||||
return !isfinite(src[idx]); | |||||
#else | |||||
return !std::isfinite(src[idx]); | |||||
#endif | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
return lhs | rhs; | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
wtype read(uint32_t idx) { return !std::isfinite(src.ptr<src_ctype>()[idx]); } | |||||
void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
static wtype apply(wtype lhs, wtype rhs) { return lhs | rhs; } | |||||
MEGDNN_HOST MEGDNN_DEVICE | |||||
CheckNonFiniteOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
: INIT(wtype(0)), src(src), dst(dst), B(B) {} | : INIT(wtype(0)), src(src), dst(dst), B(B) {} | ||||
}; | }; | ||||
#if MEGDNN_CC_HOST | |||||
void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis); | void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis); | ||||
#endif | |||||
} // namespace reduce | } // namespace reduce | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -0,0 +1,222 @@ | |||||
/** | |||||
* \file dnn/src/common/reduce_helper_device.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/dtype.h" | |||||
#if MEGDNN_CC_HOST | |||||
#include "megdnn/basic_types.h" | |||||
#endif | |||||
namespace megdnn { | |||||
namespace device_reduce { | |||||
template <typename src_ctype, typename dst_ctype, typename wtype_> | |||||
struct SumOp { | |||||
typedef wtype_ wtype; | |||||
const wtype INIT; | |||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
const size_t B; | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
return lhs + rhs; | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE SumOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
: INIT(wtype(0)), src(src), dst(dst), B(B) {} | |||||
}; | |||||
template <typename src_ctype, typename dst_ctype, typename wtype_> | |||||
struct MeanOp { | |||||
typedef wtype_ wtype; | |||||
const wtype INIT; | |||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
const size_t B; | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { | |||||
dst[idx] = val / static_cast<wtype>(B); | |||||
} | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
return lhs + rhs; | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE MeanOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
: INIT(wtype(0)), src(src), dst(dst), B(B) {} | |||||
}; | |||||
template <typename src_ctype, typename dst_ctype, typename wtype_> | |||||
struct SumSqrOp { | |||||
typedef wtype_ wtype; | |||||
const wtype INIT; | |||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
const size_t B; | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | |||||
return static_cast<wtype>(src[idx]) * static_cast<wtype>(src[idx]); | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
return lhs + rhs; | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE SumSqrOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
: INIT(wtype(0)), src(src), dst(dst), B(B) {} | |||||
}; | |||||
template <typename src_ctype, typename dst_ctype, typename wtype_> | |||||
struct ProdOp { | |||||
typedef wtype_ wtype; | |||||
const wtype INIT; | |||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
const size_t B; | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
return lhs * rhs; | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE ProdOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
: INIT(wtype(1)), src(src), dst(dst), B(B) {} | |||||
}; | |||||
template <typename src_ctype, typename dst_ctype, typename wtype_> | |||||
struct MinOp { | |||||
typedef wtype_ wtype; | |||||
const wtype INIT; | |||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
const size_t B; | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
#if defined(__CUDA_ARCH__) | |||||
return lhs < rhs ? lhs : rhs; | |||||
#else | |||||
return std::min(lhs, rhs); | |||||
#endif | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE MinOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
: INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | |||||
}; | |||||
template <typename src_ctype, typename dst_ctype> | |||||
struct MinOp<src_ctype, dst_ctype, dt_float32> { | |||||
typedef dt_float32 wtype; | |||||
const wtype INIT; | |||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
const size_t B; | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
#if defined(__CUDA_ARCH__) | |||||
return (isnan(lhs) || lhs < rhs) ? lhs : rhs; | |||||
#else | |||||
return (std::isnan(lhs) || lhs < rhs) ? lhs : rhs; | |||||
#endif | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE MinOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
: INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | |||||
}; | |||||
template <typename src_ctype, typename dst_ctype, typename wtype_> | |||||
struct MaxOp { | |||||
typedef wtype_ wtype; | |||||
const wtype INIT; | |||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
const size_t B; | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
#if defined(__CUDA_ARCH__) | |||||
return lhs > rhs ? lhs : rhs; | |||||
#else | |||||
return std::max(lhs, rhs); | |||||
#endif | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE MaxOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | |||||
}; | |||||
template <typename src_ctype, typename dst_ctype> | |||||
struct MaxOp<src_ctype, dst_ctype, dt_float32> { | |||||
typedef dt_float32 wtype; | |||||
const wtype INIT; | |||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
const size_t B; | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
#if defined(__CUDA_ARCH__) | |||||
return (isnan(lhs) || lhs > rhs) ? lhs : rhs; | |||||
#else | |||||
return (std::isnan(lhs) || lhs > rhs) ? lhs : rhs; | |||||
#endif | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE MaxOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | |||||
}; | |||||
template <typename src_ctype, typename dst_ctype, typename wtype_> | |||||
struct CheckNonFiniteOp { | |||||
typedef wtype_ wtype; | |||||
const wtype INIT; | |||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
const size_t B; | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | |||||
#if defined(__CUDA_ARCH__) | |||||
return !isfinite(src[idx]); | |||||
#else | |||||
return !std::isfinite(src[idx]); | |||||
#endif | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
return lhs | rhs; | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
: INIT(wtype(0)), src(src), dst(dst), B(B) {} | |||||
}; | |||||
} // namespace device_reduce | |||||
namespace reduce { | |||||
#if MEGDNN_CC_HOST | |||||
void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis); | |||||
#endif | |||||
} // namespace reduce | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -362,6 +362,10 @@ static inline void copy_plane_in_bytes( | |||||
megcoreDeviceHandle_t get_device_handle(Handle* handle); | megcoreDeviceHandle_t get_device_handle(Handle* handle); | ||||
static inline void incr_refp(RefPtr& ptr, ptrdiff_t delta) { | |||||
ptr += (size_t)delta; | |||||
} | |||||
static inline void incr_voidp(void*& ptr, ptrdiff_t delta) { | static inline void incr_voidp(void*& ptr, ptrdiff_t delta) { | ||||
ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(ptr) + delta); | ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(ptr) + delta); | ||||
} | } | ||||
@@ -674,7 +678,8 @@ struct CompTypeCvter { | |||||
comp.layout.dtype.enumv() != DTypeTrait<CompType>::enumv) { | comp.layout.dtype.enumv() != DTypeTrait<CompType>::enumv) { | ||||
comp.layout.dtype = CompType(); | comp.layout.dtype = CompType(); | ||||
comp.layout.init_contiguous_stride(); | comp.layout.init_contiguous_stride(); | ||||
comp.raw_ptr = m_workspace_bundle->get(m_workspace_idx++); | |||||
comp = TensorND{ | |||||
m_workspace_bundle->get(m_workspace_idx++), comp.layout}; | |||||
if (src.layout.ndim) { | if (src.layout.ndim) { | ||||
m_cvt_opr->exec(src, comp); | m_cvt_opr->exec(src, comp); | ||||
} | } | ||||
@@ -699,7 +704,7 @@ struct CompTypeCvter { | |||||
* \brief get TensorND raw_ptr+low_byte pointer. | * \brief get TensorND raw_ptr+low_byte pointer. | ||||
*/ | */ | ||||
inline dt_byte* get_low_ptr(const TensorND* tensor) { | inline dt_byte* get_low_ptr(const TensorND* tensor) { | ||||
return static_cast<dt_byte*>(tensor->raw_ptr) + tensor->layout.span().low_byte; | |||||
return static_cast<dt_byte*>(tensor->raw_ptr()) + tensor->layout.span().low_byte; | |||||
} | } | ||||
/*! | /*! | ||||
@@ -11,7 +11,7 @@ | |||||
#include "src/cuda/argmxx/opr_impl.h" | #include "src/cuda/argmxx/opr_impl.h" | ||||
#include "src/common/argmxx_helper.h" | #include "src/common/argmxx_helper.h" | ||||
#include "src/common/reduce_helper.h" | |||||
#include "src/common/reduce_helper_device.h" | |||||
#include "src/cuda/reduce_helper.cuh" | #include "src/cuda/reduce_helper.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
@@ -117,32 +117,34 @@ void BNForwardImpl::exec( | |||||
#if CUDNN_VERSION >= 7410 | #if CUDNN_VERSION >= 7410 | ||||
cudnn_check(cudnnBatchNormalizationForwardTrainingEx( | cudnn_check(cudnnBatchNormalizationForwardTrainingEx( | ||||
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, | handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, | ||||
&beta, // one & zero | |||||
tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x | |||||
nullptr, nullptr, // zDesc & z | |||||
tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y | |||||
tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||||
bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, mean.raw_ptr, | |||||
variance.raw_ptr, m_param.epsilon, batch_mean.raw_ptr, | |||||
batch_inv_variance.raw_ptr, nullptr, workspace.raw_ptr, | |||||
workspace.size, reserve.raw_ptr, reserve.layout.access_bytes())); | |||||
&beta, // one & zero | |||||
tensor_desc.xy_desc.desc, src.raw_ptr(), // xDesc & x | |||||
nullptr, nullptr, // zDesc & z | |||||
tensor_desc.xy_desc.desc, dst.raw_ptr(), // yDesc & y | |||||
tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||||
bn_scale.raw_ptr(), bn_bias.raw_ptr(), m_param.avg_factor, | |||||
mean.raw_ptr(), variance.raw_ptr(), m_param.epsilon, | |||||
batch_mean.raw_ptr(), batch_inv_variance.raw_ptr(), nullptr, | |||||
workspace.raw_ptr, workspace.size, reserve.raw_ptr(), | |||||
reserve.layout.access_bytes())); | |||||
#else | #else | ||||
cudnn_check(cudnnBatchNormalizationForwardTraining( | cudnn_check(cudnnBatchNormalizationForwardTraining( | ||||
handle, tensor_desc.bn_mode, &alpha, &beta, | handle, tensor_desc.bn_mode, &alpha, &beta, | ||||
tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x | |||||
tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y | |||||
tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||||
bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, mean.raw_ptr, | |||||
variance.raw_ptr, m_param.epsilon, batch_mean.raw_ptr, | |||||
batch_inv_variance.raw_ptr)); | |||||
tensor_desc.xy_desc.desc, src.raw_ptr(), // xDesc & x | |||||
tensor_desc.xy_desc.desc, dst.raw_ptr(), // yDesc & y | |||||
tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||||
bn_scale.raw_ptr(), bn_bias.raw_ptr(), m_param.avg_factor, | |||||
mean.raw_ptr(), variance.raw_ptr(), m_param.epsilon, | |||||
batch_mean.raw_ptr(), batch_inv_variance.raw_ptr())); | |||||
#endif // CUDNN_VERSION >= 7410 | #endif // CUDNN_VERSION >= 7410 | ||||
break; | break; | ||||
case param::BN::FwdMode::INFERENCE: | case param::BN::FwdMode::INFERENCE: | ||||
cudnn_check(cudnnBatchNormalizationForwardInference( | cudnn_check(cudnnBatchNormalizationForwardInference( | ||||
handle, tensor_desc.bn_mode, &alpha, &beta, | handle, tensor_desc.bn_mode, &alpha, &beta, | ||||
tensor_desc.xy_desc.desc, src.raw_ptr, tensor_desc.xy_desc.desc, | |||||
dst.raw_ptr, tensor_desc.param_desc.desc, bn_scale.raw_ptr, | |||||
bn_bias.raw_ptr, mean.raw_ptr, variance.raw_ptr, m_param.epsilon)); | |||||
tensor_desc.xy_desc.desc, src.raw_ptr(), tensor_desc.xy_desc.desc, | |||||
dst.raw_ptr(), tensor_desc.param_desc.desc, bn_scale.raw_ptr(), | |||||
bn_bias.raw_ptr(), mean.raw_ptr(), variance.raw_ptr(), | |||||
m_param.epsilon)); | |||||
break; | break; | ||||
default: | default: | ||||
megdnn_throw("Unknown forward mode type of batch normalization."); | megdnn_throw("Unknown forward mode type of batch normalization."); | ||||
@@ -198,27 +200,27 @@ void BNBackwardImpl::exec( | |||||
cudnn_check(cudnnBatchNormalizationBackwardEx( | cudnn_check(cudnnBatchNormalizationBackwardEx( | ||||
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, &beta, &alpha, | handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, &beta, &alpha, | ||||
&beta, tensor_desc.xy_desc.desc, | &beta, tensor_desc.xy_desc.desc, | ||||
x.raw_ptr, // xDesc & x | |||||
nullptr, nullptr, // yDesc & y | |||||
tensor_desc.xy_desc.desc, dy.raw_ptr, // dyDesc & dy | |||||
nullptr, nullptr, // dzDesc & dz | |||||
tensor_desc.xy_desc.desc, dx.raw_ptr, // dxDesc & dx | |||||
tensor_desc.param_desc.desc, bn_scale.raw_ptr, // bnScale | |||||
nullptr, // bnBias | |||||
d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, // dScale, dBias | |||||
m_param.epsilon, saved_batch_mean.raw_ptr, saved_batch_inv_variance.raw_ptr, | |||||
nullptr, workspace.raw_ptr, workspace.size, reserve.raw_ptr, | |||||
reserve.layout.access_bytes())); | |||||
x.raw_ptr(), // xDesc & x | |||||
nullptr, nullptr, // yDesc & y | |||||
tensor_desc.xy_desc.desc, dy.raw_ptr(), // dyDesc & dy | |||||
nullptr, nullptr, // dzDesc & dz | |||||
tensor_desc.xy_desc.desc, dx.raw_ptr(), // dxDesc & dx | |||||
tensor_desc.param_desc.desc, bn_scale.raw_ptr(), // bnScale | |||||
nullptr, // bnBias | |||||
d_bn_scale.raw_ptr(), d_bn_bias.raw_ptr(), // dScale, dBias | |||||
m_param.epsilon, saved_batch_mean.raw_ptr(), | |||||
saved_batch_inv_variance.raw_ptr(), nullptr, workspace.raw_ptr, | |||||
workspace.size, reserve.raw_ptr(), reserve.layout.access_bytes())); | |||||
#else | #else | ||||
cudnn_check(cudnnBatchNormalizationBackward( | cudnn_check(cudnnBatchNormalizationBackward( | ||||
handle, tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta, | handle, tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta, | ||||
tensor_desc.xy_desc.desc, x.raw_ptr, // xDesc & x | |||||
tensor_desc.xy_desc.desc, dy.raw_ptr, // dyDesc & dy | |||||
tensor_desc.xy_desc.desc, dx.raw_ptr, // dxDesc & dx | |||||
tensor_desc.param_desc.desc, bn_scale.raw_ptr, // bnScale | |||||
d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, // dScale, dBias | |||||
m_param.epsilon, saved_batch_mean.raw_ptr, | |||||
saved_batch_inv_variance.raw_ptr)); | |||||
tensor_desc.xy_desc.desc, x.raw_ptr(), // xDesc & x | |||||
tensor_desc.xy_desc.desc, dy.raw_ptr(), // dyDesc & dy | |||||
tensor_desc.xy_desc.desc, dx.raw_ptr(), // dxDesc & dx | |||||
tensor_desc.param_desc.desc, bn_scale.raw_ptr(), // bnScale | |||||
d_bn_scale.raw_ptr(), d_bn_bias.raw_ptr(), // dScale, dBias | |||||
m_param.epsilon, saved_batch_mean.raw_ptr(), | |||||
saved_batch_inv_variance.raw_ptr())); | |||||
#endif | #endif | ||||
} | } | ||||
@@ -80,9 +80,9 @@ void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec(const ExecArgs& args) con | |||||
rep(n, N) { | rep(n, N) { | ||||
TensorND A_, B_, C_; | TensorND A_, B_, C_; | ||||
auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) { | auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) { | ||||
out.raw_ptr = static_cast<void*>( | |||||
static_cast<dt_byte*>(in.raw_ptr) + | |||||
n * in.layout.stride[0] * in.layout.dtype.size()); | |||||
out.reset_ptr(static_cast<void*>( | |||||
static_cast<dt_byte*>(in.raw_ptr()) + | |||||
n * in.layout.stride[0] * in.layout.dtype.size())); | |||||
out.layout = in.layout.remove_axis(0); | out.layout = in.layout.remove_axis(0); | ||||
}; | }; | ||||
tensor_n_from_batch(args.tensor_a, A_); | tensor_n_from_batch(args.tensor_a, A_); | ||||
@@ -76,13 +76,13 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { | |||||
static_cast<void*>(workspace.raw_ptr + 2 * batch * sizeof(uintptr_t))); | static_cast<void*>(workspace.raw_ptr + 2 * batch * sizeof(uintptr_t))); | ||||
arange<uintptr_t>( | arange<uintptr_t>( | ||||
As, reinterpret_cast<uintptr_t>(args.tensor_a.raw_ptr), | |||||
As, reinterpret_cast<uintptr_t>(args.tensor_a.raw_ptr()), | |||||
args.layout_a.stride[0] * dtype.size(), batch, stream); | args.layout_a.stride[0] * dtype.size(), batch, stream); | ||||
arange<uintptr_t>( | arange<uintptr_t>( | ||||
Bs, reinterpret_cast<uintptr_t>(args.tensor_b.raw_ptr), | |||||
Bs, reinterpret_cast<uintptr_t>(args.tensor_b.raw_ptr()), | |||||
args.layout_b.stride[0] * dtype.size(), batch, stream); | args.layout_b.stride[0] * dtype.size(), batch, stream); | ||||
arange<uintptr_t>( | arange<uintptr_t>( | ||||
Cs, reinterpret_cast<uintptr_t>(args.tensor_c.raw_ptr), | |||||
Cs, reinterpret_cast<uintptr_t>(args.tensor_c.raw_ptr()), | |||||
args.layout_c.stride[0] * dtype.size(), batch, stream); | args.layout_c.stride[0] * dtype.size(), batch, stream); | ||||
auto io32_c32 = [&]() { | auto io32_c32 = [&]() { | ||||
@@ -62,10 +62,10 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const | |||||
"workspace bundle size should be 1(ws_algo)"); | "workspace bundle size should be 1(ws_algo)"); | ||||
cublas_check(cublasLtMatmul( | cublas_check(cublasLtMatmul( | ||||
cublasLt_handle, desc.matmul_desc, one_half, | cublasLt_handle, desc.matmul_desc, one_half, | ||||
static_cast<const __half*>(args.tensor_b.raw_ptr), desc.layout_b, | |||||
static_cast<const __half*>(args.tensor_a.raw_ptr), desc.layout_a, | |||||
zero_half, static_cast<const __half*>(args.tensor_c.raw_ptr), | |||||
desc.layout_c, static_cast<__half*>(args.tensor_c.raw_ptr), | |||||
static_cast<const __half*>(args.tensor_b.raw_ptr()), desc.layout_b, | |||||
static_cast<const __half*>(args.tensor_a.raw_ptr()), desc.layout_a, | |||||
zero_half, static_cast<const __half*>(args.tensor_c.raw_ptr()), | |||||
desc.layout_c, static_cast<__half*>(args.tensor_c.raw_ptr()), | |||||
desc.layout_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0), stream)); | desc.layout_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0), stream)); | ||||
}; | }; | ||||
auto batched_sgemm = [&]() { | auto batched_sgemm = [&]() { | ||||
@@ -77,7 +77,7 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const | |||||
auto dev_a = (desc.dt_a == CUDA_R_16F) | auto dev_a = (desc.dt_a == CUDA_R_16F) | ||||
? static_cast<void*>(args.tensor_a.ptr<dt_float16>()) | ? static_cast<void*>(args.tensor_a.ptr<dt_float16>()) | ||||
: static_cast<void*>(args.tensor_a.ptr<dt_float32>()); | : static_cast<void*>(args.tensor_a.ptr<dt_float32>()); | ||||
auto dev_c = static_cast<void*>(args.tensor_c.raw_ptr); | |||||
auto dev_c = static_cast<void*>(args.tensor_c.raw_ptr()); | |||||
megdnn_assert( | megdnn_assert( | ||||
ws_bundle.nr_workspace() == 1, | ws_bundle.nr_workspace() == 1, | ||||
"workspace bundle size should be 1(ws_algo)"); | "workspace bundle size should be 1(ws_algo)"); | ||||
@@ -104,14 +104,14 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const | |||||
transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, &pm, | transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, &pm, | ||||
sizeof(pm))); | sizeof(pm))); | ||||
cublas_check(cublasLtMatrixTransform( | cublas_check(cublasLtMatrixTransform( | ||||
cublasLt_handle, transform_desc, one, args.tensor_b.raw_ptr, | |||||
cublasLt_handle, transform_desc, one, args.tensor_b.raw_ptr(), | |||||
desc.layout_b, zero, nullptr, nullptr, ws_b, desc.layout_trans_b, | desc.layout_b, zero, nullptr, nullptr, ws_b, desc.layout_trans_b, | ||||
stream)); | stream)); | ||||
cublas_check(cublasLtMatrixTransformDescSetAttribute( | cublas_check(cublasLtMatrixTransformDescSetAttribute( | ||||
transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_a, | transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_a, | ||||
sizeof(trans_a))); | sizeof(trans_a))); | ||||
cublas_check(cublasLtMatrixTransform( | cublas_check(cublasLtMatrixTransform( | ||||
cublasLt_handle, transform_desc, one, args.tensor_a.raw_ptr, | |||||
cublasLt_handle, transform_desc, one, args.tensor_a.raw_ptr(), | |||||
desc.layout_a, zero, nullptr, nullptr, ws_a, desc.layout_trans_a, | desc.layout_a, zero, nullptr, nullptr, ws_a, desc.layout_trans_a, | ||||
stream)); | stream)); | ||||
cublas_check(cublasLtMatmul( | cublas_check(cublasLtMatmul( | ||||
@@ -124,7 +124,7 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const | |||||
sizeof(trans_c))); | sizeof(trans_c))); | ||||
cublas_check(cublasLtMatrixTransform( | cublas_check(cublasLtMatrixTransform( | ||||
cublasLt_handle, transform_desc, one, ws_c, desc.layout_trans_c, zero, | cublasLt_handle, transform_desc, one, ws_c, desc.layout_trans_c, zero, | ||||
nullptr, nullptr, args.tensor_c.raw_ptr, desc.layout_c, stream)); | |||||
nullptr, nullptr, args.tensor_c.raw_ptr(), desc.layout_c, stream)); | |||||
cublas_check(cublasLtMatrixTransformDescDestroy(transform_desc)); | cublas_check(cublasLtMatrixTransformDescDestroy(transform_desc)); | ||||
}; | }; | ||||
@@ -8,7 +8,7 @@ | |||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/common/reduce_helper.h" | |||||
#include "src/common/reduce_helper_device.h" | |||||
#include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
#include "src/cuda/reduce_helper.cuh" | #include "src/cuda/reduce_helper.cuh" | ||||
@@ -18,7 +18,9 @@ namespace cuda { | |||||
#define COMMA , | #define COMMA , | ||||
INST_REDUCE(reduce::CheckNonFiniteOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, false); | |||||
INST_REDUCE( | |||||
device_reduce::CheckNonFiniteOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, | |||||
false); | |||||
#undef COMMA | #undef COMMA | ||||
} // namespace cuda | } // namespace cuda | ||||
@@ -15,12 +15,12 @@ | |||||
#include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/common/reduce_helper.h" | |||||
#include "src/common/reduce_helper_device.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace cuda { | namespace cuda { | ||||
using reduce::CheckNonFiniteOp; | |||||
using device_reduce::CheckNonFiniteOp; | |||||
size_t CheckNonFiniteImpl::get_workspace_in_bytes( | size_t CheckNonFiniteImpl::get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& dst) { | const TensorLayout& src, const TensorLayout& dst) { | ||||
@@ -45,7 +45,7 @@ ChecksumForward::Result ChecksumForwardImpl::exec( | |||||
check_exec(data.layout, workspace.size); | check_exec(data.layout, workspace.size); | ||||
auto stream = cuda_stream(handle()); | auto stream = cuda_stream(handle()); | ||||
auto ptr = static_cast<uint8_t*>(data.raw_ptr); | |||||
auto ptr = static_cast<uint8_t*>(data.raw_ptr()); | |||||
size_t size_all = data.layout.shape[0], size_ints = size_all / sizeof(uint32_t); | size_t size_all = data.layout.shape[0], size_ints = size_all / sizeof(uint32_t); | ||||
auto last_val_size = std::min<size_t>(size_all, 4); | auto last_val_size = std::min<size_t>(size_all, 4); | ||||
cuda_check(cudaMemcpyAsync( | cuda_check(cudaMemcpyAsync( | ||||
@@ -54,7 +54,7 @@ ChecksumForward::Result ChecksumForwardImpl::exec( | |||||
if (size_ints) { | if (size_ints) { | ||||
checksum::calc( | checksum::calc( | ||||
static_cast<uint32_t*>(wbundle.get(1)), | static_cast<uint32_t*>(wbundle.get(1)), | ||||
static_cast<uint32_t*>(data.raw_ptr), | |||||
static_cast<uint32_t*>(data.raw_ptr()), | |||||
static_cast<uint32_t*>(wbundle.get(0)), size_ints, stream); | static_cast<uint32_t*>(wbundle.get(0)), size_ints, stream); | ||||
cuda_check(cudaMemcpyAsync( | cuda_check(cudaMemcpyAsync( | ||||
&result.checksum, wbundle.get(1), sizeof(result.checksum), | &result.checksum, wbundle.get(1), sizeof(result.checksum), | ||||
@@ -135,9 +135,9 @@ size_t ConvBiasForwardImpl::AlgoBatchedMatmul::get_workspace_in_bytes( | |||||
void ConvBiasForwardImpl::AlgoBatchedMatmul::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoBatchedMatmul::exec(const ExecArgs& args) const { | ||||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
auto conv_dst_tensor = *args.dst_tensor; | |||||
TensorND conv_dst_tensor = *args.dst_tensor; | |||||
if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
conv_dst_tensor.raw_ptr = bundle.get(1); | |||||
conv_dst_tensor = TensorND{bundle.get(1), args.dst_tensor->layout}; | |||||
conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
@@ -150,9 +150,9 @@ void ConvBiasForwardImpl::AlgoBatchedMatmul::exec(const ExecArgs& args) const { | |||||
{ | { | ||||
auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
TensorND A{args.filter_tensor->raw_ptr, config.first[0]}, | |||||
B{args.src_tensor->raw_ptr, config.first[1]}, | |||||
C{args.dst_tensor->raw_ptr, config.first[2]}; | |||||
TensorND A{args.filter_tensor->raw_ptr(), config.first[0]}, | |||||
B{args.src_tensor->raw_ptr(), config.first[1]}, | |||||
C{args.dst_tensor->raw_ptr(), config.first[2]}; | |||||
config.second->exec(A, B, C, bundle.get_workspace(0)); | config.second->exec(A, B, C, bundle.get_workspace(0)); | ||||
} | } | ||||
handle_bias_and_nonlinear( | handle_bias_and_nonlinear( | ||||
@@ -52,9 +52,9 @@ size_t ConvBiasForwardImpl::AlgoChanwise::get_workspace_in_bytes( | |||||
void ConvBiasForwardImpl::AlgoChanwise::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoChanwise::exec(const ExecArgs& args) const { | ||||
WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | ||||
auto conv_dst_tensor = *args.dst_tensor; | |||||
TensorND conv_dst_tensor = *args.dst_tensor; | |||||
if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
conv_dst_tensor.raw_ptr = bundle.get(0); | |||||
conv_dst_tensor = TensorND{bundle.get(0), args.dst_tensor->layout}; | |||||
conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
@@ -74,9 +74,9 @@ void ConvBiasForwardImpl::AlgoChanwise::exec(const ExecArgs& args) const { | |||||
#if CUDA_VERSION >= 9000 | #if CUDA_VERSION >= 9000 | ||||
if (is_compute_capability_required(5, 3)) { | if (is_compute_capability_required(5, 3)) { | ||||
chanwise::run_fwd( | chanwise::run_fwd( | ||||
static_cast<half*>(conv_dst_tensor.raw_ptr), | |||||
static_cast<half*>(args.src_tensor->raw_ptr), | |||||
static_cast<half*>(args.filter_tensor->raw_ptr), kparam, | |||||
static_cast<half*>(conv_dst_tensor.raw_ptr()), | |||||
static_cast<half*>(args.src_tensor->raw_ptr()), | |||||
static_cast<half*>(args.filter_tensor->raw_ptr()), kparam, | |||||
stream); | stream); | ||||
} else { | } else { | ||||
chanwise::run_fwd( | chanwise::run_fwd( | ||||
@@ -50,9 +50,9 @@ size_t ConvBiasForwardImpl::AlgoChanwise8x8x32::get_workspace_in_bytes( | |||||
void ConvBiasForwardImpl::AlgoChanwise8x8x32::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoChanwise8x8x32::exec(const ExecArgs& args) const { | ||||
WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | ||||
auto conv_dst_tensor = *args.dst_tensor; | |||||
TensorND conv_dst_tensor = *args.dst_tensor; | |||||
if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
conv_dst_tensor.raw_ptr = bundle.get(0); | |||||
conv_dst_tensor = TensorND{bundle.get(0), args.dst_tensor->layout}; | |||||
conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
@@ -65,9 +65,9 @@ size_t ConvBiasForwardImpl::AlgoChanwiseSmall::get_workspace_in_bytes( | |||||
void ConvBiasForwardImpl::AlgoChanwiseSmall::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoChanwiseSmall::exec(const ExecArgs& args) const { | ||||
WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | ||||
auto conv_dst_tensor = *args.dst_tensor; | |||||
TensorND conv_dst_tensor = *args.dst_tensor; | |||||
if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
conv_dst_tensor.raw_ptr = bundle.get(0); | |||||
conv_dst_tensor = TensorND{bundle.get(0), conv_dst_tensor.layout}; | |||||
conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
@@ -85,9 +85,9 @@ void ConvBiasForwardImpl::AlgoChanwiseSmall::exec(const ExecArgs& args) const { | |||||
#if CUDA_VERSION >= 9000 | #if CUDA_VERSION >= 9000 | ||||
case DTypeEnum::Float16: | case DTypeEnum::Float16: | ||||
chanwise::run_fwd_small( | chanwise::run_fwd_small( | ||||
static_cast<half*>(conv_dst_tensor.raw_ptr), | |||||
static_cast<half*>(args.src_tensor->raw_ptr), | |||||
static_cast<half*>(args.filter_tensor->raw_ptr), kparam, | |||||
static_cast<half*>(conv_dst_tensor.raw_ptr()), | |||||
static_cast<half*>(args.src_tensor->raw_ptr()), | |||||
static_cast<half*>(args.filter_tensor->raw_ptr()), kparam, | |||||
stream); | stream); | ||||
break; | break; | ||||
#endif | #endif | ||||
@@ -100,9 +100,9 @@ size_t ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_in_bytes( | |||||
void ConvBiasForwardImpl::AlgoCUDNNConv::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoCUDNNConv::exec(const ExecArgs& args) const { | ||||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
auto conv_dst_tensor = *args.dst_tensor; | |||||
TensorND conv_dst_tensor = *args.dst_tensor; | |||||
if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
conv_dst_tensor.raw_ptr = bundle.get(1); | |||||
conv_dst_tensor = TensorND{bundle.get(1), args.dst_tensor->layout}; | |||||
conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
@@ -120,10 +120,10 @@ void ConvBiasForwardImpl::AlgoCUDNNConv::exec(const ExecArgs& args) const { | |||||
float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
auto status = cudnnConvolutionForward( | auto status = cudnnConvolutionForward( | ||||
conv_args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | conv_args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | ||||
conv_args.src_tensor->raw_ptr, D.filter_desc.desc, | |||||
conv_args.filter_tensor->raw_ptr, D.conv_desc.conv_desc, m_cudnn_enum, | |||||
conv_args.src_tensor->raw_ptr(), D.filter_desc.desc, | |||||
conv_args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum, | |||||
conv_workspace.raw_ptr, conv_workspace.size, &beta, D.dst_desc.desc, | conv_workspace.raw_ptr, conv_workspace.size, &beta, D.dst_desc.desc, | ||||
conv_args.dst_tensor->raw_ptr); | |||||
conv_args.dst_tensor->raw_ptr()); | |||||
megdnn_assert( | megdnn_assert( | ||||
status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s", | status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s", | ||||
cudnnGetErrorString(status), conv_args.to_string().c_str()); | cudnnGetErrorString(status), conv_args.to_string().c_str()); | ||||
@@ -231,7 +231,7 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( | |||||
auto workspace_ptr = args.workspace.raw_ptr; | auto workspace_ptr = args.workspace.raw_ptr; | ||||
auto workspace_size = args.workspace.size; | auto workspace_size = args.workspace.size; | ||||
auto bias_ptr = args.bias_tensor->raw_ptr; | |||||
auto bias_ptr = args.bias_tensor->raw_ptr(); | |||||
if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() && | if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() && | ||||
args.src_layout->dtype.category() != DTypeCategory::FLOAT) { | args.src_layout->dtype.category() != DTypeCategory::FLOAT) { | ||||
auto cvt = args.handle->create_operator<TypeCvt>(); | auto cvt = args.handle->create_operator<TypeCvt>(); | ||||
@@ -242,7 +242,7 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( | |||||
auto bias_size_in_bytes = float_bias_layout.span().dist_byte(); | auto bias_size_in_bytes = float_bias_layout.span().dist_byte(); | ||||
megdnn_assert(args.workspace.size >= bias_size_in_bytes); | megdnn_assert(args.workspace.size >= bias_size_in_bytes); | ||||
cvt->exec( | cvt->exec( | ||||
{args.bias_tensor->raw_ptr, converted_bias_layout}, | |||||
{args.bias_tensor->raw_ptr(), converted_bias_layout}, | |||||
TensorND{workspace_ptr, float_bias_layout}); | TensorND{workspace_ptr, float_bias_layout}); | ||||
bias_ptr = workspace_ptr; | bias_ptr = workspace_ptr; | ||||
@@ -254,19 +254,19 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( | |||||
if (args.z_layout->ndim == 0) { | if (args.z_layout->ndim == 0) { | ||||
status = cudnnConvolutionBiasActivationForward( | status = cudnnConvolutionBiasActivationForward( | ||||
args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | ||||
args.src_tensor->raw_ptr, D.filter_desc.desc, | |||||
args.filter_tensor->raw_ptr, D.conv_desc.conv_desc, m_cudnn_enum, | |||||
args.src_tensor->raw_ptr(), D.filter_desc.desc, | |||||
args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum, | |||||
workspace_ptr, workspace_size, &beta, D.dst_desc.desc, | workspace_ptr, workspace_size, &beta, D.dst_desc.desc, | ||||
args.dst_tensor->raw_ptr, D.bias_desc.desc, bias_ptr, | |||||
D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr); | |||||
args.dst_tensor->raw_ptr(), D.bias_desc.desc, bias_ptr, | |||||
D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr()); | |||||
} else { | } else { | ||||
status = cudnnConvolutionBiasActivationForward( | status = cudnnConvolutionBiasActivationForward( | ||||
args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | ||||
args.src_tensor->raw_ptr, D.filter_desc.desc, | |||||
args.filter_tensor->raw_ptr, D.conv_desc.conv_desc, m_cudnn_enum, | |||||
args.src_tensor->raw_ptr(), D.filter_desc.desc, | |||||
args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum, | |||||
workspace_ptr, workspace_size, &beta, D.z_desc.desc, | workspace_ptr, workspace_size, &beta, D.z_desc.desc, | ||||
args.z_tensor->raw_ptr, D.bias_desc.desc, bias_ptr, | |||||
D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr); | |||||
args.z_tensor->raw_ptr(), D.bias_desc.desc, bias_ptr, | |||||
D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr()); | |||||
} | } | ||||
megdnn_assert( | megdnn_assert( | ||||
@@ -142,9 +142,10 @@ size_t ConvBiasForwardImpl::AlgoGroupConvGeneral::get_workspace_in_bytes( | |||||
void ConvBiasForwardImpl::AlgoGroupConvGeneral::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoGroupConvGeneral::exec(const ExecArgs& args) const { | ||||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
auto conv_dst_tensor = *args.dst_tensor; | |||||
TensorND conv_dst_tensor = *args.dst_tensor; | |||||
if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
conv_dst_tensor.raw_ptr = bundle.get(bundle.nr_workspace() - 1); | |||||
conv_dst_tensor = TensorND{ | |||||
bundle.get(bundle.nr_workspace() - 1), args.dst_tensor->layout}; | |||||
conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
@@ -156,11 +157,11 @@ void ConvBiasForwardImpl::AlgoGroupConvGeneral::exec(const ExecArgs& args) const | |||||
sub_args.dst_layout = &conv_dst_tensor.layout; | sub_args.dst_layout = &conv_dst_tensor.layout; | ||||
auto config = prepare_sub_opr(sub_args); | auto config = prepare_sub_opr(sub_args); | ||||
TensorND tsrc{args.src_tensor->raw_ptr, config.first[0]}; | |||||
TensorND tfilter{args.filter_tensor->raw_ptr, config.first[1]}; | |||||
TensorND tbias{args.bias_tensor->raw_ptr, config.first[2]}; | |||||
TensorND tz{args.z_tensor->raw_ptr, config.first[3]}; | |||||
TensorND tdst{conv_dst_tensor.raw_ptr, config.first[4]}; | |||||
TensorND tsrc{args.src_tensor->raw_ptr(), config.first[0]}; | |||||
TensorND tfilter{args.filter_tensor->raw_ptr(), config.first[1]}; | |||||
TensorND tbias{args.bias_tensor->raw_ptr(), config.first[2]}; | |||||
TensorND tz{args.z_tensor->raw_ptr(), config.first[3]}; | |||||
TensorND tdst{conv_dst_tensor.raw_ptr(), config.first[4]}; | |||||
size_t c_pos; | size_t c_pos; | ||||
if (args.filter_meta.format == Param::Format::NCHW || | if (args.filter_meta.format == Param::Format::NCHW || | ||||
@@ -187,9 +188,9 @@ void ConvBiasForwardImpl::AlgoGroupConvGeneral::exec(const ExecArgs& args) const | |||||
for (uint32_t g = 0; g < grp; ++g) { | for (uint32_t g = 0; g < grp; ++g) { | ||||
config.second->exec( | config.second->exec( | ||||
tsrc, tfilter, tbias, tz, tdst, nullptr, bundle.get_workspace(0)); | tsrc, tfilter, tbias, tz, tdst, nullptr, bundle.get_workspace(0)); | ||||
incr_voidp(tsrc.raw_ptr, strd_src); | |||||
incr_voidp(tdst.raw_ptr, strd_dst); | |||||
incr_voidp(tfilter.raw_ptr, strd_flt); | |||||
incr_refp(tsrc.get_ref_ptr(), strd_src); | |||||
incr_refp(tdst.get_ref_ptr(), strd_dst); | |||||
incr_refp(tfilter.get_ref_ptr(), strd_flt); | |||||
} | } | ||||
} | } | ||||
handle_bias_and_nonlinear( | handle_bias_and_nonlinear( | ||||
@@ -189,19 +189,19 @@ SmallVector<size_t> matmul_get_workspace_bundle(const BiasForwardSizeArgs& args) | |||||
} | } | ||||
void flip_filter( | void flip_filter( | ||||
const BiasForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr) { | |||||
const BiasForwardSizeArgs& args, const Workspace& workspace, RefPtr& ref_ptr) { | |||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
megdnn_assert(fm.group == 1 && fm.spatial_ndim == 2); | megdnn_assert(fm.group == 1 && fm.spatial_ndim == 2); | ||||
auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1]; | auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1]; | ||||
auto dtype = fm.dtype; | auto dtype = fm.dtype; | ||||
megdnn_assert(workspace.size >= dtype.size() * OC * IC * FH * FW); | megdnn_assert(workspace.size >= dtype.size() * OC * IC * FH * FW); | ||||
TensorND src{raw_ptr, {{OC, IC, FH, FW}, dtype}}, | |||||
TensorND src{{{OC, IC, FH, FW}, dtype}, ref_ptr}, | |||||
dst{workspace.raw_ptr + (FH * FW - 1) * dtype.size(), src.layout}; | dst{workspace.raw_ptr + (FH * FW - 1) * dtype.size(), src.layout}; | ||||
dst.layout.stride[2] = -dst.layout.stride[2]; | dst.layout.stride[2] = -dst.layout.stride[2]; | ||||
dst.layout.stride[3] = -dst.layout.stride[3]; | dst.layout.stride[3] = -dst.layout.stride[3]; | ||||
args.handle->relayout_opr()->exec(src, dst); | args.handle->relayout_opr()->exec(src, dst); | ||||
raw_ptr = workspace.raw_ptr; | |||||
ref_ptr.reset(workspace.raw_ptr); | |||||
} | } | ||||
} // namespace conv_bias | } // namespace conv_bias | ||||
@@ -58,7 +58,7 @@ SmallVector<size_t> matmul_get_workspace_bundle(const BiasForwardSizeArgs& args) | |||||
* change \p raw_ptr to workspace. | * change \p raw_ptr to workspace. | ||||
*/ | */ | ||||
void flip_filter( | void flip_filter( | ||||
const BiasForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr); | |||||
const BiasForwardSizeArgs& args, const Workspace& workspace, RefPtr& ref_ptr); | |||||
struct CUDNNForwardDescs { | struct CUDNNForwardDescs { | ||||
TensorDesc src_desc, dst_desc, bias_desc, z_desc; | TensorDesc src_desc, dst_desc, bias_desc, z_desc; | ||||
@@ -39,7 +39,7 @@ SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGem | |||||
void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess( | void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess( | ||||
const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | ||||
void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
reorder_filter(args, filter_ptr); | reorder_filter(args, filter_ptr); | ||||
} | } | ||||
@@ -48,12 +48,12 @@ std::tuple<void*, void*> ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm | |||||
void* filter_ptr = nullptr; | void* filter_ptr = nullptr; | ||||
if (args.preprocessed_filter) { | if (args.preprocessed_filter) { | ||||
megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | ||||
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
} else { | } else { | ||||
filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | ||||
reorder_filter(args, filter_ptr); | reorder_filter(args, filter_ptr); | ||||
} | } | ||||
void* bias_ptr = args.bias_tensor->raw_ptr; | |||||
void* bias_ptr = args.bias_tensor->raw_ptr(); | |||||
return {filter_ptr, bias_ptr}; | return {filter_ptr, bias_ptr}; | ||||
} | } | ||||
@@ -39,7 +39,7 @@ SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm: | |||||
void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::exec_preprocess( | void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::exec_preprocess( | ||||
const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | ||||
void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
reorder_filter(args, m_algo_param.access_size, filter_ptr); | reorder_filter(args, m_algo_param.access_size, filter_ptr); | ||||
} | } | ||||
@@ -48,12 +48,12 @@ std::tuple<void*, void*> ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm:: | |||||
void* filter_ptr = nullptr; | void* filter_ptr = nullptr; | ||||
if (args.preprocessed_filter) { | if (args.preprocessed_filter) { | ||||
megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | ||||
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
} else { | } else { | ||||
filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | ||||
reorder_filter(args, m_algo_param.access_size, filter_ptr); | reorder_filter(args, m_algo_param.access_size, filter_ptr); | ||||
} | } | ||||
void* bias_ptr = args.bias_tensor->raw_ptr; | |||||
void* bias_ptr = args.bias_tensor->raw_ptr(); | |||||
return {filter_ptr, bias_ptr}; | return {filter_ptr, bias_ptr}; | ||||
} | } | ||||
@@ -103,7 +103,7 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec( | |||||
std::tie(filter_ptr, bias_ptr) = prepare_filter_bias(args); | std::tie(filter_ptr, bias_ptr) = prepare_filter_bias(args); | ||||
if (args.z_layout->ndim > 0) | if (args.z_layout->ndim > 0) | ||||
z_ptr = args.z_tensor->raw_ptr; | |||||
z_ptr = args.z_tensor->raw_ptr(); | |||||
// \note these constants of cutlass epilogue will be passed to method | // \note these constants of cutlass epilogue will be passed to method | ||||
// `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | ||||
@@ -131,8 +131,8 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec( | |||||
use_conv_filter_unity_opt, without_shared_load); | use_conv_filter_unity_opt, without_shared_load); | ||||
execute_cutlass_conv_op( | execute_cutlass_conv_op( | ||||
op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, z_ptr, | |||||
args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, | |||||
op, args.src_tensor->raw_ptr(), filter_ptr, bias_ptr, z_ptr, | |||||
args.dst_tensor->raw_ptr(), nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, | |||||
pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, | pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, | ||||
&dst_scale, stream, &src_zero); | &dst_scale, stream, &src_zero); | ||||
@@ -159,7 +159,7 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::reorder_filter( | |||||
// filter: KCRS64 => CRSK64 and reorder oc | // filter: KCRS64 => CRSK64 and reorder oc | ||||
cutlass_wrapper::reorder_ncxhwx_imma_filter<4, 64>( | cutlass_wrapper::reorder_ncxhwx_imma_filter<4, 64>( | ||||
reinterpret_cast<int8_t*>(reordered_filter), | reinterpret_cast<int8_t*>(reordered_filter), | ||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, fw, | |||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr()), co, ci, fh, fw, | |||||
true, stream); | true, stream); | ||||
} | } | ||||
#endif | #endif | ||||
@@ -115,7 +115,7 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( | |||||
std::tie(filter_ptr, bias_ptr) = prepare_filter_bias(args); | std::tie(filter_ptr, bias_ptr) = prepare_filter_bias(args); | ||||
if (args.z_layout->ndim > 0) | if (args.z_layout->ndim > 0) | ||||
z_ptr = args.z_tensor->raw_ptr; | |||||
z_ptr = args.z_tensor->raw_ptr(); | |||||
// \note these constants of cutlass epilogue will be passed to method | // \note these constants of cutlass epilogue will be passed to method | ||||
// `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | ||||
@@ -151,8 +151,8 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( | |||||
use_conv_filter_unity_opt, without_shared_load); | use_conv_filter_unity_opt, without_shared_load); | ||||
execute_cutlass_conv_op( | execute_cutlass_conv_op( | ||||
op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, z_ptr, | |||||
args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, | |||||
op, args.src_tensor->raw_ptr(), filter_ptr, bias_ptr, z_ptr, | |||||
args.dst_tensor->raw_ptr(), nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, | |||||
pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, | pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, | ||||
&dst_scale, stream, &src_zero); | &dst_scale, stream, &src_zero); | ||||
@@ -188,7 +188,7 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( | |||||
cutlass_wrapper::reorder_nhwc_imma_filter<4>( | cutlass_wrapper::reorder_nhwc_imma_filter<4>( | ||||
reinterpret_cast<int8_t*>(reordered_filter), | reinterpret_cast<int8_t*>(reordered_filter), | ||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, fw, | |||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr()), co, ci, fh, fw, | |||||
trans_oc, alignbits, oc_iterleaved, stream); | trans_oc, alignbits, oc_iterleaved, stream); | ||||
} | } | ||||
#endif | #endif | ||||
@@ -158,18 +158,15 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::exec( | |||||
UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); | UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); | ||||
// reorder filter | // reorder filter | ||||
{ | { | ||||
TensorLayout in = *(args.filter_layout); | |||||
TensorLayout out = {{ci / 16, 4, fh, fw, co, 4}, in.dtype}; | |||||
TensorLayout out = { | |||||
{ci / 16, 4, fh, fw, co, 4}, args.filter_tensor->layout.dtype}; | |||||
out.stride[0] = 16 * co * fh * fw; | out.stride[0] = 16 * co * fh * fw; | ||||
out.stride[1] = 4; | out.stride[1] = 4; | ||||
out.stride[2] = fw * co * 16; | out.stride[2] = fw * co * 16; | ||||
out.stride[3] = co * 16; | out.stride[3] = co * 16; | ||||
out.stride[4] = 16; | out.stride[4] = 16; | ||||
out.stride[5] = 1; | out.stride[5] = 1; | ||||
TensorND ts_in, ts_out; | |||||
ts_in.layout = in, ts_out.layout = out; | |||||
ts_in.raw_ptr = args.filter_tensor->raw_ptr, | |||||
ts_out.raw_ptr = args.workspace.raw_ptr; | |||||
TensorND ts_in = *args.filter_tensor, ts_out{args.workspace.raw_ptr, out}; | |||||
args.opr->handle()->create_operator<RelayoutForward>()->exec(ts_in, ts_out); | args.opr->handle()->create_operator<RelayoutForward>()->exec(ts_in, ts_out); | ||||
} | } | ||||
@@ -160,18 +160,15 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::exec( | |||||
UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); | UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); | ||||
// reorder filter | // reorder filter | ||||
{ | { | ||||
TensorLayout in = *(args.filter_layout); | |||||
TensorLayout out = {{ci / 16, 4, fh, fw, co, 4}, in.dtype}; | |||||
TensorLayout out = { | |||||
{ci / 16, 4, fh, fw, co, 4}, args.filter_tensor->layout.dtype}; | |||||
out.stride[0] = 16 * co * fh * fw; | out.stride[0] = 16 * co * fh * fw; | ||||
out.stride[1] = 4; | out.stride[1] = 4; | ||||
out.stride[2] = fw * co * 16; | out.stride[2] = fw * co * 16; | ||||
out.stride[3] = co * 16; | out.stride[3] = co * 16; | ||||
out.stride[4] = 16; | out.stride[4] = 16; | ||||
out.stride[5] = 1; | out.stride[5] = 1; | ||||
TensorND ts_in, ts_out; | |||||
ts_in.layout = in, ts_out.layout = out; | |||||
ts_in.raw_ptr = args.filter_tensor->raw_ptr, | |||||
ts_out.raw_ptr = args.workspace.raw_ptr; | |||||
TensorND ts_in = *args.filter_tensor, ts_out{args.workspace.raw_ptr, out}; | |||||
args.opr->handle()->create_operator<RelayoutForward>()->exec(ts_in, ts_out); | args.opr->handle()->create_operator<RelayoutForward>()->exec(ts_in, ts_out); | ||||
} | } | ||||
@@ -125,11 +125,11 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr); | filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr); | ||||
// filter: KCRS32 => CRSK32 and reorder oc | // filter: KCRS32 => CRSK32 and reorder oc | ||||
cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | ||||
filter_ptr, reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, | |||||
ci, fh, fw, trans_oc, stream); | |||||
filter_ptr, reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr()), | |||||
co, ci, fh, fw, trans_oc, stream); | |||||
} else { | } else { | ||||
filter_ptr = | |||||
reinterpret_cast<int8_t*>(args.preprocessed_filter->tensors[0].raw_ptr); | |||||
filter_ptr = reinterpret_cast<int8_t*>( | |||||
args.preprocessed_filter->tensors[0].raw_ptr()); | |||||
} | } | ||||
float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
@@ -157,9 +157,9 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
use_conv_filter_unity_opt, without_shared_load); | use_conv_filter_unity_opt, without_shared_load); | ||||
execute_cutlass_conv_op( | execute_cutlass_conv_op( | ||||
op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr, | |||||
z_dev_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, ho, | |||||
wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, | |||||
op, args.src_tensor->raw_ptr(), filter_ptr, args.bias_tensor->raw_ptr(), | |||||
z_dev_ptr, args.dst_tensor->raw_ptr(), nullptr, n, hi, wi, ci, co, fh, fw, | |||||
ho, wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, | |||||
&threshold, &dst_scale, stream); | &threshold, &dst_scale, stream); | ||||
after_kernel_launch(); | after_kernel_launch(); | ||||
@@ -204,8 +204,8 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec_preprocess( | |||||
cudaStream_t stream = cuda_stream(args.opr->handle()); | cudaStream_t stream = cuda_stream(args.opr->handle()); | ||||
// filter: KCRS32 => CRSK32 and reorder oc | // filter: KCRS32 => CRSK32 and reorder oc | ||||
cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | ||||
reinterpret_cast<int8_t*>(args.preprocessed_filter->tensors[0].raw_ptr), | |||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, fw, | |||||
reinterpret_cast<int8_t*>(args.preprocessed_filter->tensors[0].raw_ptr()), | |||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr()), co, ci, fh, fw, | |||||
trans_oc, stream); | trans_oc, stream); | ||||
} | } | ||||
#endif | #endif | ||||
@@ -155,16 +155,13 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
src.init_contiguous_stride(); | src.init_contiguous_stride(); | ||||
TensorLayout dst = src; | TensorLayout dst = src; | ||||
dst.stride[0] = 1, dst.stride[1] = dst[0]; | dst.stride[0] = 1, dst.stride[1] = dst[0]; | ||||
TensorND ts_src, ts_dst; | |||||
ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||||
ts_src.layout = src; | |||||
ts_dst.raw_ptr = args.workspace.raw_ptr; | |||||
ts_dst.layout = dst; | |||||
TensorND ts_src{args.filter_tensor->raw_ptr(), src}, | |||||
ts_dst{args.workspace.raw_ptr, dst}; | |||||
auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
transpose->exec(ts_src, ts_dst); | transpose->exec(ts_src, ts_dst); | ||||
} else { | } else { | ||||
filter_ptr = | |||||
reinterpret_cast<int8_t*>(args.preprocessed_filter->tensors[0].raw_ptr); | |||||
filter_ptr = reinterpret_cast<int8_t*>( | |||||
args.preprocessed_filter->tensors[0].raw_ptr()); | |||||
} | } | ||||
float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
@@ -190,7 +187,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
float delta = 0.f; | float delta = 0.f; | ||||
void* z_ptr = nullptr; | void* z_ptr = nullptr; | ||||
if (args.z_layout->ndim > 0) { | if (args.z_layout->ndim > 0) { | ||||
z_ptr = args.z_tensor->raw_ptr; | |||||
z_ptr = args.z_tensor->raw_ptr(); | |||||
gamma = 1.f; | gamma = 1.f; | ||||
if (args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { | if (args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { | ||||
megdnn_assert( | megdnn_assert( | ||||
@@ -213,10 +210,10 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
use_conv_filter_unity_opt, without_shared_load); | use_conv_filter_unity_opt, without_shared_load); | ||||
execute_cutlass_conv_op( | execute_cutlass_conv_op( | ||||
op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr, z_ptr, | |||||
args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, | |||||
pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, | |||||
&dst_scale, stream); | |||||
op, args.src_tensor->raw_ptr(), filter_ptr, args.bias_tensor->raw_ptr(), | |||||
z_ptr, args.dst_tensor->raw_ptr(), nullptr, n, hi, wi, ci, co, fh, fw, ho, | |||||
wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, | |||||
&threshold, &dst_scale, stream); | |||||
after_kernel_launch(); | after_kernel_launch(); | ||||
} | } | ||||
@@ -261,11 +258,8 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec_preprocess( | |||||
src.init_contiguous_stride(); | src.init_contiguous_stride(); | ||||
TensorLayout dst = src; | TensorLayout dst = src; | ||||
dst.stride[0] = 1, dst.stride[1] = dst[0]; | dst.stride[0] = 1, dst.stride[1] = dst[0]; | ||||
TensorND ts_src, ts_dst; | |||||
ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||||
ts_src.layout = src; | |||||
ts_dst.raw_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
ts_dst.layout = dst; | |||||
TensorND ts_src{args.filter_tensor->raw_ptr(), src}, | |||||
ts_dst{args.preprocessed_filter->tensors[0].raw_ptr(), dst}; | |||||
auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
transpose->exec(ts_src, ts_dst); | transpose->exec(ts_src, ts_dst); | ||||
} | } | ||||
@@ -96,11 +96,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::exec( | |||||
src.init_contiguous_stride(); | src.init_contiguous_stride(); | ||||
TensorLayout dst = src; | TensorLayout dst = src; | ||||
dst.stride[0] = 1, dst.stride[1] = dst[0]; | dst.stride[0] = 1, dst.stride[1] = dst[0]; | ||||
TensorND ts_src, ts_dst; | |||||
ts_src.raw_ptr = args.src_tensor->raw_ptr; | |||||
ts_src.layout = src; | |||||
ts_dst.raw_ptr = ws_src; | |||||
ts_dst.layout = dst; | |||||
TensorND ts_src{args.src_tensor->raw_ptr(), src}, ts_dst{ws_src, dst}; | |||||
auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
transpose->exec(ts_src, ts_dst); | transpose->exec(ts_src, ts_dst); | ||||
} | } | ||||
@@ -111,11 +107,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::exec( | |||||
src.init_contiguous_stride(); | src.init_contiguous_stride(); | ||||
TensorLayout dst = src; | TensorLayout dst = src; | ||||
dst.stride[0] = 1, dst.stride[1] = dst[0]; | dst.stride[0] = 1, dst.stride[1] = dst[0]; | ||||
TensorND ts_src, ts_dst; | |||||
ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||||
ts_src.layout = src; | |||||
ts_dst.raw_ptr = ws_filter; | |||||
ts_dst.layout = dst; | |||||
TensorND ts_src{args.filter_tensor->raw_ptr(), src}, ts_dst{ws_filter, dst}; | |||||
auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
transpose->exec(ts_src, ts_dst); | transpose->exec(ts_src, ts_dst); | ||||
} | } | ||||
@@ -142,11 +134,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::exec( | |||||
src.init_contiguous_stride(); | src.init_contiguous_stride(); | ||||
TensorLayout dst = src; | TensorLayout dst = src; | ||||
dst.stride[0] = 1, dst.stride[1] = dst[0]; | dst.stride[0] = 1, dst.stride[1] = dst[0]; | ||||
TensorND ts_src, ts_dst; | |||||
ts_src.raw_ptr = args.z_tensor->raw_ptr; | |||||
ts_src.layout = src; | |||||
ts_dst.raw_ptr = ws_z; | |||||
ts_dst.layout = dst; | |||||
TensorND ts_src{args.z_tensor->raw_ptr(), src}, ts_dst{ws_z, dst}; | |||||
auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
transpose->exec(ts_src, ts_dst); | transpose->exec(ts_src, ts_dst); | ||||
z_dev_ptr = reinterpret_cast<int8_t*>(ws_z); | z_dev_ptr = reinterpret_cast<int8_t*>(ws_z); | ||||
@@ -168,11 +156,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::exec( | |||||
src.init_contiguous_stride(); | src.init_contiguous_stride(); | ||||
TensorLayout dst = src; | TensorLayout dst = src; | ||||
dst.stride[0] = 1, dst.stride[1] = dst[0]; | dst.stride[0] = 1, dst.stride[1] = dst[0]; | ||||
TensorND ts_src, ts_dst; | |||||
ts_src.raw_ptr = ws_dst; | |||||
ts_src.layout = src; | |||||
ts_dst.raw_ptr = args.dst_tensor->raw_ptr; | |||||
ts_dst.layout = dst; | |||||
TensorND ts_src{ws_dst, src}, ts_dst{args.dst_tensor->raw_ptr(), dst}; | |||||
auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
transpose->exec(ts_src, ts_dst); | transpose->exec(ts_src, ts_dst); | ||||
} | } | ||||
@@ -114,7 +114,7 @@ SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm:: | |||||
void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::exec_preprocess( | void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::exec_preprocess( | ||||
const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
reorder_filter(args, m_algo_param.access_size, filter_ptr); | reorder_filter(args, m_algo_param.access_size, filter_ptr); | ||||
} | } | ||||
@@ -189,15 +189,15 @@ void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::exec( | |||||
void* z_ptr = nullptr; | void* z_ptr = nullptr; | ||||
if (args.preprocessed_filter) { | if (args.preprocessed_filter) { | ||||
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
} else { | } else { | ||||
filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | ||||
reorder_filter(args, m_algo_param.access_size, filter_ptr); | reorder_filter(args, m_algo_param.access_size, filter_ptr); | ||||
} | } | ||||
bias_ptr = args.bias_tensor->raw_ptr; | |||||
bias_ptr = args.bias_tensor->raw_ptr(); | |||||
if (args.z_layout->ndim > 0) | if (args.z_layout->ndim > 0) | ||||
z_ptr = args.z_tensor->raw_ptr; | |||||
z_ptr = args.z_tensor->raw_ptr(); | |||||
// \note these constants of cutlass epilogue will be passed to method | // \note these constants of cutlass epilogue will be passed to method | ||||
// `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | ||||
@@ -233,8 +233,8 @@ void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::exec( | |||||
use_conv_filter_unity_opt, without_shared_load); | use_conv_filter_unity_opt, without_shared_load); | ||||
execute_cutlass_conv_op( | execute_cutlass_conv_op( | ||||
op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, z_ptr, | |||||
args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, | |||||
op, args.src_tensor->raw_ptr(), filter_ptr, bias_ptr, z_ptr, | |||||
args.dst_tensor->raw_ptr(), nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, | |||||
pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, | pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, | ||||
&dst_scale, stream); | &dst_scale, stream); | ||||
@@ -272,7 +272,7 @@ void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::reorder_filter( | |||||
cutlass_wrapper::reorder_nhwc_imma_filter<8>( | cutlass_wrapper::reorder_nhwc_imma_filter<8>( | ||||
reinterpret_cast<int8_t*>(reordered_filter), | reinterpret_cast<int8_t*>(reordered_filter), | ||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, fw, | |||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr()), co, ci, fh, fw, | |||||
trans_oc, alignbits, oc_iterleaved, stream); | trans_oc, alignbits, oc_iterleaved, stream); | ||||
} | } | ||||
#endif | #endif | ||||
@@ -52,8 +52,8 @@ SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGe | |||||
void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess( | void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess( | ||||
const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | ||||
void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
void* bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr; | |||||
void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
void* bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr(); | |||||
void* reduce_filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | void* reduce_filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | ||||
void* reduce_workspace = reinterpret_cast<void*>( | void* reduce_workspace = reinterpret_cast<void*>( | ||||
args.workspace.raw_ptr + args.bias_layout->span().dist_byte()); | args.workspace.raw_ptr + args.bias_layout->span().dist_byte()); | ||||
@@ -67,8 +67,8 @@ std::tuple<void*, void*> ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGem | |||||
void* bias_ptr = nullptr; | void* bias_ptr = nullptr; | ||||
if (args.preprocessed_filter) { | if (args.preprocessed_filter) { | ||||
megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | ||||
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr; | |||||
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr(); | |||||
return {filter_ptr, bias_ptr}; | return {filter_ptr, bias_ptr}; | ||||
} else { | } else { | ||||
filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | ||||
@@ -130,7 +130,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( | |||||
int src_zero_point = | int src_zero_point = | ||||
args.src_tensor->layout.dtype.param<dtype::Quantized4Asymm>().zero_point; | args.src_tensor->layout.dtype.param<dtype::Quantized4Asymm>().zero_point; | ||||
do_dispatch_reduce_filter_and_update_bias_4bit<true>( | do_dispatch_reduce_filter_and_update_bias_4bit<true>( | ||||
reinterpret_cast<uint8_t*>(args.filter_tensor->raw_ptr), | |||||
reinterpret_cast<uint8_t*>(args.filter_tensor->raw_ptr()), | |||||
args.bias_tensor->compatible_ptr<int32_t>(), co, ci * fh * fw / 8, | args.bias_tensor->compatible_ptr<int32_t>(), co, ci * fh * fw / 8, | ||||
reinterpret_cast<int32_t*>(updated_bias), | reinterpret_cast<int32_t*>(updated_bias), | ||||
reinterpret_cast<int32_t*>(reduce_workspace), src_zero_point, stream); | reinterpret_cast<int32_t*>(reduce_workspace), src_zero_point, stream); | ||||
@@ -52,8 +52,8 @@ SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm | |||||
void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::exec_preprocess( | void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::exec_preprocess( | ||||
const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | ||||
void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
void* bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr; | |||||
void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
void* bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr(); | |||||
void* reduce_filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | void* reduce_filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | ||||
void* reduce_workspace = reinterpret_cast<void*>( | void* reduce_workspace = reinterpret_cast<void*>( | ||||
args.workspace.raw_ptr + args.bias_layout->span().dist_byte()); | args.workspace.raw_ptr + args.bias_layout->span().dist_byte()); | ||||
@@ -67,8 +67,8 @@ std::tuple<void*, void*> ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm: | |||||
void* bias_ptr = nullptr; | void* bias_ptr = nullptr; | ||||
if (args.preprocessed_filter) { | if (args.preprocessed_filter) { | ||||
megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | ||||
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr; | |||||
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr(); | |||||
return {filter_ptr, bias_ptr}; | return {filter_ptr, bias_ptr}; | ||||
} else { | } else { | ||||
filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | ||||
@@ -146,7 +146,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::update_bias( | |||||
int src_zero_point = | int src_zero_point = | ||||
args.src_tensor->layout.dtype.param<dtype::Quantized4Asymm>().zero_point; | args.src_tensor->layout.dtype.param<dtype::Quantized4Asymm>().zero_point; | ||||
do_dispatch_reduce_filter_and_update_bias_4bit<true>( | do_dispatch_reduce_filter_and_update_bias_4bit<true>( | ||||
reinterpret_cast<uint8_t*>(args.filter_tensor->raw_ptr), | |||||
reinterpret_cast<uint8_t*>(args.filter_tensor->raw_ptr()), | |||||
args.bias_tensor->compatible_ptr<int32_t>(), co, ci * fh * fw / 8, | args.bias_tensor->compatible_ptr<int32_t>(), co, ci * fh * fw / 8, | ||||
reinterpret_cast<int32_t*>(updated_bias), | reinterpret_cast<int32_t*>(updated_bias), | ||||
reinterpret_cast<int32_t*>(reduce_workspace), src_zero_point, stream); | reinterpret_cast<int32_t*>(reduce_workspace), src_zero_point, stream); | ||||
@@ -40,9 +40,9 @@ size_t ConvBiasForwardImpl::AlgoInplaceMatmul::get_workspace_in_bytes( | |||||
void ConvBiasForwardImpl::AlgoInplaceMatmul::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoInplaceMatmul::exec(const ExecArgs& args) const { | ||||
WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | ||||
auto conv_dst_tensor = *args.dst_tensor; | |||||
TensorND conv_dst_tensor = *args.dst_tensor; | |||||
if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
conv_dst_tensor.raw_ptr = bundle.get(0); | |||||
conv_dst_tensor = TensorND{bundle.get(0), args.dst_tensor->layout}; | |||||
conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
@@ -115,9 +115,10 @@ size_t ConvBiasForwardImpl::AlgoMatmul::get_workspace_in_bytes( | |||||
void ConvBiasForwardImpl::AlgoMatmul::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoMatmul::exec(const ExecArgs& args) const { | ||||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
auto conv_dst_tensor = *args.dst_tensor; | |||||
TensorND conv_dst_tensor = *args.dst_tensor; | |||||
if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
conv_dst_tensor.raw_ptr = bundle.get(bundle.nr_workspace() - 1); | |||||
conv_dst_tensor = TensorND{ | |||||
bundle.get(bundle.nr_workspace() - 1), args.dst_tensor->layout}; | |||||
conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
@@ -168,7 +169,7 @@ void ConvBiasForwardImpl::AlgoMatmul::exec_internal( | |||||
C(dst_t, config.first[2]); | C(dst_t, config.first[2]); | ||||
size_t matmul_ws_idx = 2; | size_t matmul_ws_idx = 2; | ||||
if (fm.should_flip) { | if (fm.should_flip) { | ||||
conv_bias::flip_filter(args, bundle.get_workspace(2), A.raw_ptr); | |||||
conv_bias::flip_filter(args, bundle.get_workspace(2), A.get_ref_ptr()); | |||||
matmul_ws_idx = 3; | matmul_ws_idx = 3; | ||||
} | } | ||||
@@ -128,12 +128,10 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal(const ExecArgs& args) | |||||
auto bundle = get_bundle<format>(args); | auto bundle = get_bundle<format>(args); | ||||
bundle.set(args.workspace.raw_ptr); | bundle.set(args.workspace.raw_ptr); | ||||
TensorND src_tensor, dst_tensor, filter_tensor; | |||||
if (format == Param::Format::NHWC) { | |||||
src_tensor = *args.src_tensor; | |||||
dst_tensor = *args.dst_tensor; | |||||
filter_tensor = *args.filter_tensor; | |||||
} else { | |||||
TensorND src_tensor = *args.src_tensor; | |||||
TensorND dst_tensor = *args.dst_tensor; | |||||
TensorND filter_tensor = *args.filter_tensor; | |||||
if (format == Param::Format::NCHW4) { | |||||
// NCHW4 | // NCHW4 | ||||
auto to_nhwc = [](const TensorLayout& layout, void* raw_ptr) -> TensorND { | auto to_nhwc = [](const TensorLayout& layout, void* raw_ptr) -> TensorND { | ||||
return {raw_ptr, | return {raw_ptr, | ||||
@@ -147,7 +145,7 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal(const ExecArgs& args) | |||||
auto N = src.layout[0], C = src.layout[1] * 4, H = src.layout[2], | auto N = src.layout[0], C = src.layout[1] * 4, H = src.layout[2], | ||||
W = src.layout[3]; | W = src.layout[3]; | ||||
args.handle->relayout_opr()->exec( | args.handle->relayout_opr()->exec( | ||||
{src.raw_ptr, | |||||
{src.raw_ptr(), | |||||
TensorLayout{ | TensorLayout{ | ||||
{N, H, W, C / 4, 4}, | {N, H, W, C / 4, 4}, | ||||
{src.layout.stride[0], src.layout.stride[2], | {src.layout.stride[0], src.layout.stride[2], | ||||
@@ -156,8 +154,8 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal(const ExecArgs& args) | |||||
src.layout.dtype}}, | src.layout.dtype}}, | ||||
{dst_ptr, TensorLayout{{N, H, W, C / 4, 4}, src.layout.dtype}}); | {dst_ptr, TensorLayout{{N, H, W, C / 4, 4}, src.layout.dtype}}); | ||||
}; | }; | ||||
relayout(*args.src_tensor, src_tensor.raw_ptr); | |||||
relayout(*args.filter_tensor, filter_tensor.raw_ptr); | |||||
relayout(*args.src_tensor, src_tensor.raw_ptr()); | |||||
relayout(*args.filter_tensor, filter_tensor.raw_ptr()); | |||||
} | } | ||||
size_t N, IH, IW, IC; | size_t N, IH, IW, IC; | ||||
@@ -193,7 +191,7 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal(const ExecArgs& args) | |||||
// copy (OC, FH*FW*IC) to (OC, FH*FW*IC) with stride=LD | // copy (OC, FH*FW*IC) to (OC, FH*FW*IC) with stride=LD | ||||
inp1 = static_cast<int8_t*>(bundle.get(1)); | inp1 = static_cast<int8_t*>(bundle.get(1)); | ||||
cuda_check(cudaMemcpy2DAsync( | cuda_check(cudaMemcpy2DAsync( | ||||
inp1, LD * sizeof(int8_t), filter_tensor.raw_ptr, | |||||
inp1, LD * sizeof(int8_t), filter_tensor.raw_ptr(), | |||||
FH * FW * IC * sizeof(int8_t), FH * FW * IC * sizeof(int8_t), OC, | FH * FW * IC * sizeof(int8_t), FH * FW * IC * sizeof(int8_t), OC, | ||||
cudaMemcpyDeviceToDevice, stream)); | cudaMemcpyDeviceToDevice, stream)); | ||||
inp1_stride = LD; | inp1_stride = LD; | ||||
@@ -222,12 +220,13 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal(const ExecArgs& args) | |||||
void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec(const ExecArgs& args) const { | ||||
ExecArgs conv_args = args; | ExecArgs conv_args = args; | ||||
auto conv_dst_tensor = *args.dst_tensor; | |||||
TensorND conv_dst_tensor = *args.dst_tensor; | |||||
if (args.filter_meta.format == Param::Format::NHWC) { | if (args.filter_meta.format == Param::Format::NHWC) { | ||||
auto bundle = get_bundle<Param::Format::NHWC>(args); | auto bundle = get_bundle<Param::Format::NHWC>(args); | ||||
bundle.set(args.workspace.raw_ptr); | bundle.set(args.workspace.raw_ptr); | ||||
if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
conv_dst_tensor.raw_ptr = bundle.get(bundle.nr_workspace() - 1); | |||||
conv_dst_tensor = TensorND{ | |||||
bundle.get(bundle.nr_workspace() - 1), args.dst_tensor->layout}; | |||||
conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
@@ -239,7 +238,8 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec(const ExecArgs& args) const { | |||||
auto bundle = get_bundle<Param::Format::NCHW4>(args); | auto bundle = get_bundle<Param::Format::NCHW4>(args); | ||||
bundle.set(args.workspace.raw_ptr); | bundle.set(args.workspace.raw_ptr); | ||||
if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
conv_dst_tensor.raw_ptr = bundle.get(bundle.nr_workspace() - 1); | |||||
conv_dst_tensor = TensorND{ | |||||
bundle.get(bundle.nr_workspace() - 1), args.dst_tensor->layout}; | |||||
conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
@@ -131,26 +131,26 @@ void ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::exec(const ExecArgs& args) const | |||||
auto&& stream = cuda_stream(handle); | auto&& stream = cuda_stream(handle); | ||||
// zp filter | // zp filter | ||||
do_dispatch_reduce_with_scale_filter_4bit<false>( | do_dispatch_reduce_with_scale_filter_4bit<false>( | ||||
static_cast<uint8_t*>(args.filter_tensor->raw_ptr), -zp_data, OC, | |||||
static_cast<uint8_t*>(args.filter_tensor->raw_ptr()), -zp_data, OC, | |||||
FH * FW * IC / 8, ws_zp_filter.ptr<int32_t>(), stream); | FH * FW * IC / 8, ws_zp_filter.ptr<int32_t>(), stream); | ||||
// zp data | // zp data | ||||
do_dispatch_reduce_with_scale_data_u4( | do_dispatch_reduce_with_scale_data_u4( | ||||
ws_zp_data.ptr<int32_t>(), static_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||||
N, IH, IW, OH, OW, PH, PW, FH, FW, SH, SW, IC, -zp_filter, | |||||
static_cast<uint8_t>(zp_data), stream); | |||||
ws_zp_data.ptr<int32_t>(), | |||||
static_cast<uint8_t*>(args.src_tensor->raw_ptr()), N, IH, IW, OH, OW, PH, | |||||
PW, FH, FW, SH, SW, IC, -zp_filter, static_cast<uint8_t>(zp_data), stream); | |||||
// do conv | // do conv | ||||
if (use_kernel_fhxfw(args)) { | if (use_kernel_fhxfw(args)) { | ||||
wmma_conv_integer_subbyte::_do_wmma_conv_integer_subbyte_fhxfw( | wmma_conv_integer_subbyte::_do_wmma_conv_integer_subbyte_fhxfw( | ||||
static_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||||
static_cast<uint8_t*>(args.filter_tensor->raw_ptr), | |||||
static_cast<uint8_t*>(args.src_tensor->raw_ptr()), | |||||
static_cast<uint8_t*>(args.filter_tensor->raw_ptr()), | |||||
args.dst_tensor->compatible_ptr<int32_t>(), N, IH, IW, OH, OW, PH, PW, | args.dst_tensor->compatible_ptr<int32_t>(), N, IH, IW, OH, OW, PH, PW, | ||||
IC, OC, FH, FW, SH, SW, static_cast<uint8_t>(zp_data), stream); | IC, OC, FH, FW, SH, SW, static_cast<uint8_t>(zp_data), stream); | ||||
} else { | } else { | ||||
auto&& ws_relayout_filter = ws_bundle.get_workspace(2); | auto&& ws_relayout_filter = ws_bundle.get_workspace(2); | ||||
wmma_conv_integer_subbyte::_do_wmma_conv_integer_subbyte_1xfw( | wmma_conv_integer_subbyte::_do_wmma_conv_integer_subbyte_1xfw( | ||||
static_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||||
static_cast<uint8_t*>(args.filter_tensor->raw_ptr), | |||||
static_cast<uint8_t*>(args.src_tensor->raw_ptr()), | |||||
static_cast<uint8_t*>(args.filter_tensor->raw_ptr()), | |||||
args.dst_tensor->compatible_ptr<int32_t>(), | args.dst_tensor->compatible_ptr<int32_t>(), | ||||
ws_relayout_filter.ptr<uint8_t>(), N, IH, IW, OH, OW, PH, PW, IC, OC, | ws_relayout_filter.ptr<uint8_t>(), N, IH, IW, OH, OW, PH, PW, IC, OC, | ||||
FH, FW, SH, SW, static_cast<uint8_t>(zp_data), stream); | FH, FW, SH, SW, static_cast<uint8_t>(zp_data), stream); | ||||
@@ -60,9 +60,9 @@ void ConvolutionBackwardDataImpl::AlgoChanwise::exec(const ExecArgs& args) const | |||||
#if CUDA_VERSION >= 9000 | #if CUDA_VERSION >= 9000 | ||||
if (is_compute_capability_required(5, 3)) { | if (is_compute_capability_required(5, 3)) { | ||||
return chanwise::run_bwd_data( | return chanwise::run_bwd_data( | ||||
static_cast<__half*>(args.grad_tensor->raw_ptr), | |||||
static_cast<__half*>(args.diff_tensor->raw_ptr), | |||||
static_cast<__half*>(args.filter_tensor->raw_ptr), kparam, | |||||
static_cast<__half*>(args.grad_tensor->raw_ptr()), | |||||
static_cast<__half*>(args.diff_tensor->raw_ptr()), | |||||
static_cast<__half*>(args.filter_tensor->raw_ptr()), kparam, | |||||
stream); | stream); | ||||
} else { | } else { | ||||
return chanwise::run_bwd_data( | return chanwise::run_bwd_data( | ||||
@@ -68,9 +68,9 @@ void ConvolutionBackwardDataImpl::AlgoChanwiseSmall::exec(const ExecArgs& args) | |||||
#if CUDA_VERSION >= 9000 | #if CUDA_VERSION >= 9000 | ||||
case DTypeEnum::Float16: | case DTypeEnum::Float16: | ||||
return chanwise::run_bwd_data_small( | return chanwise::run_bwd_data_small( | ||||
static_cast<half*>(args.grad_tensor->raw_ptr), | |||||
static_cast<half*>(args.diff_tensor->raw_ptr), | |||||
static_cast<half*>(args.filter_tensor->raw_ptr), kparam, stream); | |||||
static_cast<half*>(args.grad_tensor->raw_ptr()), | |||||
static_cast<half*>(args.diff_tensor->raw_ptr()), | |||||
static_cast<half*>(args.filter_tensor->raw_ptr()), kparam, stream); | |||||
#endif | #endif | ||||
default: | default: | ||||
break; | break; | ||||
@@ -71,9 +71,10 @@ void ConvolutionBackwardDataImpl::AlgoCUDNN::exec(const ExecArgs& args) const { | |||||
float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
auto status = cudnnConvolutionBackwardData( | auto status = cudnnConvolutionBackwardData( | ||||
args.handle->cudnn_handle(), &alpha, D.filter_desc.desc, | args.handle->cudnn_handle(), &alpha, D.filter_desc.desc, | ||||
args.filter_tensor->raw_ptr, D.diff_desc.desc, args.diff_tensor->raw_ptr, | |||||
D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, | |||||
&beta, D.grad_desc.desc, args.grad_tensor->raw_ptr); | |||||
args.filter_tensor->raw_ptr(), D.diff_desc.desc, | |||||
args.diff_tensor->raw_ptr(), D.conv_desc.desc, m_cudnn_enum, | |||||
args.workspace.raw_ptr, args.workspace.size, &beta, D.grad_desc.desc, | |||||
args.grad_tensor->raw_ptr()); | |||||
megdnn_assert( | megdnn_assert( | ||||
status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | ||||
cudnnGetErrorString(status), args.to_string().c_str()); | cudnnGetErrorString(status), args.to_string().c_str()); | ||||
@@ -103,9 +103,9 @@ void ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::exec( | |||||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
{ | { | ||||
auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
TensorND tfilter{args.filter_tensor->raw_ptr, config.first[0]}; | |||||
TensorND tdiff{args.diff_tensor->raw_ptr, config.first[1]}; | |||||
TensorND tgrad{args.grad_tensor->raw_ptr, config.first[2]}; | |||||
TensorND tfilter{args.filter_tensor->raw_ptr(), config.first[0]}; | |||||
TensorND tdiff{args.diff_tensor->raw_ptr(), config.first[1]}; | |||||
TensorND tgrad{args.grad_tensor->raw_ptr(), config.first[2]}; | |||||
size_t c_pos = 1; | size_t c_pos = 1; | ||||
@@ -121,9 +121,9 @@ void ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::exec( | |||||
auto grp = args.filter_meta.group; | auto grp = args.filter_meta.group; | ||||
for (uint32_t g = 0; g < grp; ++g) { | for (uint32_t g = 0; g < grp; ++g) { | ||||
config.second->exec(tfilter, tdiff, tgrad, bundle.get_workspace(0)); | config.second->exec(tfilter, tdiff, tgrad, bundle.get_workspace(0)); | ||||
incr_voidp(tfilter.raw_ptr, strd_flt); | |||||
incr_voidp(tdiff.raw_ptr, strd_diff); | |||||
incr_voidp(tgrad.raw_ptr, strd_grad); | |||||
incr_refp(tfilter.get_ref_ptr(), strd_flt); | |||||
incr_refp(tdiff.get_ref_ptr(), strd_diff); | |||||
incr_refp(tgrad.get_ref_ptr(), strd_grad); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -140,7 +140,8 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | |||||
auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
relayout->exec( | relayout->exec( | ||||
{args.filter_tensor->raw_ptr, exec_src}, {inner_filter_ptr, exec_dst}); | |||||
{args.filter_tensor->raw_ptr(), exec_src}, | |||||
{inner_filter_ptr, exec_dst}); | |||||
} | } | ||||
{ | { | ||||
inner_diff_ptr = reinterpret_cast<int8_t*>(bundle.get(1)); | inner_diff_ptr = reinterpret_cast<int8_t*>(bundle.get(1)); | ||||
@@ -152,7 +153,7 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | |||||
auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
relayout->exec( | relayout->exec( | ||||
{args.diff_tensor->raw_ptr, exec_src}, {inner_diff_ptr, exec_dst}); | |||||
{args.diff_tensor->raw_ptr(), exec_src}, {inner_diff_ptr, exec_dst}); | |||||
} | } | ||||
int8_t* inner_grad_ptr = reinterpret_cast<int8_t*>(bundle.get(2)); | int8_t* inner_grad_ptr = reinterpret_cast<int8_t*>(bundle.get(2)); | ||||
@@ -196,7 +197,7 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | |||||
auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
relayout->exec( | relayout->exec( | ||||
{inner_grad_ptr, exec_src}, {args.grad_tensor->raw_ptr, exec_dst}); | |||||
{inner_grad_ptr, exec_src}, {args.grad_tensor->raw_ptr(), exec_dst}); | |||||
} | } | ||||
} | } | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -143,7 +143,7 @@ void ConvolutionBackwardDataImpl::AlgoMatmul::exec_internal(const ExecArgs& args | |||||
TensorND A(args.filter_tensor->ptr<T>(), Al), B(col, Bl), C(diff_t, Cl); | TensorND A(args.filter_tensor->ptr<T>(), Al), B(col, Bl), C(diff_t, Cl); | ||||
if (fm.should_flip) { | if (fm.should_flip) { | ||||
convolution::flip_filter( | convolution::flip_filter( | ||||
args.as_fwd_args(), wbundle.get_workspace(2), A.raw_ptr); | |||||
args.as_fwd_args(), wbundle.get_workspace(2), A.get_ref_ptr()); | |||||
config.second->exec(A, C, B, wbundle.get_workspace(3)); | config.second->exec(A, C, B, wbundle.get_workspace(3)); | ||||
} else { | } else { | ||||
config.second->exec(A, C, B, wbundle.get_workspace(2)); | config.second->exec(A, C, B, wbundle.get_workspace(2)); | ||||
@@ -50,9 +50,9 @@ void ConvolutionBackwardFilterImpl::AlgoChanwise::exec(const ExecArgs& args) con | |||||
#if CUDA_VERSION >= 9000 | #if CUDA_VERSION >= 9000 | ||||
if (is_compute_capability_required(5, 3)) { | if (is_compute_capability_required(5, 3)) { | ||||
return chanwise::run_bwd_filter( | return chanwise::run_bwd_filter( | ||||
static_cast<__half*>(args.grad_tensor->raw_ptr), | |||||
static_cast<__half*>(args.src_tensor->raw_ptr), | |||||
static_cast<__half*>(args.diff_tensor->raw_ptr), kparam, | |||||
static_cast<__half*>(args.grad_tensor->raw_ptr()), | |||||
static_cast<__half*>(args.src_tensor->raw_ptr()), | |||||
static_cast<__half*>(args.diff_tensor->raw_ptr()), kparam, | |||||
stream); | stream); | ||||
} else { | } else { | ||||
return chanwise::run_bwd_filter( | return chanwise::run_bwd_filter( | ||||
@@ -71,9 +71,9 @@ void ConvolutionBackwardFilterImpl::AlgoCUDNN::exec(const ExecArgs& args) const | |||||
float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
auto status = cudnnConvolutionBackwardFilter( | auto status = cudnnConvolutionBackwardFilter( | ||||
args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | ||||
args.src_tensor->raw_ptr, D.diff_desc.desc, args.diff_tensor->raw_ptr, | |||||
args.src_tensor->raw_ptr(), D.diff_desc.desc, args.diff_tensor->raw_ptr(), | |||||
D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, | D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, | ||||
&beta, D.grad_desc.desc, args.grad_tensor->raw_ptr); | |||||
&beta, D.grad_desc.desc, args.grad_tensor->raw_ptr()); | |||||
megdnn_assert( | megdnn_assert( | ||||
status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | ||||
cudnnGetErrorString(status), args.to_string().c_str()); | cudnnGetErrorString(status), args.to_string().c_str()); | ||||
@@ -101,9 +101,9 @@ void ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::exec( | |||||
{ | { | ||||
auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
TensorND tsrc{args.src_tensor->raw_ptr, config.first[0]}; | |||||
TensorND tdiff{args.diff_tensor->raw_ptr, config.first[1]}; | |||||
TensorND tgrad{args.grad_tensor->raw_ptr, config.first[2]}; | |||||
TensorND tsrc{args.src_tensor->raw_ptr(), config.first[0]}; | |||||
TensorND tdiff{args.diff_tensor->raw_ptr(), config.first[1]}; | |||||
TensorND tgrad{args.grad_tensor->raw_ptr(), config.first[2]}; | |||||
size_t c_pos = 1; | size_t c_pos = 1; | ||||
@@ -118,9 +118,9 @@ void ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::exec( | |||||
auto grp = fm.group; | auto grp = fm.group; | ||||
for (uint32_t g = 0; g < grp; ++g) { | for (uint32_t g = 0; g < grp; ++g) { | ||||
config.second->exec(tsrc, tdiff, tgrad, bundle.get_workspace(0)); | config.second->exec(tsrc, tdiff, tgrad, bundle.get_workspace(0)); | ||||
incr_voidp(tsrc.raw_ptr, strd_src); | |||||
incr_voidp(tdiff.raw_ptr, strd_diff); | |||||
incr_voidp(tgrad.raw_ptr, strd_grad); | |||||
incr_refp(tsrc.get_ref_ptr(), strd_src); | |||||
incr_refp(tdiff.get_ref_ptr(), strd_diff); | |||||
incr_refp(tgrad.get_ref_ptr(), strd_grad); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -133,7 +133,7 @@ void ConvolutionBackwardFilterImpl::AlgoMatmul::exec_internal(const ExecArgs& ar | |||||
froml.stride[0] = args.diff_layout->stride[0]; | froml.stride[0] = args.diff_layout->stride[0]; | ||||
tol.stride[0] = 1; | tol.stride[0] = 1; | ||||
tol.stride[1] = N; | tol.stride[1] = N; | ||||
TensorND from(args.diff_tensor->ptr<T>(), froml), to(diff_t, tol); | |||||
TensorND from(args.diff_tensor->raw_ptr(), froml), to(diff_t, tol); | |||||
args.handle->relayout_opr()->exec(from, to); | args.handle->relayout_opr()->exec(from, to); | ||||
} | } | ||||
{ | { | ||||
@@ -149,13 +149,13 @@ void ConvolutionBackwardFilterImpl::AlgoMatmul::exec_internal(const ExecArgs& ar | |||||
Cl({OC, OH * OW * N}, typename DTypeTrait<T>::dtype()); | Cl({OC, OH * OW * N}, typename DTypeTrait<T>::dtype()); | ||||
TensorND A(args.grad_tensor->ptr<T>(), Al), B(col, Bl), C(diff_t, Cl); | TensorND A(args.grad_tensor->ptr<T>(), Al), B(col, Bl), C(diff_t, Cl); | ||||
if (fm.should_flip) { | if (fm.should_flip) { | ||||
A.raw_ptr = wbundle.get(2); | |||||
A.reset_ptr(wbundle.get(2)); | |||||
config.second->exec(C, B, A, wbundle.get_workspace(3)); | config.second->exec(C, B, A, wbundle.get_workspace(3)); | ||||
convolution::flip_filter( | convolution::flip_filter( | ||||
args.as_fwd_args(), | args.as_fwd_args(), | ||||
{static_cast<dt_byte*>(args.grad_tensor->raw_ptr), | |||||
{static_cast<dt_byte*>(args.grad_tensor->raw_ptr()), | |||||
wbundle.get_size(2)}, | wbundle.get_size(2)}, | ||||
A.raw_ptr); | |||||
A.get_ref_ptr()); | |||||
} else { | } else { | ||||
config.second->exec(C, B, A, wbundle.get_workspace(2)); | config.second->exec(C, B, A, wbundle.get_workspace(2)); | ||||
} | } | ||||
@@ -68,19 +68,19 @@ SmallVector<size_t> convolution::matmul_get_workspace_bundle( | |||||
} | } | ||||
void convolution::flip_filter( | void convolution::flip_filter( | ||||
const ForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr) { | |||||
const ForwardSizeArgs& args, const Workspace& workspace, RefPtr& ref_ptr) { | |||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
megdnn_assert(fm.group == 1 && fm.spatial_ndim == 2); | megdnn_assert(fm.group == 1 && fm.spatial_ndim == 2); | ||||
auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1]; | auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1]; | ||||
auto dtype = fm.dtype; | auto dtype = fm.dtype; | ||||
megdnn_assert(workspace.size >= dtype.size() * OC * IC * FH * FW); | megdnn_assert(workspace.size >= dtype.size() * OC * IC * FH * FW); | ||||
TensorND src{raw_ptr, {{OC, IC, FH, FW}, dtype}}, | |||||
TensorND src{{{OC, IC, FH, FW}, dtype}, ref_ptr}, | |||||
dst{workspace.raw_ptr + (FH * FW - 1) * dtype.size(), src.layout}; | dst{workspace.raw_ptr + (FH * FW - 1) * dtype.size(), src.layout}; | ||||
dst.layout.stride[2] = -dst.layout.stride[2]; | dst.layout.stride[2] = -dst.layout.stride[2]; | ||||
dst.layout.stride[3] = -dst.layout.stride[3]; | dst.layout.stride[3] = -dst.layout.stride[3]; | ||||
args.handle->relayout_opr()->exec(src, dst); | args.handle->relayout_opr()->exec(src, dst); | ||||
raw_ptr = workspace.raw_ptr; | |||||
ref_ptr.reset(workspace.raw_ptr); | |||||
} | } | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -85,7 +85,7 @@ struct CUDNNBwdFilterDescs { | |||||
* change \p raw_ptr to workspace. | * change \p raw_ptr to workspace. | ||||
*/ | */ | ||||
void flip_filter( | void flip_filter( | ||||
const ForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr); | |||||
const ForwardSizeArgs& args, const Workspace& workspace, RefPtr& raw_ptr); | |||||
} // namespace convolution | } // namespace convolution | ||||
} // namespace cuda | } // namespace cuda | ||||
@@ -55,9 +55,10 @@ void Convolution3DBackwardDataImpl::AlgoCUDNN::exec(const ExecArgs& args) const | |||||
float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
auto status = cudnnConvolutionBackwardData( | auto status = cudnnConvolutionBackwardData( | ||||
args.handle->cudnn_handle(), &alpha, D.filter_desc.desc, | args.handle->cudnn_handle(), &alpha, D.filter_desc.desc, | ||||
args.filter_tensor->raw_ptr, D.diff_desc.desc, args.diff_tensor->raw_ptr, | |||||
D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, | |||||
&beta, D.grad_desc.desc, args.grad_tensor->raw_ptr); | |||||
args.filter_tensor->raw_ptr(), D.diff_desc.desc, | |||||
args.diff_tensor->raw_ptr(), D.conv_desc.desc, m_cudnn_enum, | |||||
args.workspace.raw_ptr, args.workspace.size, &beta, D.grad_desc.desc, | |||||
args.grad_tensor->raw_ptr()); | |||||
megdnn_assert( | megdnn_assert( | ||||
status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | ||||
cudnnGetErrorString(status), args.to_string().c_str()); | cudnnGetErrorString(status), args.to_string().c_str()); | ||||
@@ -96,9 +96,9 @@ void Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::exec( | |||||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
{ | { | ||||
auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
TensorND tfilter{args.filter_tensor->raw_ptr, config.first[0]}; | |||||
TensorND tdiff{args.diff_tensor->raw_ptr, config.first[1]}; | |||||
TensorND tgrad{args.grad_tensor->raw_ptr, config.first[2]}; | |||||
TensorND tfilter{args.filter_tensor->raw_ptr(), config.first[0]}; | |||||
TensorND tdiff{args.diff_tensor->raw_ptr(), config.first[1]}; | |||||
TensorND tgrad{args.grad_tensor->raw_ptr(), config.first[2]}; | |||||
size_t c_pos = 1; | size_t c_pos = 1; | ||||
auto grp = args.filter_meta.group; | auto grp = args.filter_meta.group; | ||||
@@ -114,9 +114,9 @@ void Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::exec( | |||||
for (uint32_t g = 0; g < grp; ++g) { | for (uint32_t g = 0; g < grp; ++g) { | ||||
config.second->exec(tfilter, tdiff, tgrad, bundle.get_workspace(0)); | config.second->exec(tfilter, tdiff, tgrad, bundle.get_workspace(0)); | ||||
incr_voidp(tfilter.raw_ptr, strd_flt); | |||||
incr_voidp(tdiff.raw_ptr, strd_diff); | |||||
incr_voidp(tgrad.raw_ptr, strd_grad); | |||||
incr_refp(tfilter.get_ref_ptr(), strd_flt); | |||||
incr_refp(tdiff.get_ref_ptr(), strd_diff); | |||||
incr_refp(tgrad.get_ref_ptr(), strd_grad); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -56,9 +56,9 @@ void Convolution3DBackwardFilterImpl::AlgoCUDNN::exec(const ExecArgs& args) cons | |||||
float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
auto status = cudnnConvolutionBackwardFilter( | auto status = cudnnConvolutionBackwardFilter( | ||||
args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | ||||
args.src_tensor->raw_ptr, D.diff_desc.desc, args.diff_tensor->raw_ptr, | |||||
args.src_tensor->raw_ptr(), D.diff_desc.desc, args.diff_tensor->raw_ptr(), | |||||
D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, | D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, | ||||
&beta, D.grad_desc.desc, args.grad_tensor->raw_ptr); | |||||
&beta, D.grad_desc.desc, args.grad_tensor->raw_ptr()); | |||||
megdnn_assert( | megdnn_assert( | ||||
status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | ||||
cudnnGetErrorString(status), args.to_string().c_str()); | cudnnGetErrorString(status), args.to_string().c_str()); | ||||
@@ -98,9 +98,9 @@ void Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral::exec( | |||||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
{ | { | ||||
auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
TensorND tsrc{args.src_tensor->raw_ptr, config.first[0]}; | |||||
TensorND tdiff{args.diff_tensor->raw_ptr, config.first[1]}; | |||||
TensorND tgrad{args.grad_tensor->raw_ptr, config.first[2]}; | |||||
TensorND tsrc{args.src_tensor->raw_ptr(), config.first[0]}; | |||||
TensorND tdiff{args.diff_tensor->raw_ptr(), config.first[1]}; | |||||
TensorND tgrad{args.grad_tensor->raw_ptr(), config.first[2]}; | |||||
size_t c_pos = 1; | size_t c_pos = 1; | ||||
auto grp = args.grad_filter_meta.group; | auto grp = args.grad_filter_meta.group; | ||||
@@ -116,9 +116,9 @@ void Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral::exec( | |||||
for (uint32_t g = 0; g < grp; ++g) { | for (uint32_t g = 0; g < grp; ++g) { | ||||
config.second->exec(tsrc, tdiff, tgrad, bundle.get_workspace(0)); | config.second->exec(tsrc, tdiff, tgrad, bundle.get_workspace(0)); | ||||
incr_voidp(tsrc.raw_ptr, strd_src); | |||||
incr_voidp(tdiff.raw_ptr, strd_diff); | |||||
incr_voidp(tgrad.raw_ptr, strd_grad); | |||||
incr_refp(tsrc.get_ref_ptr(), strd_src); | |||||
incr_refp(tdiff.get_ref_ptr(), strd_diff); | |||||
incr_refp(tgrad.get_ref_ptr(), strd_grad); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -54,17 +54,17 @@ size_t Convolution3DForwardImpl::Algo1x1x1::get_workspace_in_bytes( | |||||
void Convolution3DForwardImpl::Algo1x1x1::exec(const ExecArgs& args) const { | void Convolution3DForwardImpl::Algo1x1x1::exec(const ExecArgs& args) const { | ||||
TensorND A, B, C; | TensorND A, B, C; | ||||
extract_matmul_layouts(args, A.layout, B.layout, C.layout); | extract_matmul_layouts(args, A.layout, B.layout, C.layout); | ||||
A.raw_ptr = args.filter_tensor->raw_ptr; | |||||
B.raw_ptr = args.src_tensor->raw_ptr; | |||||
C.raw_ptr = args.dst_tensor->raw_ptr; | |||||
A.reset_ptr(args.filter_tensor->raw_ptr()); | |||||
B.reset_ptr(args.src_tensor->raw_ptr()); | |||||
C.reset_ptr(args.dst_tensor->raw_ptr()); | |||||
size_t batch = args.src_layout->shape[0]; | size_t batch = args.src_layout->shape[0]; | ||||
auto mm = args.handle->matmul_opr(); | auto mm = args.handle->matmul_opr(); | ||||
auto strd_B = args.src_layout->stride[0] * args.src_layout->dtype.size(), | auto strd_B = args.src_layout->stride[0] * args.src_layout->dtype.size(), | ||||
strd_C = args.dst_layout->stride[0] * args.dst_layout->dtype.size(); | strd_C = args.dst_layout->stride[0] * args.dst_layout->dtype.size(); | ||||
for (size_t i = 0; i < batch; ++i) { | for (size_t i = 0; i < batch; ++i) { | ||||
mm->exec(A, B, C, args.workspace); | mm->exec(A, B, C, args.workspace); | ||||
incr_voidp(B.raw_ptr, strd_B); | |||||
incr_voidp(C.raw_ptr, strd_C); | |||||
incr_refp(B.get_ref_ptr(), strd_B); | |||||
incr_refp(C.get_ref_ptr(), strd_C); | |||||
} | } | ||||
} | } | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -53,9 +53,10 @@ void Convolution3DForwardImpl::AlgoCUDNN::exec(const ExecArgs& args) const { | |||||
float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
auto status = cudnnConvolutionForward( | auto status = cudnnConvolutionForward( | ||||
args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | ||||
args.src_tensor->raw_ptr, D.filter_desc.desc, args.filter_tensor->raw_ptr, | |||||
D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, | |||||
&beta, D.dst_desc.desc, args.dst_tensor->raw_ptr); | |||||
args.src_tensor->raw_ptr(), D.filter_desc.desc, | |||||
args.filter_tensor->raw_ptr(), D.conv_desc.desc, m_cudnn_enum, | |||||
args.workspace.raw_ptr, args.workspace.size, &beta, D.dst_desc.desc, | |||||
args.dst_tensor->raw_ptr()); | |||||
megdnn_assert( | megdnn_assert( | ||||
status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s", | status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s", | ||||
cudnnGetErrorString(status), args.to_string().c_str()); | cudnnGetErrorString(status), args.to_string().c_str()); | ||||
@@ -103,9 +103,9 @@ void Convolution3DForwardImpl::AlgoGroupConvGeneral::exec(const ExecArgs& args) | |||||
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
{ | { | ||||
auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
TensorND tsrc{args.src_tensor->raw_ptr, config.first[0]}; | |||||
TensorND tfilter{args.filter_tensor->raw_ptr, config.first[1]}; | |||||
TensorND tdst{args.dst_tensor->raw_ptr, config.first[2]}; | |||||
TensorND tsrc{args.src_tensor->raw_ptr(), config.first[0]}; | |||||
TensorND tfilter{args.filter_tensor->raw_ptr(), config.first[1]}; | |||||
TensorND tdst{args.dst_tensor->raw_ptr(), config.first[2]}; | |||||
size_t c_pos; | size_t c_pos; | ||||
if (args.filter_meta.format == Param::Format::NCDHW) { | if (args.filter_meta.format == Param::Format::NCDHW) { | ||||
@@ -127,9 +127,9 @@ void Convolution3DForwardImpl::AlgoGroupConvGeneral::exec(const ExecArgs& args) | |||||
for (uint32_t g = 0; g < grp; ++g) { | for (uint32_t g = 0; g < grp; ++g) { | ||||
config.second->exec(tsrc, tfilter, tdst, bundle.get_workspace(0)); | config.second->exec(tsrc, tfilter, tdst, bundle.get_workspace(0)); | ||||
incr_voidp(tsrc.raw_ptr, strd_src); | |||||
incr_voidp(tdst.raw_ptr, strd_dst); | |||||
incr_voidp(tfilter.raw_ptr, strd_flt); | |||||
incr_refp(tsrc.get_ref_ptr(), strd_src); | |||||
incr_refp(tdst.get_ref_ptr(), strd_dst); | |||||
incr_refp(tfilter.get_ref_ptr(), strd_flt); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -35,20 +35,20 @@ bool convolution3d::is_cudnn_supported(const ForwardSizeArgs& args) { | |||||
} | } | ||||
void convolution3d::flip_filter( | void convolution3d::flip_filter( | ||||
const ForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr) { | |||||
const ForwardSizeArgs& args, const Workspace& workspace, RefPtr& ref_ptr) { | |||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
megdnn_assert(fm.group == 1 && fm.spatial_ndim == 3); | megdnn_assert(fm.group == 1 && fm.spatial_ndim == 3); | ||||
auto OC = fm.ocpg, IC = fm.icpg, FD = fm.spatial[0], FH = fm.spatial[1], | auto OC = fm.ocpg, IC = fm.icpg, FD = fm.spatial[0], FH = fm.spatial[1], | ||||
FW = fm.spatial[2]; | FW = fm.spatial[2]; | ||||
auto dtype = DType::from_enum(fm.dtype_enum); | auto dtype = DType::from_enum(fm.dtype_enum); | ||||
megdnn_assert(workspace.size >= dtype.size() * OC * IC * FD * FH * FW); | megdnn_assert(workspace.size >= dtype.size() * OC * IC * FD * FH * FW); | ||||
TensorND src{raw_ptr, {{OC, IC, FD, FH, FW}, dtype}}, | |||||
TensorND src{{{OC, IC, FD, FH, FW}, dtype}, ref_ptr}, | |||||
dst{workspace.raw_ptr + (FD * FH * FW - 1) * dtype.size(), src.layout}; | dst{workspace.raw_ptr + (FD * FH * FW - 1) * dtype.size(), src.layout}; | ||||
dst.layout.stride[2] = -dst.layout.stride[2]; | dst.layout.stride[2] = -dst.layout.stride[2]; | ||||
dst.layout.stride[3] = -dst.layout.stride[3]; | dst.layout.stride[3] = -dst.layout.stride[3]; | ||||
dst.layout.stride[4] = -dst.layout.stride[4]; | dst.layout.stride[4] = -dst.layout.stride[4]; | ||||
args.handle->relayout_opr()->exec(src, dst); | args.handle->relayout_opr()->exec(src, dst); | ||||
raw_ptr = workspace.raw_ptr; | |||||
ref_ptr.reset(workspace.raw_ptr); | |||||
} | } | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -84,7 +84,7 @@ struct CUDNNBwdFilterDescs { | |||||
* change \p raw_ptr to workspace. | * change \p raw_ptr to workspace. | ||||
*/ | */ | ||||
void flip_filter( | void flip_filter( | ||||
const ForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr); | |||||
const ForwardSizeArgs& args, const Workspace& workspace, RefPtr& raw_ptr); | |||||
inline bool cudnn_get_convolution_fwd_algo_helper( | inline bool cudnn_get_convolution_fwd_algo_helper( | ||||
cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc, | cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc, | ||||
@@ -169,10 +169,10 @@ void ConvPoolingForwardImpl::exec( | |||||
nonlineMode = IDENTITY; | nonlineMode = IDENTITY; | ||||
} | } | ||||
float *src_ptr = static_cast<float*>(src.raw_ptr), | |||||
*filter_ptr = static_cast<float*>(filter.raw_ptr), | |||||
*bias_ptr = static_cast<float*>(bias.raw_ptr), | |||||
*dst_ptr = static_cast<float*>(dst.raw_ptr); | |||||
float *src_ptr = static_cast<float*>(src.raw_ptr()), | |||||
*filter_ptr = static_cast<float*>(filter.raw_ptr()), | |||||
*bias_ptr = static_cast<float*>(bias.raw_ptr()), | |||||
*dst_ptr = static_cast<float*>(dst.raw_ptr()); | |||||
switch (this->param().method) { | switch (this->param().method) { | ||||
case Param::Method::WITH_SHARED_MEM: | case Param::Method::WITH_SHARED_MEM: | ||||
@@ -12,7 +12,7 @@ | |||||
#include "./opr_impl.h" | #include "./opr_impl.h" | ||||
#include "./kern.cuh" | #include "./kern.cuh" | ||||
#include "src/common/reduce_helper.h" | |||||
#include "src/common/reduce_helper_device.h" | |||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -58,7 +58,7 @@ void DctChannelSelectForwardImpl::exec( | |||||
megdnn_assert( | megdnn_assert( | ||||
param().format == Param::Format::NCHW4, "qint8 only support nchw4"); | param().format == Param::Format::NCHW4, "qint8 only support nchw4"); | ||||
dct::call_kern_dct<dct_block, dct::DctLayoutFormat::NCHW4>( | dct::call_kern_dct<dct_block, dct::DctLayoutFormat::NCHW4>( | ||||
src.ptr<uint8_t>(), (int8_t*)dst.raw_ptr, in, ic, ih, iw, oc, | |||||
src.ptr<uint8_t>(), (int8_t*)dst.raw_ptr(), in, ic, ih, iw, oc, | |||||
with_fix_32_mask, mask_offset_ptr, mask_val_ptr, stream, error_info, | with_fix_32_mask, mask_offset_ptr, mask_val_ptr, stream, error_info, | ||||
m_error_tracker, | m_error_tracker, | ||||
dst.layout.dtype.param<::megdnn::dtype::QuantizedS8>().scale); | dst.layout.dtype.param<::megdnn::dtype::QuantizedS8>().scale); | ||||
@@ -227,7 +227,7 @@ INST(dt_quint8); | |||||
template <int ndim> | template <int ndim> | ||||
void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init( | void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init( | ||||
const TensorND& rv, int /*grid_size*/, int /*block_size*/) { | const TensorND& rv, int /*grid_size*/, int /*block_size*/) { | ||||
m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr); | |||||
m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr()); | |||||
ptrdiff_t min_stride = std::numeric_limits<ptrdiff_t>::max(); | ptrdiff_t min_stride = std::numeric_limits<ptrdiff_t>::max(); | ||||
for (size_t i = 0; i < rv.layout.ndim; ++i) { | for (size_t i = 0; i < rv.layout.ndim; ++i) { | ||||
m_stride[i] = rv.layout.stride[i]; | m_stride[i] = rv.layout.stride[i]; | ||||
@@ -21,31 +21,31 @@ using namespace megdnn; | |||||
using namespace cuda; | using namespace cuda; | ||||
void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32( | void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32( | ||||
const ElemwiseOpParamN<3>& param, dt_int32* dst) { | |||||
const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||||
BroadcastChannelInfo binfo0, binfo1; | BroadcastChannelInfo binfo0, binfo1; | ||||
if (is_vector(param[0].layout) && | if (is_vector(param[0].layout) && | ||||
is_broadcasted_channel_like(param[1].layout, binfo0) && | is_broadcasted_channel_like(param[1].layout, binfo0) && | ||||
is_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) { | is_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) { | ||||
elemwise_multi_type::fma3_int16x32x32x32_1c1( | elemwise_multi_type::fma3_int16x32x32x32_1c1( | ||||
param, dst, cuda_stream(this->handle())); | |||||
param, dst.ptr<dt_int32>(), cuda_stream(this->handle())); | |||||
return; | return; | ||||
} | } | ||||
megdnn_throw("unsupported fma3 int16x32x32x32 layout"); | megdnn_throw("unsupported fma3 int16x32x32x32 layout"); | ||||
} | } | ||||
void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8( | void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8( | ||||
const ElemwiseOpParamN<3>& param, dt_int8* dst) { | |||||
const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||||
Broadcast1xInfo binfo0, binfo1; | Broadcast1xInfo binfo0, binfo1; | ||||
auto p1 = param[1].ptr<float>(), p2 = param[2].ptr<float>(); | auto p1 = param[1].ptr<float>(), p2 = param[2].ptr<float>(); | ||||
auto stream = cuda_stream(this->handle()); | auto stream = cuda_stream(this->handle()); | ||||
if (is_vector(param[0].layout) && is_broadcasted_1x(param[1].layout, binfo0) && | if (is_vector(param[0].layout) && is_broadcasted_1x(param[1].layout, binfo0) && | ||||
is_broadcasted_1x(param[2].layout, binfo1) && binfo0 == binfo1) { | is_broadcasted_1x(param[2].layout, binfo1) && binfo0 == binfo1) { | ||||
switch (param[0].layout.dtype.enumv()) { | switch (param[0].layout.dtype.enumv()) { | ||||
#define cb(t) \ | |||||
case DTypeTrait<t>::enumv: \ | |||||
elemwise_multi_type::fma3_iXxf32xf32xi8_bcast_1x( \ | |||||
param[0].ptr<DTypeTrait<t>::ctype>(), p1, p2, dst, binfo0.x, binfo0.y, \ | |||||
stream); \ | |||||
#define cb(t) \ | |||||
case DTypeTrait<t>::enumv: \ | |||||
elemwise_multi_type::fma3_iXxf32xf32xi8_bcast_1x( \ | |||||
param[0].ptr<DTypeTrait<t>::ctype>(), p1, p2, dst.ptr<dt_int8>(), \ | |||||
binfo0.x, binfo0.y, stream); \ | |||||
return; | return; | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) | ||||
#undef cb | #undef cb | ||||
@@ -58,14 +58,14 @@ void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8( | |||||
} | } | ||||
void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( | void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( | ||||
const ElemwiseOpParamN<2>& param, dt_int8* dst) { | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||||
auto stream = cuda_stream(this->handle()); | auto stream = cuda_stream(this->handle()); | ||||
if (is_vector(param[0].layout) && is_broadcasted_scalar(param[1].layout)) { | if (is_vector(param[0].layout) && is_broadcasted_scalar(param[1].layout)) { | ||||
switch (param[0].layout.dtype.enumv()) { | switch (param[0].layout.dtype.enumv()) { | ||||
#define DISPATCH(t) \ | |||||
case DTypeTrait<t>::enumv: \ | |||||
elemwise_multi_type::round_shr_saturate_iXxi8xiX_scalar< \ | |||||
DTypeTrait<t>::ctype, dt_int8>(param, dst, stream); \ | |||||
#define DISPATCH(t) \ | |||||
case DTypeTrait<t>::enumv: \ | |||||
elemwise_multi_type::round_shr_saturate_iXxi8xiX_scalar< \ | |||||
DTypeTrait<t>::ctype, dt_int8>(param, dst.ptr<dt_int8>(), stream); \ | |||||
return; | return; | ||||
DISPATCH(::megdnn::dtype::Int32) | DISPATCH(::megdnn::dtype::Int32) | ||||
DISPATCH(::megdnn::dtype::Int16) | DISPATCH(::megdnn::dtype::Int16) | ||||
@@ -85,7 +85,7 @@ void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( | |||||
} | } | ||||
void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | ||||
const ElemwiseOpParamN<6>& param, dt_int8* dst) { | |||||
const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||||
auto stream = cuda_stream(this->handle()); | auto stream = cuda_stream(this->handle()); | ||||
BroadcastChannelInfo info; | BroadcastChannelInfo info; | ||||
if (is_vector(param[0].layout) && | if (is_vector(param[0].layout) && | ||||
@@ -95,7 +95,7 @@ void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | |||||
is_broadcasted_scalar(param[4].layout) && | is_broadcasted_scalar(param[4].layout) && | ||||
is_broadcasted_scalar(param[5].layout)) { | is_broadcasted_scalar(param[5].layout)) { | ||||
elemwise_multi_type::fuse_add_rmulh_round_shr_saturate_bcast_1c11<dt_int16>( | elemwise_multi_type::fuse_add_rmulh_round_shr_saturate_bcast_1c11<dt_int16>( | ||||
param, dst, stream); | |||||
param, dst.ptr<dt_int8>(), stream); | |||||
return; | return; | ||||
} | } | ||||
megdnn_throw( | megdnn_throw( | ||||
@@ -106,7 +106,7 @@ void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | |||||
} | } | ||||
void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | ||||
const ElemwiseOpParamN<6>& param, dt_int8* dst) { | |||||
const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||||
auto stream = cuda_stream(this->handle()); | auto stream = cuda_stream(this->handle()); | ||||
BroadcastChannelInfo info; | BroadcastChannelInfo info; | ||||
if (is_vector(param[0].layout) && | if (is_vector(param[0].layout) && | ||||
@@ -116,7 +116,7 @@ void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | |||||
is_broadcasted_scalar(param[4].layout) && | is_broadcasted_scalar(param[4].layout) && | ||||
is_broadcasted_scalar(param[5].layout)) { | is_broadcasted_scalar(param[5].layout)) { | ||||
elemwise_multi_type::fuse_add_rmulh_round_shr_saturate_bcast_1c11<dt_int32>( | elemwise_multi_type::fuse_add_rmulh_round_shr_saturate_bcast_1c11<dt_int32>( | ||||
param, dst, stream); | |||||
param, dst.ptr<dt_int8>(), stream); | |||||
return; | return; | ||||
} | } | ||||
megdnn_throw( | megdnn_throw( | ||||
@@ -127,14 +127,14 @@ void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | |||||
} | } | ||||
void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi16( | void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi16( | ||||
const ElemwiseOpParamN<2>& param, dt_int16* dst) { | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||||
auto stream = cuda_stream(this->handle()); | auto stream = cuda_stream(this->handle()); | ||||
if (is_vector(param[0].layout) && is_broadcasted_scalar(param[1].layout)) { | if (is_vector(param[0].layout) && is_broadcasted_scalar(param[1].layout)) { | ||||
switch (param[0].layout.dtype.enumv()) { | switch (param[0].layout.dtype.enumv()) { | ||||
#define DISPATCH(t) \ | |||||
case DTypeTrait<t>::enumv: \ | |||||
elemwise_multi_type::round_shr_saturate_iXxi8xiX_scalar< \ | |||||
DTypeTrait<t>::ctype, dt_int16>(param, dst, stream); \ | |||||
#define DISPATCH(t) \ | |||||
case DTypeTrait<t>::enumv: \ | |||||
elemwise_multi_type::round_shr_saturate_iXxi8xiX_scalar< \ | |||||
DTypeTrait<t>::ctype, dt_int16>(param, dst.ptr<dt_int16>(), stream); \ | |||||
return; | return; | ||||
DISPATCH(::megdnn::dtype::Int32) | DISPATCH(::megdnn::dtype::Int32) | ||||
DISPATCH(::megdnn::dtype::Int16) | DISPATCH(::megdnn::dtype::Int16) | ||||
@@ -227,22 +227,22 @@ IMPL_MODE_DISPATCHER(2, dt_quint4, dt_qint32); | |||||
#undef _cb_dispatch_mode | #undef _cb_dispatch_mode | ||||
#define _cb_dispatch_mode(_m) \ | |||||
case param::Elemwise::Mode::_m: \ | |||||
do { \ | |||||
using KernImpl = ElemwiseKern< \ | |||||
megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, float>; \ | |||||
using Op = kern_ops_quantized::QuantizedMultiTypeOp< \ | |||||
arity, src_ctype, dst_ctype, KernImpl>; \ | |||||
using dst_storage = typename VectTypeTrait<dst_ctype>::Storage; \ | |||||
dst_storage* dst = reinterpret_cast<dst_storage*>(dst_tensor.raw_ptr); \ | |||||
Op op(src_params, dst, dst_param); \ | |||||
ElemwiseOpParamN<1> param_dst; \ | |||||
param_dst[0] = dst_tensor; \ | |||||
param_dst.init_from_given_tensor(); \ | |||||
run_elemwise<Op, src_ctype, dst_ctype, arity>( \ | |||||
param, param_dst, stream, op); \ | |||||
return; \ | |||||
#define _cb_dispatch_mode(_m) \ | |||||
case param::Elemwise::Mode::_m: \ | |||||
do { \ | |||||
using KernImpl = ElemwiseKern< \ | |||||
megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, float>; \ | |||||
using Op = kern_ops_quantized::QuantizedMultiTypeOp< \ | |||||
arity, src_ctype, dst_ctype, KernImpl>; \ | |||||
using dst_storage = typename VectTypeTrait<dst_ctype>::Storage; \ | |||||
dst_storage* dst = reinterpret_cast<dst_storage*>(dst_tensor.raw_ptr()); \ | |||||
Op op(src_params, dst, dst_param); \ | |||||
ElemwiseOpParamN<1> param_dst; \ | |||||
param_dst[0] = dst_tensor; \ | |||||
param_dst.init_from_given_tensor(); \ | |||||
run_elemwise<Op, src_ctype, dst_ctype, arity>( \ | |||||
param, param_dst, stream, op); \ | |||||
return; \ | |||||
} while (0); | } while (0); | ||||
#define FOREACH(cb) \ | #define FOREACH(cb) \ | ||||
@@ -18,22 +18,22 @@ namespace cuda { | |||||
class ElemwiseMultiTypeImpl final : public ElemwiseMultiTypeImplHelper { | class ElemwiseMultiTypeImpl final : public ElemwiseMultiTypeImplHelper { | ||||
void on_fuse_mul_add3_int16x32x32x32( | void on_fuse_mul_add3_int16x32x32x32( | ||||
const ElemwiseOpParamN<3>& param, dt_int32* dst) override; | |||||
const ElemwiseOpParamN<3>& param, const TensorND& dst) override; | |||||
void on_fuse_mul_add3_iXxf32xf32xi8( | void on_fuse_mul_add3_iXxf32xf32xi8( | ||||
const ElemwiseOpParamN<3>& param, dt_int8* dst) override; | |||||
const ElemwiseOpParamN<3>& param, const TensorND& dst) override; | |||||
void on_round_shr_saturate_iXxi8xi8( | void on_round_shr_saturate_iXxi8xi8( | ||||
const ElemwiseOpParamN<2>& param, dt_int8* dst) override; | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst) override; | |||||
void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | ||||
const ElemwiseOpParamN<6>& param, dt_int8* dst) override; | |||||
const ElemwiseOpParamN<6>& param, const TensorND& dst) override; | |||||
void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | ||||
const ElemwiseOpParamN<6>& param, dt_int8* dst) override; | |||||
const ElemwiseOpParamN<6>& param, const TensorND& dst) override; | |||||
void on_round_shr_saturate_iXxi8xi16( | void on_round_shr_saturate_iXxi8xi16( | ||||
const ElemwiseOpParamN<2>& param, dt_int16* dst) override; | |||||
const ElemwiseOpParamN<2>& param, const TensorND& dst) override; | |||||
void on_quantized_mode( | void on_quantized_mode( | ||||
const ElemwiseOpParamN<1>& param, const TensorND& dst, | const ElemwiseOpParamN<1>& param, const TensorND& dst, | ||||
@@ -32,11 +32,6 @@ std::unique_ptr<LocalForward> get_opr(Handle* handle, param::Convolution param) | |||||
return std::move(opr); | return std::move(opr); | ||||
} | } | ||||
template <typename T> | |||||
void incr_ptr(T*& dst, ptrdiff_t delta) { | |||||
dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta); | |||||
} | |||||
TensorLayout prepare_src_dst(const TensorLayout& input, size_t g) { | TensorLayout prepare_src_dst(const TensorLayout& input, size_t g) { | ||||
TensorLayout ret = input; | TensorLayout ret = input; | ||||
megdnn_assert(ret[1] % g == 0); | megdnn_assert(ret[1] % g == 0); | ||||
@@ -84,18 +79,20 @@ void GroupLocalForwardImpl::exec( | |||||
SH, SW, stream); | SH, SW, stream); | ||||
} else { | } else { | ||||
auto&& opr = get_opr(handle, param()); | auto&& opr = get_opr(handle, param()); | ||||
TensorND src_g = {src.raw_ptr, prepare_src_dst(src.layout, G)}; | |||||
TensorND dst_g = {dst.raw_ptr, prepare_src_dst(dst.layout, G)}; | |||||
TensorND filter_g = {filter.raw_ptr, prepare_filter(filter.layout)}; | |||||
TensorND src_g = {src.raw_ptr(), prepare_src_dst(src.layout, G)}; | |||||
TensorND dst_g = {dst.raw_ptr(), prepare_src_dst(dst.layout, G)}; | |||||
TensorND filter_g = {filter.raw_ptr(), prepare_filter(filter.layout)}; | |||||
for (size_t g = 0; g < G; ++g) { | for (size_t g = 0; g < G; ++g) { | ||||
opr->exec(src_g, filter_g, dst_g, workspace); | opr->exec(src_g, filter_g, dst_g, workspace); | ||||
incr_ptr( | |||||
src_g.raw_ptr, src_g.layout.stride[1] * src_g.layout.shape[1] * | |||||
src_g.layout.dtype.size()); | |||||
incr_ptr( | |||||
dst_g.raw_ptr, dst_g.layout.stride[1] * dst_g.layout.shape[1] * | |||||
dst_g.layout.dtype.size()); | |||||
incr_ptr(filter_g.raw_ptr, filter_g.layout.span().dist_byte()); | |||||
incr_refp( | |||||
src_g.get_ref_ptr(), src_g.layout.stride[1] * | |||||
src_g.layout.shape[1] * | |||||
src_g.layout.dtype.size()); | |||||
incr_refp( | |||||
dst_g.get_ref_ptr(), dst_g.layout.stride[1] * | |||||
dst_g.layout.shape[1] * | |||||
dst_g.layout.dtype.size()); | |||||
incr_refp(filter_g.get_ref_ptr(), filter_g.layout.span().dist_byte()); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -106,7 +106,7 @@ void LocalShareBackwardDataImpl::AlgoBatchedMatMul::exec(const ExecArgs& args) c | |||||
B1.stride[4] = wo; | B1.stride[4] = wo; | ||||
B1.stride[5] = 1; | B1.stride[5] = 1; | ||||
B1.stride[6] = co * ho * wo; | B1.stride[6] = co * ho * wo; | ||||
TensorND ts_B1{args.diff_tensor->raw_ptr, B1}; | |||||
TensorND ts_B1{args.diff_tensor->raw_ptr(), B1}; | |||||
TensorLayout B2{ | TensorLayout B2{ | ||||
{groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, dtype::Float32()}; | {groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, dtype::Float32()}; | ||||
B2.init_contiguous_stride(); | B2.init_contiguous_stride(); | ||||
@@ -122,7 +122,7 @@ void LocalShareBackwardDataImpl::AlgoBatchedMatMul::exec(const ExecArgs& args) c | |||||
TensorLayout C{ | TensorLayout C{ | ||||
{groups * sgh * sgw, icpg * fh * fw, ho / sgh * wo / sgw * n}, | {groups * sgh * sgw, icpg * fh * fw, ho / sgh * wo / sgw * n}, | ||||
dtype::Float32()}; | dtype::Float32()}; | ||||
TensorND ts_A{args.filter_tensor->raw_ptr, A}; | |||||
TensorND ts_A{args.filter_tensor->raw_ptr(), A}; | |||||
TensorND ts_B{ws_pretranspose, B}; | TensorND ts_B{ws_pretranspose, B}; | ||||
TensorND ts_C{ws_col2im, C}; | TensorND ts_C{ws_col2im, C}; | ||||
Workspace ws_wrapper; | Workspace ws_wrapper; | ||||
@@ -113,7 +113,7 @@ void LocalShareBackwardFilterImpl::AlgoBatchedMatMul::exec(const ExecArgs& args) | |||||
B1.stride[4] = co * ho * wo; | B1.stride[4] = co * ho * wo; | ||||
B1.stride[5] = wo; | B1.stride[5] = wo; | ||||
B1.stride[6] = 1; | B1.stride[6] = 1; | ||||
TensorND ts_B1{args.diff_tensor->raw_ptr, B1}; | |||||
TensorND ts_B1{args.diff_tensor->raw_ptr(), B1}; | |||||
TensorLayout B2{ | TensorLayout B2{ | ||||
{groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, dtype::Float32()}; | {groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, dtype::Float32()}; | ||||
B2.init_contiguous_stride(); | B2.init_contiguous_stride(); | ||||
@@ -133,7 +133,7 @@ void LocalShareBackwardFilterImpl::AlgoBatchedMatMul::exec(const ExecArgs& args) | |||||
TensorLayout C{{groups * sgh * sgw, icpg * fh * fw, ocpg}, dtype::Float32()}; | TensorLayout C{{groups * sgh * sgw, icpg * fh * fw, ocpg}, dtype::Float32()}; | ||||
TensorND ts_A{ws_im2col, A}; | TensorND ts_A{ws_im2col, A}; | ||||
TensorND ts_B{ws_pretranspose, B}; | TensorND ts_B{ws_pretranspose, B}; | ||||
TensorND ts_C{args.grad_tensor->raw_ptr, C}; | |||||
TensorND ts_C{args.grad_tensor->raw_ptr(), C}; | |||||
Workspace ws_wrapper; | Workspace ws_wrapper; | ||||
ws_wrapper.raw_ptr = reinterpret_cast<dt_byte*>(ws_matmul); | ws_wrapper.raw_ptr = reinterpret_cast<dt_byte*>(ws_matmul); | ||||
ws_wrapper.size = ws.get_size(2); | ws_wrapper.size = ws.get_size(2); | ||||
@@ -100,7 +100,7 @@ void LocalShareForwardImpl::AlgoBatchedMatMul::exec(const ExecArgs& args) const | |||||
TensorLayout C{ | TensorLayout C{ | ||||
{groups * sgh * sgw, ho / sgh * wo / sgw * n, ocpg}, dtype::Float32()}; | {groups * sgh * sgw, ho / sgh * wo / sgw * n, ocpg}, dtype::Float32()}; | ||||
TensorND ts_A{ws_im2col, A}; | TensorND ts_A{ws_im2col, A}; | ||||
TensorND ts_B{args.filter_tensor->raw_ptr, B}; | |||||
TensorND ts_B{args.filter_tensor->raw_ptr(), B}; | |||||
TensorND ts_C{ws_posttranspose, C}; | TensorND ts_C{ws_posttranspose, C}; | ||||
Workspace ws_wrapper; | Workspace ws_wrapper; | ||||
ws_wrapper.raw_ptr = reinterpret_cast<dt_byte*>(ws_matmul); | ws_wrapper.raw_ptr = reinterpret_cast<dt_byte*>(ws_matmul); | ||||
@@ -119,7 +119,7 @@ void LocalShareForwardImpl::AlgoBatchedMatMul::exec(const ExecArgs& args) const | |||||
C1.stride[6] = ocpg; | C1.stride[6] = ocpg; | ||||
TensorLayout C2 = args.dst_layout; | TensorLayout C2 = args.dst_layout; | ||||
TensorND ts_C1{ws_posttranspose, C1}; | TensorND ts_C1{ws_posttranspose, C1}; | ||||
TensorND ts_C2{args.dst_tensor->raw_ptr, C2}; | |||||
TensorND ts_C2{args.dst_tensor->raw_ptr(), C2}; | |||||
auto&& relayout_opr = args.opr->handle()->create_operator<Relayout>(); | auto&& relayout_opr = args.opr->handle()->create_operator<Relayout>(); | ||||
relayout_opr->exec(ts_C1, ts_C2); | relayout_opr->exec(ts_C1, ts_C2); | ||||
} | } | ||||
@@ -29,7 +29,7 @@ void LRNForwardImpl::exec( | |||||
float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
cudnn_check(cudnnLRNCrossChannelForward( | cudnn_check(cudnnLRNCrossChannelForward( | ||||
handle, lrn_desc.desc, CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, src_desc.desc, | handle, lrn_desc.desc, CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, src_desc.desc, | ||||
src.raw_ptr, &beta, dst_desc.desc, dst.raw_ptr)); | |||||
src.raw_ptr(), &beta, dst_desc.desc, dst.raw_ptr())); | |||||
} | } | ||||
void LRNBackwardImpl::setup_descs( | void LRNBackwardImpl::setup_descs( | ||||
@@ -51,8 +51,8 @@ void LRNBackwardImpl::exec( | |||||
float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
cudnn_check(cudnnLRNCrossChannelBackward( | cudnn_check(cudnnLRNCrossChannelBackward( | ||||
handle, lrn_desc.desc, CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dst_desc.desc, | handle, lrn_desc.desc, CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dst_desc.desc, | ||||
dst.raw_ptr, diff_desc.desc, diff.raw_ptr, src_desc.desc, src.raw_ptr, | |||||
&beta, grad_desc.desc, grad.raw_ptr)); | |||||
dst.raw_ptr(), diff_desc.desc, diff.raw_ptr(), src_desc.desc, src.raw_ptr(), | |||||
&beta, grad_desc.desc, grad.raw_ptr())); | |||||
} | } | ||||
} // namespace cuda | } // namespace cuda | ||||
@@ -37,11 +37,11 @@ void MatrixInverseImpl::exec( | |||||
auto stream = handle->stream(); | auto stream = handle->stream(); | ||||
batched_matrix_mul::arange<uintptr_t>( | batched_matrix_mul::arange<uintptr_t>( | ||||
reinterpret_cast<uintptr_t*>(psrc_batch), | reinterpret_cast<uintptr_t*>(psrc_batch), | ||||
reinterpret_cast<uintptr_t>(src.raw_ptr), n * n * sizeof(float), batch, | |||||
reinterpret_cast<uintptr_t>(src.raw_ptr()), n * n * sizeof(float), batch, | |||||
stream); | stream); | ||||
batched_matrix_mul::arange<uintptr_t>( | batched_matrix_mul::arange<uintptr_t>( | ||||
reinterpret_cast<uintptr_t*>(pdst_batch), | reinterpret_cast<uintptr_t*>(pdst_batch), | ||||
reinterpret_cast<uintptr_t>(dst.raw_ptr), n * n * sizeof(float), batch, | |||||
reinterpret_cast<uintptr_t>(dst.raw_ptr()), n * n * sizeof(float), batch, | |||||
stream); | stream); | ||||
cublas_check(cublasSmatinvBatched( | cublas_check(cublasSmatinvBatched( | ||||
handle->cublas_handle(), n, psrc_batch, n, pdst_batch, n, info, batch)); | handle->cublas_handle(), n, psrc_batch, n, pdst_batch, n, info, batch)); | ||||