|
|
@@ -1588,45 +1588,61 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { |
|
|
|
auto replace_elemwise_opr = [&relayout_inp_to_chw]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
enum class TENSORPROPERTY { |
|
|
|
SCALA = 0, |
|
|
|
DEFAULT_MOD4 = 1, |
|
|
|
NHWCD4 = 2, |
|
|
|
UNKNOW = 3, |
|
|
|
}; |
|
|
|
auto get_property = [](VarNode* node) -> TENSORPROPERTY { |
|
|
|
auto&& shape = node->shape(); |
|
|
|
auto&& format = node->format(); |
|
|
|
if (shape.ndim == 4 && format.is_default() && shape[1] % 4 == 0) { |
|
|
|
return TENSORPROPERTY::DEFAULT_MOD4; |
|
|
|
} |
|
|
|
if (shape.is_scalar()) { |
|
|
|
return TENSORPROPERTY::SCALA; |
|
|
|
} |
|
|
|
if (shape.ndim == 5 && format.type() == TensorFormat::Type::IMAGE2D_PACK4) { |
|
|
|
return TENSORPROPERTY::NHWCD4; |
|
|
|
} |
|
|
|
return TENSORPROPERTY::UNKNOW; |
|
|
|
}; |
|
|
|
|
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
bool has_inp_changed = false; |
|
|
|
bool can_exec_cd4 = true; |
|
|
|
for (size_t i = 0; i < opr->input().size(); i++) { |
|
|
|
if (!new_inp[i]->format().is_default()) { |
|
|
|
has_inp_changed = true; |
|
|
|
} else if (new_inp[i]->shape().ndim == 4) { |
|
|
|
if (new_inp[i]->shape()[1] % 4 != 0) { |
|
|
|
can_exec_cd4 = false; |
|
|
|
} |
|
|
|
//! cd4 elemwise with scaler is unsupported |
|
|
|
} else if (!new_inp[i]->shape().is_scalar()) { |
|
|
|
auto property = get_property(new_inp[i]); |
|
|
|
if (property == TENSORPROPERTY::UNKNOW) { |
|
|
|
can_exec_cd4 = false; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (!can_exec_cd4) { |
|
|
|
return relayout_inp_to_chw(opr, new_inp); |
|
|
|
} |
|
|
|
if (has_inp_changed) { |
|
|
|
// assumption: all inputs are changed from nchw to nhwcd4 |
|
|
|
auto t_inp = new_inp; |
|
|
|
for (size_t i = 0; i < opr->input().size(); i++) { |
|
|
|
if (new_inp[i]->shape().ndim == 4) { |
|
|
|
auto param = megdnn::param::RelayoutFormat(); |
|
|
|
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I; |
|
|
|
auto rf = opr::RelayoutFormat::make(new_inp[i], param); |
|
|
|
t_inp[i] = rf.node(); |
|
|
|
} else { |
|
|
|
mgb_assert( |
|
|
|
(new_inp[i]->shape().ndim == 5 && |
|
|
|
new_inp[i]->format().type() == |
|
|
|
TensorFormat::Type::IMAGE2D_PACK4) || |
|
|
|
new_inp[i]->shape().is_scalar()); |
|
|
|
} |
|
|
|
|
|
|
|
//! check and change all inputs to cd4 |
|
|
|
auto t_inp = new_inp; |
|
|
|
for (size_t i = 0; i < opr->input().size(); i++) { |
|
|
|
auto property = get_property(new_inp[i]); |
|
|
|
if (property == TENSORPROPERTY::DEFAULT_MOD4) { |
|
|
|
auto param = megdnn::param::RelayoutFormat(); |
|
|
|
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I; |
|
|
|
auto rf = opr::RelayoutFormat::make(new_inp[i], param); |
|
|
|
t_inp[i] = rf.node(); |
|
|
|
} else { |
|
|
|
mgb_assert( |
|
|
|
property == TENSORPROPERTY::SCALA || |
|
|
|
property == TENSORPROPERTY::NHWCD4, |
|
|
|
"This node should be scala ir CD4 format, but got shape = %s, " |
|
|
|
"format = %s", |
|
|
|
new_inp[i]->shape().to_string().c_str(), |
|
|
|
new_inp[i]->format().to_string().c_str()); |
|
|
|
} |
|
|
|
return serialization::copy_opr_shallow(*opr, t_inp, opr->config()); |
|
|
|
} else { |
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); |
|
|
|
} |
|
|
|
return serialization::copy_opr_shallow(*opr, t_inp, opr->config()); |
|
|
|
}; |
|
|
|
|
|
|
|
/* This helper function converts the first input to the NCHW format to |
|
|
@@ -1661,7 +1677,6 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { |
|
|
|
replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw; |
|
|
|
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw; |
|
|
|
replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw; |
|
|
|
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; |
|
|
|
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; |
|
|
|
replace_func[opr::WarpPerspectiveForward::typeinfo()] = |
|
|
|
replace_warp_perspective_opr; |
|
|
|