|
|
@@ -267,18 +267,13 @@ WorkspaceBundle ConvBiasImpl::AlgoMkldnnQint8::get_bundle( |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#define REORDER_MEMORY(megdnn_memory, reorder_memory) \ |
|
|
|
do { \ |
|
|
|
if (megdnn_memory.get_desc() != conv_prim_desc.src_desc()) { \ |
|
|
|
reorder_memory = memory(conv_prim_desc.src_desc(), eng_mkldnn); \ |
|
|
|
auto reorder_pd = reorder::primitive_desc( \ |
|
|
|
eng_mkldnn, megdnn_memory.get_desc(), eng_mkldnn, \ |
|
|
|
reorder_memory.get_desc()); \ |
|
|
|
auto reorder_exe = reorder(reorder_pd); \ |
|
|
|
reorder_exe.execute(stream_mkldnn, megdnn_memory, reorder_memory); \ |
|
|
|
} else { \ |
|
|
|
reorder_memory = megdnn_memory; \ |
|
|
|
} \ |
|
|
|
#define REORDER_MEMORY(megdnn_memory, reorder_memory) \ |
|
|
|
do { \ |
|
|
|
auto reorder_pd = reorder::primitive_desc( \ |
|
|
|
eng_mkldnn, megdnn_memory.get_desc(), eng_mkldnn, \ |
|
|
|
reorder_memory.get_desc()); \ |
|
|
|
auto reorder_exe = reorder(reorder_pd); \ |
|
|
|
reorder_exe.execute(stream_mkldnn, megdnn_memory, reorder_memory); \ |
|
|
|
} while (0) |
|
|
|
|
|
|
|
void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32( |
|
|
@@ -340,7 +335,10 @@ void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32( |
|
|
|
|
|
|
|
auto conv = convolution_forward(conv_prim_desc); |
|
|
|
|
|
|
|
memory conv_src_memory, conv_weight_memory, conv_dst_memory; |
|
|
|
memory conv_src_memory = memory(conv_prim_desc.src_desc(), eng_mkldnn); |
|
|
|
memory conv_weight_memory = |
|
|
|
memory(conv_prim_desc.weights_desc(), eng_mkldnn); |
|
|
|
memory conv_dst_memory; |
|
|
|
|
|
|
|
REORDER_MEMORY(megdnn_src_memory, conv_src_memory); |
|
|
|
REORDER_MEMORY(megdnn_weight_memory, conv_weight_memory); |
|
|
@@ -354,7 +352,7 @@ void ConvBiasImpl::AlgoMkldnnQint8::kern_mkldnn_s8x8x32( |
|
|
|
conv.execute(stream_mkldnn, {{DNNL_ARG_SRC, conv_src_memory}, |
|
|
|
{DNNL_ARG_WEIGHTS, conv_weight_memory}, |
|
|
|
{DNNL_ARG_DST, conv_dst_memory}}); |
|
|
|
REORDER_MEMORY(megdnn_dst_memory, conv_dst_memory); |
|
|
|
REORDER_MEMORY(conv_dst_memory, megdnn_dst_memory); |
|
|
|
stream_mkldnn.wait(); |
|
|
|
} else { |
|
|
|
std::vector<primitive> net; |
|
|
|