|
|
@@ -435,13 +435,10 @@ void TensorReformatPass::translate_pass(OptState& opt) const { |
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
}; |
|
|
|
auto tshp0 = opr::Concat::make( |
|
|
|
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), |
|
|
|
tshp1 = opr::Concat::make( |
|
|
|
{sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0); |
|
|
|
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0); |
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); |
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
return y2.node(); |
|
|
|
return y1.node(); |
|
|
|
}; |
|
|
|
reformat[LayoutType::NCHW4_TO_NCHW] = [](VarNode* inp) -> VarNode* { |
|
|
|
auto x = SymbolVar(inp); |
|
|
@@ -455,7 +452,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const { |
|
|
|
auto y1 = opr::Reshape::make(y0, tshp0); |
|
|
|
return y1.node(); |
|
|
|
}; |
|
|
|
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] = [](VarNode* inp) -> VarNode* { |
|
|
|
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] = |
|
|
|
[](VarNode* inp) -> VarNode* { |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
@@ -471,7 +469,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const { |
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
return y2.node(); |
|
|
|
}; |
|
|
|
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] = [](VarNode* inp) -> VarNode* { |
|
|
|
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] = |
|
|
|
[](VarNode* inp) -> VarNode* { |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|