From bc95e873ef8c3fc6a0846c9f40d3ed2a0c01aeb7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 11 Jun 2020 21:52:07 +0800 Subject: [PATCH] fix(jit): fix jit grad a) fix shape mismatch when take grad of JITExecutor including Dimshuffle b) avoid redundant computation in the grad of JITExecutor c) not pass unused vars as inputs to the grad of JITExecutor to save device memory d) traverse internal graph only once in JITExecutor ctor instead of traverse whole graph in each call of setup_args() e) expand the gradient graph into the origin graph if all inputs are const GitOrigin-RevId: ba6a2b29e975c7f63a21785efad87dbda76143d4 --- src/jit/impl/executor_opr.cpp | 353 ++++++++++++++++++++------ src/jit/impl/placeholder_opr.cpp | 1 - src/jit/include/megbrain/jit/executor_opr.h | 12 + src/jit/include/megbrain/jit/internal_graph.h | 2 - src/jit/test/fusion.cpp | 10 + src/jit/test/helper.cpp | 2 +- src/jit/test/helper.h | 8 + 7 files changed, 300 insertions(+), 88 deletions(-) diff --git a/src/jit/impl/executor_opr.cpp b/src/jit/impl/executor_opr.cpp index d4c948c2..744af62f 100644 --- a/src/jit/impl/executor_opr.cpp +++ b/src/jit/impl/executor_opr.cpp @@ -88,15 +88,34 @@ JITExecutor::JITExecutor(const InternalGraphPtr& internal_graph, cg::add_workspace_output(this); } + // check if output of internal_graph is depend on all placeholders + size_t nr_placeholders = internal_graph_ptr()->placeholders().size(); + std::vector used(nr_placeholders, false); // check if there is reduce or dimshuffle opr - cg::DepOprIter{[this](cg::OperatorNodeBase* opr) { + cg::DepOprIter{[this, nr_placeholders, &used](cg::OperatorNodeBase* opr) { if (opr->same_type()) { m_feature_bits |= JITFeatureBits::REDUCE; } if (opr->same_type()) { m_feature_bits |= JITFeatureBits::DIMSHUFFLE; } + if (auto ph = opr->try_cast_final()) { + mgb_assert(ph->input_id() < nr_placeholders, + "bad placeholders %s in JITExecutor %s", + ph->cname(), cname()); + used[ph->input_id()] = true; + } }}.add(internal_graph->output()); + + for (size_t i = 0; i < nr_placeholders; ++ i) { + mgb_assert(used[i], + "placeholder %s is not depended on the output of %s", + internal_graph_ptr()->placeholders()[i]->cname(), cname()); + } + + if (has_dimshuffle()) { + prepare_dimshuffle(); + } } void JITExecutor::add_input_layout_constraint() { @@ -151,14 +170,14 @@ void JITExecutor::scn_do_execute() { //! can be ignored void JITExecutor::do_dimshuffle() { - auto get_dimshuffled_layout = [](const TensorLayout& ily, int32_t* pattern, - size_t pattern_len) { + static auto get_dimshuffled_layout = [](const TensorLayout& ily, + std::vector pattern) { TensorLayout oly{ily.dtype}; - oly.ndim = pattern_len; + oly.ndim = pattern.size(); bool input_used[TensorLayout::MAX_NDIM] = {0}; - for (uint32_t idx = 0; idx < pattern_len; ++idx) { + for (uint32_t idx = 0; idx < pattern.size(); ++idx) { auto i = pattern[idx]; if (i < 0) { oly.shape[idx] = 1; @@ -179,53 +198,20 @@ void JITExecutor::do_dimshuffle() { return oly; }; - // DFS to make sure traverse the dimshuffles in one branch - std::unordered_set visited; - std::vector stack(0); - std::vector idx(0); // input index - stack.push_back(m_internal_graph->output()->owner_opr()); - idx.push_back(0); - - while (!stack.empty()) { - if (idx.back() < stack.back()->input().size() && - !visited.count(stack.back()->input(idx.back()))) { - visited.insert(stack.back()->input(idx.back())); - stack.push_back(stack.back()->input(idx.back())->owner_opr()); - if (stack.back()->same_type()) { - auto jitph = gopt::try_cast_as_op(stack.back()); - size_t input_id = jitph->input_id(); - auto&& input = m_args.inputs[input_id]; - - for (int i = stack.size() - 1; i >= 0; --i) { - if (stack[i]->same_type()) { - auto param = - stack[i]->cast_final_safe() - .param(); - - mgb_assert(input.layout.ndim == param.ndim, - "input ndim mismatch for Dimshuffle: " - "expect=%u " - "actual=%zu", - param.ndim, input.layout.ndim); - auto dimshuffled_layout = get_dimshuffled_layout( - input.layout, param.pattern, param.pattern_len); - input.layout = dimshuffled_layout; - } - } - - stack.pop_back(); - ++idx.back(); - } else { - idx.push_back(0); - } - } else { - stack.pop_back(); - idx.pop_back(); - if (!stack.empty()) - ++idx.back(); - } + for (auto&& i : m_internal_graph->placeholders()) { + auto&& input = m_args.inputs[i->input_id()]; + auto&& iter = m_jitph2dimshuffle.find(i); + if (iter == m_jitph2dimshuffle.end()) continue; + auto&& param = iter->second; + mgb_assert(input.layout.ndim == param.second, + "input ndim mismatch for Dimshuffle: " + "expect=%u " + "actual=%zu", + param.second, input.layout.ndim); + auto dimshuffled_layout = get_dimshuffled_layout( + input.layout, param.first); + input.layout = dimshuffled_layout; } - } void JITExecutor::update_args() { @@ -259,7 +245,9 @@ void JITExecutor::update_args() { } //! dimshuffle opr need to change the input. - do_dimshuffle(); + if (has_dimshuffle()) { + do_dimshuffle(); + } if (m_compiler->property().contain_flag(CPFlag::NEED_INPUT_COLLAPSE)) { // collective collapse datum layout, try to reduce the output ndim @@ -304,6 +292,82 @@ void JITExecutor::update_args() { m_args.need_update = false; } +void JITExecutor::prepare_dimshuffle() { + std::unordered_set visited; + std::vector stack(0); + std::vector idx(0); // input index + using Param = DimshuffleParam; + std::vector dimshuffle_stack; + + auto merge_dimshuffle = [&](const opr::Dimshuffle::Param& p) { + if (dimshuffle_stack.empty()) { + dimshuffle_stack.emplace_back(); + auto&& param = dimshuffle_stack.back(); + param.first.insert(param.first.end(), p.pattern, p.pattern + p.pattern_len); + param.second = p.ndim; + } else { + // merge(p, src) -> param and it has performing dimshuffle(dimshuffle(x, p), src) + // is equivalent to dimshuffle(x, param) + dimshuffle_stack.emplace_back(); + auto&& param = dimshuffle_stack.back(); + auto&& src = dimshuffle_stack[dimshuffle_stack.size() - 2]; + mgb_assert(p.pattern_len == src.second); + param.first.resize(src.first.size()); + for (size_t i = 0; i < src.first.size(); ++ i) { + if (src.first[i] == -1) { + param.first[i] = -1; + } else { + param.first[i] = p.pattern[src.first[i]]; + } + } + param.second = p.ndim; + } + }; + auto push_back = [&](cg::OperatorNodeBase* op) { + mgb_assert(!op->same_type()); + if (auto o = op->try_cast_final()) { + merge_dimshuffle(o->param()); + } + stack.push_back(op); + idx.push_back(0); + }; + auto pop_back = [&]() { + auto&& op = stack.back(); + if (op->same_type()) { + dimshuffle_stack.pop_back(); + } + stack.pop_back(); + idx.pop_back(); + }; + + push_back(m_internal_graph->output()->owner_opr()); + + while (!stack.empty()) { + if (idx.back() < stack.back()->input().size()) { + auto cur_opr = stack.back()->input(idx.back())->owner_opr(); + if (visited.insert(cur_opr).second) { + if (auto jitph = cur_opr->try_cast_final()) { + if (!dimshuffle_stack.empty()) { + mgb_assert( + m_jitph2dimshuffle.emplace(jitph, dimshuffle_stack.back()).second, + "already visited JITPlaceholder %s", + jitph->cname()); + } + ++ idx.back(); + } else { + push_back(cur_opr); + } + } else { + ++ idx.back(); + } + } else { + pop_back(); + if (!stack.empty()) + ++ idx.back(); + } + } +} + const JITExecutor::Args& JITExecutor::args() const { if (m_args.need_update) { const_cast(this)->update_args(); @@ -383,6 +447,56 @@ megdnn::TensorShape JITExecutor::broadcasted_input_shape() const { #if MGB_ENABLE_GRAD +namespace { +class InternalGraphRewriter { + ThinHashMap m_var_map; + VarNode* m_dest_var; + VarNodeArray m_new_inp; + VarNode* get_var(VarNode* var) { + auto&& iter = m_var_map.find(var); + if (iter != m_var_map.end()) { + return iter->second; + } + return var; + } +public: + InternalGraphRewriter(VarNode* dest_var) + :m_dest_var{dest_var}{} + void iter(thin_function&& cb) { + m_var_map.clear(); + cg::DepOprIter{std::move(cb)}.add(m_dest_var->owner_opr()); + m_dest_var = get_var(m_dest_var); + } + VarNode* dest_var() { + return m_dest_var; + } + void replace_var(VarNode* src, VarNode* dst) { + // Note: do not perform var replacing recursively + // when we extract used placeholders from internal graph, we don't + // consider placeholder replacement pair (a to b), (b to c) as a + // var replacing chain (a to b to c) but as a injective function + // from (a, b) to (b, c) + // in other cases, each var node would be passed as \p src or + // \p dst at most once + m_var_map[src] = dst; + } + void auto_replace_outputs(cg::OperatorNodeBase* opr) { + // in JIT internal graph, output size of opr is always 1 + mgb_assert(opr->usable_output().size() == 1); + m_new_inp.clear(); + bool need_replace = false; + for (auto&& i : opr->input()) { + auto inp = get_var(i); + m_new_inp.push_back(inp); + need_replace |= (inp != i); + } + if (need_replace) { + auto new_op = serialization::copy_opr_shallow(*opr, m_new_inp); + replace_var(opr->output(0), new_op->output(0)); + } + } +}; +} // anonymous namespace MGB_IMPL_OPR_GRAD(JITExecutor) { VarNodeArray grad_inputs; for (auto input : opr.input()) @@ -404,49 +518,120 @@ MGB_IMPL_OPR_GRAD(JITExecutor) { if (gx.node()->owner_opr()->same_type()) { return opr::InvalidGrad::make(opr, wrt_idx); } + // early return if grad expression is single node + for (size_t i = 0; i < fwd_igraph_ptr->placeholders().size(); ++i) { + if (gx.node() == fwd_igraph_ptr->placeholders()[i]->output(0)) { + return grad_inputs[i]; + } + } + if (gx.node() == og_ph.node()) { + return out_grad[0]; + } + if (gx.node() == fwd_igraph_ptr->output()) { + return opr.output(0); + } + if (auto imm = gopt::try_cast_as_op(gx.node()->owner_opr())) { + HostTensorND hval{grad_inputs[0]->comp_node()}; + hval.copy_from(imm->value()).sync(); + return opr::ImmutableTensor::make(*imm->owner_graph(), hval).node(); + } + + // replace output var in internal graph with output placeholder, so + // we could forward opr.output(computeed by forward JITExecutor) into + // placeholder to avoid redundant computation + InternalGraphRewriter rewriter{gx.node()}; + rewriter.iter([&rewriter, &fwd_igraph_ptr, + &output_ph](cg::OperatorNodeBase* opr) { + if (opr == fwd_igraph_ptr->output()->owner_opr()) { + rewriter.replace_var(opr->output(0), output_ph.node()); + return; + } + rewriter.auto_replace_outputs(opr); + }); + + static auto expand_into_origin_graph = [](cg::OperatorNodeBase* opr, + InternalGraphRewriter& rewriter, const VarNodeArray& grad_inputs) { + if (auto ph = gopt::try_cast_as_op(opr)) { + rewriter.replace_var( + opr->output(0), grad_inputs.at(ph->input_id())); + return; + } + if (auto imm = gopt::try_cast_as_op(opr)) { + HostTensorND hval{grad_inputs[0]->comp_node()}; + hval.copy_from(imm->value()).sync(); + rewriter.replace_var(opr->output(0), + opr::ImmutableTensor::make(*opr->owner_graph(), hval).node()); + return; + } + rewriter.auto_replace_outputs(opr); + }; + if (opr.compiler()->property().feature_bits & JITFeatureBits::REDUCE) { // expand the gradient graph into the original graph to handle bcast // oprs - ThinHashMap old2new; - VarNodeArray new_inp; - auto on_opr = [&old2new, &grad_inputs, - &new_inp](cg::OperatorNodeBase* opr) { + using namespace std::placeholders; + rewriter.iter(std::bind(expand_into_origin_graph, _1, + std::ref(rewriter), std::cref(grad_inputs))); + return rewriter.dest_var(); + } else { + VarNodeArray new_grad_inputs; + PlaceholderArray placeholders; + bool all_inp_const = true; + // gx was not depend on all JITPlaceholders so we need to extract used + // placeholders and build a new internal graph + rewriter.iter([&rewriter, &grad_inputs, &new_grad_inputs, + &placeholders, &all_inp_const](cg::OperatorNodeBase* opr) { if (auto ph = gopt::try_cast_as_op(opr)) { - old2new[opr->output(0)] = grad_inputs.at(ph->input_id()); - return; - } - if (auto imm = gopt::try_cast_as_op(opr)) { - HostTensorND hval{grad_inputs[0]->comp_node()}; - hval.copy_from(imm->value()).sync(); - old2new[opr->output(0)] = - opr::ImmutableTensor::make(*opr->owner_graph(), hval) - .node(); + new_grad_inputs.push_back(grad_inputs[ph->input_id()]); + auto new_ph = JITPlaceholder::make( + new_grad_inputs.back(), placeholders.size()) + .node()->owner_opr(); + placeholders.push_back(new_ph->try_cast_final()); + mgb_assert(placeholders.back()); + rewriter.replace_var(opr->output(0), new_ph->output(0)); + if (!cg::is_const_var_value(new_grad_inputs.back())) { + all_inp_const = false; + } return; } - new_inp.clear(); - for (auto inp : opr->input()) { - new_inp.push_back(old2new.at(inp)); - } - auto new_opr = serialization::copy_opr_shallow(*opr, new_inp); - old2new[opr->output(0)] = new_opr->output(0); - }; - cg::DepOprIter{on_opr}.add(gx.node()); - return old2new.at(gx.node()); - } else { - PlaceholderArray placeholders = fwd_igraph_ptr->placeholders(); - for (SymbolVar i : {output_ph, og_ph}) { - placeholders.push_back( - &i.node()->owner_opr()->cast_final_safe()); + rewriter.auto_replace_outputs(opr); + }); + if (all_inp_const) { + // if all_inp_const, expand grad graph into origin graph by replace + // placeholders with const inputs, so it could benefit from static + // infer and const folding mechanism + using namespace std::placeholders; + rewriter.iter(std::bind(expand_into_origin_graph, _1, + std::ref(rewriter), std::cref(new_grad_inputs))); + return rewriter.dest_var(); } - for (size_t i = 0; i < placeholders.size(); ++i) { - if (gx.node() == placeholders[i]->output(0)) { - return grad_inputs[i]; + gx = rewriter.dest_var(); + + auto shape_infer = fwd_igraph_ptr->shape_infer(); + if (opr.has_dimshuffle()) { + auto&& iter = opr.dimshuffle_params().find( + fwd_igraph_ptr->placeholders()[wrt_idx]); + if (iter != opr.dimshuffle_params().end()) { + auto&& pattern = iter->second.first; + auto&& ndim = iter->second.second; + std::vector back(ndim, -1); + for (size_t i = 0; i < pattern.size(); i ++) { + // outdim[i] is indim[j] + auto j = pattern[i]; + if (j >= 0) { + mgb_assert(back[j] == -1, + "taking grad for Dimshuffle with duplicated " + "input axis unsupported"); + back[j] = i; + } + } + shape_infer = opr::Dimshuffle::make(shape_infer, back, pattern.size()).node(); } } auto grad_ig = std::make_shared( - gx.node(), fwd_igraph_ptr->shape_infer(), nullptr, + gx.node(), shape_infer, nullptr, std::move(placeholders)); - auto grad_jit = JITExecutor::make(grad_ig, grad_inputs); + auto grad_jit = JITExecutor::make(grad_ig, new_grad_inputs); if (opr.input_broadcastable()[wrt_idx]) { grad_jit = opr::reduce_sum( diff --git a/src/jit/impl/placeholder_opr.cpp b/src/jit/impl/placeholder_opr.cpp index 85a1000e..9de43199 100644 --- a/src/jit/impl/placeholder_opr.cpp +++ b/src/jit/impl/placeholder_opr.cpp @@ -26,7 +26,6 @@ JITPlaceholder::JITPlaceholder(VarNode* src_var, size_t id, InpType inp_type) {}), m_inp_type{inp_type}, m_id{id} { - add_equivalence_component>(m_id); mgb_assert(src_var->dtype().category() == DTypeCategory::FLOAT || src_var->dtype().category() == DTypeCategory::INT, "JIT can only be applied to float/int operators, got %s", diff --git a/src/jit/include/megbrain/jit/executor_opr.h b/src/jit/include/megbrain/jit/executor_opr.h index 9f9f5159..dabfede6 100644 --- a/src/jit/include/megbrain/jit/executor_opr.h +++ b/src/jit/include/megbrain/jit/executor_opr.h @@ -35,6 +35,7 @@ MGB_DEFINE_OPR_CLASS(JITExecutor, cg::SingleCNOperatorNodeBase) // { using ModeTrait = megdnn::Elemwise::ModeTrait; InternalGraphPtr m_internal_graph; + using DimshuffleParam = std::pair, uint32_t>; public: using Mode = opr::Elemwise::Mode; @@ -112,6 +113,11 @@ public: return static_cast(m_feature_bits & JITFeatureBits::DIMSHUFFLE); } + const ThinHashMap& + dimshuffle_params() const { + return m_jitph2dimshuffle; + } + //! get broadcasted shape of inputs megdnn::TensorShape broadcasted_input_shape() const; @@ -124,8 +130,14 @@ private: Compiler* const m_compiler = nullptr; Executable* m_executable = nullptr; std::vector m_input_broadcastable; + // JITPlaceHolder -> pair of (dimshuffle pattern, ndim) + // do DFS on internal graph only once in prepare_dimshuffle(), so we can + // easily get the dimshuffle param which should be applied on given + // JITPlaceholder + ThinHashMap m_jitph2dimshuffle; void update_args(); void do_dimshuffle(); + void prepare_dimshuffle(); NodeProp* do_make_node_prop() const override; }; diff --git a/src/jit/include/megbrain/jit/internal_graph.h b/src/jit/include/megbrain/jit/internal_graph.h index 5b41d0df..94715536 100644 --- a/src/jit/include/megbrain/jit/internal_graph.h +++ b/src/jit/include/megbrain/jit/internal_graph.h @@ -61,8 +61,6 @@ public: const PlaceholderArray& placeholders() const { return m_placeholders; } - static InternalGraphPtr expand_excutor_op(const InternalGraphPtr&); - private: // For compilation cache, if the output_for_cache is same means the // expression tree is same. diff --git a/src/jit/test/fusion.cpp b/src/jit/test/fusion.cpp index 60a4a7ac..82606b5b 100644 --- a/src/jit/test/fusion.cpp +++ b/src/jit/test/fusion.cpp @@ -1435,6 +1435,16 @@ TEST(TestJITNvrtc, DimshuffleGrad) { funcs.second->execute(); MGB_ASSERT_TENSOR_NEAR(host_y1, host_y2, 1e-3); } + { + FusionChecker checker{2, + [](const SymbolVarArray& inp) -> SymbolVar { + auto var = opr::Dimshuffle::make(inp[0], {1, 2, 3, 0}); + return inp[1] * var; + }, + CompNode::load("gpu0")}; + checker.set_jit_level(1) + .run({TensorShape{1, 2, 3, 4}, {2, 3, 4, 1}}); + } } #endif // MGB_JIT diff --git a/src/jit/test/helper.cpp b/src/jit/test/helper.cpp index 6889c04d..c0311a5d 100644 --- a/src/jit/test/helper.cpp +++ b/src/jit/test/helper.cpp @@ -98,7 +98,7 @@ void FusionChecker::ensure_init_graph() { } else { ComputingGraph::Options opt; opt.graph_opt_level = 3; - opt.graph_opt.jit = 2; + opt.graph_opt.jit = m_jit_level; unpack_vector(gopt::GraphOptimizer{} .add_preset_passes(true, nullptr, &opt) .apply({{m_truth_y}}) diff --git a/src/jit/test/helper.h b/src/jit/test/helper.h index 6c3b8eca..33d3e1b7 100644 --- a/src/jit/test/helper.h +++ b/src/jit/test/helper.h @@ -65,6 +65,13 @@ public: return *this; } + //! set jit level, default is 2, see graph_opt.jit in graph options + //! for more details + FusionChecker& set_jit_level(uint8_t jit_level) { + m_jit_level = jit_level; + return *this; + } + /*! * \brief run and check correctness * @@ -76,6 +83,7 @@ private: bool m_check_opr_type = true; bool m_direct_build = false; const size_t m_nr_input; + uint8_t m_jit_level = 2; const CompNode m_comp_node; HostTensorGenerator<> m_input_gen; SmallVector> m_inputs_val;