GitOrigin-RevId: e37fbe0ffe
tags/v0.3.2
@@ -213,11 +213,17 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoNaive::dispatch_kerns( | |||||
const NCBKernParam& param, | const NCBKernParam& param, | ||||
const NCBKernIndex& ncb_index) { | const NCBKernIndex& ncb_index) { | ||||
MIDOUT_BEGIN(megdnn_fallback_naive, 2) { | MIDOUT_BEGIN(megdnn_fallback_naive, 2) { | ||||
size_t group_id = ncb_index.ndrange_id[0]; | |||||
size_t batch_id = ncb_index.ndrange_id[1]; | |||||
size_t thread_id = ncb_index.thread_id; | size_t thread_id = ncb_index.thread_id; | ||||
auto thread_param = param; | auto thread_param = param; | ||||
thread_param.workspace_ptr = reinterpret_cast<void*>( | thread_param.workspace_ptr = reinterpret_cast<void*>( | ||||
reinterpret_cast<ptrdiff_t>(param.workspace_ptr) + | reinterpret_cast<ptrdiff_t>(param.workspace_ptr) + | ||||
thread_id * workspace_per_thread); | thread_id * workspace_per_thread); | ||||
thread_param.filter_ptr = param.filter<void>(group_id); | |||||
thread_param.dst_ptr = param.dst<void>(batch_id, group_id); | |||||
thread_param.src_ptr = param.src<void>(batch_id, group_id); | |||||
thread_param.bias_ptr = param.bias<void>(batch_id, group_id); | |||||
kern_default(opr_param, thread_param); | kern_default(opr_param, thread_param); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -111,7 +111,6 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||||
size_t channel_id = ncb_index.ndrange_id[2]; | size_t channel_id = ncb_index.ndrange_id[2]; | ||||
size_t padding_group_size = IH2 * IW2 * IC; | size_t padding_group_size = IH2 * IW2 * IC; | ||||
size_t input_channel_offset = IH * IW * channel_id; | |||||
size_t workspace_channel_offset = IH2 * IW2 * channel_id; | size_t workspace_channel_offset = IH2 * IW2 * channel_id; | ||||
size_t workspace_group_offset = group_id * padding_group_size; | size_t workspace_group_offset = group_id * padding_group_size; | ||||
size_t workspace_batch_offset = | size_t workspace_batch_offset = | ||||
@@ -123,7 +122,7 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||||
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | ||||
} | } | ||||
src_ctype* src = const_cast<src_ctype*>( | src_ctype* src = const_cast<src_ctype*>( | ||||
param.src<src_ctype>(batch_id, group_id) + input_channel_offset); | |||||
param.src<src_ctype>(batch_id, group_id, channel_id)); | |||||
src_ctype* src2; | src_ctype* src2; | ||||
src2 = static_cast<src_ctype*>( | src2 = static_cast<src_ctype*>( | ||||
bundle.get(Im2colBundelIndex::BUNDLE_PADDING_INDEX)) + | bundle.get(Im2colBundelIndex::BUNDLE_PADDING_INDEX)) + | ||||
@@ -246,10 +246,9 @@ void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param, | |||||
ConvBiasImpl::Algorithm* algo) { | ConvBiasImpl::Algorithm* algo) { | ||||
auto ncb_kerns = ncb_algo_dispatch_kerns(algo, param); | auto ncb_kerns = ncb_algo_dispatch_kerns(algo, param); | ||||
for (auto&& kernel : ncb_kerns) { | for (auto&& kernel : ncb_kerns) { | ||||
auto run = [=](size_t index, size_t thread_id) { | |||||
auto copy_param = param; | |||||
auto run = [kernel, param](size_t index, size_t thread_id) { | |||||
CpuNDRange ndrange_id(kernel.global_size, index); | CpuNDRange ndrange_id(kernel.global_size, index); | ||||
kernel.kern(copy_param, {thread_id, ndrange_id}); | |||||
kernel.kern(param, {thread_id, ndrange_id}); | |||||
}; | }; | ||||
static_cast<naive::HandleImpl*>(handle())->dispatch_kern( | static_cast<naive::HandleImpl*>(handle())->dispatch_kern( | ||||
run, kernel.global_size.total_size()); | run, kernel.global_size.total_size()); | ||||
@@ -328,28 +327,29 @@ const char* ConvBiasImpl::get_algorithm_set_name() const { | |||||
namespace megdnn{ | namespace megdnn{ | ||||
namespace fallback { | namespace fallback { | ||||
//! when format is nchwxx and channel wise mode, multi group will pack | |||||
//! together, so pack_group_size is the number of packed group | |||||
template <typename T> | template <typename T> | ||||
const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_id, | |||||
size_t group_pack_size) const { | |||||
src_type.assert_is_compatible_ctype<T>(); | |||||
const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_pack_id, | |||||
size_t channel_pack_id, | |||||
size_t group_pack_size, | |||||
size_t channel_pack_size) const { | |||||
size_t batch_offset = batch_id * inp_bs * src_type.size(); | size_t batch_offset = batch_id * inp_bs * src_type.size(); | ||||
size_t group_offset = group_pack_size * group_id * filter_meta.icpg * | |||||
size_t group_offset = group_pack_size * group_pack_id * filter_meta.icpg * | |||||
isz[0] * isz[1] * src_type.size(); | isz[0] * isz[1] * src_type.size(); | ||||
size_t channel_offset = channel_pack_size * channel_pack_id * isz[0] * | |||||
isz[1] * src_type.size(); | |||||
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(src_ptr) + | return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(src_ptr) + | ||||
batch_offset + group_offset); | |||||
batch_offset + group_offset + channel_offset); | |||||
} | } | ||||
//! when format is nchwxx and channel wise mode, multi group will pack | |||||
//! together, so pack_group_size is the number of packed group | |||||
template <typename T> | template <typename T> | ||||
const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id, | |||||
const T* ConvBiasImpl::NCBKernParam::filter(size_t group_pack_id, | |||||
size_t pack_group_size) const { | size_t pack_group_size) const { | ||||
size_t group_offset = 0_z; | size_t group_offset = 0_z; | ||||
switch (filter_meta.format) { | switch (filter_meta.format) { | ||||
case Param::Format::NCHW: { | case Param::Format::NCHW: { | ||||
group_offset = pack_group_size * group_id * filter_meta.icpg * | |||||
group_offset = pack_group_size * group_pack_id * filter_meta.icpg * | |||||
filter_meta.ocpg * filter_meta.spatial[0] * | filter_meta.ocpg * filter_meta.spatial[0] * | ||||
filter_meta.spatial[1] * filter_type.size(); | filter_meta.spatial[1] * filter_type.size(); | ||||
break; | break; | ||||
@@ -359,15 +359,15 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id, | |||||
size_t icpg = filter_meta.icpg; | size_t icpg = filter_meta.icpg; | ||||
size_t ocpg = filter_meta.ocpg; | size_t ocpg = filter_meta.ocpg; | ||||
//! four format of weight layout | //! four format of weight layout | ||||
//! 1. {oc/8, ic/8, fh, fw, 8, 8}, 2. {g, oc/8, ic/8, fh, | |||||
//! fw, 8, 8} | |||||
//! 3. {g/8, 1, 1, fh, fw, 8, 8}, 3. {oc/8 ,fh, fw, ic, 8} | |||||
//! 1. {oc/8, ic/8, fh, fw, 8, 8}, | |||||
//! 2. {g, oc/8, ic/8, fh, fw, 8, 8}, | |||||
//! 3. {g/8, fh, fw, 1, 1, 8}, 4. {oc/8, fh, fw, ic, 8} | |||||
megdnn_assert((icpg % 8 == 0 && ocpg % 8 == 0) || | megdnn_assert((icpg % 8 == 0 && ocpg % 8 == 0) || | ||||
(group % 8 == 0 && icpg == 1 && ocpg == 1 && | (group % 8 == 0 && icpg == 1 && ocpg == 1 && | ||||
pack_group_size > 1) || | pack_group_size > 1) || | ||||
(group == 1 && ocpg % 8 == 0), | (group == 1 && ocpg % 8 == 0), | ||||
"The filter shepe is not right of nchw88"); | "The filter shepe is not right of nchw88"); | ||||
group_offset = pack_group_size * group_id * filter_meta.icpg * | |||||
group_offset = pack_group_size * group_pack_id * filter_meta.icpg * | |||||
filter_meta.ocpg * filter_meta.spatial[0] * | filter_meta.ocpg * filter_meta.spatial[0] * | ||||
filter_meta.spatial[1] * filter_type.size(); | filter_meta.spatial[1] * filter_type.size(); | ||||
@@ -380,7 +380,7 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id, | |||||
//! 2. {alpha, alpha, ocpg/8, icpg/8, 8, 8} | //! 2. {alpha, alpha, ocpg/8, icpg/8, 8, 8} | ||||
//! 3. {g, alpha, alpha, oc, ic, 8, 8} | //! 3. {g, alpha, alpha, oc, ic, 8, 8} | ||||
//! 4. {alpha, alpha, oc, ic} | //! 4. {alpha, alpha, oc, ic} | ||||
group_offset = pack_group_size * group_id * filter_meta.icpg * | |||||
group_offset = pack_group_size * group_pack_id * filter_meta.icpg * | |||||
filter_meta.ocpg * | filter_meta.ocpg * | ||||
(filter_meta.spatial[0] + output_block_size - 1) * | (filter_meta.spatial[0] + output_block_size - 1) * | ||||
(filter_meta.spatial[1] + output_block_size - 1) * | (filter_meta.spatial[1] + output_block_size - 1) * | ||||
@@ -388,58 +388,66 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id, | |||||
break; | break; | ||||
} | } | ||||
default: | default: | ||||
megdnn_assert("other filter format is not support yet"); | |||||
megdnn_assert(0, "other filter format is not support yet"); | |||||
} | } | ||||
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(filter_ptr) + | return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(filter_ptr) + | ||||
group_offset); | group_offset); | ||||
} | } | ||||
//! when format is nchwxx and channel wise mode, multi group will pack | |||||
//! together, so pack_group_size is the number of packed group | |||||
template <typename T> | template <typename T> | ||||
const T* ConvBiasImpl::NCBKernParam::bias(size_t batch_id, size_t group_id, | |||||
size_t group_pack_size) const { | |||||
bias_type.assert_is_compatible_ctype<T>(); | |||||
const T* ConvBiasImpl::NCBKernParam::bias(size_t batch_id, size_t group_pack_id, | |||||
size_t channel_pack_id, | |||||
size_t group_pack_size, | |||||
size_t channel_pack_size) const { | |||||
size_t batch_offset = 0_z; | size_t batch_offset = 0_z; | ||||
size_t group_offset = 0_z; | size_t group_offset = 0_z; | ||||
size_t channel_offset = 0_z; | |||||
if (bias_mode == BiasMode::BIAS) { | if (bias_mode == BiasMode::BIAS) { | ||||
batch_offset = batch_id * bias_bs * bias_type.size(); | batch_offset = batch_id * bias_bs * bias_type.size(); | ||||
group_offset = group_pack_size * group_id * filter_meta.ocpg * osz[0] * | |||||
osz[1] * bias_type.size(); | |||||
group_offset = group_pack_size * group_pack_id * filter_meta.ocpg * | |||||
osz[0] * osz[1] * bias_type.size(); | |||||
channel_offset = channel_pack_size * channel_pack_id * osz[0] * osz[1] * | |||||
bias_type.size(); | |||||
} else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
group_offset = group_pack_size * group_id * filter_meta.ocpg * | |||||
group_offset = group_pack_size * group_pack_id * filter_meta.ocpg * | |||||
bias_type.size(); | bias_type.size(); | ||||
channel_offset = channel_pack_size * channel_pack_id * bias_type.size(); | |||||
} | } | ||||
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(bias_ptr) + | return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(bias_ptr) + | ||||
batch_offset + group_offset); | |||||
batch_offset + group_offset + channel_offset); | |||||
} | } | ||||
//! when format is nchwxx and channel wise mode, multi group will pack | |||||
//! together, so pack_group_size is the number of packed group | |||||
template <typename T> | template <typename T> | ||||
T* ConvBiasImpl::NCBKernParam::dst(size_t batch_id, size_t group_id, | |||||
size_t group_pack_size) const { | |||||
dst_type.assert_is_compatible_ctype<T>(); | |||||
T* ConvBiasImpl::NCBKernParam::dst(size_t batch_id, size_t group_pack_id, | |||||
size_t channel_pack_id, | |||||
size_t group_pack_size, | |||||
size_t channel_pack_size) const { | |||||
size_t batch_offset = batch_id * out_bs * dst_type.size(); | size_t batch_offset = batch_id * out_bs * dst_type.size(); | ||||
size_t group_offset = group_pack_size * group_id * filter_meta.ocpg * | |||||
size_t group_offset = group_pack_size * group_pack_id * filter_meta.ocpg * | |||||
osz[0] * osz[1] * dst_type.size(); | osz[0] * osz[1] * dst_type.size(); | ||||
size_t channel_offset = channel_pack_size * channel_pack_id * osz[0] * | |||||
osz[1] * dst_type.size(); | |||||
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(dst_ptr) + | return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(dst_ptr) + | ||||
batch_offset + group_offset); | |||||
batch_offset + group_offset + channel_offset); | |||||
} | } | ||||
#define INST(T) \ | |||||
template const T* ConvBiasImpl::NCBKernParam::src<T>( \ | |||||
size_t batch_id, size_t group_id, size_t group_pack_size) const; \ | |||||
template const T* ConvBiasImpl::NCBKernParam::bias<T>( \ | |||||
size_t batch_id, size_t group_id, size_t group_pack_size) const; \ | |||||
template const T* ConvBiasImpl::NCBKernParam::filter<T>( \ | |||||
size_t group_id, size_t group_pack_size) const; \ | |||||
template T* ConvBiasImpl::NCBKernParam::dst<T>( \ | |||||
size_t batch_id, size_t group_id, size_t group_pack_size) const; | |||||
#define INST(T) \ | |||||
template const T* ConvBiasImpl::NCBKernParam::src<T>( \ | |||||
size_t batch_id, size_t group_id, size_t channel_id, \ | |||||
size_t group_pack_size, size_t channel_pack_size) const; \ | |||||
template const T* ConvBiasImpl::NCBKernParam::bias<T>( \ | |||||
size_t batch_id, size_t group_id, size_t channel_id, \ | |||||
size_t group_pack_size, size_t channel_pack_size) const; \ | |||||
template const T* ConvBiasImpl::NCBKernParam::filter<T>( \ | |||||
size_t group_id, size_t group_pack_size) const; \ | |||||
template T* ConvBiasImpl::NCBKernParam::dst<T>( \ | |||||
size_t batch_id, size_t group_id, size_t channel_id, \ | |||||
size_t group_pack_size, size_t channel_pack_size) const; | |||||
#define INST_DT(d) INST(DTypeTrait<d>::ctype) | #define INST_DT(d) INST(DTypeTrait<d>::ctype) | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(INST_DT) | MEGDNN_FOREACH_COMPUTING_DTYPE(INST_DT) | ||||
INST(void) | |||||
#undef INST | #undef INST | ||||
#undef INST_DT | #undef INST_DT | ||||
} // namespace fallback | } // namespace fallback | ||||
@@ -103,10 +103,32 @@ public: | |||||
src_type.assert_is_compatible_ctype<T>(); | src_type.assert_is_compatible_ctype<T>(); | ||||
return static_cast<const T*>(src_ptr); | return static_cast<const T*>(src_ptr); | ||||
} | } | ||||
//! when format is nchwxx, multi channel will pack into one | |||||
//! chnannel_pack_id. pack_channel_size is the number of packed channel | |||||
//! when format is nchwxx and channel wise, multi group will pack into | |||||
//! one group_pack_id. group_pack_size is the number of packed group | |||||
//! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8} | |||||
template <typename T> | |||||
const T* src(size_t batch_id, size_t group_pack_id, | |||||
size_t channel_pack_id = 0, size_t group_pack_size = 1, | |||||
size_t channel_pack_size = 1) const; | |||||
template <typename T> | |||||
const T* bias(size_t batch_id, size_t group_pack_id, | |||||
size_t channel_pack_id = 0, size_t group_pack_size = 1, | |||||
size_t channel_pack_size = 1) const; | |||||
template <typename T> | template <typename T> | ||||
const T* src(size_t batch_id, size_t group_id, | |||||
size_t group_pack_size = 1_z) const; | |||||
T* dst(size_t batch_id, size_t group_pack_id, | |||||
size_t channel_pack_id = 0, size_t group_pack_size = 1, | |||||
size_t channel_pack_size = 1) const; | |||||
//! when format is nchwxx and channel wise, multi group will pack into | |||||
//! one group_pack_id. group_pack_size is the number of packed group | |||||
//! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8} | |||||
template <typename T> | |||||
const T* filter(size_t group_pack_id, | |||||
size_t pack_group_size = 1_z) const; | |||||
template <typename T> | template <typename T> | ||||
const T* filter() const { | const T* filter() const { | ||||
@@ -115,29 +137,18 @@ public: | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
const T* filter(size_t group_id, size_t pack_group_size = 1_z) const; | |||||
template <typename T> | |||||
const T* bias() const { | const T* bias() const { | ||||
bias_type.assert_is_compatible_ctype<T>(); | bias_type.assert_is_compatible_ctype<T>(); | ||||
return static_cast<const T*>(bias_ptr); | return static_cast<const T*>(bias_ptr); | ||||
} | } | ||||
template <typename T> | template <typename T> | ||||
const T* bias(size_t batch_id, size_t group_id, | |||||
size_t group_pack_size = 1_z) const; | |||||
template <typename T> | |||||
T* dst() const { | T* dst() const { | ||||
dst_type.assert_is_compatible_ctype<T>(); | dst_type.assert_is_compatible_ctype<T>(); | ||||
return static_cast<T*>(dst_ptr); | return static_cast<T*>(dst_ptr); | ||||
} | } | ||||
template <typename T> | template <typename T> | ||||
T* dst(size_t batch_id, size_t group_id, | |||||
size_t group_pack_size = 1_z) const; | |||||
template <typename T> | |||||
T* workspace() const { | T* workspace() const { | ||||
return static_cast<T*>(workspace_ptr); | return static_cast<T*>(workspace_ptr); | ||||
} | } | ||||
@@ -197,9 +197,12 @@ ConvolutionImpl::AlgoFallback::dispatch_kern( | |||||
auto kern_fallback = [workspace_per_thread](const NCBKernParam& p, | auto kern_fallback = [workspace_per_thread](const NCBKernParam& p, | ||||
const NCBKernIndex& ncb_index) { | const NCBKernIndex& ncb_index) { | ||||
UNPACK_CONV_F32_NCB_KERN_SIZES(p); | UNPACK_CONV_F32_NCB_KERN_SIZES(p); | ||||
size_t batch_id = ncb_index.ndrange_id[1]; | |||||
size_t group_id = ncb_index.ndrange_id[0]; | |||||
MEGDNN_MARK_USED_VAR(N); | MEGDNN_MARK_USED_VAR(N); | ||||
auto src = p.src<float>(), filter = p.filter<float>(); | |||||
auto dst = p.dst<float>(); | |||||
auto src = p.src<float>(batch_id, group_id), | |||||
filter = p.filter<float>(group_id); | |||||
auto dst = p.dst<float>(batch_id, group_id); | |||||
size_t thread_id = ncb_index.thread_id; | size_t thread_id = ncb_index.thread_id; | ||||
void* workspace_ptr = reinterpret_cast<void*>( | void* workspace_ptr = reinterpret_cast<void*>( | ||||
reinterpret_cast<ptrdiff_t>(p.workspace_ptr) + | reinterpret_cast<ptrdiff_t>(p.workspace_ptr) + | ||||
@@ -20,18 +20,25 @@ namespace fallback { | |||||
template <typename ST, typename DT, typename CT> | template <typename ST, typename DT, typename CT> | ||||
void kern_naive_forward(const ConvolutionImpl::NCBKernParam& p, | void kern_naive_forward(const ConvolutionImpl::NCBKernParam& p, | ||||
const ConvolutionImpl::NCBKernIndex& /*index*/) { | |||||
const ConvolutionImpl::NCBKernIndex& ncb_index) { | |||||
size_t batch_id = ncb_index.ndrange_id[1]; | |||||
size_t group_id = ncb_index.ndrange_id[0]; | |||||
auto IC = p.filter_meta.icpg, IH = p.isz[0], IW = p.isz[1], | auto IC = p.filter_meta.icpg, IH = p.isz[0], IW = p.isz[1], | ||||
OC = p.filter_meta.ocpg, OH = p.osz[0], OW = p.osz[1]; | OC = p.filter_meta.ocpg, OH = p.osz[0], OW = p.osz[1]; | ||||
ptrdiff_t fstrd = p.filter_meta.icpg * p.filter_meta.ocpg * | |||||
p.filter_meta.spatial[0] * p.filter_meta.spatial[1] * | |||||
p.filter_type.size(); | |||||
ptrdiff_t istrd = p.filter_meta.icpg * p.src_type.size(); | |||||
ptrdiff_t ostrd = p.filter_meta.ocpg * p.dst_type.size(); | |||||
TensorND src, dst; | TensorND src, dst; | ||||
src.raw_ptr = const_cast<void*>(p.src_ptr); | |||||
dst.raw_ptr = p.dst_ptr; | |||||
src.layout.dtype = p.src_type; | src.layout.dtype = p.src_type; | ||||
dst.layout.dtype = p.dst_type; | dst.layout.dtype = p.dst_type; | ||||
if (p.filter_meta.format == param::Convolution::Format::NCHW) { | if (p.filter_meta.format == param::Convolution::Format::NCHW) { | ||||
src.layout.init_contiguous_stride({1, IC, IH, IW}); | |||||
dst.layout.init_contiguous_stride({1, OC, OH, OW}); | |||||
istrd *= p.isz[0] * p.isz[1]; | |||||
ostrd *= p.osz[0] * p.osz[1]; | |||||
src.layout.init_contiguous_stride({1, IC, IH, IW}); | |||||
dst.layout.init_contiguous_stride({1, OC, OH, OW}); | |||||
} else { | } else { | ||||
// Must be NHWC | // Must be NHWC | ||||
megdnn_assert( | megdnn_assert( | ||||
@@ -41,9 +48,17 @@ void kern_naive_forward(const ConvolutionImpl::NCBKernParam& p, | |||||
src.layout.init_contiguous_stride({1, IH, IW, IC}); | src.layout.init_contiguous_stride({1, IH, IW, IC}); | ||||
dst.layout.init_contiguous_stride({1, OH, OW, OC}); | dst.layout.init_contiguous_stride({1, OH, OW, OC}); | ||||
} | } | ||||
src.raw_ptr = reinterpret_cast<void*>( | |||||
reinterpret_cast<uintptr_t>(p.src_ptr) + | |||||
batch_id * p.inp_bs * p.src_type.size() + group_id * istrd); | |||||
dst.raw_ptr = reinterpret_cast<void*>( | |||||
reinterpret_cast<uintptr_t>(p.dst_ptr) + | |||||
batch_id * p.out_bs * p.dst_type.size() + group_id * ostrd); | |||||
ST* filter = reinterpret_cast<ST*>( | |||||
reinterpret_cast<uintptr_t>(p.filter_ptr) + group_id * fstrd); | |||||
std::copy(p.inp_s, p.inp_s + 4, src.layout.stride); | std::copy(p.inp_s, p.inp_s + 4, src.layout.stride); | ||||
std::copy(p.out_s, p.out_s + 4, dst.layout.stride); | std::copy(p.out_s, p.out_s + 4, dst.layout.stride); | ||||
naive::convolution::forward<ST, ST, DT, CT>(src, p.filter<ST>(), dst, | |||||
naive::convolution::forward<ST, ST, DT, CT>(src, filter, dst, | |||||
p.filter_meta); | p.filter_meta); | ||||
} | } | ||||
@@ -189,41 +189,15 @@ ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param( | |||||
void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, | void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, | ||||
Algorithm* algo) { | Algorithm* algo) { | ||||
auto kerns = ncb_algo_dispatch_kern(algo, param); | auto kerns = ncb_algo_dispatch_kern(algo, param); | ||||
size_t src_batch_stride = param.inp_bs * param.src_type.size(); | |||||
size_t dst_batch_stride = param.out_bs * param.dst_type.size(); | |||||
auto group = param.filter_meta.group; | |||||
auto fallback_handle = handle(); | auto fallback_handle = handle(); | ||||
for (auto kernel : kerns) { | for (auto kernel : kerns) { | ||||
megdnn_assert(param.filter_meta.format == Param::Format::NCHW || | megdnn_assert(param.filter_meta.format == Param::Format::NCHW || | ||||
param.filter_meta.format == Param::Format::NHWC || | |||||
param.filter_meta.format == Param::Format::NHWC || | |||||
param.filter_meta.format == Param::Format::NCHW88, | |||||
"invalid conv format"); | "invalid conv format"); | ||||
ptrdiff_t istrd = 0, fstrd = 0, ostrd = 0; | |||||
fstrd = param.filter_meta.icpg * param.filter_meta.ocpg * | |||||
param.filter_meta.spatial[0] * param.filter_meta.spatial[1] * | |||||
param.filter_type.size(); | |||||
istrd = param.filter_meta.icpg * param.src_type.size(); | |||||
ostrd = param.filter_meta.ocpg * param.dst_type.size(); | |||||
if (param.filter_meta.format == Param::Format::NCHW) { | |||||
istrd *= param.isz[0] * param.isz[1]; | |||||
ostrd *= param.osz[0] * param.osz[1]; | |||||
} else { | |||||
// must be NHWC. No action performed. | |||||
} | |||||
auto run = [=](size_t index, size_t thread_id) { | |||||
auto copy_param = param; | |||||
auto run = [param, kernel](size_t index, size_t thread_id) { | |||||
CpuNDRange ndrange_id(kernel.global_size, index); | CpuNDRange ndrange_id(kernel.global_size, index); | ||||
size_t group_id = ndrange_id[0]; | |||||
size_t batch_id = ndrange_id[1]; | |||||
megdnn_assert(group_id < group, | |||||
"The group id should smaller than gruop"); | |||||
//! The kernel ptr point to batch index | |||||
incr_ptr(copy_param.src_ptr, | |||||
group_id * istrd + batch_id * src_batch_stride); | |||||
incr_ptr(copy_param.filter_ptr, group_id * fstrd); | |||||
incr_ptr(copy_param.dst_ptr, | |||||
group_id * ostrd + batch_id * dst_batch_stride); | |||||
kernel.kern(copy_param, {thread_id, ndrange_id}); | |||||
kernel.kern(param, {thread_id, ndrange_id}); | |||||
}; | }; | ||||
static_cast<naive::HandleImpl*>(fallback_handle) | static_cast<naive::HandleImpl*>(fallback_handle) | ||||
->dispatch_kern(run, kernel.global_size.total_size()); | ->dispatch_kern(run, kernel.global_size.total_size()); | ||||
@@ -100,6 +100,43 @@ public: | |||||
T* workspace() const { | T* workspace() const { | ||||
return static_cast<T*>(workspace_ptr); | return static_cast<T*>(workspace_ptr); | ||||
} | } | ||||
//! when format is nchwxx and channel wise, multi group will pack into | |||||
//! one group_pack_id. group_pack_size is the number of packed group | |||||
//! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8} | |||||
template <typename T> | |||||
T* dst(size_t batch_id, size_t group_pack_id, | |||||
size_t group_pack_size = 1_z) const{ | |||||
size_t batch_offset = batch_id * out_bs * dst_type.size(); | |||||
size_t group_offset = group_pack_size * group_pack_id * | |||||
filter_meta.ocpg * osz[0] * osz[1] * | |||||
dst_type.size(); | |||||
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(dst_ptr) + | |||||
batch_offset + group_offset); | |||||
} | |||||
template <typename T> | |||||
const T* src(size_t batch_id, size_t group_pack_id, | |||||
size_t group_pack_size = 1_z) const { | |||||
size_t batch_offset = batch_id * inp_bs * src_type.size(); | |||||
size_t group_offset = group_pack_size * group_pack_id * | |||||
filter_meta.icpg * isz[0] * isz[1] * | |||||
src_type.size(); | |||||
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(src_ptr) + | |||||
batch_offset + group_offset); | |||||
} | |||||
template <typename T> | |||||
const T* filter(size_t group_pack_id, | |||||
size_t pack_group_size = 1_z) const { | |||||
size_t group_offset = pack_group_size * group_pack_id * | |||||
filter_meta.icpg * filter_meta.ocpg * | |||||
filter_meta.spatial[0] * | |||||
filter_meta.spatial[1] * filter_type.size(); | |||||
return reinterpret_cast<T*>( | |||||
reinterpret_cast<ptrdiff_t>(filter_ptr) + group_offset); | |||||
} | |||||
}; | }; | ||||
static void* const sm_fallback_conv_algo_type; | static void* const sm_fallback_conv_algo_type; | ||||
@@ -58,43 +58,45 @@ void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, | |||||
} | } | ||||
} // namespace | } // namespace | ||||
#define GET_KERN \ | |||||
auto fm = param.filter_meta; \ | |||||
size_t N = param.n; \ | |||||
size_t IC = param.filter_meta.icpg; \ | |||||
size_t OC = param.filter_meta.ocpg; \ | |||||
size_t group = fm.group; \ | |||||
WorkspaceBundle wbundle = get_bundle(param); \ | |||||
SmallVector<NCBKern> ret_kerns; \ | |||||
if (m_large_group) { \ | |||||
auto exec_one_group = [wbundle](const NCBKernParam& kern_param, \ | |||||
const NCBKernIndex& ncb_index) { \ | |||||
auto fm = kern_param.filter_meta; \ | |||||
size_t IC = fm.icpg; \ | |||||
size_t OC = fm.ocpg; \ | |||||
WorkspaceBundle bundle = wbundle; \ | |||||
for (size_t ic = 0; ic < IC; ic++) { \ | |||||
copy_padding_kern( \ | |||||
bundle, kern_param, \ | |||||
{ncb_index.thread_id, {ncb_index.thread_id, 0, ic}}); \ | |||||
} \ | |||||
for (size_t oc = 0; oc < OC; oc++) { \ | |||||
do_conv_kern( \ | |||||
bundle, kern_param, \ | |||||
{ncb_index.thread_id, {ncb_index.thread_id, 0, oc}}); \ | |||||
} \ | |||||
}; \ | |||||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); \ | |||||
} else { \ | |||||
WorkspaceBundle bundle = wbundle; \ | |||||
auto copy_padding = \ | |||||
std::bind(copy_padding_kern, bundle, std::placeholders::_1, \ | |||||
std::placeholders::_2); \ | |||||
ret_kerns.push_back({copy_padding, {group, N, IC}}); \ | |||||
auto do_conv = std::bind(do_conv_kern, bundle, std::placeholders::_1, \ | |||||
std::placeholders::_2); \ | |||||
ret_kerns.push_back({do_conv, {group, N, OC}}); \ | |||||
} \ | |||||
#define GET_KERN \ | |||||
auto fm = param.filter_meta; \ | |||||
size_t N = param.n; \ | |||||
size_t IC = param.filter_meta.icpg; \ | |||||
size_t OC = param.filter_meta.ocpg; \ | |||||
size_t group = fm.group; \ | |||||
WorkspaceBundle wbundle = get_bundle(param); \ | |||||
SmallVector<NCBKern> ret_kerns; \ | |||||
if (m_large_group) { \ | |||||
auto exec_one_group = [wbundle](const NCBKernParam& kern_param, \ | |||||
const NCBKernIndex& ncb_index) { \ | |||||
auto fm = kern_param.filter_meta; \ | |||||
size_t IC = fm.icpg; \ | |||||
size_t OC = fm.ocpg; \ | |||||
WorkspaceBundle bundle = wbundle; \ | |||||
for (size_t ic = 0; ic < IC; ic++) { \ | |||||
copy_padding_kern(bundle, kern_param, ncb_index, \ | |||||
{ncb_index.thread_id, 0, ic}); \ | |||||
} \ | |||||
for (size_t oc = 0; oc < OC; oc++) { \ | |||||
do_conv_kern(bundle, kern_param, ncb_index, \ | |||||
{ncb_index.thread_id, 0, oc}); \ | |||||
} \ | |||||
}; \ | |||||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); \ | |||||
} else { \ | |||||
auto copy_padding = [wbundle](const NCBKernParam& kern_param, \ | |||||
const NCBKernIndex& ncb_index) { \ | |||||
copy_padding_kern(wbundle, kern_param, ncb_index, \ | |||||
ncb_index.ndrange_id); \ | |||||
}; \ | |||||
ret_kerns.push_back({copy_padding, {group, N, IC}}); \ | |||||
auto do_conv = [wbundle](const NCBKernParam& kern_param, \ | |||||
const NCBKernIndex& ncb_index) { \ | |||||
do_conv_kern(wbundle, kern_param, ncb_index, \ | |||||
ncb_index.ndrange_id); \ | |||||
}; \ | |||||
ret_kerns.push_back({do_conv, {group, N, OC}}); \ | |||||
} \ | |||||
return ret_kerns; | return ret_kerns; | ||||
/* ===================== direct algo ===================== */ | /* ===================== direct algo ===================== */ | ||||
@@ -145,7 +147,8 @@ size_t ConvBiasImpl::AlgoDirect::get_workspace( | |||||
//! Process one input channel copy padding | //! Process one input channel copy padding | ||||
void ConvBiasImpl::AlgoDirect::copy_padding_kern( | void ConvBiasImpl::AlgoDirect::copy_padding_kern( | ||||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | ||||
const ConvBiasImpl::NCBKernIndex& ncb_index) { | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids) { | |||||
size_t IH = kern_param.isz[0]; | size_t IH = kern_param.isz[0]; | ||||
size_t IW = kern_param.isz[1]; | size_t IW = kern_param.isz[1]; | ||||
size_t IC = kern_param.filter_meta.icpg; | size_t IC = kern_param.filter_meta.icpg; | ||||
@@ -160,14 +163,18 @@ void ConvBiasImpl::AlgoDirect::copy_padding_kern( | |||||
get_rectified_img_size(IH, IW, FH, FW, OH, OW, PH, PW, IH2, IW2, OH2, OW2); | get_rectified_img_size(IH, IW, FH, FW, OH, OW, PH, PW, IH2, IW2, OH2, OW2); | ||||
bool rectify_src = (IH != IH2 || IW != IW2); | bool rectify_src = (IH != IH2 || IW != IW2); | ||||
size_t padding_group_size = IH2 * IW2 * IC; | size_t padding_group_size = IH2 * IW2 * IC; | ||||
const float* sptr = static_cast<const float*>(kern_param.src_ptr) + | |||||
ncb_index.ndrange_id[2] * IH * IW; | |||||
size_t batch_id = ncb_index.ndrange_id[1]; | |||||
size_t group_id = ncb_index.ndrange_id[0]; | |||||
size_t channel_id = workspace_ids[2]; | |||||
const float* sptr = static_cast<const float*>( | |||||
kern_param.src<float>(batch_id, group_id)) + | |||||
channel_id * IH * IW; | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
//! Used for get the workspace offset | //! Used for get the workspace offset | ||||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||||
workspace_batch_id = ncb_index.ndrange_id[1], | |||||
workspace_channel_id = ncb_index.ndrange_id[2]; | |||||
size_t workspace_group_id = workspace_ids[0], | |||||
workspace_batch_id = workspace_ids[1], | |||||
workspace_channel_id = workspace_ids[2]; | |||||
//! If large group, each thread has its own worspace, set group_id with | //! If large group, each thread has its own worspace, set group_id with | ||||
//! thread_id | //! thread_id | ||||
if (rectify_src) { | if (rectify_src) { | ||||
@@ -234,7 +241,8 @@ void ConvBiasImpl::AlgoDirect::copy_padding_kern( | |||||
//! compute one output channel | //! compute one output channel | ||||
void ConvBiasImpl::AlgoDirect::do_conv_kern(WorkspaceBundle bundle, | void ConvBiasImpl::AlgoDirect::do_conv_kern(WorkspaceBundle bundle, | ||||
const NCBKernParam& kern_param, | const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index) { | |||||
const NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids) { | |||||
size_t OH = kern_param.osz[0]; | size_t OH = kern_param.osz[0]; | ||||
size_t OW = kern_param.osz[1]; | size_t OW = kern_param.osz[1]; | ||||
size_t IH = kern_param.isz[0]; | size_t IH = kern_param.isz[0]; | ||||
@@ -265,14 +273,16 @@ void ConvBiasImpl::AlgoDirect::do_conv_kern(WorkspaceBundle bundle, | |||||
megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) { | megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
bias_offset = 1_z; | bias_offset = 1_z; | ||||
} | } | ||||
size_t group_id = ncb_index.ndrange_id[0]; | |||||
size_t batch_id = ncb_index.ndrange_id[1]; | |||||
//! Used for get the workspace offset | //! Used for get the workspace offset | ||||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||||
workspace_batch_id = ncb_index.ndrange_id[1], | |||||
oc = ncb_index.ndrange_id[2]; | |||||
const float* sptr = kern_param.src<float>(); | |||||
const float* filter = kern_param.filter<float>() + oc * FH * FW * IC; | |||||
const float* bias_ptr = kern_param.bias<float>() + oc * bias_offset; | |||||
float* dst = kern_param.dst<float>() + oc * OH * OW; | |||||
size_t workspace_group_id = workspace_ids[0], | |||||
workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; | |||||
const float* sptr = kern_param.src<float>(batch_id, group_id); | |||||
const float* filter = kern_param.filter<float>(group_id) + oc * FH * FW * IC; | |||||
const float* bias_ptr = | |||||
kern_param.bias<float>(batch_id, group_id) + oc * bias_offset; | |||||
float* dst = kern_param.dst<float>(batch_id, group_id) + oc * OH * OW; | |||||
if (rectify_src) { | if (rectify_src) { | ||||
sptr = static_cast<float*>(bundle.get(0)) + | sptr = static_cast<float*>(bundle.get(0)) + | ||||
workspace_group_id * padding_group_size + | workspace_group_id * padding_group_size + | ||||
@@ -358,7 +368,8 @@ size_t ConvBiasImpl::AlgoDirectStride2::get_workspace( | |||||
//! Process one input channel copy padding | //! Process one input channel copy padding | ||||
void ConvBiasImpl::AlgoDirectStride2::copy_padding_kern( | void ConvBiasImpl::AlgoDirectStride2::copy_padding_kern( | ||||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | ||||
const ConvBiasImpl::NCBKernIndex& ncb_index) { | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids) { | |||||
size_t IH = kern_param.isz[0]; | size_t IH = kern_param.isz[0]; | ||||
size_t IW = kern_param.isz[1]; | size_t IW = kern_param.isz[1]; | ||||
size_t IC = kern_param.filter_meta.icpg; | size_t IC = kern_param.filter_meta.icpg; | ||||
@@ -373,13 +384,17 @@ void ConvBiasImpl::AlgoDirectStride2::copy_padding_kern( | |||||
get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); | get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); | ||||
bool rectify_src = need_src_copy(kern_param); | bool rectify_src = need_src_copy(kern_param); | ||||
size_t padding_group_size = IH2 * IW2 * IC; | size_t padding_group_size = IH2 * IW2 * IC; | ||||
const float* sptr = static_cast<const float*>(kern_param.src_ptr) + | |||||
ncb_index.ndrange_id[2] * IH * IW; | |||||
size_t group_id = ncb_index.ndrange_id[0]; | |||||
size_t batch_id = ncb_index.ndrange_id[1]; | |||||
size_t channel_id = workspace_ids[2]; | |||||
const float* sptr = static_cast<const float*>( | |||||
kern_param.src<float>(batch_id, group_id)) + | |||||
channel_id * IH * IW; | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
//! Used for get the workspace offset | //! Used for get the workspace offset | ||||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||||
workspace_batch_id = ncb_index.ndrange_id[1], | |||||
workspace_channel_id = ncb_index.ndrange_id[2]; | |||||
size_t workspace_group_id = workspace_ids[0], | |||||
workspace_batch_id = workspace_ids[1], | |||||
workspace_channel_id = workspace_ids[2]; | |||||
if (rectify_src) { | if (rectify_src) { | ||||
//! copy to sptr_base to eliminate padding effect | //! copy to sptr_base to eliminate padding effect | ||||
float* sptr_base = static_cast<float*>(bundle.get(0)) + | float* sptr_base = static_cast<float*>(bundle.get(0)) + | ||||
@@ -397,7 +412,7 @@ void ConvBiasImpl::AlgoDirectStride2::copy_padding_kern( | |||||
//! compute one output channel | //! compute one output channel | ||||
void ConvBiasImpl::AlgoDirectStride2::do_conv_kern( | void ConvBiasImpl::AlgoDirectStride2::do_conv_kern( | ||||
WorkspaceBundle bundle, const NCBKernParam& kern_param, | WorkspaceBundle bundle, const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index) { | |||||
const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { | |||||
size_t OH = kern_param.osz[0]; | size_t OH = kern_param.osz[0]; | ||||
size_t OW = kern_param.osz[1]; | size_t OW = kern_param.osz[1]; | ||||
size_t IH = kern_param.isz[0]; | size_t IH = kern_param.isz[0]; | ||||
@@ -439,14 +454,17 @@ void ConvBiasImpl::AlgoDirectStride2::do_conv_kern( | |||||
megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) { | megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
bias_offset = 1_z; | bias_offset = 1_z; | ||||
} | } | ||||
size_t group_id = ncb_index.ndrange_id[0]; | |||||
size_t batch_id = ncb_index.ndrange_id[1]; | |||||
//! Used for get the workspace offset | //! Used for get the workspace offset | ||||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||||
workspace_batch_id = ncb_index.ndrange_id[1], | |||||
oc = ncb_index.ndrange_id[2]; | |||||
const float* sptr = kern_param.src<float>(); | |||||
const float* filter = kern_param.filter<float>() + oc * FH * FW * IC; | |||||
const float* bias_ptr = kern_param.bias<float>() + oc * bias_offset; | |||||
float* dst = kern_param.dst<float>() + oc * OH * OW; | |||||
size_t workspace_group_id = workspace_ids[0], | |||||
workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; | |||||
const float* sptr = kern_param.src<float>(batch_id, group_id); | |||||
const float* filter = | |||||
kern_param.filter<float>(group_id) + oc * FH * FW * IC; | |||||
const float* bias_ptr = | |||||
kern_param.bias<float>(batch_id, group_id) + oc * bias_offset; | |||||
float* dst = kern_param.dst<float>(batch_id, group_id) + oc * OH * OW; | |||||
if (rectify_src) { | if (rectify_src) { | ||||
sptr = static_cast<float*>(bundle.get(0)) + | sptr = static_cast<float*>(bundle.get(0)) + | ||||
workspace_group_id * padding_group_size + | workspace_group_id * padding_group_size + | ||||
@@ -547,23 +565,22 @@ MatrixMul* ConvBiasImpl::AlgoMatrixMul::get_matmul_opr() { | |||||
} | } | ||||
void ConvBiasImpl::AlgoMatrixMul::kimpl(const NCBKernParam& param, | void ConvBiasImpl::AlgoMatrixMul::kimpl(const NCBKernParam& param, | ||||
const NCBKernIndex&) { | |||||
const NCBKernIndex& ncb_index) { | |||||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | UNPACK_CONV_F32_NCB_KERN_SIZES(param); | ||||
auto IH2 = IH + 2 * PH; | auto IH2 = IH + 2 * PH; | ||||
auto IW2 = IW + 2 * PW; | auto IW2 = IW + 2 * PW; | ||||
size_t group_id = ncb_index.ndrange_id[0]; | |||||
bool is_xcorr = !param.filter_meta.should_flip; | bool is_xcorr = !param.filter_meta.should_flip; | ||||
auto bundle = get_bundle(param); | auto bundle = get_bundle(param); | ||||
bundle.set(param.workspace_ptr); | bundle.set(param.workspace_ptr); | ||||
// workspace = tmp..src2 | // workspace = tmp..src2 | ||||
for (size_t n = 0; n < N; ++n) { | for (size_t n = 0; n < N; ++n) { | ||||
float* src = const_cast<float*>(param.src<float>()) + n * param.inp_bs; | |||||
float* dst = param.dst<float>() + n * param.out_bs; | |||||
float* bias_ptr = | |||||
static_cast<float*>(const_cast<void*>(param.bias_ptr)); | |||||
if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||||
bias_ptr += n * param.out_bs; | |||||
} | |||||
float* src = const_cast<float*>(param.src<float>(n, group_id)); | |||||
float* dst = param.dst<float>(n, group_id); | |||||
float* bias_ptr = static_cast<float*>( | |||||
const_cast<void*>(param.bias<void>(n, group_id))); | |||||
float *B, *src2; | float *B, *src2; | ||||
if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { | if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { | ||||
// special case: 1x1 | // special case: 1x1 | ||||
@@ -613,7 +630,7 @@ void ConvBiasImpl::AlgoMatrixMul::kimpl(const NCBKernParam& param, | |||||
{ | { | ||||
TensorND A_, B_, C_; | TensorND A_, B_, C_; | ||||
A_.layout = TensorLayout({OC, IC * FH * FW}, dtype::Float32()); | A_.layout = TensorLayout({OC, IC * FH * FW}, dtype::Float32()); | ||||
A_.raw_ptr = const_cast<float*>(param.filter<float>()); | |||||
A_.raw_ptr = const_cast<float*>(param.filter<float>(group_id)); | |||||
B_.layout = TensorLayout({IC * FH * FW, OH * OW}, dtype::Float32()); | B_.layout = TensorLayout({IC * FH * FW, OH * OW}, dtype::Float32()); | ||||
B_.raw_ptr = B; | B_.raw_ptr = B; | ||||
C_.layout = TensorLayout({OC, OH * OW}, dtype::Float32()); | C_.layout = TensorLayout({OC, OH * OW}, dtype::Float32()); | ||||
@@ -22,10 +22,12 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase { | |||||
static void copy_padding_kern(WorkspaceBundle bundle, | static void copy_padding_kern(WorkspaceBundle bundle, | ||||
const NCBKernParam& kern_param, | const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index); | |||||
const NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids); | |||||
static void do_conv_kern(WorkspaceBundle bundle, | static void do_conv_kern(WorkspaceBundle bundle, | ||||
const NCBKernParam& kern_param, | const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index); | |||||
const NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids); | |||||
bool m_large_group; | bool m_large_group; | ||||
public: | public: | ||||
@@ -57,10 +59,12 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase { | |||||
static void copy_padding_kern(WorkspaceBundle bundle, | static void copy_padding_kern(WorkspaceBundle bundle, | ||||
const NCBKernParam& kern_param, | const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index); | |||||
const NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids); | |||||
static void do_conv_kern(WorkspaceBundle bundle, | static void do_conv_kern(WorkspaceBundle bundle, | ||||
const NCBKernParam& kern_param, | const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index); | |||||
const NCBKernIndex& ncb_index, | |||||
const CpuNDRange& workspace_ids); | |||||
bool m_large_group; | bool m_large_group; | ||||
public: | public: | ||||
@@ -146,9 +146,11 @@ WorkspaceBundle ConvBiasImpl::AlgoMkldnnQint8::get_bundle( | |||||
} while (0) | } while (0) | ||||
void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32( | void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32( | ||||
const NCBKernParam& param, const NCBKernIndex&) { | |||||
const NCBKernParam& param, const NCBKernIndex& ncb_index) { | |||||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | UNPACK_CONV_F32_NCB_KERN_SIZES(param); | ||||
MEGDNN_MARK_USED_VAR(N); | MEGDNN_MARK_USED_VAR(N); | ||||
size_t group_id = ncb_index.ndrange_id[0]; | |||||
size_t batch_id = ncb_index.ndrange_id[1]; | |||||
auto x86_handle = static_cast<HandleImpl*>(inplace_cpu_handle().get()); | auto x86_handle = static_cast<HandleImpl*>(inplace_cpu_handle().get()); | ||||
megdnn_assert(x86_handle != nullptr, "x86 handle can not be null"); | megdnn_assert(x86_handle != nullptr, "x86 handle can not be null"); | ||||
auto eng_mkldnn = x86_handle->mkldnn_engine(); | auto eng_mkldnn = x86_handle->mkldnn_engine(); | ||||
@@ -167,10 +169,11 @@ void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32( | |||||
auto megdnn_dst_md = memory::desc({dst_shape}, memory::data_type::s32, | auto megdnn_dst_md = memory::desc({dst_shape}, memory::data_type::s32, | ||||
memory::format_tag::nchw); | memory::format_tag::nchw); | ||||
auto megdnn_weight_memory = memory(megdnn_weight_md, eng_mkldnn, | |||||
const_cast<void*>(param.filter_ptr)); | |||||
int8_t* src = const_cast<int8_t*>(param.src<int8_t>()); | |||||
int32_t* dst = param.dst<int32_t>(); | |||||
auto megdnn_weight_memory = | |||||
memory(megdnn_weight_md, eng_mkldnn, | |||||
const_cast<void*>(param.filter<void>(group_id))); | |||||
int8_t* src = const_cast<int8_t*>(param.src<int8_t>(batch_id, group_id)); | |||||
int32_t* dst = param.dst<int32_t>(batch_id, group_id); | |||||
auto megdnn_src_memory = | auto megdnn_src_memory = | ||||
memory(megdnn_src_md, eng_mkldnn, static_cast<void*>(src)); | memory(megdnn_src_md, eng_mkldnn, static_cast<void*>(src)); | ||||
@@ -353,18 +356,18 @@ MatrixMul* ConvBiasImpl::AlgoMkldnnMatmulQint8::get_matmul_opr() { | |||||
} | } | ||||
void ConvBiasImpl::AlgoMkldnnMatmulQint8::kern_mkldnn_matmul_s8x8x32( | void ConvBiasImpl::AlgoMkldnnMatmulQint8::kern_mkldnn_matmul_s8x8x32( | ||||
const NCBKernParam& param, const NCBKernIndex&) { | |||||
const NCBKernParam& param, const NCBKernIndex& ncb_index) { | |||||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | UNPACK_CONV_F32_NCB_KERN_SIZES(param); | ||||
auto IH2 = IH + 2 * PH; | auto IH2 = IH + 2 * PH; | ||||
auto IW2 = IW + 2 * PW; | auto IW2 = IW + 2 * PW; | ||||
size_t group_id = ncb_index.ndrange_id[0]; | |||||
bool is_xcorr = !param.filter_meta.should_flip; | bool is_xcorr = !param.filter_meta.should_flip; | ||||
auto bundle = get_bundle(param); | auto bundle = get_bundle(param); | ||||
bundle.set(param.workspace_ptr); | bundle.set(param.workspace_ptr); | ||||
for (size_t n = 0; n < N; ++n) { | for (size_t n = 0; n < N; ++n) { | ||||
int8_t* src = | |||||
const_cast<int8_t*>(param.src<int8_t>()) + n * param.inp_bs; | |||||
int32_t* dst = param.dst<int32_t>() + n * param.out_bs; | |||||
int8_t* src = const_cast<int8_t*>(param.src<int8_t>(n, group_id)); | |||||
int32_t* dst = param.dst<int32_t>(n, group_id); | |||||
int8_t *B, *src2; | int8_t *B, *src2; | ||||
if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { | if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { | ||||
// special case: 1x1 | // special case: 1x1 | ||||
@@ -414,7 +417,7 @@ void ConvBiasImpl::AlgoMkldnnMatmulQint8::kern_mkldnn_matmul_s8x8x32( | |||||
{ | { | ||||
TensorND A_, B_, C_; | TensorND A_, B_, C_; | ||||
A_.layout = TensorLayout({OC, IC * FH * FW}, dtype::Int8()); | A_.layout = TensorLayout({OC, IC * FH * FW}, dtype::Int8()); | ||||
A_.raw_ptr = const_cast<int8_t*>(param.filter<int8_t>()); | |||||
A_.raw_ptr = const_cast<int8_t*>(param.filter<int8_t>(group_id)); | |||||
B_.layout = TensorLayout({IC * FH * FW, OH * OW}, dtype::Int8()); | B_.layout = TensorLayout({IC * FH * FW, OH * OW}, dtype::Int8()); | ||||
B_.raw_ptr = B; | B_.raw_ptr = B; | ||||
C_.layout = TensorLayout({OC, OH * OW}, dtype::Int32()); | C_.layout = TensorLayout({OC, OH * OW}, dtype::Int32()); | ||||
@@ -47,8 +47,8 @@ void pack_src_conv_avx2_stride1(WorkspaceBundle bundle, | |||||
batch_id = ncb_index.ndrange_id[1], | batch_id = ncb_index.ndrange_id[1], | ||||
channel_id = ncb_index.ndrange_id[2]; | channel_id = ncb_index.ndrange_id[2]; | ||||
const int8_t* src_ptr = | |||||
kern_param.src<int8_t>() + ic_step * channel_id * c_stride; | |||||
const int8_t* src_ptr = kern_param.src<int8_t>(batch_id, group_id) + | |||||
ic_step * channel_id * c_stride; | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
int8_t* packed_src = static_cast<int8_t*>(bundle.get(0)) + | int8_t* packed_src = static_cast<int8_t*>(bundle.get(0)) + | ||||
batch_id * group * packed_group_size + | batch_id * group * packed_group_size + | ||||
@@ -129,7 +129,7 @@ static inline void pack_filter_conv_avx2_stride1( | |||||
size_t group_id = ncb_index.ndrange_id[0], | size_t group_id = ncb_index.ndrange_id[0], | ||||
oc_index_id = ncb_index.ndrange_id[1]; | oc_index_id = ncb_index.ndrange_id[1]; | ||||
const int8_t* pack_filter_ptr = kern_param.filter<int8_t>(); | |||||
const int8_t* pack_filter_ptr = kern_param.filter<int8_t>(group_id); | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
int16_t* out_ptr = static_cast<int16_t*>(bundle.get(1)) + | int16_t* out_ptr = static_cast<int16_t*>(bundle.get(1)) + | ||||
group_id * round_up(oc, oc_step) * oc_out_stride; | group_id * round_up(oc, oc_step) * oc_out_stride; | ||||
@@ -632,19 +632,18 @@ void do_conv_kern(WorkspaceBundle bundle, | |||||
const uint32_t packed_group_size = | const uint32_t packed_group_size = | ||||
div_ceil(ic, ic_step) * pack_ih * pack_iw; | div_ceil(ic, ic_step) * pack_ih * pack_iw; | ||||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||||
workspace_batch_id = ncb_index.ndrange_id[1], | |||||
workspace_channel_id = ncb_index.ndrange_id[2]; | |||||
size_t group_id = ncb_index.ndrange_id[0], | |||||
batch_id = ncb_index.ndrange_id[1], | |||||
channel_id = ncb_index.ndrange_id[2]; | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
int8_t* src_ptr = static_cast<int8_t*>(bundle.get(0)) + | int8_t* src_ptr = static_cast<int8_t*>(bundle.get(0)) + | ||||
workspace_group_id * packed_group_size + | |||||
workspace_batch_id * group * packed_group_size; | |||||
int16_t* filter_ptr = | |||||
static_cast<int16_t*>(bundle.get(1)) + | |||||
workspace_group_id * round_up(oc, oc_step) * filter_round_size + | |||||
oc_step * workspace_channel_id * filter_round_size; | |||||
group_id * packed_group_size + | |||||
batch_id * group * packed_group_size; | |||||
int16_t* filter_ptr = static_cast<int16_t*>(bundle.get(1)) + | |||||
group_id * round_up(oc, oc_step) * filter_round_size + | |||||
oc_step * channel_id * filter_round_size; | |||||
bool need_post_process = | bool need_post_process = | ||||
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | ||||
@@ -652,12 +651,11 @@ void do_conv_kern(WorkspaceBundle bundle, | |||||
int32_t* dst_tptr = nullptr; | int32_t* dst_tptr = nullptr; | ||||
if (need_post_process) { | if (need_post_process) { | ||||
dst_tptr = static_cast<int32_t*>(bundle.get(2)) + | dst_tptr = static_cast<int32_t*>(bundle.get(2)) + | ||||
workspace_batch_id * group * oc * oc_stride + | |||||
workspace_group_id * oc * oc_stride + | |||||
oc_step * workspace_channel_id * oh * ow; | |||||
batch_id * group * oc * oc_stride + | |||||
group_id * oc * oc_stride + oc_step * channel_id * oh * ow; | |||||
} else { | } else { | ||||
dst_tptr = kern_param.dst<int32_t>() + | |||||
oc_step * workspace_channel_id * oh * ow; | |||||
dst_tptr = kern_param.dst<int32_t>(batch_id, group_id) + | |||||
oc_step * channel_id * oh * ow; | |||||
} | } | ||||
const uint32_t oc_end = oc / oc_step * oc_step; | const uint32_t oc_end = oc / oc_step * oc_step; | ||||
@@ -666,7 +664,7 @@ void do_conv_kern(WorkspaceBundle bundle, | |||||
const uint32_t oh_remain = oh - oh_end; | const uint32_t oh_remain = oh - oh_end; | ||||
const uint32_t ow_end = ow / ow_step * ow_step; | const uint32_t ow_end = ow / ow_step * ow_step; | ||||
const uint32_t ow_remain = ow - ow_end; | const uint32_t ow_remain = ow - ow_end; | ||||
const uint32_t oc_index = oc_step * workspace_channel_id; | |||||
const uint32_t oc_index = oc_step * channel_id; | |||||
AlgoAVX2DirectConvStride1S8S8S32_forward<oc_step, ic_step, oh_step, | AlgoAVX2DirectConvStride1S8S8S32_forward<oc_step, ic_step, oh_step, | ||||
ow_step>( | ow_step>( | ||||
@@ -684,29 +682,29 @@ void do_post_process(WorkspaceBundle bundle, | |||||
const uint32_t oh = kern_param.osz[0]; | const uint32_t oh = kern_param.osz[0]; | ||||
const uint32_t ow = kern_param.osz[1]; | const uint32_t ow = kern_param.osz[1]; | ||||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||||
workspace_batch_id = ncb_index.ndrange_id[1]; | |||||
size_t group_id = ncb_index.ndrange_id[0], | |||||
batch_id = ncb_index.ndrange_id[1]; | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
bool need_post_process = | bool need_post_process = | ||||
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | ||||
void* dst_tptr = nullptr; | void* dst_tptr = nullptr; | ||||
if (need_post_process) { | if (need_post_process) { | ||||
dst_tptr = static_cast<int32_t*>(bundle.get(2)) + | dst_tptr = static_cast<int32_t*>(bundle.get(2)) + | ||||
workspace_batch_id * group * oc * oh * ow + | |||||
workspace_group_id * oc * oh * ow; | |||||
batch_id * group * oc * oh * ow + group_id * oc * oh * ow; | |||||
} else { | } else { | ||||
dst_tptr = kern_param.dst<dt_int32>(); | |||||
dst_tptr = kern_param.dst<dt_int32>(batch_id, group_id); | |||||
} | } | ||||
void* dst_ptr = kern_param.dst<void>(batch_id, group_id); | |||||
#define cb(_bias_ctype, _dst_ctype, _postprocess_mode) \ | #define cb(_bias_ctype, _dst_ctype, _postprocess_mode) \ | ||||
{ \ | { \ | ||||
const dt_int32* bias_ptr = kern_param.bias<dt_int32>(); \ | |||||
const dt_int32* bias_ptr = \ | |||||
kern_param.bias<dt_int32>(batch_id, group_id); \ | |||||
PostProcess<DTypeTrait<_bias_ctype>::ctype, \ | PostProcess<DTypeTrait<_bias_ctype>::ctype, \ | ||||
DTypeTrait<_dst_ctype>::ctype, \ | DTypeTrait<_dst_ctype>::ctype, \ | ||||
_postprocess_mode>::run(dst_tptr, \ | _postprocess_mode>::run(dst_tptr, \ | ||||
const_cast<dt_int32*>(bias_ptr), \ | const_cast<dt_int32*>(bias_ptr), \ | ||||
kern_param.dst_ptr, \ | |||||
kern_param.bias_mode, \ | |||||
dst_ptr, kern_param.bias_mode, \ | |||||
kern_param.nonlineMode, \ | kern_param.nonlineMode, \ | ||||
kern_param.bias_type, \ | kern_param.bias_type, \ | ||||
kern_param.dst_type, 1, oc, oh, \ | kern_param.dst_type, 1, oc, oh, \ | ||||
@@ -45,8 +45,8 @@ void pack_src_conv_avx2_stride2(WorkspaceBundle bundle, | |||||
batch_id = ncb_index.ndrange_id[1], | batch_id = ncb_index.ndrange_id[1], | ||||
channel_id = ncb_index.ndrange_id[2]; | channel_id = ncb_index.ndrange_id[2]; | ||||
const int8_t* src_ptr = | |||||
kern_param.src<int8_t>() + ic_step * channel_id * c_stride; | |||||
const int8_t* src_ptr = kern_param.src<int8_t>(batch_id, group_id) + | |||||
ic_step * channel_id * c_stride; | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
int8_t* packed_src = static_cast<int8_t*>(bundle.get(0)) + | int8_t* packed_src = static_cast<int8_t*>(bundle.get(0)) + | ||||
batch_id * group * packed_group_size + | batch_id * group * packed_group_size + | ||||
@@ -187,7 +187,7 @@ static inline void pack_filter_conv_avx2_stride2( | |||||
size_t group_id = ncb_index.ndrange_id[0], | size_t group_id = ncb_index.ndrange_id[0], | ||||
oc_index_id = ncb_index.ndrange_id[1]; | oc_index_id = ncb_index.ndrange_id[1]; | ||||
const int8_t* pack_filter_ptr = kern_param.filter<int8_t>(); | |||||
const int8_t* pack_filter_ptr = kern_param.filter<int8_t>(group_id); | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
int16_t* out_ptr = static_cast<int16_t*>(bundle.get(1)) + | int16_t* out_ptr = static_cast<int16_t*>(bundle.get(1)) + | ||||
group_id * round_up(oc, oc_step) * oc_out_stride; | group_id * round_up(oc, oc_step) * oc_out_stride; | ||||
@@ -705,18 +705,17 @@ void kernel_imp(WorkspaceBundle bundle, | |||||
const uint32_t packed_group_size = | const uint32_t packed_group_size = | ||||
div_ceil(ic, ic_step) * pack_ih * pack_iw; | div_ceil(ic, ic_step) * pack_ih * pack_iw; | ||||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||||
workspace_batch_id = ncb_index.ndrange_id[1], | |||||
workspace_channel_id = ncb_index.ndrange_id[2]; | |||||
size_t group_id = ncb_index.ndrange_id[0], | |||||
batch_id = ncb_index.ndrange_id[1], | |||||
channel_id = ncb_index.ndrange_id[2]; | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
int8_t* src_ptr = static_cast<int8_t*>(bundle.get(0)) + | int8_t* src_ptr = static_cast<int8_t*>(bundle.get(0)) + | ||||
workspace_group_id * packed_group_size + | |||||
workspace_batch_id * group * packed_group_size; | |||||
int16_t* filter_ptr = | |||||
static_cast<int16_t*>(bundle.get(1)) + | |||||
workspace_group_id * round_up(oc, oc_step) * filter_round_size + | |||||
oc_step * workspace_channel_id * filter_round_size; | |||||
group_id * packed_group_size + | |||||
batch_id * group * packed_group_size; | |||||
int16_t* filter_ptr = static_cast<int16_t*>(bundle.get(1)) + | |||||
group_id * round_up(oc, oc_step) * filter_round_size + | |||||
oc_step * channel_id * filter_round_size; | |||||
bool need_post_process = | bool need_post_process = | ||||
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | ||||
@@ -724,12 +723,11 @@ void kernel_imp(WorkspaceBundle bundle, | |||||
int32_t* dst_tptr = nullptr; | int32_t* dst_tptr = nullptr; | ||||
if (need_post_process) { | if (need_post_process) { | ||||
dst_tptr = static_cast<int32_t*>(bundle.get(2)) + | dst_tptr = static_cast<int32_t*>(bundle.get(2)) + | ||||
workspace_batch_id * group * oc * oc_stride + | |||||
workspace_group_id * oc * oc_stride + | |||||
oc_step * workspace_channel_id * oh * ow; | |||||
batch_id * group * oc * oc_stride + | |||||
group_id * oc * oc_stride + oc_step * channel_id * oh * ow; | |||||
} else { | } else { | ||||
dst_tptr = kern_param.dst<int32_t>() + | |||||
oc_step * workspace_channel_id * oh * ow; | |||||
dst_tptr = kern_param.dst<int32_t>(batch_id, group_id) + | |||||
oc_step * channel_id * oh * ow; | |||||
} | } | ||||
const uint32_t oc_end = oc / oc_step * oc_step; | const uint32_t oc_end = oc / oc_step * oc_step; | ||||
const uint32_t oc_remain = oc - oc_end; | const uint32_t oc_remain = oc - oc_end; | ||||
@@ -737,7 +735,7 @@ void kernel_imp(WorkspaceBundle bundle, | |||||
const uint32_t oh_remain = oh - oh_end; | const uint32_t oh_remain = oh - oh_end; | ||||
const uint32_t ow_end = ow / ow_step * ow_step; | const uint32_t ow_end = ow / ow_step * ow_step; | ||||
const uint32_t ow_remain = ow - ow_end; | const uint32_t ow_remain = ow - ow_end; | ||||
const uint32_t oc_index = oc_step * workspace_channel_id; | |||||
const uint32_t oc_index = oc_step * channel_id; | |||||
kernel_handle_oh_remain<oc_step, ic_step, oh_step, ow_step>( | kernel_handle_oh_remain<oc_step, ic_step, oh_step, ow_step>( | ||||
oh_remain, oc_remain, ow_remain, filter_ptr, src_ptr, dst_tptr, | oh_remain, oc_remain, ow_remain, filter_ptr, src_ptr, dst_tptr, | ||||
@@ -754,8 +752,8 @@ void do_post_process(WorkspaceBundle bundle, | |||||
const uint32_t oh = kern_param.osz[0]; | const uint32_t oh = kern_param.osz[0]; | ||||
const uint32_t ow = kern_param.osz[1]; | const uint32_t ow = kern_param.osz[1]; | ||||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||||
workspace_batch_id = ncb_index.ndrange_id[1]; | |||||
size_t group_id = ncb_index.ndrange_id[0], | |||||
batch_id = ncb_index.ndrange_id[1]; | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
bool need_post_process = | bool need_post_process = | ||||
@@ -763,21 +761,22 @@ void do_post_process(WorkspaceBundle bundle, | |||||
void* dst_tptr = nullptr; | void* dst_tptr = nullptr; | ||||
if (need_post_process) { | if (need_post_process) { | ||||
dst_tptr = static_cast<int32_t*>(bundle.get(2)) + | dst_tptr = static_cast<int32_t*>(bundle.get(2)) + | ||||
workspace_batch_id * group * oc * oh * ow + | |||||
workspace_group_id * oc * oh * ow; | |||||
batch_id * group * oc * oh * ow + | |||||
group_id * oc * oh * ow; | |||||
} else { | } else { | ||||
dst_tptr = kern_param.dst<dt_int32>(); | |||||
dst_tptr = kern_param.dst<dt_int32>(batch_id, group_id); | |||||
} | } | ||||
void* dst_ptr = kern_param.dst<void>(batch_id, group_id); | |||||
#define cb(_bias_ctype, _dst_ctype, _postprocess_mode) \ | #define cb(_bias_ctype, _dst_ctype, _postprocess_mode) \ | ||||
{ \ | { \ | ||||
const dt_int32* bias_ptr = kern_param.bias<dt_int32>(); \ | |||||
const dt_int32* bias_ptr = \ | |||||
kern_param.bias<dt_int32>(batch_id, group_id); \ | |||||
PostProcess<DTypeTrait<_bias_ctype>::ctype, \ | PostProcess<DTypeTrait<_bias_ctype>::ctype, \ | ||||
DTypeTrait<_dst_ctype>::ctype, \ | DTypeTrait<_dst_ctype>::ctype, \ | ||||
_postprocess_mode>::run(dst_tptr, \ | _postprocess_mode>::run(dst_tptr, \ | ||||
const_cast<dt_int32*>(bias_ptr), \ | const_cast<dt_int32*>(bias_ptr), \ | ||||
kern_param.dst_ptr, \ | |||||
kern_param.bias_mode, \ | |||||
dst_ptr, kern_param.bias_mode, \ | |||||
kern_param.nonlineMode, \ | kern_param.nonlineMode, \ | ||||
kern_param.bias_type, \ | kern_param.bias_type, \ | ||||
kern_param.dst_type, 1, oc, oh, \ | kern_param.dst_type, 1, oc, oh, \ | ||||