|
|
@@ -1534,12 +1534,12 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { |
|
|
|
"The origin filter is not NCHW mode"); |
|
|
|
VarNodeArray temp_inp = new_inp; |
|
|
|
//! if src is nchwxx, should RelayoutPlaceholder to nchw |
|
|
|
if (temp_inp[0]->shape().ndim == 5) { |
|
|
|
if (new_inp[0]->shape().ndim == 5) { |
|
|
|
auto new_src = RelayoutPlaceholder::make(new_inp[0], src_to_nchw_mode); |
|
|
|
temp_inp[0] = new_src.node(); |
|
|
|
} |
|
|
|
//! the bias is nchwxx |
|
|
|
if (new_inp.size() > 2 && temp_inp[2]->shape().ndim == 5) { |
|
|
|
if (new_inp.size() > 2 && new_inp[2]->shape().ndim == 5) { |
|
|
|
auto new_bias = RelayoutPlaceholder::make(new_inp[2], src_to_nchw_mode); |
|
|
|
temp_inp[2] = new_bias.node(); |
|
|
|
} |
|
|
|