|
|
@@ -80,8 +80,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW> { |
|
|
|
|
|
|
|
template <> |
|
|
|
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW44> { |
|
|
|
static Maybe<OprTensorFormatsConfiguration> dispatch( |
|
|
|
const OperatorNodeBase* opr) { |
|
|
|
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { |
|
|
|
OprTensorFormatsConfiguration config; |
|
|
|
config.typeinfo = opr->dyn_typeinfo(); |
|
|
|
config.opr_format = OprFormat::NCHW44; |
|
|
@@ -101,8 +100,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW44> { |
|
|
|
#if !MEGDNN_DISABLE_FLOAT16 |
|
|
|
template <> |
|
|
|
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW88> { |
|
|
|
static Maybe<OprTensorFormatsConfiguration> dispatch( |
|
|
|
const OperatorNodeBase* opr) { |
|
|
|
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { |
|
|
|
OprTensorFormatsConfiguration config; |
|
|
|
config.typeinfo = opr->dyn_typeinfo(); |
|
|
|
config.opr_format = OprFormat::NCHW88; |
|
|
@@ -440,8 +438,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> { |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> { |
|
|
|
static Maybe<OprTensorFormatsConfiguration> dispatch( |
|
|
|
const OperatorNodeBase* opr) { |
|
|
|
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { |
|
|
|
const auto& conv = opr->cast_final_safe<Opr>(); |
|
|
|
OprTensorFormatsConfiguration config; |
|
|
|
config.typeinfo = opr->dyn_typeinfo(); |
|
|
@@ -451,8 +448,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> { |
|
|
|
for (size_t i = 0; i < opr->input().size(); ++i) { |
|
|
|
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; |
|
|
|
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); |
|
|
|
TensorType tensor_type = |
|
|
|
i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; |
|
|
|
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; |
|
|
|
config.input_tensor_types.emplace_back(tensor_type); |
|
|
|
} |
|
|
|
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; |
|
|
@@ -484,8 +480,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> { |
|
|
|
#if !MEGDNN_DISABLE_FLOAT16 |
|
|
|
template <typename Opr> |
|
|
|
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> { |
|
|
|
static Maybe<OprTensorFormatsConfiguration> dispatch( |
|
|
|
const OperatorNodeBase* opr) { |
|
|
|
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { |
|
|
|
const auto& conv = opr->cast_final_safe<Opr>(); |
|
|
|
OprTensorFormatsConfiguration config; |
|
|
|
config.typeinfo = opr->dyn_typeinfo(); |
|
|
@@ -495,8 +490,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> { |
|
|
|
for (size_t i = 0; i < opr->input().size(); ++i) { |
|
|
|
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16; |
|
|
|
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); |
|
|
|
TensorType tensor_type = |
|
|
|
i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; |
|
|
|
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; |
|
|
|
config.input_tensor_types.emplace_back(tensor_type); |
|
|
|
} |
|
|
|
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16; |
|
|
@@ -528,8 +522,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> { |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44_DOT> { |
|
|
|
static Maybe<OprTensorFormatsConfiguration> dispatch( |
|
|
|
const OperatorNodeBase* opr) { |
|
|
|
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { |
|
|
|
const auto& conv = opr->cast_final_safe<Opr>(); |
|
|
|
OprTensorFormatsConfiguration config; |
|
|
|
config.typeinfo = opr->dyn_typeinfo(); |
|
|
@@ -538,22 +531,18 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44_DOT> { |
|
|
|
// setup dtypes |
|
|
|
for (size_t i = 0; i < opr->input().size(); ++i) { |
|
|
|
if (i == 2) { |
|
|
|
available &= opr->input(i)->dtype().enumv() == |
|
|
|
DTypeEnum::QuantizedS32; |
|
|
|
available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; |
|
|
|
} else { |
|
|
|
available &= opr->input(i)->dtype().enumv() == |
|
|
|
DTypeEnum::QuantizedS8 || |
|
|
|
opr->input(i)->dtype().enumv() == |
|
|
|
DTypeEnum::Quantized8Asymm; |
|
|
|
available &= |
|
|
|
opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8 || |
|
|
|
opr->input(i)->dtype().enumv() == DTypeEnum::Quantized8Asymm; |
|
|
|
} |
|
|
|
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); |
|
|
|
TensorType tensor_type = |
|
|
|
i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; |
|
|
|
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; |
|
|
|
config.input_tensor_types.emplace_back(tensor_type); |
|
|
|
} |
|
|
|
available &= |
|
|
|
opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || |
|
|
|
opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm; |
|
|
|
available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || |
|
|
|
opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm; |
|
|
|
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); |
|
|
|
// setup tensor formats |
|
|
|
if (conv.param().sparse == Opr::Param::Sparse::DENSE) { |
|
|
@@ -747,7 +736,7 @@ StaticData::StaticData() { |
|
|
|
|
|
|
|
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); |
|
|
|
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); |
|
|
|
#if !MEGDNN_DISABLE_FLOAT16 |
|
|
|
#if !MEGDNN_DISABLE_FLOAT16 |
|
|
|
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); |
|
|
|
#endif |
|
|
|
|
|
|
|