Browse Source

feat(dnn/arm_common/elemwise): add arm_common support chw44 elemwise

GitOrigin-RevId: aba44e0123
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
c59be192cd
4 changed files with 25 additions and 13 deletions
  1. +12
    -4
      dnn/src/common/elemwise/opr_impl_helper.cpp
  2. +3
    -1
      dnn/src/common/elemwise/opr_impl_helper.h
  3. +6
    -4
      dnn/src/x86/elemwise/opr_impl.cpp
  4. +4
    -4
      dnn/src/x86/elemwise_op.h

+ 12
- 4
dnn/src/common/elemwise/opr_impl_helper.cpp View File

@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */


#include "./opr_impl_helper.h" #include "./opr_impl_helper.h"
@@ -79,18 +80,19 @@ bool ElemwiseLayoutHelper::is_broadcasted_scalar(const TensorLayout& layout) {
} }
return true; return true;
} }
template <size_t slice_size>
bool ElemwiseLayoutHelper::is_broadcastedx_channel_like( bool ElemwiseLayoutHelper::is_broadcastedx_channel_like(
const TensorLayout& layout, BroadcastChannelInfo& info) { const TensorLayout& layout, BroadcastChannelInfo& info) {
if (layout.format.type() == TensorFormat::Type::DEFAULT && if (layout.format.type() == TensorFormat::Type::DEFAULT &&
layout.ndim == 3 && layout.stride[0] == 8 && layout.stride[1] == 0 &&
layout.stride[2] == 1) {
layout.ndim == 3 && layout.stride[0] == slice_size &&
layout.stride[1] == 0 && layout.stride[2] == 1) {
info.x = layout.shape[0]; info.x = layout.shape[0];
info.y = layout.shape[1]; info.y = layout.shape[1];
info.z = layout.shape[2]; info.z = layout.shape[2];
return true; return true;
} else if (layout.format.type() == TensorFormat::Type::DEFAULT && } else if (layout.format.type() == TensorFormat::Type::DEFAULT &&
layout.ndim == 4 && layout.stride[0] == 0 && layout.ndim == 4 && layout.stride[0] == 0 &&
layout.stride[1] == 8 && layout.stride[2] == 0 &&
layout.stride[1] == slice_size && layout.stride[2] == 0 &&
layout.stride[3] == 1) { layout.stride[3] == 1) {
info.x = layout.shape[1]; info.x = layout.shape[1];
info.y = layout.shape[2]; info.y = layout.shape[2];
@@ -99,6 +101,12 @@ bool ElemwiseLayoutHelper::is_broadcastedx_channel_like(
} }
return false; return false;
} }
#define INST(n) \
template bool ElemwiseLayoutHelper::is_broadcastedx_channel_like<n>( \
const TensorLayout& layout, BroadcastChannelInfo& info)
INST(4);
INST(8);
#undef INST


bool ElemwiseLayoutHelper::is_broadcasted_channel_like( bool ElemwiseLayoutHelper::is_broadcasted_channel_like(
const TensorLayout& layout, BroadcastChannelInfo& info) { const TensorLayout& layout, BroadcastChannelInfo& info) {


+ 3
- 1
dnn/src/common/elemwise/opr_impl_helper.h View File

@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */


#pragma once #pragma once
@@ -87,6 +88,7 @@ public:
* Note that Input can also be 3-dimensional, and must be [x, 1, z] * Note that Input can also be 3-dimensional, and must be [x, 1, z]
* broadacsted into [x, y, z] * broadacsted into [x, y, z]
*/ */
template <size_t slice_size>
static bool is_broadcastedx_channel_like(const TensorLayout& layout, static bool is_broadcastedx_channel_like(const TensorLayout& layout,
BroadcastChannelInfo& info); BroadcastChannelInfo& info);
}; };


+ 6
- 4
dnn/src/x86/elemwise/opr_impl.cpp View File

@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/x86/elemwise/opr_impl.h" #include "src/x86/elemwise/opr_impl.h"
#include "src/x86/elemwise_op.h" #include "src/x86/elemwise_op.h"
@@ -360,13 +361,14 @@ bool ElemwiseImpl::exec_binary() {
return true; \ return true; \
} }
{ {
bool normal_case = is_vector(src1.layout) &&
is_broadcastedx_channel_like(src0.layout, binfo);
bool normal_case =
is_vector(src1.layout) &&
is_broadcastedx_channel_like<8>(src0.layout, binfo);
bool swap_case = false; bool swap_case = false;
bool commutable = mode_trait().commutable; bool commutable = mode_trait().commutable;
if (!normal_case && commutable) { if (!normal_case && commutable) {
swap_case = is_vector(src0.layout) && swap_case = is_vector(src0.layout) &&
is_broadcastedx_channel_like(src1.layout, binfo);
is_broadcastedx_channel_like<8>(src1.layout, binfo);
} }


if ((swap_case || normal_case) && if ((swap_case || normal_case) &&


+ 4
- 4
dnn/src/x86/elemwise_op.h View File

@@ -414,7 +414,7 @@ struct OpCallerBinary<Op, SIMDType::AVX2, BCAST101x_VEC> {
const typename Op::src_ctype* src1, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, typename Op::dst_ctype* dst, DType src0_dtype,
DType src1_dtype, DType dst_dtype, size_t batch, DType src1_dtype, DType dst_dtype, size_t batch,
size_t nr_blocks_in_channel, size_t channel_stride,
size_t nr_channel_blocks, size_t channel_stride,
size_t channel_block_dim) { size_t channel_block_dim) {
megdnn_assert(channel_block_dim == 8, "avx2 only support nchw88"); megdnn_assert(channel_block_dim == 8, "avx2 only support nchw88");
Op op(src0_dtype, src1_dtype, dst_dtype); Op op(src0_dtype, src1_dtype, dst_dtype);
@@ -422,7 +422,7 @@ struct OpCallerBinary<Op, SIMDType::AVX2, BCAST101x_VEC> {
ParamElemVisitor<typename Op::src_ctype, SIMDType::AVX2> vis1; ParamElemVisitor<typename Op::src_ctype, SIMDType::AVX2> vis1;
for (size_t b = 0; b < batch; b++) { for (size_t b = 0; b < batch; b++) {
auto src0_ptr = src0; auto src0_ptr = src0;
for (size_t cb = 0; cb < nr_blocks_in_channel; cb++) {
for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
auto src0_block_ptr = src0_ptr + cb * channel_block_dim; auto src0_block_ptr = src0_ptr + cb * channel_block_dim;
auto channel_block_vec = vis0(src0_block_ptr); auto channel_block_vec = vis0(src0_block_ptr);
size_t img_index = 0; size_t img_index = 0;
@@ -451,12 +451,12 @@ struct OpCallerBinary<Op, SIMDType::NONE, BCAST101x_VEC> {
const typename Op::src_ctype* src1, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, typename Op::dst_ctype* dst, DType src0_dtype,
DType src1_dtype, DType dst_dtype, size_t batch, DType src1_dtype, DType dst_dtype, size_t batch,
size_t nr_blocks_in_channel, size_t channel_stride,
size_t nr_channel_blocks, size_t channel_stride,
size_t channel_block_dim) { size_t channel_block_dim) {
Op op(src0_dtype, src1_dtype, dst_dtype); Op op(src0_dtype, src1_dtype, dst_dtype);
for (size_t b = 0; b < batch; b++) { for (size_t b = 0; b < batch; b++) {
auto src0_ptr = src0; auto src0_ptr = src0;
for (size_t cb = 0; cb < nr_blocks_in_channel; cb++) {
for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
auto src0_block_ptr = src0_ptr + cb * channel_block_dim; auto src0_block_ptr = src0_ptr + cb * channel_block_dim;
for (size_t i = 0; i < channel_stride; i++) { for (size_t i = 0; i < channel_stride; i++) {
for (size_t c_iter = 0; c_iter < channel_block_dim; for (size_t c_iter = 0; c_iter < channel_block_dim;


Loading…
Cancel
Save