GitOrigin-RevId: 844b7e8d39
tags/v1.6.0-rc1
@@ -19,6 +19,7 @@ using namespace gopt; | |||
using Dimension = megdnn::Dimension; | |||
using NamedTensorShape = megdnn::NamedTensorShape; | |||
// =================== ModifyShapeMixin ====================*/ | |||
ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { | |||
static constexpr uint32_t UNDETERMINED_EXTENT = | |||
Dimension::UNDETERMINED_EXTENT; | |||
@@ -50,7 +51,9 @@ ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { | |||
ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( | |||
const Pattern& pattern) const { | |||
auto src = m_src; | |||
auto checker = [src, pattern](VarNode* var) { | |||
auto checker = [src, pattern](const VarNodeArray& input) { | |||
mgb_assert(input.size() >= 1); | |||
const auto& var = input.front(); | |||
const auto& shp = var->shape(); | |||
if (shp.ndim != src.ndim) | |||
return false; | |||
@@ -73,10 +76,14 @@ ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( | |||
return checker; | |||
} | |||
ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { | |||
// =================== MakeShapeEmitter ====================*/ | |||
MakeShapeEmitter::EmitResult MakeShapeEmitter::emit() const { | |||
auto pattern = mixin_analyze(); | |||
auto builder = [pattern](VarNode* var) { | |||
auto sym_var = SymbolVar(var); | |||
auto builder = [pattern](const VarNodeArray& input) { | |||
mgb_assert(input.size() == 1, | |||
"number of input of MakeShapeBuilder should be 1(got:%zu)", | |||
input.size()); | |||
auto sym_var = SymbolVar(input.front()); | |||
auto shp = opr::GetVarShape::make(sym_var); | |||
auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); }; | |||
auto sub = [&shp, &cv](int ax) { | |||
@@ -97,31 +104,59 @@ ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { | |||
} | |||
} | |||
auto tshp = opr::Concat::make(axs, 0); | |||
auto ovar = opr::Reshape::make(sym_var, tshp); | |||
return tshp.node(); | |||
}; | |||
auto checker = mixin_emit_checker(pattern); | |||
return std::make_tuple(builder, checker); | |||
} | |||
// =================== ReshapeEmitter ====================*/ | |||
ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { | |||
auto pattern = mixin_analyze(); | |||
auto builder = [pattern](const VarNodeArray& input) { | |||
mgb_assert(input.size() == 2, | |||
"number of input of Reshape should be 2(got:%zu)", | |||
input.size()); | |||
auto ovar = opr::Reshape::make(input[0], input[1]); | |||
return ovar.node(); | |||
}; | |||
auto checker = mixin_emit_checker(pattern); | |||
return std::make_tuple(builder, checker); | |||
} | |||
// =================== DimshuffleEmitter ====================*/ | |||
DimshuffleEmitter::EmitResult DimshuffleEmitter::emit() const { | |||
auto&& pattern = m_pattern; | |||
auto builder = [pattern](VarNode* var) { | |||
auto sym_var = SymbolVar(var); | |||
auto builder = [pattern](const VarNodeArray& input) { | |||
mgb_assert(input.size() == 1, | |||
"number of input of Dimshuffle should be 1(got:%zu)", | |||
input.size()); | |||
auto sym_var = SymbolVar(input.front()); | |||
return opr::Dimshuffle::make(sym_var, pattern).node(); | |||
}; | |||
auto checker = [pattern](VarNode* var) { | |||
return var->shape().ndim == pattern.size(); | |||
auto checker = [pattern](const VarNodeArray& input) { | |||
mgb_assert(input.size() == 1, | |||
"number of input of Dimshuffle should be 1(got:%zu)", | |||
input.size()); | |||
return input.front()->shape().ndim == pattern.size(); | |||
}; | |||
return std::make_tuple(builder, checker); | |||
} | |||
// =================== ReformatEmitter ====================*/ | |||
ReformatEmitter::EmitResult ReformatEmitter::emit() const { | |||
auto ops = analyze(); | |||
auto builder = [ops](VarNode* var) { | |||
VarNode* ovar = var; | |||
for (const auto& i : ops) { | |||
ovar = i(ovar); | |||
auto builders = analyze(); | |||
auto builder = [builders](const VarNodeArray& input) { | |||
VarNode *var, *ovar; | |||
var = ovar = input.front(); | |||
if (builders.make_shape1) { | |||
auto shp1 = builders.make_shape1({var}); | |||
ovar = builders.reshape1({ovar, shp1}); | |||
} | |||
ovar = builders.dimshuffle({ovar}); | |||
if (builders.make_shape2) { | |||
auto shp2 = builders.make_shape2({var}); | |||
ovar = builders.reshape2({ovar, shp2}); | |||
} | |||
return ovar; | |||
}; | |||
@@ -130,7 +165,7 @@ ReformatEmitter::EmitResult ReformatEmitter::emit() const { | |||
return std::make_tuple(builder, checker); | |||
} | |||
SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const { | |||
ReformatEmitter::UnderlyingBuilders ReformatEmitter::analyze() const { | |||
struct Dim { | |||
Dimension dim; | |||
int index; | |||
@@ -196,12 +231,21 @@ SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const { | |||
i1[i] = src_dims[src_perm[i]].dim; | |||
i2[i] = src_dims[src_perm[permute[i]]].dim; | |||
} | |||
SmallVector<Builder> ops; | |||
if (!m_src.eq_shape(i1)) | |||
ops.emplace_back(std::get<0>(ReshapeEmitter(m_src, i1).emit())); | |||
ops.emplace_back(std::get<0>(DimshuffleEmitter(permute).emit())); | |||
if (!m_dest.eq_shape(i2)) | |||
ops.emplace_back(std::get<0>(ReshapeEmitter(i2, m_dest).emit())); | |||
return ops; | |||
UnderlyingBuilders builders; | |||
if (!m_src.eq_shape(i1)) { | |||
builders.make_shape1 = | |||
std::move(std::get<0>(MakeShapeEmitter(m_src, i1).emit())); | |||
builders.reshape1 = | |||
std::move(std::get<0>(ReshapeEmitter(m_src, i1).emit())); | |||
} | |||
builders.dimshuffle = | |||
std::move(std::get<0>(DimshuffleEmitter(permute).emit())); | |||
if (!m_dest.eq_shape(i2)) { | |||
builders.make_shape2 = | |||
std::move(std::get<0>(MakeShapeEmitter(m_src, m_dest).emit())); | |||
builders.reshape2 = | |||
std::move(std::get<0>(ReshapeEmitter(i2, m_dest).emit())); | |||
} | |||
return builders; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -20,8 +20,8 @@ namespace gopt { | |||
class Emitter { | |||
public: | |||
using Builder = thin_function<VarNode*(VarNode*)>; | |||
using Checker = thin_function<bool(VarNode*)>; | |||
using Builder = thin_function<VarNode*(const VarNodeArray&)>; | |||
using Checker = thin_function<bool(const VarNodeArray&)>; | |||
using EmitResult = std::tuple<Builder, Checker>; | |||
virtual ~Emitter() = default; | |||
virtual EmitResult emit() const = 0; | |||
@@ -39,6 +39,14 @@ protected: | |||
megdnn::NamedTensorShape m_src, m_dest; | |||
}; | |||
class MakeShapeEmitter final : public Emitter, ModifyShapeMixin { | |||
public: | |||
MakeShapeEmitter(const megdnn::NamedTensorShape& src, | |||
const megdnn::NamedTensorShape& dest) | |||
: ModifyShapeMixin(src, dest) {} | |||
EmitResult emit() const override; | |||
}; | |||
class ReshapeEmitter final : public Emitter, ModifyShapeMixin { | |||
public: | |||
ReshapeEmitter(const megdnn::NamedTensorShape& src, | |||
@@ -64,7 +72,10 @@ public: | |||
EmitResult emit() const override; | |||
private: | |||
SmallVector<Builder> analyze() const; | |||
struct UnderlyingBuilders { | |||
Builder make_shape1, make_shape2, reshape1, reshape2, dimshuffle; | |||
}; | |||
UnderlyingBuilders analyze() const; | |||
}; | |||
} // namespace gopt | |||
} // namespace mgb | |||
@@ -21,10 +21,10 @@ TEST(TestReformatEmitter, Basic) { | |||
constexpr size_t N = 12, C = 64, H = 7, W = 7; | |||
HostTensorGenerator<> gen; | |||
using NamedTensorShape = megdnn::NamedTensorShape; | |||
auto dest = NamedTensorShape::make_named_tensor_shape( | |||
NamedTensorShape::Format::NCHW4); | |||
auto src = NamedTensorShape::make_named_tensor_shape( | |||
NamedTensorShape::Format::NCHW32); | |||
auto dest = NamedTensorShape::make_named_tensor_shape( | |||
NamedTensorShape::Format::NCHW4); | |||
auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); | |||
auto reformat = std::get<0>(tuple); | |||
auto checker = std::get<1>(tuple); | |||
@@ -53,10 +53,21 @@ TEST(TestReformatEmitter, Basic) { | |||
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||
}; | |||
auto x = mkvar("x", {N, C / 32, H, W, 32}); | |||
EXPECT_TRUE(checker(x.node())); | |||
EXPECT_TRUE(checker({x.node()})); | |||
auto x_ = mkvar("x", {N, H, W, C}); | |||
EXPECT_FALSE(checker(x_.node())); | |||
auto y1 = SymbolVar(reformat(x.node())); | |||
EXPECT_FALSE(checker({x_.node()})); | |||
auto y1 = SymbolVar(reformat({x.node()})); | |||
size_t nr_shapeof = 0; | |||
size_t nr_reshape = 0; | |||
cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) { | |||
if (o->same_type<opr::GetVarShape>()) | |||
nr_shapeof++; | |||
if (o->same_type<opr::Reshape>()) | |||
nr_reshape++; | |||
}} | |||
.add(y1.node()->owner_opr()); | |||
ASSERT_EQ(nr_shapeof, 1); | |||
ASSERT_EQ(nr_reshape, 2); | |||
auto y2 = SymbolVar(nchw32_to_nchw4(x.node())); | |||
HostTensorND t1, t2; | |||
auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||
@@ -84,12 +95,116 @@ TEST(TestReformatEmitter, MoreComplicated) { | |||
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||
}; | |||
auto x = mkvar("x", {N, C / 64, H, W, 64}); | |||
EXPECT_TRUE(checker(x.node())); | |||
EXPECT_TRUE(checker({x.node()})); | |||
auto x_ = mkvar("x", {N, H, W, C}); | |||
EXPECT_FALSE(checker(x_.node())); | |||
auto y = SymbolVar(reformat(x.node())); | |||
EXPECT_FALSE(checker({x_.node()})); | |||
auto y = SymbolVar(reformat({x.node()})); | |||
HostTensorND t; | |||
auto func = graph->compile({make_callback_copy(y, t)}); | |||
func->execute(); | |||
} | |||
TEST(TestReformatEmitter, EliminateRedudantReshape) { | |||
constexpr size_t N = 16, C = 64, H = 7, W = 7; | |||
HostTensorGenerator<> gen; | |||
using NamedTensorShape = megdnn::NamedTensorShape; | |||
auto src = NamedTensorShape::make_named_tensor_shape( | |||
NamedTensorShape::Format::NCHW); | |||
auto dest = NamedTensorShape::make_named_tensor_shape( | |||
NamedTensorShape::Format::NHWC); | |||
auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); | |||
auto reformat = std::get<0>(tuple); | |||
auto checker = std::get<1>(tuple); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
auto nchw_to_nhwc = [](VarNode* in) { | |||
auto x = SymbolVar(in); | |||
auto y = opr::Dimshuffle::make(x, {0, 2, 3, 1}); | |||
return y.node(); | |||
}; | |||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||
}; | |||
auto x = mkvar("x", {N, C, H, W}); | |||
EXPECT_TRUE(checker({x.node()})); | |||
auto y1 = SymbolVar(reformat({x.node()})); | |||
size_t nr_reshape = 0; | |||
cg::DepOprIter{[&nr_reshape](cg::OperatorNodeBase* o) { | |||
if (o->same_type<opr::Reshape>()) | |||
nr_reshape++; | |||
}} | |||
.add(y1.node()->owner_opr()); | |||
ASSERT_EQ(nr_reshape, 0); | |||
HostTensorND t1, t2; | |||
auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||
func1->execute(); | |||
auto y2 = SymbolVar(nchw_to_nhwc(x.node())); | |||
auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||
func2->execute(); | |||
MGB_ASSERT_TENSOR_EQ(t1, t2); | |||
} | |||
TEST(TestReformatEmitter, Nchw4ToNchw) { | |||
constexpr size_t N = 12, C = 64, H = 7, W = 7; | |||
HostTensorGenerator<> gen; | |||
using NamedTensorShape = megdnn::NamedTensorShape; | |||
auto src = NamedTensorShape::make_named_tensor_shape( | |||
NamedTensorShape::Format::NCHW4); | |||
auto dest = NamedTensorShape::make_named_tensor_shape( | |||
NamedTensorShape::Format::NCHW); | |||
auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); | |||
auto reformat = std::get<0>(tuple); | |||
auto checker = std::get<1>(tuple); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
auto nchw4_to_nchw = [](VarNode* in) { | |||
auto x = SymbolVar(in); | |||
auto xshp = opr::GetVarShape::make(x); | |||
auto cv = [&x](int v) { return x.make_scalar(v); }; | |||
auto sub = [&xshp, &cv](int idx) { | |||
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||
}; | |||
auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); | |||
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); | |||
auto y1 = opr::Reshape::make(y0, tshp); | |||
return y1.node(); | |||
}; | |||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||
}; | |||
auto x = mkvar("x", {N, C / 4, H, W, 4}); | |||
EXPECT_TRUE(checker({x.node()})); | |||
auto y1 = SymbolVar(reformat({x.node()})); | |||
SmallVector<VarNode*> reshapes; | |||
VarNode* dimshuffle; | |||
cg::DepOprIter{[&dimshuffle, &reshapes](cg::OperatorNodeBase* o) { | |||
if (o->same_type<opr::Reshape>()) { | |||
reshapes.push_back(o->output(0)); | |||
} | |||
if (o->same_type<opr::Dimshuffle>()) | |||
dimshuffle = o->output(0); | |||
}} | |||
.add(y1.node()->owner_opr()); | |||
ASSERT_EQ(reshapes.size(), 1); | |||
{ | |||
gopt::SubGraph graph({y1}); | |||
gopt::UniqReaderCheck check(graph); | |||
EXPECT_TRUE(check(reshapes[0])); | |||
EXPECT_TRUE(dimshuffle); | |||
} | |||
auto y2 = SymbolVar(nchw4_to_nchw(x.node())); | |||
HostTensorND t1, t2; | |||
auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||
func1->execute(); | |||
auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||
func2->execute(); | |||
MGB_ASSERT_TENSOR_EQ(t1, t2); | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |