GitOrigin-RevId: 844b7e8d39
tags/v1.6.0-rc1
@@ -19,6 +19,7 @@ using namespace gopt; | |||||
using Dimension = megdnn::Dimension; | using Dimension = megdnn::Dimension; | ||||
using NamedTensorShape = megdnn::NamedTensorShape; | using NamedTensorShape = megdnn::NamedTensorShape; | ||||
// =================== ModifyShapeMixin ====================*/ | |||||
ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { | ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { | ||||
static constexpr uint32_t UNDETERMINED_EXTENT = | static constexpr uint32_t UNDETERMINED_EXTENT = | ||||
Dimension::UNDETERMINED_EXTENT; | Dimension::UNDETERMINED_EXTENT; | ||||
@@ -50,7 +51,9 @@ ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { | |||||
ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( | ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( | ||||
const Pattern& pattern) const { | const Pattern& pattern) const { | ||||
auto src = m_src; | auto src = m_src; | ||||
auto checker = [src, pattern](VarNode* var) { | |||||
auto checker = [src, pattern](const VarNodeArray& input) { | |||||
mgb_assert(input.size() >= 1); | |||||
const auto& var = input.front(); | |||||
const auto& shp = var->shape(); | const auto& shp = var->shape(); | ||||
if (shp.ndim != src.ndim) | if (shp.ndim != src.ndim) | ||||
return false; | return false; | ||||
@@ -73,10 +76,14 @@ ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( | |||||
return checker; | return checker; | ||||
} | } | ||||
ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { | |||||
// =================== MakeShapeEmitter ====================*/ | |||||
MakeShapeEmitter::EmitResult MakeShapeEmitter::emit() const { | |||||
auto pattern = mixin_analyze(); | auto pattern = mixin_analyze(); | ||||
auto builder = [pattern](VarNode* var) { | |||||
auto sym_var = SymbolVar(var); | |||||
auto builder = [pattern](const VarNodeArray& input) { | |||||
mgb_assert(input.size() == 1, | |||||
"number of input of MakeShapeBuilder should be 1(got:%zu)", | |||||
input.size()); | |||||
auto sym_var = SymbolVar(input.front()); | |||||
auto shp = opr::GetVarShape::make(sym_var); | auto shp = opr::GetVarShape::make(sym_var); | ||||
auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); }; | auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); }; | ||||
auto sub = [&shp, &cv](int ax) { | auto sub = [&shp, &cv](int ax) { | ||||
@@ -97,31 +104,59 @@ ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { | |||||
} | } | ||||
} | } | ||||
auto tshp = opr::Concat::make(axs, 0); | auto tshp = opr::Concat::make(axs, 0); | ||||
auto ovar = opr::Reshape::make(sym_var, tshp); | |||||
return tshp.node(); | |||||
}; | |||||
auto checker = mixin_emit_checker(pattern); | |||||
return std::make_tuple(builder, checker); | |||||
} | |||||
// =================== ReshapeEmitter ====================*/ | |||||
ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { | |||||
auto pattern = mixin_analyze(); | |||||
auto builder = [pattern](const VarNodeArray& input) { | |||||
mgb_assert(input.size() == 2, | |||||
"number of input of Reshape should be 2(got:%zu)", | |||||
input.size()); | |||||
auto ovar = opr::Reshape::make(input[0], input[1]); | |||||
return ovar.node(); | return ovar.node(); | ||||
}; | }; | ||||
auto checker = mixin_emit_checker(pattern); | auto checker = mixin_emit_checker(pattern); | ||||
return std::make_tuple(builder, checker); | return std::make_tuple(builder, checker); | ||||
} | } | ||||
// =================== DimshuffleEmitter ====================*/ | |||||
DimshuffleEmitter::EmitResult DimshuffleEmitter::emit() const { | DimshuffleEmitter::EmitResult DimshuffleEmitter::emit() const { | ||||
auto&& pattern = m_pattern; | auto&& pattern = m_pattern; | ||||
auto builder = [pattern](VarNode* var) { | |||||
auto sym_var = SymbolVar(var); | |||||
auto builder = [pattern](const VarNodeArray& input) { | |||||
mgb_assert(input.size() == 1, | |||||
"number of input of Dimshuffle should be 1(got:%zu)", | |||||
input.size()); | |||||
auto sym_var = SymbolVar(input.front()); | |||||
return opr::Dimshuffle::make(sym_var, pattern).node(); | return opr::Dimshuffle::make(sym_var, pattern).node(); | ||||
}; | }; | ||||
auto checker = [pattern](VarNode* var) { | |||||
return var->shape().ndim == pattern.size(); | |||||
auto checker = [pattern](const VarNodeArray& input) { | |||||
mgb_assert(input.size() == 1, | |||||
"number of input of Dimshuffle should be 1(got:%zu)", | |||||
input.size()); | |||||
return input.front()->shape().ndim == pattern.size(); | |||||
}; | }; | ||||
return std::make_tuple(builder, checker); | return std::make_tuple(builder, checker); | ||||
} | } | ||||
// =================== ReformatEmitter ====================*/ | |||||
ReformatEmitter::EmitResult ReformatEmitter::emit() const { | ReformatEmitter::EmitResult ReformatEmitter::emit() const { | ||||
auto ops = analyze(); | |||||
auto builder = [ops](VarNode* var) { | |||||
VarNode* ovar = var; | |||||
for (const auto& i : ops) { | |||||
ovar = i(ovar); | |||||
auto builders = analyze(); | |||||
auto builder = [builders](const VarNodeArray& input) { | |||||
VarNode *var, *ovar; | |||||
var = ovar = input.front(); | |||||
if (builders.make_shape1) { | |||||
auto shp1 = builders.make_shape1({var}); | |||||
ovar = builders.reshape1({ovar, shp1}); | |||||
} | |||||
ovar = builders.dimshuffle({ovar}); | |||||
if (builders.make_shape2) { | |||||
auto shp2 = builders.make_shape2({var}); | |||||
ovar = builders.reshape2({ovar, shp2}); | |||||
} | } | ||||
return ovar; | return ovar; | ||||
}; | }; | ||||
@@ -130,7 +165,7 @@ ReformatEmitter::EmitResult ReformatEmitter::emit() const { | |||||
return std::make_tuple(builder, checker); | return std::make_tuple(builder, checker); | ||||
} | } | ||||
SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const { | |||||
ReformatEmitter::UnderlyingBuilders ReformatEmitter::analyze() const { | |||||
struct Dim { | struct Dim { | ||||
Dimension dim; | Dimension dim; | ||||
int index; | int index; | ||||
@@ -196,12 +231,21 @@ SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const { | |||||
i1[i] = src_dims[src_perm[i]].dim; | i1[i] = src_dims[src_perm[i]].dim; | ||||
i2[i] = src_dims[src_perm[permute[i]]].dim; | i2[i] = src_dims[src_perm[permute[i]]].dim; | ||||
} | } | ||||
SmallVector<Builder> ops; | |||||
if (!m_src.eq_shape(i1)) | |||||
ops.emplace_back(std::get<0>(ReshapeEmitter(m_src, i1).emit())); | |||||
ops.emplace_back(std::get<0>(DimshuffleEmitter(permute).emit())); | |||||
if (!m_dest.eq_shape(i2)) | |||||
ops.emplace_back(std::get<0>(ReshapeEmitter(i2, m_dest).emit())); | |||||
return ops; | |||||
UnderlyingBuilders builders; | |||||
if (!m_src.eq_shape(i1)) { | |||||
builders.make_shape1 = | |||||
std::move(std::get<0>(MakeShapeEmitter(m_src, i1).emit())); | |||||
builders.reshape1 = | |||||
std::move(std::get<0>(ReshapeEmitter(m_src, i1).emit())); | |||||
} | |||||
builders.dimshuffle = | |||||
std::move(std::get<0>(DimshuffleEmitter(permute).emit())); | |||||
if (!m_dest.eq_shape(i2)) { | |||||
builders.make_shape2 = | |||||
std::move(std::get<0>(MakeShapeEmitter(m_src, m_dest).emit())); | |||||
builders.reshape2 = | |||||
std::move(std::get<0>(ReshapeEmitter(i2, m_dest).emit())); | |||||
} | |||||
return builders; | |||||
} | } | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -20,8 +20,8 @@ namespace gopt { | |||||
class Emitter { | class Emitter { | ||||
public: | public: | ||||
using Builder = thin_function<VarNode*(VarNode*)>; | |||||
using Checker = thin_function<bool(VarNode*)>; | |||||
using Builder = thin_function<VarNode*(const VarNodeArray&)>; | |||||
using Checker = thin_function<bool(const VarNodeArray&)>; | |||||
using EmitResult = std::tuple<Builder, Checker>; | using EmitResult = std::tuple<Builder, Checker>; | ||||
virtual ~Emitter() = default; | virtual ~Emitter() = default; | ||||
virtual EmitResult emit() const = 0; | virtual EmitResult emit() const = 0; | ||||
@@ -39,6 +39,14 @@ protected: | |||||
megdnn::NamedTensorShape m_src, m_dest; | megdnn::NamedTensorShape m_src, m_dest; | ||||
}; | }; | ||||
class MakeShapeEmitter final : public Emitter, ModifyShapeMixin { | |||||
public: | |||||
MakeShapeEmitter(const megdnn::NamedTensorShape& src, | |||||
const megdnn::NamedTensorShape& dest) | |||||
: ModifyShapeMixin(src, dest) {} | |||||
EmitResult emit() const override; | |||||
}; | |||||
class ReshapeEmitter final : public Emitter, ModifyShapeMixin { | class ReshapeEmitter final : public Emitter, ModifyShapeMixin { | ||||
public: | public: | ||||
ReshapeEmitter(const megdnn::NamedTensorShape& src, | ReshapeEmitter(const megdnn::NamedTensorShape& src, | ||||
@@ -64,7 +72,10 @@ public: | |||||
EmitResult emit() const override; | EmitResult emit() const override; | ||||
private: | private: | ||||
SmallVector<Builder> analyze() const; | |||||
struct UnderlyingBuilders { | |||||
Builder make_shape1, make_shape2, reshape1, reshape2, dimshuffle; | |||||
}; | |||||
UnderlyingBuilders analyze() const; | |||||
}; | }; | ||||
} // namespace gopt | } // namespace gopt | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -21,10 +21,10 @@ TEST(TestReformatEmitter, Basic) { | |||||
constexpr size_t N = 12, C = 64, H = 7, W = 7; | constexpr size_t N = 12, C = 64, H = 7, W = 7; | ||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
using NamedTensorShape = megdnn::NamedTensorShape; | using NamedTensorShape = megdnn::NamedTensorShape; | ||||
auto dest = NamedTensorShape::make_named_tensor_shape( | |||||
NamedTensorShape::Format::NCHW4); | |||||
auto src = NamedTensorShape::make_named_tensor_shape( | auto src = NamedTensorShape::make_named_tensor_shape( | ||||
NamedTensorShape::Format::NCHW32); | NamedTensorShape::Format::NCHW32); | ||||
auto dest = NamedTensorShape::make_named_tensor_shape( | |||||
NamedTensorShape::Format::NCHW4); | |||||
auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); | auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); | ||||
auto reformat = std::get<0>(tuple); | auto reformat = std::get<0>(tuple); | ||||
auto checker = std::get<1>(tuple); | auto checker = std::get<1>(tuple); | ||||
@@ -53,10 +53,21 @@ TEST(TestReformatEmitter, Basic) { | |||||
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | ||||
}; | }; | ||||
auto x = mkvar("x", {N, C / 32, H, W, 32}); | auto x = mkvar("x", {N, C / 32, H, W, 32}); | ||||
EXPECT_TRUE(checker(x.node())); | |||||
EXPECT_TRUE(checker({x.node()})); | |||||
auto x_ = mkvar("x", {N, H, W, C}); | auto x_ = mkvar("x", {N, H, W, C}); | ||||
EXPECT_FALSE(checker(x_.node())); | |||||
auto y1 = SymbolVar(reformat(x.node())); | |||||
EXPECT_FALSE(checker({x_.node()})); | |||||
auto y1 = SymbolVar(reformat({x.node()})); | |||||
size_t nr_shapeof = 0; | |||||
size_t nr_reshape = 0; | |||||
cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) { | |||||
if (o->same_type<opr::GetVarShape>()) | |||||
nr_shapeof++; | |||||
if (o->same_type<opr::Reshape>()) | |||||
nr_reshape++; | |||||
}} | |||||
.add(y1.node()->owner_opr()); | |||||
ASSERT_EQ(nr_shapeof, 1); | |||||
ASSERT_EQ(nr_reshape, 2); | |||||
auto y2 = SymbolVar(nchw32_to_nchw4(x.node())); | auto y2 = SymbolVar(nchw32_to_nchw4(x.node())); | ||||
HostTensorND t1, t2; | HostTensorND t1, t2; | ||||
auto func1 = graph->compile({make_callback_copy(y1, t1)}); | auto func1 = graph->compile({make_callback_copy(y1, t1)}); | ||||
@@ -84,12 +95,116 @@ TEST(TestReformatEmitter, MoreComplicated) { | |||||
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | ||||
}; | }; | ||||
auto x = mkvar("x", {N, C / 64, H, W, 64}); | auto x = mkvar("x", {N, C / 64, H, W, 64}); | ||||
EXPECT_TRUE(checker(x.node())); | |||||
EXPECT_TRUE(checker({x.node()})); | |||||
auto x_ = mkvar("x", {N, H, W, C}); | auto x_ = mkvar("x", {N, H, W, C}); | ||||
EXPECT_FALSE(checker(x_.node())); | |||||
auto y = SymbolVar(reformat(x.node())); | |||||
EXPECT_FALSE(checker({x_.node()})); | |||||
auto y = SymbolVar(reformat({x.node()})); | |||||
HostTensorND t; | HostTensorND t; | ||||
auto func = graph->compile({make_callback_copy(y, t)}); | auto func = graph->compile({make_callback_copy(y, t)}); | ||||
func->execute(); | func->execute(); | ||||
} | } | ||||
TEST(TestReformatEmitter, EliminateRedudantReshape) { | |||||
constexpr size_t N = 16, C = 64, H = 7, W = 7; | |||||
HostTensorGenerator<> gen; | |||||
using NamedTensorShape = megdnn::NamedTensorShape; | |||||
auto src = NamedTensorShape::make_named_tensor_shape( | |||||
NamedTensorShape::Format::NCHW); | |||||
auto dest = NamedTensorShape::make_named_tensor_shape( | |||||
NamedTensorShape::Format::NHWC); | |||||
auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); | |||||
auto reformat = std::get<0>(tuple); | |||||
auto checker = std::get<1>(tuple); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto nchw_to_nhwc = [](VarNode* in) { | |||||
auto x = SymbolVar(in); | |||||
auto y = opr::Dimshuffle::make(x, {0, 2, 3, 1}); | |||||
return y.node(); | |||||
}; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||||
}; | |||||
auto x = mkvar("x", {N, C, H, W}); | |||||
EXPECT_TRUE(checker({x.node()})); | |||||
auto y1 = SymbolVar(reformat({x.node()})); | |||||
size_t nr_reshape = 0; | |||||
cg::DepOprIter{[&nr_reshape](cg::OperatorNodeBase* o) { | |||||
if (o->same_type<opr::Reshape>()) | |||||
nr_reshape++; | |||||
}} | |||||
.add(y1.node()->owner_opr()); | |||||
ASSERT_EQ(nr_reshape, 0); | |||||
HostTensorND t1, t2; | |||||
auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||||
func1->execute(); | |||||
auto y2 = SymbolVar(nchw_to_nhwc(x.node())); | |||||
auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||||
func2->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(t1, t2); | |||||
} | |||||
TEST(TestReformatEmitter, Nchw4ToNchw) { | |||||
constexpr size_t N = 12, C = 64, H = 7, W = 7; | |||||
HostTensorGenerator<> gen; | |||||
using NamedTensorShape = megdnn::NamedTensorShape; | |||||
auto src = NamedTensorShape::make_named_tensor_shape( | |||||
NamedTensorShape::Format::NCHW4); | |||||
auto dest = NamedTensorShape::make_named_tensor_shape( | |||||
NamedTensorShape::Format::NCHW); | |||||
auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); | |||||
auto reformat = std::get<0>(tuple); | |||||
auto checker = std::get<1>(tuple); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto nchw4_to_nchw = [](VarNode* in) { | |||||
auto x = SymbolVar(in); | |||||
auto xshp = opr::GetVarShape::make(x); | |||||
auto cv = [&x](int v) { return x.make_scalar(v); }; | |||||
auto sub = [&xshp, &cv](int idx) { | |||||
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||||
}; | |||||
auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); | |||||
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); | |||||
auto y1 = opr::Reshape::make(y0, tshp); | |||||
return y1.node(); | |||||
}; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||||
}; | |||||
auto x = mkvar("x", {N, C / 4, H, W, 4}); | |||||
EXPECT_TRUE(checker({x.node()})); | |||||
auto y1 = SymbolVar(reformat({x.node()})); | |||||
SmallVector<VarNode*> reshapes; | |||||
VarNode* dimshuffle; | |||||
cg::DepOprIter{[&dimshuffle, &reshapes](cg::OperatorNodeBase* o) { | |||||
if (o->same_type<opr::Reshape>()) { | |||||
reshapes.push_back(o->output(0)); | |||||
} | |||||
if (o->same_type<opr::Dimshuffle>()) | |||||
dimshuffle = o->output(0); | |||||
}} | |||||
.add(y1.node()->owner_opr()); | |||||
ASSERT_EQ(reshapes.size(), 1); | |||||
{ | |||||
gopt::SubGraph graph({y1}); | |||||
gopt::UniqReaderCheck check(graph); | |||||
EXPECT_TRUE(check(reshapes[0])); | |||||
EXPECT_TRUE(dimshuffle); | |||||
} | |||||
auto y2 = SymbolVar(nchw4_to_nchw(x.node())); | |||||
HostTensorND t1, t2; | |||||
auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||||
func1->execute(); | |||||
auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||||
func2->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(t1, t2); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |