Browse Source

feat(jit): add memfwd in jit executor opr

GitOrigin-RevId: b58860bbe8
release-1.2
Megvii Engine Team 4 years ago
parent
commit
8f7f52ae4d
8 changed files with 66 additions and 29 deletions
  1. +1
    -1
      dnn/src/common/pooling.cpp
  2. +7
    -0
      src/jit/impl/executor_opr.cpp
  3. +5
    -1
      src/jit/include/megbrain/jit/executor_opr.h
  4. +6
    -0
      src/jit/test/fusion.cpp
  5. +1
    -26
      src/opr/impl/basic_arith.cpp
  6. +31
    -0
      src/opr/impl/internal/identical_fwd.cpp
  7. +2
    -1
      src/opr/include/megbrain/opr/basic_arith.h
  8. +13
    -0
      src/opr/include/megbrain/opr/internal/identical_fwd.h

+ 1
- 1
dnn/src/common/pooling.cpp View File

@@ -92,7 +92,7 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
size_t sw = this->param().stride_w;
size_t ph = this->param().pad_h;
size_t pw = this->param().pad_w;
if (ph < fh && pw < fw) {
if (ph >= fh || pw >= fw) {
megdnn_log_error(
"pooling padding size (%zu %zu) should not be bigger than "
"window size (%zu %zu), it only can be used in CaffePooling",


+ 7
- 0
src/jit/impl/executor_opr.cpp View File

@@ -135,6 +135,13 @@ void JITExecutor::init_output_mem_plan(bool dynamic) {
m_args.need_update = true;
}

void JITExecutor::mem_plan_fwd_in2out_writable() {
//! currently mem fwd only support elemwise fusion
if (m_feature_bits != JITFeatureBits::NONE) return;
mixin_mem_plan_fwd_in2out_writable(*this);
}


SymbolVar JITExecutor::make(const InternalGraphPtr& internal_graph,
const VarNodeArray& inputs,
const OperatorNodeConfig& config) {


+ 5
- 1
src/jit/include/megbrain/jit/executor_opr.h View File

@@ -13,6 +13,7 @@

#include "megbrain/graph/operator_node.h"
#include "megbrain/jit/internal_graph.h"
#include "megbrain/opr/internal/identical_fwd.h"

#if MGB_JIT

@@ -31,7 +32,8 @@ class Compiler;
* JITExecutor generates runtime Args for this specific inputs, and calls
* methods in Compiler to get the Executable object for actual computing.
*/
MGB_DEFINE_OPR_CLASS(JITExecutor, cg::SingleCNOperatorNodeBase) // {
MGB_DEFINE_OPR_CLASS(JITExecutor, cg::SingleCNOperatorNodeBase,
opr::mixin::FwdIn2OutWritableHelper) // {
using ModeTrait = megdnn::Elemwise::ModeTrait;

InternalGraphPtr m_internal_graph;
@@ -57,6 +59,8 @@ public:

void init_output_mem_plan(bool dynamic) override;

void mem_plan_fwd_in2out_writable() override;

const InternalGraph& internal_graph() const { return *m_internal_graph; }

const InternalGraphPtr internal_graph_ptr() const {


+ 6
- 0
src/jit/test/fusion.cpp View File

@@ -137,6 +137,12 @@ void run<basic>(Backend backend, CompNode cn) {
// only one broadcast is allowed in JIT fusion
ASSERT_EQ(1u, jits[0]->input().size());
ASSERT_EQ(4u, jits[1]->input().size());

//! check memfwd
ASSERT_EQ(prev_dev_ptr(jits[0]->input(0)),
prev_dev_ptr(jits[0]->output(0)));
ASSERT_EQ(prev_dev_ptr(jits[1]->input(0)),
prev_dev_ptr(jits[1]->output(0)));
}

template <>


+ 1
- 26
src/opr/impl/basic_arith.cpp View File

@@ -338,32 +338,7 @@ void Elemwise::broadcast_collective_collapse(
}

void Elemwise::mem_plan_fwd_in2out_writable() {
auto &&inp = input();
auto isize = inp.size();
mgb_assert(isize <= 6);
bool have_conflict[6] = {false};
for (size_t i = 0; i < isize; ++i) {
for (size_t j = i + 1; j < isize; ++j) {
auto type = cg::get_mem_plan_intersection_type(inp[i], inp[j]);
using Type = cg::MemPlanIntersectionType;
bool overlap = type == Type::OVERLAP;
bool self_fwd = type == Type::IDENTICAL &&
(!inp[i]->layout().is_contiguous() ||
!inp[j]->layout().is_contiguous());
if (overlap || self_fwd) {
have_conflict[i] = true;
have_conflict[j] = true;
}
}
}
auto o = output(0);
for (size_t idx = 0; idx < isize; ++ idx) {
auto i = inp[idx];
// equal shape means no broadcast
if (!have_conflict[idx] &&
o->shape().eq_shape(i->shape()) && i->layout().is_contiguous())
o->set_fwd_in2out_writable(i);
}
mixin_mem_plan_fwd_in2out_writable(*this);
}

void Elemwise::scn_do_execute() {


+ 31
- 0
src/opr/impl/internal/identical_fwd.cpp View File

@@ -33,6 +33,37 @@ void mixin::init_rt_force_dynamic_mem_alloc_imply_chain_for_dyn_pass_i2o(
valid_out->add_rt_force_dynamic_mem_alloc_imply_chain(opr.input(0));
}

/* ===================== FwdIn2OutWritableHelper ===================== */
void FwdIn2OutWritableHelper::mixin_mem_plan_fwd_in2out_writable(
OperatorNodeBase& opr) {
auto&& inp = opr.input();
auto isize = inp.size();
std::vector<bool> have_conflict(isize, false);
for (size_t i = 0; i < isize; ++i) {
for (size_t j = i + 1; j < isize; ++j) {
auto type = cg::get_mem_plan_intersection_type(inp[i], inp[j]);
using Type = cg::MemPlanIntersectionType;
bool overlap = type == Type::OVERLAP;
bool self_fwd = type == Type::IDENTICAL &&
(!inp[i]->layout().is_contiguous() ||
!inp[j]->layout().is_contiguous());
if (overlap || self_fwd) {
have_conflict[i] = true;
have_conflict[j] = true;
}
}
}
auto o = opr.output(0);
for (size_t idx = 0; idx < isize; ++ idx) {
auto i = inp[idx];
// equal shape means no broadcast
if (!have_conflict[idx] && o->shape().eq_shape(i->shape()) &&
o->dtype().enumv() == i->dtype().enumv() &&
i->layout().is_contiguous())
o->set_fwd_in2out_writable(i);
}
}

/* ===================== ReadonlyFwdHelper ===================== */

void ReadonlyFwdHelper::mixin_rofwd_init_mem_plan(OperatorNodeBase &opr) {


+ 2
- 1
src/opr/include/megbrain/opr/basic_arith.h View File

@@ -58,7 +58,8 @@ namespace intl {
* The operands are broadcasted automatically on dimensions of shape one to
* match shapes of each other; it works like broadcasting in numpy.
*/
MGB_DEFINE_OPR_CLASS(Elemwise, intl::ElemwiseBase) // {
MGB_DEFINE_OPR_CLASS(Elemwise, intl::ElemwiseBase,
mixin::FwdIn2OutWritableHelper) // {
using ModeTrait = megdnn::Elemwise::ModeTrait;

public:


+ 13
- 0
src/opr/include/megbrain/opr/internal/identical_fwd.h View File

@@ -19,6 +19,19 @@ namespace opr {

namespace mixin {

/*!
* \brief mixin for operators which essentially works by forward input to output
*/
class FwdIn2OutWritableHelper : public cg::OperatorNodeMixinBase {
protected:
/*!
* \brief call this function in mem_plan_fwd_in2out_writable(),
* this function will check if the input have conflict to find if the
* output can be forward.
*/
void mixin_mem_plan_fwd_in2out_writable(OperatorNodeBase &opr);
};

//! for internal use by DynamicOutputIfInputDynamic
void init_rt_force_dynamic_mem_alloc_imply_chain_for_dyn_pass_i2o(
OperatorNodeBase &opr);


Loading…
Cancel
Save