From 58b8b14554a536a29c246aacf5fe23ccad264535 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 7 Jul 2021 13:30:49 +0800 Subject: [PATCH] refactor(mgb/gopt): add checker for reformat emitter GitOrigin-RevId: 53a8c128f57e05147a0acaffbf52fe55bcbad281 --- src/gopt/impl/reformat_emitter.cpp | 125 ++++++++++++++-------- src/gopt/include/megbrain/gopt/reformat_emitter.h | 42 ++++---- src/gopt/test/reformat_emitter.cpp | 14 ++- 3 files changed, 116 insertions(+), 65 deletions(-) diff --git a/src/gopt/impl/reformat_emitter.cpp b/src/gopt/impl/reformat_emitter.cpp index 21e22dd2..8c9bacfa 100644 --- a/src/gopt/impl/reformat_emitter.cpp +++ b/src/gopt/impl/reformat_emitter.cpp @@ -10,8 +10,8 @@ * implied. */ -#include #include "megbrain/gopt/reformat_emitter.h" +#include #include "megbrain/opr/tensor_manip.h" using namespace mgb; @@ -19,34 +19,7 @@ using namespace gopt; using Dimension = megdnn::Dimension; using NamedTensorShape = megdnn::NamedTensorShape; -ReshapeEmitter::Operator ReshapeEmitter::emit() const { - auto pattern = analyze(); - auto op = [pattern](VarNode* var) { - auto sym_var = SymbolVar(var); - auto shp = opr::GetVarShape::make(sym_var); - auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); }; - auto sub = [&shp, &cv](int ax) { - return opr::IndexAt::make(shp, {{0, cv(ax)}}); - }; - SymbolVarArray axs; - for (auto i : pattern) { - if (std::get<0>(i) >= 0) { - if (std::get<2>(i)) - axs.emplace_back(sub(std::get<0>(i)) * std::get<1>(i)); - else - axs.emplace_back(sub(std::get<0>(i)) / std::get<1>(i)); - } else { - axs.emplace_back(cv(std::get<1>(i))); - } - } - auto tshp = opr::Concat::make(axs, 0); - auto ovar = opr::Reshape::make(sym_var, tshp); - return ovar.node(); - }; - return op; -} - -SmallVector> ReshapeEmitter::analyze() const { +ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { static constexpr uint32_t UNDETERMINED_EXTENT = Dimension::UNDETERMINED_EXTENT; ThinHashMap name2dominant; @@ -58,7 +31,7 @@ SmallVector> ReshapeEmitter::analyze() const { } } - SmallVector> pattern(m_dest.ndim); + Pattern pattern(m_dest.ndim); for (size_t i = 0; i < m_dest.ndim; ++i) { auto name = m_dest[i].name(); if (m_dest[i].extent() == UNDETERMINED_EXTENT) { @@ -74,28 +47,90 @@ SmallVector> ReshapeEmitter::analyze() const { return pattern; } -DimshuffleEmitter::Operator DimshuffleEmitter::emit() const { - auto pattern = m_pattern; - auto op = [pattern](VarNode* var) { +ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( + const Pattern& pattern) const { + auto src = m_src; + auto checker = [src, pattern](VarNode* var) { + const auto& shp = var->shape(); + if (shp.ndim != src.ndim) + return false; + bool available = true; + for (size_t i = 0; i < shp.ndim; ++i) { + if (src[i].extent() != Dimension::UNDETERMINED_EXTENT) { + available &= (shp[i] == src[i].extent()); + } + } + for (auto&& i : pattern) { + int axis, factor; + bool mul; + std::tie(axis, factor, mul) = i; + if (axis >= 0 && !mul) { + available &= (shp[axis] % factor == 0); + } + } + return available; + }; + return checker; +} + +ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { + auto pattern = mixin_analyze(); + auto builder = [pattern](VarNode* var) { + auto sym_var = SymbolVar(var); + auto shp = opr::GetVarShape::make(sym_var); + auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); }; + auto sub = [&shp, &cv](int ax) { + return opr::IndexAt::make(shp, {{0, cv(ax)}}); + }; + SymbolVarArray axs; + for (auto&& i : pattern) { + int axis, factor; + bool mul; + std::tie(axis, factor, mul) = i; + if (axis >= 0) { + if (mul) + axs.emplace_back(sub(axis) * factor); + else + axs.emplace_back(sub(axis) / factor); + } else { + axs.emplace_back(cv(factor)); + } + } + auto tshp = opr::Concat::make(axs, 0); + auto ovar = opr::Reshape::make(sym_var, tshp); + return ovar.node(); + }; + auto checker = mixin_emit_checker(pattern); + return std::make_tuple(builder, checker); +} + +DimshuffleEmitter::EmitResult DimshuffleEmitter::emit() const { + auto&& pattern = m_pattern; + auto builder = [pattern](VarNode* var) { auto sym_var = SymbolVar(var); return opr::Dimshuffle::make(sym_var, pattern).node(); }; - return op; + auto checker = [pattern](VarNode* var) { + return var->shape().ndim == pattern.size(); + }; + return std::make_tuple(builder, checker); } -ReformatEmitter::Operator ReformatEmitter::emit() const { +ReformatEmitter::EmitResult ReformatEmitter::emit() const { auto ops = analyze(); - auto op = [ops](VarNode* var) { + auto builder = [ops](VarNode* var) { VarNode* ovar = var; - for (const auto& o : ops) { - ovar = o(ovar); + for (const auto& i : ops) { + ovar = i(ovar); } return ovar; }; - return op; + auto pattern = mixin_analyze(); + auto checker = mixin_emit_checker(pattern); + return std::make_tuple(builder, checker); } -SmallVector ReformatEmitter::analyze() const { +SmallVector ReformatEmitter::analyze() const { struct Dim { Dimension dim; int index; @@ -161,12 +196,12 @@ SmallVector ReformatEmitter::analyze() const { i1[i] = src_dims[src_perm[i]].dim; i2[i] = src_dims[src_perm[permute[i]]].dim; } - SmallVector ops; + SmallVector ops; if (!m_src.eq_shape(i1)) - ops.emplace_back(ReshapeEmitter(m_src, i1).emit()); - ops.emplace_back(DimshuffleEmitter(permute).emit()); + ops.emplace_back(std::get<0>(ReshapeEmitter(m_src, i1).emit())); + ops.emplace_back(std::get<0>(DimshuffleEmitter(permute).emit())); if (!m_dest.eq_shape(i2)) - ops.emplace_back(ReshapeEmitter(i2, m_dest).emit()); + ops.emplace_back(std::get<0>(ReshapeEmitter(i2, m_dest).emit())); return ops; } - +// vim: syntax=cpp.doxygen diff --git a/src/gopt/include/megbrain/gopt/reformat_emitter.h b/src/gopt/include/megbrain/gopt/reformat_emitter.h index 5f579fca..9d83cf85 100644 --- a/src/gopt/include/megbrain/gopt/reformat_emitter.h +++ b/src/gopt/include/megbrain/gopt/reformat_emitter.h @@ -20,45 +20,51 @@ namespace gopt { class Emitter { public: - using Operator = thin_function; + using Builder = thin_function; + using Checker = thin_function; + using EmitResult = std::tuple; virtual ~Emitter() = default; - virtual Operator emit() const = 0; + virtual EmitResult emit() const = 0; }; -class ReshapeEmitter final : public Emitter { +class ModifyShapeMixin { +protected: + using Pattern = SmallVector>; + using Checker = Emitter::Checker; + ModifyShapeMixin(const megdnn::NamedTensorShape& src, + const megdnn::NamedTensorShape& dest) + : m_src(src), m_dest(dest) {} + Pattern mixin_analyze() const; + Checker mixin_emit_checker(const Pattern& pattern) const; + megdnn::NamedTensorShape m_src, m_dest; +}; + +class ReshapeEmitter final : public Emitter, ModifyShapeMixin { public: - using Operator = typename Emitter::Operator; ReshapeEmitter(const megdnn::NamedTensorShape& src, const megdnn::NamedTensorShape& dest) - : m_src{src}, m_dest{dest} {} - Operator emit() const override; - -private: - SmallVector> analyze() const; - megdnn::NamedTensorShape m_src, m_dest; + : ModifyShapeMixin(src, dest) {} + EmitResult emit() const override; }; class DimshuffleEmitter final : public Emitter { public: - using Operator = typename Emitter::Operator; DimshuffleEmitter(const std::vector& pattern) : m_pattern{pattern} {} - Operator emit() const override; + EmitResult emit() const override; private: std::vector m_pattern; }; -class ReformatEmitter final : public Emitter { +class ReformatEmitter final : public Emitter, ModifyShapeMixin { public: - using Operator = typename Emitter::Operator; ReformatEmitter(const megdnn::NamedTensorShape& src, const megdnn::NamedTensorShape& dest) - : m_src{src}, m_dest{dest} {} - Operator emit() const override; + : ModifyShapeMixin(src, dest) {} + EmitResult emit() const override; private: - SmallVector analyze() const; - megdnn::NamedTensorShape m_src, m_dest; + SmallVector analyze() const; }; } // namespace gopt } // namespace mgb diff --git a/src/gopt/test/reformat_emitter.cpp b/src/gopt/test/reformat_emitter.cpp index e877ba0d..ab8ae1d0 100644 --- a/src/gopt/test/reformat_emitter.cpp +++ b/src/gopt/test/reformat_emitter.cpp @@ -25,7 +25,9 @@ TEST(TestReformatEmitter, Basic) { NamedTensorShape::Format::NCHW4); auto src = NamedTensorShape::make_named_tensor_shape( NamedTensorShape::Format::NCHW32); - auto reformat = gopt::ReformatEmitter(src, dest).emit(); + auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); + auto reformat = std::get<0>(tuple); + auto checker = std::get<1>(tuple); auto graph = ComputingGraph::make(); graph->options().graph_opt_level = 0; @@ -51,6 +53,9 @@ TEST(TestReformatEmitter, Basic) { return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); }; auto x = mkvar("x", {N, C / 32, H, W, 32}); + EXPECT_TRUE(checker(x.node())); + auto x_ = mkvar("x", {N, H, W, C}); + EXPECT_FALSE(checker(x_.node())); auto y1 = SymbolVar(reformat(x.node())); auto y2 = SymbolVar(nchw32_to_nchw4(x.node())); HostTensorND t1, t2; @@ -69,7 +74,9 @@ TEST(TestReformatEmitter, MoreComplicated) { NamedTensorShape::Format::NCHW64); auto dest = NamedTensorShape::make_named_tensor_shape( NamedTensorShape::Format::NCHW88); - auto reformat = gopt::ReformatEmitter(src, dest).emit(); + auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); + auto reformat = std::get<0>(tuple); + auto checker = std::get<1>(tuple); auto graph = ComputingGraph::make(); graph->options().graph_opt_level = 0; @@ -77,6 +84,9 @@ TEST(TestReformatEmitter, MoreComplicated) { return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); }; auto x = mkvar("x", {N, C / 64, H, W, 64}); + EXPECT_TRUE(checker(x.node())); + auto x_ = mkvar("x", {N, H, W, C}); + EXPECT_FALSE(checker(x_.node())); auto y = SymbolVar(reformat(x.node())); HostTensorND t; auto func = graph->compile({make_callback_copy(y, t)});