@@ -92,7 +92,7 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, | |||||
size_t sw = this->param().stride_w; | size_t sw = this->param().stride_w; | ||||
size_t ph = this->param().pad_h; | size_t ph = this->param().pad_h; | ||||
size_t pw = this->param().pad_w; | size_t pw = this->param().pad_w; | ||||
if (ph < fh && pw < fw) { | |||||
if (ph >= fh || pw >= fw) { | |||||
megdnn_log_error( | megdnn_log_error( | ||||
"pooling padding size (%zu %zu) should not be bigger than " | "pooling padding size (%zu %zu) should not be bigger than " | ||||
"window size (%zu %zu), it only can be used in CaffePooling", | "window size (%zu %zu), it only can be used in CaffePooling", | ||||
@@ -135,6 +135,13 @@ void JITExecutor::init_output_mem_plan(bool dynamic) { | |||||
m_args.need_update = true; | m_args.need_update = true; | ||||
} | } | ||||
void JITExecutor::mem_plan_fwd_in2out_writable() { | |||||
//! currently mem fwd only support elemwise fusion | |||||
if (m_feature_bits != JITFeatureBits::NONE) return; | |||||
mixin_mem_plan_fwd_in2out_writable(*this); | |||||
} | |||||
SymbolVar JITExecutor::make(const InternalGraphPtr& internal_graph, | SymbolVar JITExecutor::make(const InternalGraphPtr& internal_graph, | ||||
const VarNodeArray& inputs, | const VarNodeArray& inputs, | ||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
@@ -13,6 +13,7 @@ | |||||
#include "megbrain/graph/operator_node.h" | #include "megbrain/graph/operator_node.h" | ||||
#include "megbrain/jit/internal_graph.h" | #include "megbrain/jit/internal_graph.h" | ||||
#include "megbrain/opr/internal/identical_fwd.h" | |||||
#if MGB_JIT | #if MGB_JIT | ||||
@@ -31,7 +32,8 @@ class Compiler; | |||||
* JITExecutor generates runtime Args for this specific inputs, and calls | * JITExecutor generates runtime Args for this specific inputs, and calls | ||||
* methods in Compiler to get the Executable object for actual computing. | * methods in Compiler to get the Executable object for actual computing. | ||||
*/ | */ | ||||
MGB_DEFINE_OPR_CLASS(JITExecutor, cg::SingleCNOperatorNodeBase) // { | |||||
MGB_DEFINE_OPR_CLASS(JITExecutor, cg::SingleCNOperatorNodeBase, | |||||
opr::mixin::FwdIn2OutWritableHelper) // { | |||||
using ModeTrait = megdnn::Elemwise::ModeTrait; | using ModeTrait = megdnn::Elemwise::ModeTrait; | ||||
InternalGraphPtr m_internal_graph; | InternalGraphPtr m_internal_graph; | ||||
@@ -57,6 +59,8 @@ public: | |||||
void init_output_mem_plan(bool dynamic) override; | void init_output_mem_plan(bool dynamic) override; | ||||
void mem_plan_fwd_in2out_writable() override; | |||||
const InternalGraph& internal_graph() const { return *m_internal_graph; } | const InternalGraph& internal_graph() const { return *m_internal_graph; } | ||||
const InternalGraphPtr internal_graph_ptr() const { | const InternalGraphPtr internal_graph_ptr() const { | ||||
@@ -137,6 +137,12 @@ void run<basic>(Backend backend, CompNode cn) { | |||||
// only one broadcast is allowed in JIT fusion | // only one broadcast is allowed in JIT fusion | ||||
ASSERT_EQ(1u, jits[0]->input().size()); | ASSERT_EQ(1u, jits[0]->input().size()); | ||||
ASSERT_EQ(4u, jits[1]->input().size()); | ASSERT_EQ(4u, jits[1]->input().size()); | ||||
//! check memfwd | |||||
ASSERT_EQ(prev_dev_ptr(jits[0]->input(0)), | |||||
prev_dev_ptr(jits[0]->output(0))); | |||||
ASSERT_EQ(prev_dev_ptr(jits[1]->input(0)), | |||||
prev_dev_ptr(jits[1]->output(0))); | |||||
} | } | ||||
template <> | template <> | ||||
@@ -338,32 +338,7 @@ void Elemwise::broadcast_collective_collapse( | |||||
} | } | ||||
void Elemwise::mem_plan_fwd_in2out_writable() { | void Elemwise::mem_plan_fwd_in2out_writable() { | ||||
auto &&inp = input(); | |||||
auto isize = inp.size(); | |||||
mgb_assert(isize <= 6); | |||||
bool have_conflict[6] = {false}; | |||||
for (size_t i = 0; i < isize; ++i) { | |||||
for (size_t j = i + 1; j < isize; ++j) { | |||||
auto type = cg::get_mem_plan_intersection_type(inp[i], inp[j]); | |||||
using Type = cg::MemPlanIntersectionType; | |||||
bool overlap = type == Type::OVERLAP; | |||||
bool self_fwd = type == Type::IDENTICAL && | |||||
(!inp[i]->layout().is_contiguous() || | |||||
!inp[j]->layout().is_contiguous()); | |||||
if (overlap || self_fwd) { | |||||
have_conflict[i] = true; | |||||
have_conflict[j] = true; | |||||
} | |||||
} | |||||
} | |||||
auto o = output(0); | |||||
for (size_t idx = 0; idx < isize; ++ idx) { | |||||
auto i = inp[idx]; | |||||
// equal shape means no broadcast | |||||
if (!have_conflict[idx] && | |||||
o->shape().eq_shape(i->shape()) && i->layout().is_contiguous()) | |||||
o->set_fwd_in2out_writable(i); | |||||
} | |||||
mixin_mem_plan_fwd_in2out_writable(*this); | |||||
} | } | ||||
void Elemwise::scn_do_execute() { | void Elemwise::scn_do_execute() { | ||||
@@ -33,6 +33,37 @@ void mixin::init_rt_force_dynamic_mem_alloc_imply_chain_for_dyn_pass_i2o( | |||||
valid_out->add_rt_force_dynamic_mem_alloc_imply_chain(opr.input(0)); | valid_out->add_rt_force_dynamic_mem_alloc_imply_chain(opr.input(0)); | ||||
} | } | ||||
/* ===================== FwdIn2OutWritableHelper ===================== */ | |||||
void FwdIn2OutWritableHelper::mixin_mem_plan_fwd_in2out_writable( | |||||
OperatorNodeBase& opr) { | |||||
auto&& inp = opr.input(); | |||||
auto isize = inp.size(); | |||||
std::vector<bool> have_conflict(isize, false); | |||||
for (size_t i = 0; i < isize; ++i) { | |||||
for (size_t j = i + 1; j < isize; ++j) { | |||||
auto type = cg::get_mem_plan_intersection_type(inp[i], inp[j]); | |||||
using Type = cg::MemPlanIntersectionType; | |||||
bool overlap = type == Type::OVERLAP; | |||||
bool self_fwd = type == Type::IDENTICAL && | |||||
(!inp[i]->layout().is_contiguous() || | |||||
!inp[j]->layout().is_contiguous()); | |||||
if (overlap || self_fwd) { | |||||
have_conflict[i] = true; | |||||
have_conflict[j] = true; | |||||
} | |||||
} | |||||
} | |||||
auto o = opr.output(0); | |||||
for (size_t idx = 0; idx < isize; ++ idx) { | |||||
auto i = inp[idx]; | |||||
// equal shape means no broadcast | |||||
if (!have_conflict[idx] && o->shape().eq_shape(i->shape()) && | |||||
o->dtype().enumv() == i->dtype().enumv() && | |||||
i->layout().is_contiguous()) | |||||
o->set_fwd_in2out_writable(i); | |||||
} | |||||
} | |||||
/* ===================== ReadonlyFwdHelper ===================== */ | /* ===================== ReadonlyFwdHelper ===================== */ | ||||
void ReadonlyFwdHelper::mixin_rofwd_init_mem_plan(OperatorNodeBase &opr) { | void ReadonlyFwdHelper::mixin_rofwd_init_mem_plan(OperatorNodeBase &opr) { | ||||
@@ -58,7 +58,8 @@ namespace intl { | |||||
* The operands are broadcasted automatically on dimensions of shape one to | * The operands are broadcasted automatically on dimensions of shape one to | ||||
* match shapes of each other; it works like broadcasting in numpy. | * match shapes of each other; it works like broadcasting in numpy. | ||||
*/ | */ | ||||
MGB_DEFINE_OPR_CLASS(Elemwise, intl::ElemwiseBase) // { | |||||
MGB_DEFINE_OPR_CLASS(Elemwise, intl::ElemwiseBase, | |||||
mixin::FwdIn2OutWritableHelper) // { | |||||
using ModeTrait = megdnn::Elemwise::ModeTrait; | using ModeTrait = megdnn::Elemwise::ModeTrait; | ||||
public: | public: | ||||
@@ -19,6 +19,19 @@ namespace opr { | |||||
namespace mixin { | namespace mixin { | ||||
/*! | |||||
* \brief mixin for operators which essentially works by forward input to output | |||||
*/ | |||||
class FwdIn2OutWritableHelper : public cg::OperatorNodeMixinBase { | |||||
protected: | |||||
/*! | |||||
* \brief call this function in mem_plan_fwd_in2out_writable(), | |||||
* this function will check if the input have conflict to find if the | |||||
* output can be forward. | |||||
*/ | |||||
void mixin_mem_plan_fwd_in2out_writable(OperatorNodeBase &opr); | |||||
}; | |||||
//! for internal use by DynamicOutputIfInputDynamic | //! for internal use by DynamicOutputIfInputDynamic | ||||
void init_rt_force_dynamic_mem_alloc_imply_chain_for_dyn_pass_i2o( | void init_rt_force_dynamic_mem_alloc_imply_chain_for_dyn_pass_i2o( | ||||
OperatorNodeBase &opr); | OperatorNodeBase &opr); | ||||