From a09fc5f784ba3ba279bbd805a8b59c0e8b603a3c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 23 Jan 2021 16:56:06 +0800 Subject: [PATCH] fix(mgb/serialization): disable inplace arith graph opt in graph load GitOrigin-RevId: d63baf8356d345013886692e464f8e4f49594887 --- imperative/python/megengine/jit/tracing.py | 6 ++-- imperative/python/test/unit/test_tracing.py | 1 - src/core/include/megbrain/graph/cg.h | 8 +++++ src/opr/impl/basic_arith.cpp | 3 +- src/serialization/impl/serializer_oss.cpp | 12 +++++-- .../include/megbrain/serialization/serializer.h | 15 +++++++++ src/serialization/test/serializer_oss.cpp | 37 ++++++++++++++++++++++ 7 files changed, 75 insertions(+), 7 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 1455acb5..8d5ae719 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -755,8 +755,10 @@ class trace: h2v = {} graph = G.Graph() - # only graph_opt_level takes effect in dump - self._apply_graph_options(graph) + + # apply graph_opt_level in dump + if self._graph_opt_level is not None: + graph.options.graph_opt_level = self._graph_opt_level for i, h in enumerate(self._arg_bindings): info = self._tinfo[h] diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 94b3c189..9e637af6 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -244,7 +244,6 @@ def test_goptions_log_sum_exp(): np.testing.assert_almost_equal(g(d, o), val) -@pytest.mark.skip(reason="could not use opt_level=0 with dump") def test_goptions_log_exp(): @trace(symbolic=True, opt_level=0, capture_as_const=True) def f(x): diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index 428e8ceb..004db94f 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -356,6 +356,14 @@ class ComputingGraph : public std::enable_shared_from_this, int16_t graph_opt_level = 2; /*! + * disable inplace arith transformations during graph + * construction + * it effectively disable level-1 graph optimization + * only for internal use during de-serialization + */ + bool disable_inplace_arith_opt = false; + + /*! * max size of allreduce packs in MB * set this option to zero to disable PackAllReducePass */ diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 7570f714..bf6f2ab7 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -221,7 +221,8 @@ SymbolVar Elemwise::make(const VarNodeArrayView& inputs, Param param, trait.name, cg::dump_var_info(inputs).c_str()); #if !MGB_BUILD_SLIM_SERVING - if (inputs[0]->owner_graph()->options().graph_opt_level) { + auto&& options = inputs[0]->owner_graph()->options(); + if (options.graph_opt_level && !(options.disable_inplace_arith_opt)) { auto repl = gopt::optimize_elemwise_expr_inplace(dtp.get_vars(), param, config); if (repl) diff --git a/src/serialization/impl/serializer_oss.cpp b/src/serialization/impl/serializer_oss.cpp index 1f2727c3..9935089b 100644 --- a/src/serialization/impl/serializer_oss.cpp +++ b/src/serialization/impl/serializer_oss.cpp @@ -756,9 +756,15 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( GraphLoader::LoadResult GraphLoaderOSS::OprLoadContextImpl::load_oprs() { // load oprs const auto* oprs = m_loader->m_graph->oprs(); - for (flatbuffers::uoffset_t i = 0; i < oprs->size(); ++i) { - m_current_opr = oprs->Get(i); - load_single_opr(m_current_opr); + { + // inplace arith graph optimization is disabled during opr load + // it tries to restore the same graph as it was dumped + // see test TestSerializer2.LOGEXP for example + GraphLoader::ScopedGraphOptDisabler _(m_graph); + for (flatbuffers::uoffset_t i = 0; i < oprs->size(); ++i) { + m_current_opr = oprs->Get(i); + load_single_opr(m_current_opr); + } } // batched loading device values diff --git a/src/serialization/include/megbrain/serialization/serializer.h b/src/serialization/include/megbrain/serialization/serializer.h index 30706e2a..752fe740 100644 --- a/src/serialization/include/megbrain/serialization/serializer.h +++ b/src/serialization/include/megbrain/serialization/serializer.h @@ -61,6 +61,21 @@ namespace serialization { const ComputingGraph::OutputSpec &outspec); }; + //! helper to disable inplace arith graph optimization during + //! de-serialization + struct ScopedGraphOptDisabler { + bool option_saved; + std::shared_ptr cg; + ScopedGraphOptDisabler(std::shared_ptr& cg_p) + : option_saved(true), cg(cg_p) { + std::swap(option_saved, + cg->options().disable_inplace_arith_opt); + } + ~ScopedGraphOptDisabler() { + cg->options().disable_inplace_arith_opt = option_saved; + } + }; + //! mem_node => tensor_value using SharedTensorMapEntry = ThinHashMap>; diff --git a/src/serialization/test/serializer_oss.cpp b/src/serialization/test/serializer_oss.cpp index bdb82734..cc6a7dff 100644 --- a/src/serialization/test/serializer_oss.cpp +++ b/src/serialization/test/serializer_oss.cpp @@ -761,4 +761,41 @@ TEST(TestSerializer2, HasOutputDtype) { load(); } +TEST(TestSerializer2, LOGEXP) { + auto fname = GET_OUTPUT_FILE(); + TensorShape shape{2, 3}; + using Mode = opr::Elemwise::Mode; + bool inplace_opt = true; + auto dump = [&]() { + auto cn = CompNode::load("xpu0"); + auto host_x = std::make_shared(cn, shape); + for (size_t i = 0, it = shape.total_nr_elems(); i < it; ++i) + host_x->ptr()[i] = 0.0; // To avoid NAN + auto graph = ComputingGraph::make(); + if (!inplace_opt) + graph->options().graph_opt_level = 0; + auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}); + auto y = opr::Elemwise::make({x}, Mode::EXP); + auto z = opr::Elemwise::make({y}, Mode::LOG); + + auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()), + GraphDumpFormat::FLATBUFFERS); + auto rst = dumper->dump({z.rename("z"), z}); + size_t expected_nr_opr = inplace_opt? 1: 3; + ASSERT_EQ(expected_nr_opr, rst.nr_opr); + }; + + auto load = [&]() { + auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()), + GraphDumpFormat::FLATBUFFERS); + auto rst = loader->load(); + }; + + dump(); + load(); + + inplace_opt = !inplace_opt; + dump(); + load(); +} #endif