Browse Source

fix(mgb/fallback): disable nchw44 in conv1x1 and im2col in x86

GitOrigin-RevId: 603d2eb94a
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
2272abe18d
5 changed files with 103 additions and 91 deletions
  1. +0
    -85
      dnn/src/common/conv_bias.cpp
  2. +0
    -6
      dnn/src/common/conv_bias.h
  3. +12
    -0
      dnn/src/fallback/conv_bias/conv1x1/algos.cpp
  4. +86
    -0
      dnn/src/naive/conv_bias/opr_impl.cpp
  5. +5
    -0
      dnn/src/naive/conv_bias/opr_impl.h

+ 0
- 85
dnn/src/common/conv_bias.cpp View File

@@ -444,91 +444,6 @@ void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args,
}
}

//! Only used for naive implementation. DO NOT use the following function in
//! other backends.
void handle_z_inp_and_activation_naive(
param::ConvBias::NonlineMode nonline_mode,
const TensorND& conv_bias_tensor, const TensorND& z_tensor,
const TensorND& dst_tensor, dt_byte* workspace_ptr) {
auto res = dst_tensor, z_float = z_tensor;
//!create naive inplace handle
auto handle = inplace_cpu_handle(2);
if (z_tensor.layout.ndim > 0 &&
z_tensor.layout.dtype.category() != DTypeCategory::FLOAT) {
dt_byte *res_float_workspace_ptr = nullptr,
*z_float_workspace_ptr = nullptr;
megdnn_assert(z_tensor.layout.eq_shape(dst_tensor.layout));
res_float_workspace_ptr = workspace_ptr;
z_float_workspace_ptr = res_float_workspace_ptr +
TensorLayout{z_tensor.layout, dtype::Float32()}
.span()
.dist_byte();
res = TensorND{res_float_workspace_ptr,
TensorLayout{dst_tensor.layout, dtype::Float32()}};
z_float = TensorND{z_float_workspace_ptr,
TensorLayout{z_tensor.layout, dtype::Float32()}};
}
// ====================sfb + z_tensor=====================
if (z_tensor.layout.ndim > 0) {
if (z_tensor.layout.dtype.category() != DTypeCategory::FLOAT) {
auto&& type_cvt = handle->create_operator<TypeCvt>();
type_cvt->exec(conv_bias_tensor, res);
type_cvt->exec(z_tensor, z_float);
}
auto add_opr = handle->create_operator<ElemwiseForward>();
add_opr->param().mode = Elemwise::Param::Mode::ADD;
add_opr->exec({res, z_float}, res);
} else {
res = conv_bias_tensor;
}

using NonlineMode = param::ConvBias::NonlineMode;

switch (nonline_mode) {
#define cb(_mode) \
case NonlineMode::_mode: { \
if (res.layout.dtype.category() != DTypeCategory::QUANTIZED) { \
auto nonlinear = handle->create_operator<ElemwiseForward>(); \
nonlinear->param().mode = Elemwise::Param::Mode::_mode; \
if (res.layout.dtype == dst_tensor.layout.dtype) { \
nonlinear->exec({res}, dst_tensor); \
} else { \
nonlinear->exec({res}, res); \
handle->create_operator<TypeCvt>()->exec(res, dst_tensor); \
} \
} else { \
auto nonlinear = handle->create_operator<ElemwiseMultiType>(); \
nonlinear->param().mode = \
ElemwiseMultiType::Param::Mode::Q##_mode; \
nonlinear->exec({res}, dst_tensor); \
} \
break; \
}
cb(RELU);
cb(H_SWISH);
#undef cb
case NonlineMode::SIGMOID: {
megdnn_assert(res.layout.dtype.category() !=
DTypeCategory::QUANTIZED);
auto nonlinear = handle->create_operator<ElemwiseForward>();
nonlinear->param().mode = Elemwise::Param::Mode::SIGMOID;
nonlinear->exec({res}, res);
if (res.raw_ptr != dst_tensor.raw_ptr) {
handle->create_operator<TypeCvt>()->exec(res, dst_tensor);
}
break;
}
case NonlineMode::IDENTITY: {
if (res.raw_ptr != dst_tensor.raw_ptr) {
handle->create_operator<TypeCvt>()->exec(res, dst_tensor);
}
break;
}
default:
megdnn_assert(false);
}
}

} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 0
- 6
dnn/src/common/conv_bias.h View File

@@ -21,12 +21,6 @@ void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args,
const TensorND* conv_dst_tensor,
const TensorND* dst_tensor,
const TensorND* bias_tensor);

void handle_z_inp_and_activation_naive(
param::ConvBias::NonlineMode nonline_mode,
const TensorND& conv_bias_tensor, const TensorND& z_tensor,
const TensorND& dst_tensor, dt_byte* workspace_ptr);

} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 12
- 0
dnn/src/fallback/conv_bias/conv1x1/algos.cpp View File

@@ -204,6 +204,18 @@ ConvBiasImpl::AlgoConv1x1::dispatch_preprocess_kerns(
bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 2) {
//! x86 only support nchw
#if MEGDNN_X86
if (param.filter_meta.format != param::ConvBias::Format::NCHW) {
return false;
}
#else
if (param.filter_meta.format != param::ConvBias::Format::NCHW &&
param.filter_meta.format != param::ConvBias::Format::NCHW44 &&
param.filter_meta.format != param::ConvBias::Format::NCHW44_DOT) {
return false;
}
#endif
size_t FH = param.filter_meta.spatial[0],
FW = param.filter_meta.spatial[1];
size_t PH = param.filter_meta.padding[0],


+ 86
- 0
dnn/src/naive/conv_bias/opr_impl.cpp View File

@@ -17,6 +17,7 @@
#include "src/naive/handle.h"
#include "src/naive/lowbit_utils.h"
#include "src/common/conv_bias.h"
#include "src/common/opr_delegate.h"

#include "midout.h"
MIDOUT_DECL(megdnn_naive_conv_bias_fwd)
@@ -24,6 +25,91 @@ MIDOUT_DECL(megdnn_naive_conv_bias_fwd)
namespace megdnn {
namespace naive {

//! Only used for naive implementation. DO NOT use the following function in
//! other backends.
void handle_z_inp_and_activation_naive(
param::ConvBias::NonlineMode nonline_mode,
const TensorND& conv_bias_tensor, const TensorND& z_tensor,
const TensorND& dst_tensor, dt_byte* workspace_ptr) {
auto res = dst_tensor, z_float = z_tensor;
//!create naive inplace handle
auto handle = inplace_cpu_handle(2);
if (z_tensor.layout.ndim > 0 &&
z_tensor.layout.dtype.category() != DTypeCategory::FLOAT) {
dt_byte *res_float_workspace_ptr = nullptr,
*z_float_workspace_ptr = nullptr;
megdnn_assert(z_tensor.layout.eq_shape(dst_tensor.layout));
res_float_workspace_ptr = workspace_ptr;
z_float_workspace_ptr = res_float_workspace_ptr +
TensorLayout{z_tensor.layout, dtype::Float32()}
.span()
.dist_byte();
res = TensorND{res_float_workspace_ptr,
TensorLayout{dst_tensor.layout, dtype::Float32()}};
z_float = TensorND{z_float_workspace_ptr,
TensorLayout{z_tensor.layout, dtype::Float32()}};
}
// ====================sfb + z_tensor=====================
if (z_tensor.layout.ndim > 0) {
if (z_tensor.layout.dtype.category() != DTypeCategory::FLOAT) {
auto&& type_cvt = handle->create_operator<TypeCvt>();
type_cvt->exec(conv_bias_tensor, res);
type_cvt->exec(z_tensor, z_float);
}
auto add_opr = handle->create_operator<ElemwiseForward>();
add_opr->param().mode = Elemwise::Param::Mode::ADD;
add_opr->exec({res, z_float}, res);
} else {
res = conv_bias_tensor;
}

using NonlineMode = param::ConvBias::NonlineMode;

switch (nonline_mode) {
#define cb(_mode) \
case NonlineMode::_mode: { \
if (res.layout.dtype.category() != DTypeCategory::QUANTIZED) { \
auto nonlinear = handle->create_operator<ElemwiseForward>(); \
nonlinear->param().mode = Elemwise::Param::Mode::_mode; \
if (res.layout.dtype == dst_tensor.layout.dtype) { \
nonlinear->exec({res}, dst_tensor); \
} else { \
nonlinear->exec({res}, res); \
handle->create_operator<TypeCvt>()->exec(res, dst_tensor); \
} \
} else { \
auto nonlinear = handle->create_operator<ElemwiseMultiType>(); \
nonlinear->param().mode = \
ElemwiseMultiType::Param::Mode::Q##_mode; \
nonlinear->exec({res}, dst_tensor); \
} \
break; \
}
cb(RELU);
cb(H_SWISH);
#undef cb
case NonlineMode::SIGMOID: {
megdnn_assert(res.layout.dtype.category() !=
DTypeCategory::QUANTIZED);
auto nonlinear = handle->create_operator<ElemwiseForward>();
nonlinear->param().mode = Elemwise::Param::Mode::SIGMOID;
nonlinear->exec({res}, res);
if (res.raw_ptr != dst_tensor.raw_ptr) {
handle->create_operator<TypeCvt>()->exec(res, dst_tensor);
}
break;
}
case NonlineMode::IDENTITY: {
if (res.raw_ptr != dst_tensor.raw_ptr) {
handle->create_operator<TypeCvt>()->exec(res, dst_tensor);
}
break;
}
default:
megdnn_assert(false);
}
}

namespace convolution {

template <>


+ 5
- 0
dnn/src/naive/conv_bias/opr_impl.h View File

@@ -66,6 +66,11 @@ public:
const char* get_algorithm_set_name() const override;
};

void handle_z_inp_and_activation_naive(
param::ConvBias::NonlineMode nonline_mode,
const TensorND& conv_bias_tensor, const TensorND& z_tensor,
const TensorND& dst_tensor, dt_byte* workspace_ptr);

} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen

Loading…
Cancel
Save