GitOrigin-RevId: aba44e0123
tags/v0.4.0
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "./opr_impl_helper.h" | #include "./opr_impl_helper.h" | ||||
@@ -79,18 +80,19 @@ bool ElemwiseLayoutHelper::is_broadcasted_scalar(const TensorLayout& layout) { | |||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
template <size_t slice_size> | |||||
bool ElemwiseLayoutHelper::is_broadcastedx_channel_like( | bool ElemwiseLayoutHelper::is_broadcastedx_channel_like( | ||||
const TensorLayout& layout, BroadcastChannelInfo& info) { | const TensorLayout& layout, BroadcastChannelInfo& info) { | ||||
if (layout.format.type() == TensorFormat::Type::DEFAULT && | if (layout.format.type() == TensorFormat::Type::DEFAULT && | ||||
layout.ndim == 3 && layout.stride[0] == 8 && layout.stride[1] == 0 && | |||||
layout.stride[2] == 1) { | |||||
layout.ndim == 3 && layout.stride[0] == slice_size && | |||||
layout.stride[1] == 0 && layout.stride[2] == 1) { | |||||
info.x = layout.shape[0]; | info.x = layout.shape[0]; | ||||
info.y = layout.shape[1]; | info.y = layout.shape[1]; | ||||
info.z = layout.shape[2]; | info.z = layout.shape[2]; | ||||
return true; | return true; | ||||
} else if (layout.format.type() == TensorFormat::Type::DEFAULT && | } else if (layout.format.type() == TensorFormat::Type::DEFAULT && | ||||
layout.ndim == 4 && layout.stride[0] == 0 && | layout.ndim == 4 && layout.stride[0] == 0 && | ||||
layout.stride[1] == 8 && layout.stride[2] == 0 && | |||||
layout.stride[1] == slice_size && layout.stride[2] == 0 && | |||||
layout.stride[3] == 1) { | layout.stride[3] == 1) { | ||||
info.x = layout.shape[1]; | info.x = layout.shape[1]; | ||||
info.y = layout.shape[2]; | info.y = layout.shape[2]; | ||||
@@ -99,6 +101,12 @@ bool ElemwiseLayoutHelper::is_broadcastedx_channel_like( | |||||
} | } | ||||
return false; | return false; | ||||
} | } | ||||
#define INST(n) \ | |||||
template bool ElemwiseLayoutHelper::is_broadcastedx_channel_like<n>( \ | |||||
const TensorLayout& layout, BroadcastChannelInfo& info) | |||||
INST(4); | |||||
INST(8); | |||||
#undef INST | |||||
bool ElemwiseLayoutHelper::is_broadcasted_channel_like( | bool ElemwiseLayoutHelper::is_broadcasted_channel_like( | ||||
const TensorLayout& layout, BroadcastChannelInfo& info) { | const TensorLayout& layout, BroadcastChannelInfo& info) { | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
@@ -87,6 +88,7 @@ public: | |||||
* Note that Input can also be 3-dimensional, and must be [x, 1, z] | * Note that Input can also be 3-dimensional, and must be [x, 1, z] | ||||
* broadacsted into [x, y, z] | * broadacsted into [x, y, z] | ||||
*/ | */ | ||||
template <size_t slice_size> | |||||
static bool is_broadcastedx_channel_like(const TensorLayout& layout, | static bool is_broadcastedx_channel_like(const TensorLayout& layout, | ||||
BroadcastChannelInfo& info); | BroadcastChannelInfo& info); | ||||
}; | }; | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/x86/elemwise/opr_impl.h" | #include "src/x86/elemwise/opr_impl.h" | ||||
#include "src/x86/elemwise_op.h" | #include "src/x86/elemwise_op.h" | ||||
@@ -360,13 +361,14 @@ bool ElemwiseImpl::exec_binary() { | |||||
return true; \ | return true; \ | ||||
} | } | ||||
{ | { | ||||
bool normal_case = is_vector(src1.layout) && | |||||
is_broadcastedx_channel_like(src0.layout, binfo); | |||||
bool normal_case = | |||||
is_vector(src1.layout) && | |||||
is_broadcastedx_channel_like<8>(src0.layout, binfo); | |||||
bool swap_case = false; | bool swap_case = false; | ||||
bool commutable = mode_trait().commutable; | bool commutable = mode_trait().commutable; | ||||
if (!normal_case && commutable) { | if (!normal_case && commutable) { | ||||
swap_case = is_vector(src0.layout) && | swap_case = is_vector(src0.layout) && | ||||
is_broadcastedx_channel_like(src1.layout, binfo); | |||||
is_broadcastedx_channel_like<8>(src1.layout, binfo); | |||||
} | } | ||||
if ((swap_case || normal_case) && | if ((swap_case || normal_case) && | ||||
@@ -414,7 +414,7 @@ struct OpCallerBinary<Op, SIMDType::AVX2, BCAST101x_VEC> { | |||||
const typename Op::src_ctype* src1, | const typename Op::src_ctype* src1, | ||||
typename Op::dst_ctype* dst, DType src0_dtype, | typename Op::dst_ctype* dst, DType src0_dtype, | ||||
DType src1_dtype, DType dst_dtype, size_t batch, | DType src1_dtype, DType dst_dtype, size_t batch, | ||||
size_t nr_blocks_in_channel, size_t channel_stride, | |||||
size_t nr_channel_blocks, size_t channel_stride, | |||||
size_t channel_block_dim) { | size_t channel_block_dim) { | ||||
megdnn_assert(channel_block_dim == 8, "avx2 only support nchw88"); | megdnn_assert(channel_block_dim == 8, "avx2 only support nchw88"); | ||||
Op op(src0_dtype, src1_dtype, dst_dtype); | Op op(src0_dtype, src1_dtype, dst_dtype); | ||||
@@ -422,7 +422,7 @@ struct OpCallerBinary<Op, SIMDType::AVX2, BCAST101x_VEC> { | |||||
ParamElemVisitor<typename Op::src_ctype, SIMDType::AVX2> vis1; | ParamElemVisitor<typename Op::src_ctype, SIMDType::AVX2> vis1; | ||||
for (size_t b = 0; b < batch; b++) { | for (size_t b = 0; b < batch; b++) { | ||||
auto src0_ptr = src0; | auto src0_ptr = src0; | ||||
for (size_t cb = 0; cb < nr_blocks_in_channel; cb++) { | |||||
for (size_t cb = 0; cb < nr_channel_blocks; cb++) { | |||||
auto src0_block_ptr = src0_ptr + cb * channel_block_dim; | auto src0_block_ptr = src0_ptr + cb * channel_block_dim; | ||||
auto channel_block_vec = vis0(src0_block_ptr); | auto channel_block_vec = vis0(src0_block_ptr); | ||||
size_t img_index = 0; | size_t img_index = 0; | ||||
@@ -451,12 +451,12 @@ struct OpCallerBinary<Op, SIMDType::NONE, BCAST101x_VEC> { | |||||
const typename Op::src_ctype* src1, | const typename Op::src_ctype* src1, | ||||
typename Op::dst_ctype* dst, DType src0_dtype, | typename Op::dst_ctype* dst, DType src0_dtype, | ||||
DType src1_dtype, DType dst_dtype, size_t batch, | DType src1_dtype, DType dst_dtype, size_t batch, | ||||
size_t nr_blocks_in_channel, size_t channel_stride, | |||||
size_t nr_channel_blocks, size_t channel_stride, | |||||
size_t channel_block_dim) { | size_t channel_block_dim) { | ||||
Op op(src0_dtype, src1_dtype, dst_dtype); | Op op(src0_dtype, src1_dtype, dst_dtype); | ||||
for (size_t b = 0; b < batch; b++) { | for (size_t b = 0; b < batch; b++) { | ||||
auto src0_ptr = src0; | auto src0_ptr = src0; | ||||
for (size_t cb = 0; cb < nr_blocks_in_channel; cb++) { | |||||
for (size_t cb = 0; cb < nr_channel_blocks; cb++) { | |||||
auto src0_block_ptr = src0_ptr + cb * channel_block_dim; | auto src0_block_ptr = src0_ptr + cb * channel_block_dim; | ||||
for (size_t i = 0; i < channel_stride; i++) { | for (size_t i = 0; i < channel_stride; i++) { | ||||
for (size_t c_iter = 0; c_iter < channel_block_dim; | for (size_t c_iter = 0; c_iter < channel_block_dim; | ||||