GitOrigin-RevId: e37fbe0ffe
tags/v0.3.2
@@ -213,11 +213,17 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoNaive::dispatch_kerns( | |||
const NCBKernParam& param, | |||
const NCBKernIndex& ncb_index) { | |||
MIDOUT_BEGIN(megdnn_fallback_naive, 2) { | |||
size_t group_id = ncb_index.ndrange_id[0]; | |||
size_t batch_id = ncb_index.ndrange_id[1]; | |||
size_t thread_id = ncb_index.thread_id; | |||
auto thread_param = param; | |||
thread_param.workspace_ptr = reinterpret_cast<void*>( | |||
reinterpret_cast<ptrdiff_t>(param.workspace_ptr) + | |||
thread_id * workspace_per_thread); | |||
thread_param.filter_ptr = param.filter<void>(group_id); | |||
thread_param.dst_ptr = param.dst<void>(batch_id, group_id); | |||
thread_param.src_ptr = param.src<void>(batch_id, group_id); | |||
thread_param.bias_ptr = param.bias<void>(batch_id, group_id); | |||
kern_default(opr_param, thread_param); | |||
} | |||
MIDOUT_END(); | |||
@@ -111,7 +111,6 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||
size_t channel_id = ncb_index.ndrange_id[2]; | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
size_t input_channel_offset = IH * IW * channel_id; | |||
size_t workspace_channel_offset = IH2 * IW2 * channel_id; | |||
size_t workspace_group_offset = group_id * padding_group_size; | |||
size_t workspace_batch_offset = | |||
@@ -123,7 +122,7 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||
} | |||
src_ctype* src = const_cast<src_ctype*>( | |||
param.src<src_ctype>(batch_id, group_id) + input_channel_offset); | |||
param.src<src_ctype>(batch_id, group_id, channel_id)); | |||
src_ctype* src2; | |||
src2 = static_cast<src_ctype*>( | |||
bundle.get(Im2colBundelIndex::BUNDLE_PADDING_INDEX)) + | |||
@@ -246,10 +246,9 @@ void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param, | |||
ConvBiasImpl::Algorithm* algo) { | |||
auto ncb_kerns = ncb_algo_dispatch_kerns(algo, param); | |||
for (auto&& kernel : ncb_kerns) { | |||
auto run = [=](size_t index, size_t thread_id) { | |||
auto copy_param = param; | |||
auto run = [kernel, param](size_t index, size_t thread_id) { | |||
CpuNDRange ndrange_id(kernel.global_size, index); | |||
kernel.kern(copy_param, {thread_id, ndrange_id}); | |||
kernel.kern(param, {thread_id, ndrange_id}); | |||
}; | |||
static_cast<naive::HandleImpl*>(handle())->dispatch_kern( | |||
run, kernel.global_size.total_size()); | |||
@@ -328,28 +327,29 @@ const char* ConvBiasImpl::get_algorithm_set_name() const { | |||
namespace megdnn{ | |||
namespace fallback { | |||
//! when format is nchwxx and channel wise mode, multi group will pack | |||
//! together, so pack_group_size is the number of packed group | |||
template <typename T> | |||
const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_id, | |||
size_t group_pack_size) const { | |||
src_type.assert_is_compatible_ctype<T>(); | |||
const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_pack_id, | |||
size_t channel_pack_id, | |||
size_t group_pack_size, | |||
size_t channel_pack_size) const { | |||
size_t batch_offset = batch_id * inp_bs * src_type.size(); | |||
size_t group_offset = group_pack_size * group_id * filter_meta.icpg * | |||
size_t group_offset = group_pack_size * group_pack_id * filter_meta.icpg * | |||
isz[0] * isz[1] * src_type.size(); | |||
size_t channel_offset = channel_pack_size * channel_pack_id * isz[0] * | |||
isz[1] * src_type.size(); | |||
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(src_ptr) + | |||
batch_offset + group_offset); | |||
batch_offset + group_offset + channel_offset); | |||
} | |||
//! when format is nchwxx and channel wise mode, multi group will pack | |||
//! together, so pack_group_size is the number of packed group | |||
template <typename T> | |||
const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id, | |||
const T* ConvBiasImpl::NCBKernParam::filter(size_t group_pack_id, | |||
size_t pack_group_size) const { | |||
size_t group_offset = 0_z; | |||
switch (filter_meta.format) { | |||
case Param::Format::NCHW: { | |||
group_offset = pack_group_size * group_id * filter_meta.icpg * | |||
group_offset = pack_group_size * group_pack_id * filter_meta.icpg * | |||
filter_meta.ocpg * filter_meta.spatial[0] * | |||
filter_meta.spatial[1] * filter_type.size(); | |||
break; | |||
@@ -359,15 +359,15 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id, | |||
size_t icpg = filter_meta.icpg; | |||
size_t ocpg = filter_meta.ocpg; | |||
//! four format of weight layout | |||
//! 1. {oc/8, ic/8, fh, fw, 8, 8}, 2. {g, oc/8, ic/8, fh, | |||
//! fw, 8, 8} | |||
//! 3. {g/8, 1, 1, fh, fw, 8, 8}, 3. {oc/8 ,fh, fw, ic, 8} | |||
//! 1. {oc/8, ic/8, fh, fw, 8, 8}, | |||
//! 2. {g, oc/8, ic/8, fh, fw, 8, 8}, | |||
//! 3. {g/8, fh, fw, 1, 1, 8}, 4. {oc/8, fh, fw, ic, 8} | |||
megdnn_assert((icpg % 8 == 0 && ocpg % 8 == 0) || | |||
(group % 8 == 0 && icpg == 1 && ocpg == 1 && | |||
pack_group_size > 1) || | |||
(group == 1 && ocpg % 8 == 0), | |||
"The filter shepe is not right of nchw88"); | |||
group_offset = pack_group_size * group_id * filter_meta.icpg * | |||
group_offset = pack_group_size * group_pack_id * filter_meta.icpg * | |||
filter_meta.ocpg * filter_meta.spatial[0] * | |||
filter_meta.spatial[1] * filter_type.size(); | |||
@@ -380,7 +380,7 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id, | |||
//! 2. {alpha, alpha, ocpg/8, icpg/8, 8, 8} | |||
//! 3. {g, alpha, alpha, oc, ic, 8, 8} | |||
//! 4. {alpha, alpha, oc, ic} | |||
group_offset = pack_group_size * group_id * filter_meta.icpg * | |||
group_offset = pack_group_size * group_pack_id * filter_meta.icpg * | |||
filter_meta.ocpg * | |||
(filter_meta.spatial[0] + output_block_size - 1) * | |||
(filter_meta.spatial[1] + output_block_size - 1) * | |||
@@ -388,58 +388,66 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_id, | |||
break; | |||
} | |||
default: | |||
megdnn_assert("other filter format is not support yet"); | |||
megdnn_assert(0, "other filter format is not support yet"); | |||
} | |||
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(filter_ptr) + | |||
group_offset); | |||
} | |||
//! when format is nchwxx and channel wise mode, multi group will pack | |||
//! together, so pack_group_size is the number of packed group | |||
template <typename T> | |||
const T* ConvBiasImpl::NCBKernParam::bias(size_t batch_id, size_t group_id, | |||
size_t group_pack_size) const { | |||
bias_type.assert_is_compatible_ctype<T>(); | |||
const T* ConvBiasImpl::NCBKernParam::bias(size_t batch_id, size_t group_pack_id, | |||
size_t channel_pack_id, | |||
size_t group_pack_size, | |||
size_t channel_pack_size) const { | |||
size_t batch_offset = 0_z; | |||
size_t group_offset = 0_z; | |||
size_t channel_offset = 0_z; | |||
if (bias_mode == BiasMode::BIAS) { | |||
batch_offset = batch_id * bias_bs * bias_type.size(); | |||
group_offset = group_pack_size * group_id * filter_meta.ocpg * osz[0] * | |||
osz[1] * bias_type.size(); | |||
group_offset = group_pack_size * group_pack_id * filter_meta.ocpg * | |||
osz[0] * osz[1] * bias_type.size(); | |||
channel_offset = channel_pack_size * channel_pack_id * osz[0] * osz[1] * | |||
bias_type.size(); | |||
} else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
group_offset = group_pack_size * group_id * filter_meta.ocpg * | |||
group_offset = group_pack_size * group_pack_id * filter_meta.ocpg * | |||
bias_type.size(); | |||
channel_offset = channel_pack_size * channel_pack_id * bias_type.size(); | |||
} | |||
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(bias_ptr) + | |||
batch_offset + group_offset); | |||
batch_offset + group_offset + channel_offset); | |||
} | |||
//! when format is nchwxx and channel wise mode, multi group will pack | |||
//! together, so pack_group_size is the number of packed group | |||
template <typename T> | |||
T* ConvBiasImpl::NCBKernParam::dst(size_t batch_id, size_t group_id, | |||
size_t group_pack_size) const { | |||
dst_type.assert_is_compatible_ctype<T>(); | |||
T* ConvBiasImpl::NCBKernParam::dst(size_t batch_id, size_t group_pack_id, | |||
size_t channel_pack_id, | |||
size_t group_pack_size, | |||
size_t channel_pack_size) const { | |||
size_t batch_offset = batch_id * out_bs * dst_type.size(); | |||
size_t group_offset = group_pack_size * group_id * filter_meta.ocpg * | |||
size_t group_offset = group_pack_size * group_pack_id * filter_meta.ocpg * | |||
osz[0] * osz[1] * dst_type.size(); | |||
size_t channel_offset = channel_pack_size * channel_pack_id * osz[0] * | |||
osz[1] * dst_type.size(); | |||
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(dst_ptr) + | |||
batch_offset + group_offset); | |||
batch_offset + group_offset + channel_offset); | |||
} | |||
#define INST(T) \ | |||
template const T* ConvBiasImpl::NCBKernParam::src<T>( \ | |||
size_t batch_id, size_t group_id, size_t group_pack_size) const; \ | |||
template const T* ConvBiasImpl::NCBKernParam::bias<T>( \ | |||
size_t batch_id, size_t group_id, size_t group_pack_size) const; \ | |||
template const T* ConvBiasImpl::NCBKernParam::filter<T>( \ | |||
size_t group_id, size_t group_pack_size) const; \ | |||
template T* ConvBiasImpl::NCBKernParam::dst<T>( \ | |||
size_t batch_id, size_t group_id, size_t group_pack_size) const; | |||
#define INST(T) \ | |||
template const T* ConvBiasImpl::NCBKernParam::src<T>( \ | |||
size_t batch_id, size_t group_id, size_t channel_id, \ | |||
size_t group_pack_size, size_t channel_pack_size) const; \ | |||
template const T* ConvBiasImpl::NCBKernParam::bias<T>( \ | |||
size_t batch_id, size_t group_id, size_t channel_id, \ | |||
size_t group_pack_size, size_t channel_pack_size) const; \ | |||
template const T* ConvBiasImpl::NCBKernParam::filter<T>( \ | |||
size_t group_id, size_t group_pack_size) const; \ | |||
template T* ConvBiasImpl::NCBKernParam::dst<T>( \ | |||
size_t batch_id, size_t group_id, size_t channel_id, \ | |||
size_t group_pack_size, size_t channel_pack_size) const; | |||
#define INST_DT(d) INST(DTypeTrait<d>::ctype) | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(INST_DT) | |||
INST(void) | |||
#undef INST | |||
#undef INST_DT | |||
} // namespace fallback | |||
@@ -103,10 +103,32 @@ public: | |||
src_type.assert_is_compatible_ctype<T>(); | |||
return static_cast<const T*>(src_ptr); | |||
} | |||
//! when format is nchwxx, multi channel will pack into one | |||
//! chnannel_pack_id. pack_channel_size is the number of packed channel | |||
//! when format is nchwxx and channel wise, multi group will pack into | |||
//! one group_pack_id. group_pack_size is the number of packed group | |||
//! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8} | |||
template <typename T> | |||
const T* src(size_t batch_id, size_t group_pack_id, | |||
size_t channel_pack_id = 0, size_t group_pack_size = 1, | |||
size_t channel_pack_size = 1) const; | |||
template <typename T> | |||
const T* bias(size_t batch_id, size_t group_pack_id, | |||
size_t channel_pack_id = 0, size_t group_pack_size = 1, | |||
size_t channel_pack_size = 1) const; | |||
template <typename T> | |||
const T* src(size_t batch_id, size_t group_id, | |||
size_t group_pack_size = 1_z) const; | |||
T* dst(size_t batch_id, size_t group_pack_id, | |||
size_t channel_pack_id = 0, size_t group_pack_size = 1, | |||
size_t channel_pack_size = 1) const; | |||
//! when format is nchwxx and channel wise, multi group will pack into | |||
//! one group_pack_id. group_pack_size is the number of packed group | |||
//! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8} | |||
template <typename T> | |||
const T* filter(size_t group_pack_id, | |||
size_t pack_group_size = 1_z) const; | |||
template <typename T> | |||
const T* filter() const { | |||
@@ -115,29 +137,18 @@ public: | |||
} | |||
template <typename T> | |||
const T* filter(size_t group_id, size_t pack_group_size = 1_z) const; | |||
template <typename T> | |||
const T* bias() const { | |||
bias_type.assert_is_compatible_ctype<T>(); | |||
return static_cast<const T*>(bias_ptr); | |||
} | |||
template <typename T> | |||
const T* bias(size_t batch_id, size_t group_id, | |||
size_t group_pack_size = 1_z) const; | |||
template <typename T> | |||
T* dst() const { | |||
dst_type.assert_is_compatible_ctype<T>(); | |||
return static_cast<T*>(dst_ptr); | |||
} | |||
template <typename T> | |||
T* dst(size_t batch_id, size_t group_id, | |||
size_t group_pack_size = 1_z) const; | |||
template <typename T> | |||
T* workspace() const { | |||
return static_cast<T*>(workspace_ptr); | |||
} | |||
@@ -197,9 +197,12 @@ ConvolutionImpl::AlgoFallback::dispatch_kern( | |||
auto kern_fallback = [workspace_per_thread](const NCBKernParam& p, | |||
const NCBKernIndex& ncb_index) { | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(p); | |||
size_t batch_id = ncb_index.ndrange_id[1]; | |||
size_t group_id = ncb_index.ndrange_id[0]; | |||
MEGDNN_MARK_USED_VAR(N); | |||
auto src = p.src<float>(), filter = p.filter<float>(); | |||
auto dst = p.dst<float>(); | |||
auto src = p.src<float>(batch_id, group_id), | |||
filter = p.filter<float>(group_id); | |||
auto dst = p.dst<float>(batch_id, group_id); | |||
size_t thread_id = ncb_index.thread_id; | |||
void* workspace_ptr = reinterpret_cast<void*>( | |||
reinterpret_cast<ptrdiff_t>(p.workspace_ptr) + | |||
@@ -20,18 +20,25 @@ namespace fallback { | |||
template <typename ST, typename DT, typename CT> | |||
void kern_naive_forward(const ConvolutionImpl::NCBKernParam& p, | |||
const ConvolutionImpl::NCBKernIndex& /*index*/) { | |||
const ConvolutionImpl::NCBKernIndex& ncb_index) { | |||
size_t batch_id = ncb_index.ndrange_id[1]; | |||
size_t group_id = ncb_index.ndrange_id[0]; | |||
auto IC = p.filter_meta.icpg, IH = p.isz[0], IW = p.isz[1], | |||
OC = p.filter_meta.ocpg, OH = p.osz[0], OW = p.osz[1]; | |||
ptrdiff_t fstrd = p.filter_meta.icpg * p.filter_meta.ocpg * | |||
p.filter_meta.spatial[0] * p.filter_meta.spatial[1] * | |||
p.filter_type.size(); | |||
ptrdiff_t istrd = p.filter_meta.icpg * p.src_type.size(); | |||
ptrdiff_t ostrd = p.filter_meta.ocpg * p.dst_type.size(); | |||
TensorND src, dst; | |||
src.raw_ptr = const_cast<void*>(p.src_ptr); | |||
dst.raw_ptr = p.dst_ptr; | |||
src.layout.dtype = p.src_type; | |||
dst.layout.dtype = p.dst_type; | |||
if (p.filter_meta.format == param::Convolution::Format::NCHW) { | |||
src.layout.init_contiguous_stride({1, IC, IH, IW}); | |||
dst.layout.init_contiguous_stride({1, OC, OH, OW}); | |||
istrd *= p.isz[0] * p.isz[1]; | |||
ostrd *= p.osz[0] * p.osz[1]; | |||
src.layout.init_contiguous_stride({1, IC, IH, IW}); | |||
dst.layout.init_contiguous_stride({1, OC, OH, OW}); | |||
} else { | |||
// Must be NHWC | |||
megdnn_assert( | |||
@@ -41,9 +48,17 @@ void kern_naive_forward(const ConvolutionImpl::NCBKernParam& p, | |||
src.layout.init_contiguous_stride({1, IH, IW, IC}); | |||
dst.layout.init_contiguous_stride({1, OH, OW, OC}); | |||
} | |||
src.raw_ptr = reinterpret_cast<void*>( | |||
reinterpret_cast<uintptr_t>(p.src_ptr) + | |||
batch_id * p.inp_bs * p.src_type.size() + group_id * istrd); | |||
dst.raw_ptr = reinterpret_cast<void*>( | |||
reinterpret_cast<uintptr_t>(p.dst_ptr) + | |||
batch_id * p.out_bs * p.dst_type.size() + group_id * ostrd); | |||
ST* filter = reinterpret_cast<ST*>( | |||
reinterpret_cast<uintptr_t>(p.filter_ptr) + group_id * fstrd); | |||
std::copy(p.inp_s, p.inp_s + 4, src.layout.stride); | |||
std::copy(p.out_s, p.out_s + 4, dst.layout.stride); | |||
naive::convolution::forward<ST, ST, DT, CT>(src, p.filter<ST>(), dst, | |||
naive::convolution::forward<ST, ST, DT, CT>(src, filter, dst, | |||
p.filter_meta); | |||
} | |||
@@ -189,41 +189,15 @@ ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param( | |||
void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, | |||
Algorithm* algo) { | |||
auto kerns = ncb_algo_dispatch_kern(algo, param); | |||
size_t src_batch_stride = param.inp_bs * param.src_type.size(); | |||
size_t dst_batch_stride = param.out_bs * param.dst_type.size(); | |||
auto group = param.filter_meta.group; | |||
auto fallback_handle = handle(); | |||
for (auto kernel : kerns) { | |||
megdnn_assert(param.filter_meta.format == Param::Format::NCHW || | |||
param.filter_meta.format == Param::Format::NHWC || | |||
param.filter_meta.format == Param::Format::NHWC || | |||
param.filter_meta.format == Param::Format::NCHW88, | |||
"invalid conv format"); | |||
ptrdiff_t istrd = 0, fstrd = 0, ostrd = 0; | |||
fstrd = param.filter_meta.icpg * param.filter_meta.ocpg * | |||
param.filter_meta.spatial[0] * param.filter_meta.spatial[1] * | |||
param.filter_type.size(); | |||
istrd = param.filter_meta.icpg * param.src_type.size(); | |||
ostrd = param.filter_meta.ocpg * param.dst_type.size(); | |||
if (param.filter_meta.format == Param::Format::NCHW) { | |||
istrd *= param.isz[0] * param.isz[1]; | |||
ostrd *= param.osz[0] * param.osz[1]; | |||
} else { | |||
// must be NHWC. No action performed. | |||
} | |||
auto run = [=](size_t index, size_t thread_id) { | |||
auto copy_param = param; | |||
auto run = [param, kernel](size_t index, size_t thread_id) { | |||
CpuNDRange ndrange_id(kernel.global_size, index); | |||
size_t group_id = ndrange_id[0]; | |||
size_t batch_id = ndrange_id[1]; | |||
megdnn_assert(group_id < group, | |||
"The group id should smaller than gruop"); | |||
//! The kernel ptr point to batch index | |||
incr_ptr(copy_param.src_ptr, | |||
group_id * istrd + batch_id * src_batch_stride); | |||
incr_ptr(copy_param.filter_ptr, group_id * fstrd); | |||
incr_ptr(copy_param.dst_ptr, | |||
group_id * ostrd + batch_id * dst_batch_stride); | |||
kernel.kern(copy_param, {thread_id, ndrange_id}); | |||
kernel.kern(param, {thread_id, ndrange_id}); | |||
}; | |||
static_cast<naive::HandleImpl*>(fallback_handle) | |||
->dispatch_kern(run, kernel.global_size.total_size()); | |||
@@ -100,6 +100,43 @@ public: | |||
T* workspace() const { | |||
return static_cast<T*>(workspace_ptr); | |||
} | |||
//! when format is nchwxx and channel wise, multi group will pack into | |||
//! one group_pack_id. group_pack_size is the number of packed group | |||
//! together, like weight shape is {g/8, 1, 1, Fh, Fw, 8} | |||
template <typename T> | |||
T* dst(size_t batch_id, size_t group_pack_id, | |||
size_t group_pack_size = 1_z) const{ | |||
size_t batch_offset = batch_id * out_bs * dst_type.size(); | |||
size_t group_offset = group_pack_size * group_pack_id * | |||
filter_meta.ocpg * osz[0] * osz[1] * | |||
dst_type.size(); | |||
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(dst_ptr) + | |||
batch_offset + group_offset); | |||
} | |||
template <typename T> | |||
const T* src(size_t batch_id, size_t group_pack_id, | |||
size_t group_pack_size = 1_z) const { | |||
size_t batch_offset = batch_id * inp_bs * src_type.size(); | |||
size_t group_offset = group_pack_size * group_pack_id * | |||
filter_meta.icpg * isz[0] * isz[1] * | |||
src_type.size(); | |||
return reinterpret_cast<T*>(reinterpret_cast<ptrdiff_t>(src_ptr) + | |||
batch_offset + group_offset); | |||
} | |||
template <typename T> | |||
const T* filter(size_t group_pack_id, | |||
size_t pack_group_size = 1_z) const { | |||
size_t group_offset = pack_group_size * group_pack_id * | |||
filter_meta.icpg * filter_meta.ocpg * | |||
filter_meta.spatial[0] * | |||
filter_meta.spatial[1] * filter_type.size(); | |||
return reinterpret_cast<T*>( | |||
reinterpret_cast<ptrdiff_t>(filter_ptr) + group_offset); | |||
} | |||
}; | |||
static void* const sm_fallback_conv_algo_type; | |||
@@ -58,43 +58,45 @@ void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, | |||
} | |||
} // namespace | |||
#define GET_KERN \ | |||
auto fm = param.filter_meta; \ | |||
size_t N = param.n; \ | |||
size_t IC = param.filter_meta.icpg; \ | |||
size_t OC = param.filter_meta.ocpg; \ | |||
size_t group = fm.group; \ | |||
WorkspaceBundle wbundle = get_bundle(param); \ | |||
SmallVector<NCBKern> ret_kerns; \ | |||
if (m_large_group) { \ | |||
auto exec_one_group = [wbundle](const NCBKernParam& kern_param, \ | |||
const NCBKernIndex& ncb_index) { \ | |||
auto fm = kern_param.filter_meta; \ | |||
size_t IC = fm.icpg; \ | |||
size_t OC = fm.ocpg; \ | |||
WorkspaceBundle bundle = wbundle; \ | |||
for (size_t ic = 0; ic < IC; ic++) { \ | |||
copy_padding_kern( \ | |||
bundle, kern_param, \ | |||
{ncb_index.thread_id, {ncb_index.thread_id, 0, ic}}); \ | |||
} \ | |||
for (size_t oc = 0; oc < OC; oc++) { \ | |||
do_conv_kern( \ | |||
bundle, kern_param, \ | |||
{ncb_index.thread_id, {ncb_index.thread_id, 0, oc}}); \ | |||
} \ | |||
}; \ | |||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); \ | |||
} else { \ | |||
WorkspaceBundle bundle = wbundle; \ | |||
auto copy_padding = \ | |||
std::bind(copy_padding_kern, bundle, std::placeholders::_1, \ | |||
std::placeholders::_2); \ | |||
ret_kerns.push_back({copy_padding, {group, N, IC}}); \ | |||
auto do_conv = std::bind(do_conv_kern, bundle, std::placeholders::_1, \ | |||
std::placeholders::_2); \ | |||
ret_kerns.push_back({do_conv, {group, N, OC}}); \ | |||
} \ | |||
#define GET_KERN \ | |||
auto fm = param.filter_meta; \ | |||
size_t N = param.n; \ | |||
size_t IC = param.filter_meta.icpg; \ | |||
size_t OC = param.filter_meta.ocpg; \ | |||
size_t group = fm.group; \ | |||
WorkspaceBundle wbundle = get_bundle(param); \ | |||
SmallVector<NCBKern> ret_kerns; \ | |||
if (m_large_group) { \ | |||
auto exec_one_group = [wbundle](const NCBKernParam& kern_param, \ | |||
const NCBKernIndex& ncb_index) { \ | |||
auto fm = kern_param.filter_meta; \ | |||
size_t IC = fm.icpg; \ | |||
size_t OC = fm.ocpg; \ | |||
WorkspaceBundle bundle = wbundle; \ | |||
for (size_t ic = 0; ic < IC; ic++) { \ | |||
copy_padding_kern(bundle, kern_param, ncb_index, \ | |||
{ncb_index.thread_id, 0, ic}); \ | |||
} \ | |||
for (size_t oc = 0; oc < OC; oc++) { \ | |||
do_conv_kern(bundle, kern_param, ncb_index, \ | |||
{ncb_index.thread_id, 0, oc}); \ | |||
} \ | |||
}; \ | |||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); \ | |||
} else { \ | |||
auto copy_padding = [wbundle](const NCBKernParam& kern_param, \ | |||
const NCBKernIndex& ncb_index) { \ | |||
copy_padding_kern(wbundle, kern_param, ncb_index, \ | |||
ncb_index.ndrange_id); \ | |||
}; \ | |||
ret_kerns.push_back({copy_padding, {group, N, IC}}); \ | |||
auto do_conv = [wbundle](const NCBKernParam& kern_param, \ | |||
const NCBKernIndex& ncb_index) { \ | |||
do_conv_kern(wbundle, kern_param, ncb_index, \ | |||
ncb_index.ndrange_id); \ | |||
}; \ | |||
ret_kerns.push_back({do_conv, {group, N, OC}}); \ | |||
} \ | |||
return ret_kerns; | |||
/* ===================== direct algo ===================== */ | |||
@@ -145,7 +147,8 @@ size_t ConvBiasImpl::AlgoDirect::get_workspace( | |||
//! Process one input channel copy padding | |||
void ConvBiasImpl::AlgoDirect::copy_padding_kern( | |||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids) { | |||
size_t IH = kern_param.isz[0]; | |||
size_t IW = kern_param.isz[1]; | |||
size_t IC = kern_param.filter_meta.icpg; | |||
@@ -160,14 +163,18 @@ void ConvBiasImpl::AlgoDirect::copy_padding_kern( | |||
get_rectified_img_size(IH, IW, FH, FW, OH, OW, PH, PW, IH2, IW2, OH2, OW2); | |||
bool rectify_src = (IH != IH2 || IW != IW2); | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
const float* sptr = static_cast<const float*>(kern_param.src_ptr) + | |||
ncb_index.ndrange_id[2] * IH * IW; | |||
size_t batch_id = ncb_index.ndrange_id[1]; | |||
size_t group_id = ncb_index.ndrange_id[0]; | |||
size_t channel_id = workspace_ids[2]; | |||
const float* sptr = static_cast<const float*>( | |||
kern_param.src<float>(batch_id, group_id)) + | |||
channel_id * IH * IW; | |||
bundle.set(kern_param.workspace_ptr); | |||
//! Used for get the workspace offset | |||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||
workspace_batch_id = ncb_index.ndrange_id[1], | |||
workspace_channel_id = ncb_index.ndrange_id[2]; | |||
size_t workspace_group_id = workspace_ids[0], | |||
workspace_batch_id = workspace_ids[1], | |||
workspace_channel_id = workspace_ids[2]; | |||
//! If large group, each thread has its own worspace, set group_id with | |||
//! thread_id | |||
if (rectify_src) { | |||
@@ -234,7 +241,8 @@ void ConvBiasImpl::AlgoDirect::copy_padding_kern( | |||
//! compute one output channel | |||
void ConvBiasImpl::AlgoDirect::do_conv_kern(WorkspaceBundle bundle, | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
const NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids) { | |||
size_t OH = kern_param.osz[0]; | |||
size_t OW = kern_param.osz[1]; | |||
size_t IH = kern_param.isz[0]; | |||
@@ -265,14 +273,16 @@ void ConvBiasImpl::AlgoDirect::do_conv_kern(WorkspaceBundle bundle, | |||
megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias_offset = 1_z; | |||
} | |||
size_t group_id = ncb_index.ndrange_id[0]; | |||
size_t batch_id = ncb_index.ndrange_id[1]; | |||
//! Used for get the workspace offset | |||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||
workspace_batch_id = ncb_index.ndrange_id[1], | |||
oc = ncb_index.ndrange_id[2]; | |||
const float* sptr = kern_param.src<float>(); | |||
const float* filter = kern_param.filter<float>() + oc * FH * FW * IC; | |||
const float* bias_ptr = kern_param.bias<float>() + oc * bias_offset; | |||
float* dst = kern_param.dst<float>() + oc * OH * OW; | |||
size_t workspace_group_id = workspace_ids[0], | |||
workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; | |||
const float* sptr = kern_param.src<float>(batch_id, group_id); | |||
const float* filter = kern_param.filter<float>(group_id) + oc * FH * FW * IC; | |||
const float* bias_ptr = | |||
kern_param.bias<float>(batch_id, group_id) + oc * bias_offset; | |||
float* dst = kern_param.dst<float>(batch_id, group_id) + oc * OH * OW; | |||
if (rectify_src) { | |||
sptr = static_cast<float*>(bundle.get(0)) + | |||
workspace_group_id * padding_group_size + | |||
@@ -358,7 +368,8 @@ size_t ConvBiasImpl::AlgoDirectStride2::get_workspace( | |||
//! Process one input channel copy padding | |||
void ConvBiasImpl::AlgoDirectStride2::copy_padding_kern( | |||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids) { | |||
size_t IH = kern_param.isz[0]; | |||
size_t IW = kern_param.isz[1]; | |||
size_t IC = kern_param.filter_meta.icpg; | |||
@@ -373,13 +384,17 @@ void ConvBiasImpl::AlgoDirectStride2::copy_padding_kern( | |||
get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); | |||
bool rectify_src = need_src_copy(kern_param); | |||
size_t padding_group_size = IH2 * IW2 * IC; | |||
const float* sptr = static_cast<const float*>(kern_param.src_ptr) + | |||
ncb_index.ndrange_id[2] * IH * IW; | |||
size_t group_id = ncb_index.ndrange_id[0]; | |||
size_t batch_id = ncb_index.ndrange_id[1]; | |||
size_t channel_id = workspace_ids[2]; | |||
const float* sptr = static_cast<const float*>( | |||
kern_param.src<float>(batch_id, group_id)) + | |||
channel_id * IH * IW; | |||
bundle.set(kern_param.workspace_ptr); | |||
//! Used for get the workspace offset | |||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||
workspace_batch_id = ncb_index.ndrange_id[1], | |||
workspace_channel_id = ncb_index.ndrange_id[2]; | |||
size_t workspace_group_id = workspace_ids[0], | |||
workspace_batch_id = workspace_ids[1], | |||
workspace_channel_id = workspace_ids[2]; | |||
if (rectify_src) { | |||
//! copy to sptr_base to eliminate padding effect | |||
float* sptr_base = static_cast<float*>(bundle.get(0)) + | |||
@@ -397,7 +412,7 @@ void ConvBiasImpl::AlgoDirectStride2::copy_padding_kern( | |||
//! compute one output channel | |||
void ConvBiasImpl::AlgoDirectStride2::do_conv_kern( | |||
WorkspaceBundle bundle, const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { | |||
size_t OH = kern_param.osz[0]; | |||
size_t OW = kern_param.osz[1]; | |||
size_t IH = kern_param.isz[0]; | |||
@@ -439,14 +454,17 @@ void ConvBiasImpl::AlgoDirectStride2::do_conv_kern( | |||
megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias_offset = 1_z; | |||
} | |||
size_t group_id = ncb_index.ndrange_id[0]; | |||
size_t batch_id = ncb_index.ndrange_id[1]; | |||
//! Used for get the workspace offset | |||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||
workspace_batch_id = ncb_index.ndrange_id[1], | |||
oc = ncb_index.ndrange_id[2]; | |||
const float* sptr = kern_param.src<float>(); | |||
const float* filter = kern_param.filter<float>() + oc * FH * FW * IC; | |||
const float* bias_ptr = kern_param.bias<float>() + oc * bias_offset; | |||
float* dst = kern_param.dst<float>() + oc * OH * OW; | |||
size_t workspace_group_id = workspace_ids[0], | |||
workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; | |||
const float* sptr = kern_param.src<float>(batch_id, group_id); | |||
const float* filter = | |||
kern_param.filter<float>(group_id) + oc * FH * FW * IC; | |||
const float* bias_ptr = | |||
kern_param.bias<float>(batch_id, group_id) + oc * bias_offset; | |||
float* dst = kern_param.dst<float>(batch_id, group_id) + oc * OH * OW; | |||
if (rectify_src) { | |||
sptr = static_cast<float*>(bundle.get(0)) + | |||
workspace_group_id * padding_group_size + | |||
@@ -547,23 +565,22 @@ MatrixMul* ConvBiasImpl::AlgoMatrixMul::get_matmul_opr() { | |||
} | |||
void ConvBiasImpl::AlgoMatrixMul::kimpl(const NCBKernParam& param, | |||
const NCBKernIndex&) { | |||
const NCBKernIndex& ncb_index) { | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
auto IH2 = IH + 2 * PH; | |||
auto IW2 = IW + 2 * PW; | |||
size_t group_id = ncb_index.ndrange_id[0]; | |||
bool is_xcorr = !param.filter_meta.should_flip; | |||
auto bundle = get_bundle(param); | |||
bundle.set(param.workspace_ptr); | |||
// workspace = tmp..src2 | |||
for (size_t n = 0; n < N; ++n) { | |||
float* src = const_cast<float*>(param.src<float>()) + n * param.inp_bs; | |||
float* dst = param.dst<float>() + n * param.out_bs; | |||
float* bias_ptr = | |||
static_cast<float*>(const_cast<void*>(param.bias_ptr)); | |||
if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||
bias_ptr += n * param.out_bs; | |||
} | |||
float* src = const_cast<float*>(param.src<float>(n, group_id)); | |||
float* dst = param.dst<float>(n, group_id); | |||
float* bias_ptr = static_cast<float*>( | |||
const_cast<void*>(param.bias<void>(n, group_id))); | |||
float *B, *src2; | |||
if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { | |||
// special case: 1x1 | |||
@@ -613,7 +630,7 @@ void ConvBiasImpl::AlgoMatrixMul::kimpl(const NCBKernParam& param, | |||
{ | |||
TensorND A_, B_, C_; | |||
A_.layout = TensorLayout({OC, IC * FH * FW}, dtype::Float32()); | |||
A_.raw_ptr = const_cast<float*>(param.filter<float>()); | |||
A_.raw_ptr = const_cast<float*>(param.filter<float>(group_id)); | |||
B_.layout = TensorLayout({IC * FH * FW, OH * OW}, dtype::Float32()); | |||
B_.raw_ptr = B; | |||
C_.layout = TensorLayout({OC, OH * OW}, dtype::Float32()); | |||
@@ -22,10 +22,12 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase { | |||
static void copy_padding_kern(WorkspaceBundle bundle, | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index); | |||
const NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids); | |||
static void do_conv_kern(WorkspaceBundle bundle, | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index); | |||
const NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids); | |||
bool m_large_group; | |||
public: | |||
@@ -57,10 +59,12 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase { | |||
static void copy_padding_kern(WorkspaceBundle bundle, | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index); | |||
const NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids); | |||
static void do_conv_kern(WorkspaceBundle bundle, | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index); | |||
const NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids); | |||
bool m_large_group; | |||
public: | |||
@@ -146,9 +146,11 @@ WorkspaceBundle ConvBiasImpl::AlgoMkldnnQint8::get_bundle( | |||
} while (0) | |||
void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32( | |||
const NCBKernParam& param, const NCBKernIndex&) { | |||
const NCBKernParam& param, const NCBKernIndex& ncb_index) { | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
MEGDNN_MARK_USED_VAR(N); | |||
size_t group_id = ncb_index.ndrange_id[0]; | |||
size_t batch_id = ncb_index.ndrange_id[1]; | |||
auto x86_handle = static_cast<HandleImpl*>(inplace_cpu_handle().get()); | |||
megdnn_assert(x86_handle != nullptr, "x86 handle can not be null"); | |||
auto eng_mkldnn = x86_handle->mkldnn_engine(); | |||
@@ -167,10 +169,11 @@ void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32( | |||
auto megdnn_dst_md = memory::desc({dst_shape}, memory::data_type::s32, | |||
memory::format_tag::nchw); | |||
auto megdnn_weight_memory = memory(megdnn_weight_md, eng_mkldnn, | |||
const_cast<void*>(param.filter_ptr)); | |||
int8_t* src = const_cast<int8_t*>(param.src<int8_t>()); | |||
int32_t* dst = param.dst<int32_t>(); | |||
auto megdnn_weight_memory = | |||
memory(megdnn_weight_md, eng_mkldnn, | |||
const_cast<void*>(param.filter<void>(group_id))); | |||
int8_t* src = const_cast<int8_t*>(param.src<int8_t>(batch_id, group_id)); | |||
int32_t* dst = param.dst<int32_t>(batch_id, group_id); | |||
auto megdnn_src_memory = | |||
memory(megdnn_src_md, eng_mkldnn, static_cast<void*>(src)); | |||
@@ -353,18 +356,18 @@ MatrixMul* ConvBiasImpl::AlgoMkldnnMatmulQint8::get_matmul_opr() { | |||
} | |||
void ConvBiasImpl::AlgoMkldnnMatmulQint8::kern_mkldnn_matmul_s8x8x32( | |||
const NCBKernParam& param, const NCBKernIndex&) { | |||
const NCBKernParam& param, const NCBKernIndex& ncb_index) { | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
auto IH2 = IH + 2 * PH; | |||
auto IW2 = IW + 2 * PW; | |||
size_t group_id = ncb_index.ndrange_id[0]; | |||
bool is_xcorr = !param.filter_meta.should_flip; | |||
auto bundle = get_bundle(param); | |||
bundle.set(param.workspace_ptr); | |||
for (size_t n = 0; n < N; ++n) { | |||
int8_t* src = | |||
const_cast<int8_t*>(param.src<int8_t>()) + n * param.inp_bs; | |||
int32_t* dst = param.dst<int32_t>() + n * param.out_bs; | |||
int8_t* src = const_cast<int8_t*>(param.src<int8_t>(n, group_id)); | |||
int32_t* dst = param.dst<int32_t>(n, group_id); | |||
int8_t *B, *src2; | |||
if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) { | |||
// special case: 1x1 | |||
@@ -414,7 +417,7 @@ void ConvBiasImpl::AlgoMkldnnMatmulQint8::kern_mkldnn_matmul_s8x8x32( | |||
{ | |||
TensorND A_, B_, C_; | |||
A_.layout = TensorLayout({OC, IC * FH * FW}, dtype::Int8()); | |||
A_.raw_ptr = const_cast<int8_t*>(param.filter<int8_t>()); | |||
A_.raw_ptr = const_cast<int8_t*>(param.filter<int8_t>(group_id)); | |||
B_.layout = TensorLayout({IC * FH * FW, OH * OW}, dtype::Int8()); | |||
B_.raw_ptr = B; | |||
C_.layout = TensorLayout({OC, OH * OW}, dtype::Int32()); | |||
@@ -47,8 +47,8 @@ void pack_src_conv_avx2_stride1(WorkspaceBundle bundle, | |||
batch_id = ncb_index.ndrange_id[1], | |||
channel_id = ncb_index.ndrange_id[2]; | |||
const int8_t* src_ptr = | |||
kern_param.src<int8_t>() + ic_step * channel_id * c_stride; | |||
const int8_t* src_ptr = kern_param.src<int8_t>(batch_id, group_id) + | |||
ic_step * channel_id * c_stride; | |||
bundle.set(kern_param.workspace_ptr); | |||
int8_t* packed_src = static_cast<int8_t*>(bundle.get(0)) + | |||
batch_id * group * packed_group_size + | |||
@@ -129,7 +129,7 @@ static inline void pack_filter_conv_avx2_stride1( | |||
size_t group_id = ncb_index.ndrange_id[0], | |||
oc_index_id = ncb_index.ndrange_id[1]; | |||
const int8_t* pack_filter_ptr = kern_param.filter<int8_t>(); | |||
const int8_t* pack_filter_ptr = kern_param.filter<int8_t>(group_id); | |||
bundle.set(kern_param.workspace_ptr); | |||
int16_t* out_ptr = static_cast<int16_t*>(bundle.get(1)) + | |||
group_id * round_up(oc, oc_step) * oc_out_stride; | |||
@@ -632,19 +632,18 @@ void do_conv_kern(WorkspaceBundle bundle, | |||
const uint32_t packed_group_size = | |||
div_ceil(ic, ic_step) * pack_ih * pack_iw; | |||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||
workspace_batch_id = ncb_index.ndrange_id[1], | |||
workspace_channel_id = ncb_index.ndrange_id[2]; | |||
size_t group_id = ncb_index.ndrange_id[0], | |||
batch_id = ncb_index.ndrange_id[1], | |||
channel_id = ncb_index.ndrange_id[2]; | |||
bundle.set(kern_param.workspace_ptr); | |||
int8_t* src_ptr = static_cast<int8_t*>(bundle.get(0)) + | |||
workspace_group_id * packed_group_size + | |||
workspace_batch_id * group * packed_group_size; | |||
int16_t* filter_ptr = | |||
static_cast<int16_t*>(bundle.get(1)) + | |||
workspace_group_id * round_up(oc, oc_step) * filter_round_size + | |||
oc_step * workspace_channel_id * filter_round_size; | |||
group_id * packed_group_size + | |||
batch_id * group * packed_group_size; | |||
int16_t* filter_ptr = static_cast<int16_t*>(bundle.get(1)) + | |||
group_id * round_up(oc, oc_step) * filter_round_size + | |||
oc_step * channel_id * filter_round_size; | |||
bool need_post_process = | |||
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||
@@ -652,12 +651,11 @@ void do_conv_kern(WorkspaceBundle bundle, | |||
int32_t* dst_tptr = nullptr; | |||
if (need_post_process) { | |||
dst_tptr = static_cast<int32_t*>(bundle.get(2)) + | |||
workspace_batch_id * group * oc * oc_stride + | |||
workspace_group_id * oc * oc_stride + | |||
oc_step * workspace_channel_id * oh * ow; | |||
batch_id * group * oc * oc_stride + | |||
group_id * oc * oc_stride + oc_step * channel_id * oh * ow; | |||
} else { | |||
dst_tptr = kern_param.dst<int32_t>() + | |||
oc_step * workspace_channel_id * oh * ow; | |||
dst_tptr = kern_param.dst<int32_t>(batch_id, group_id) + | |||
oc_step * channel_id * oh * ow; | |||
} | |||
const uint32_t oc_end = oc / oc_step * oc_step; | |||
@@ -666,7 +664,7 @@ void do_conv_kern(WorkspaceBundle bundle, | |||
const uint32_t oh_remain = oh - oh_end; | |||
const uint32_t ow_end = ow / ow_step * ow_step; | |||
const uint32_t ow_remain = ow - ow_end; | |||
const uint32_t oc_index = oc_step * workspace_channel_id; | |||
const uint32_t oc_index = oc_step * channel_id; | |||
AlgoAVX2DirectConvStride1S8S8S32_forward<oc_step, ic_step, oh_step, | |||
ow_step>( | |||
@@ -684,29 +682,29 @@ void do_post_process(WorkspaceBundle bundle, | |||
const uint32_t oh = kern_param.osz[0]; | |||
const uint32_t ow = kern_param.osz[1]; | |||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||
workspace_batch_id = ncb_index.ndrange_id[1]; | |||
size_t group_id = ncb_index.ndrange_id[0], | |||
batch_id = ncb_index.ndrange_id[1]; | |||
bundle.set(kern_param.workspace_ptr); | |||
bool need_post_process = | |||
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||
void* dst_tptr = nullptr; | |||
if (need_post_process) { | |||
dst_tptr = static_cast<int32_t*>(bundle.get(2)) + | |||
workspace_batch_id * group * oc * oh * ow + | |||
workspace_group_id * oc * oh * ow; | |||
batch_id * group * oc * oh * ow + group_id * oc * oh * ow; | |||
} else { | |||
dst_tptr = kern_param.dst<dt_int32>(); | |||
dst_tptr = kern_param.dst<dt_int32>(batch_id, group_id); | |||
} | |||
void* dst_ptr = kern_param.dst<void>(batch_id, group_id); | |||
#define cb(_bias_ctype, _dst_ctype, _postprocess_mode) \ | |||
{ \ | |||
const dt_int32* bias_ptr = kern_param.bias<dt_int32>(); \ | |||
const dt_int32* bias_ptr = \ | |||
kern_param.bias<dt_int32>(batch_id, group_id); \ | |||
PostProcess<DTypeTrait<_bias_ctype>::ctype, \ | |||
DTypeTrait<_dst_ctype>::ctype, \ | |||
_postprocess_mode>::run(dst_tptr, \ | |||
const_cast<dt_int32*>(bias_ptr), \ | |||
kern_param.dst_ptr, \ | |||
kern_param.bias_mode, \ | |||
dst_ptr, kern_param.bias_mode, \ | |||
kern_param.nonlineMode, \ | |||
kern_param.bias_type, \ | |||
kern_param.dst_type, 1, oc, oh, \ | |||
@@ -45,8 +45,8 @@ void pack_src_conv_avx2_stride2(WorkspaceBundle bundle, | |||
batch_id = ncb_index.ndrange_id[1], | |||
channel_id = ncb_index.ndrange_id[2]; | |||
const int8_t* src_ptr = | |||
kern_param.src<int8_t>() + ic_step * channel_id * c_stride; | |||
const int8_t* src_ptr = kern_param.src<int8_t>(batch_id, group_id) + | |||
ic_step * channel_id * c_stride; | |||
bundle.set(kern_param.workspace_ptr); | |||
int8_t* packed_src = static_cast<int8_t*>(bundle.get(0)) + | |||
batch_id * group * packed_group_size + | |||
@@ -187,7 +187,7 @@ static inline void pack_filter_conv_avx2_stride2( | |||
size_t group_id = ncb_index.ndrange_id[0], | |||
oc_index_id = ncb_index.ndrange_id[1]; | |||
const int8_t* pack_filter_ptr = kern_param.filter<int8_t>(); | |||
const int8_t* pack_filter_ptr = kern_param.filter<int8_t>(group_id); | |||
bundle.set(kern_param.workspace_ptr); | |||
int16_t* out_ptr = static_cast<int16_t*>(bundle.get(1)) + | |||
group_id * round_up(oc, oc_step) * oc_out_stride; | |||
@@ -705,18 +705,17 @@ void kernel_imp(WorkspaceBundle bundle, | |||
const uint32_t packed_group_size = | |||
div_ceil(ic, ic_step) * pack_ih * pack_iw; | |||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||
workspace_batch_id = ncb_index.ndrange_id[1], | |||
workspace_channel_id = ncb_index.ndrange_id[2]; | |||
size_t group_id = ncb_index.ndrange_id[0], | |||
batch_id = ncb_index.ndrange_id[1], | |||
channel_id = ncb_index.ndrange_id[2]; | |||
bundle.set(kern_param.workspace_ptr); | |||
int8_t* src_ptr = static_cast<int8_t*>(bundle.get(0)) + | |||
workspace_group_id * packed_group_size + | |||
workspace_batch_id * group * packed_group_size; | |||
int16_t* filter_ptr = | |||
static_cast<int16_t*>(bundle.get(1)) + | |||
workspace_group_id * round_up(oc, oc_step) * filter_round_size + | |||
oc_step * workspace_channel_id * filter_round_size; | |||
group_id * packed_group_size + | |||
batch_id * group * packed_group_size; | |||
int16_t* filter_ptr = static_cast<int16_t*>(bundle.get(1)) + | |||
group_id * round_up(oc, oc_step) * filter_round_size + | |||
oc_step * channel_id * filter_round_size; | |||
bool need_post_process = | |||
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||
@@ -724,12 +723,11 @@ void kernel_imp(WorkspaceBundle bundle, | |||
int32_t* dst_tptr = nullptr; | |||
if (need_post_process) { | |||
dst_tptr = static_cast<int32_t*>(bundle.get(2)) + | |||
workspace_batch_id * group * oc * oc_stride + | |||
workspace_group_id * oc * oc_stride + | |||
oc_step * workspace_channel_id * oh * ow; | |||
batch_id * group * oc * oc_stride + | |||
group_id * oc * oc_stride + oc_step * channel_id * oh * ow; | |||
} else { | |||
dst_tptr = kern_param.dst<int32_t>() + | |||
oc_step * workspace_channel_id * oh * ow; | |||
dst_tptr = kern_param.dst<int32_t>(batch_id, group_id) + | |||
oc_step * channel_id * oh * ow; | |||
} | |||
const uint32_t oc_end = oc / oc_step * oc_step; | |||
const uint32_t oc_remain = oc - oc_end; | |||
@@ -737,7 +735,7 @@ void kernel_imp(WorkspaceBundle bundle, | |||
const uint32_t oh_remain = oh - oh_end; | |||
const uint32_t ow_end = ow / ow_step * ow_step; | |||
const uint32_t ow_remain = ow - ow_end; | |||
const uint32_t oc_index = oc_step * workspace_channel_id; | |||
const uint32_t oc_index = oc_step * channel_id; | |||
kernel_handle_oh_remain<oc_step, ic_step, oh_step, ow_step>( | |||
oh_remain, oc_remain, ow_remain, filter_ptr, src_ptr, dst_tptr, | |||
@@ -754,8 +752,8 @@ void do_post_process(WorkspaceBundle bundle, | |||
const uint32_t oh = kern_param.osz[0]; | |||
const uint32_t ow = kern_param.osz[1]; | |||
size_t workspace_group_id = ncb_index.ndrange_id[0], | |||
workspace_batch_id = ncb_index.ndrange_id[1]; | |||
size_t group_id = ncb_index.ndrange_id[0], | |||
batch_id = ncb_index.ndrange_id[1]; | |||
bundle.set(kern_param.workspace_ptr); | |||
bool need_post_process = | |||
@@ -763,21 +761,22 @@ void do_post_process(WorkspaceBundle bundle, | |||
void* dst_tptr = nullptr; | |||
if (need_post_process) { | |||
dst_tptr = static_cast<int32_t*>(bundle.get(2)) + | |||
workspace_batch_id * group * oc * oh * ow + | |||
workspace_group_id * oc * oh * ow; | |||
batch_id * group * oc * oh * ow + | |||
group_id * oc * oh * ow; | |||
} else { | |||
dst_tptr = kern_param.dst<dt_int32>(); | |||
dst_tptr = kern_param.dst<dt_int32>(batch_id, group_id); | |||
} | |||
void* dst_ptr = kern_param.dst<void>(batch_id, group_id); | |||
#define cb(_bias_ctype, _dst_ctype, _postprocess_mode) \ | |||
{ \ | |||
const dt_int32* bias_ptr = kern_param.bias<dt_int32>(); \ | |||
const dt_int32* bias_ptr = \ | |||
kern_param.bias<dt_int32>(batch_id, group_id); \ | |||
PostProcess<DTypeTrait<_bias_ctype>::ctype, \ | |||
DTypeTrait<_dst_ctype>::ctype, \ | |||
_postprocess_mode>::run(dst_tptr, \ | |||
const_cast<dt_int32*>(bias_ptr), \ | |||
kern_param.dst_ptr, \ | |||
kern_param.bias_mode, \ | |||
dst_ptr, kern_param.bias_mode, \ | |||
kern_param.nonlineMode, \ | |||
kern_param.bias_type, \ | |||
kern_param.dst_type, 1, oc, oh, \ | |||