@@ -88,15 +88,34 @@ JITExecutor::JITExecutor(const InternalGraphPtr& internal_graph,
cg::add_workspace_output(this);
cg::add_workspace_output(this);
}
}
// check if output of internal_graph is depend on all placeholders
size_t nr_placeholders = internal_graph_ptr()->placeholders().size();
std::vector<bool> used(nr_placeholders, false);
// check if there is reduce or dimshuffle opr
// check if there is reduce or dimshuffle opr
cg::DepOprIter{[this](cg::OperatorNodeBase* opr) {
cg::DepOprIter{[this, nr_placeholders, &used ](cg::OperatorNodeBase* opr) {
if (opr->same_type<opr::Reduce>()) {
if (opr->same_type<opr::Reduce>()) {
m_feature_bits |= JITFeatureBits::REDUCE;
m_feature_bits |= JITFeatureBits::REDUCE;
}
}
if (opr->same_type<opr::Dimshuffle>()) {
if (opr->same_type<opr::Dimshuffle>()) {
m_feature_bits |= JITFeatureBits::DIMSHUFFLE;
m_feature_bits |= JITFeatureBits::DIMSHUFFLE;
}
}
if (auto ph = opr->try_cast_final<JITPlaceholder>()) {
mgb_assert(ph->input_id() < nr_placeholders,
"bad placeholders %s in JITExecutor %s",
ph->cname(), cname());
used[ph->input_id()] = true;
}
}}.add(internal_graph->output());
}}.add(internal_graph->output());
for (size_t i = 0; i < nr_placeholders; ++ i) {
mgb_assert(used[i],
"placeholder %s is not depended on the output of %s",
internal_graph_ptr()->placeholders()[i]->cname(), cname());
}
if (has_dimshuffle()) {
prepare_dimshuffle();
}
}
}
void JITExecutor::add_input_layout_constraint() {
void JITExecutor::add_input_layout_constraint() {
@@ -151,14 +170,14 @@ void JITExecutor::scn_do_execute() {
//! can be ignored
//! can be ignored
void JITExecutor::do_dimshuffle() {
void JITExecutor::do_dimshuffle() {
auto get_dimshuffled_layout = [](const TensorLayout& ily, int32_t* pattern,
size_t pattern_le n) {
static auto get_dimshuffled_layout = [](const TensorLayout& ily,
std::vector<int> patter n) {
TensorLayout oly{ily.dtype};
TensorLayout oly{ily.dtype};
oly.ndim = pattern_len ;
oly.ndim = pattern.size() ;
bool input_used[TensorLayout::MAX_NDIM] = {0};
bool input_used[TensorLayout::MAX_NDIM] = {0};
for (uint32_t idx = 0; idx < pattern_len ; ++idx) {
for (uint32_t idx = 0; idx < pattern.size() ; ++idx) {
auto i = pattern[idx];
auto i = pattern[idx];
if (i < 0) {
if (i < 0) {
oly.shape[idx] = 1;
oly.shape[idx] = 1;
@@ -179,53 +198,20 @@ void JITExecutor::do_dimshuffle() {
return oly;
return oly;
};
};
// DFS to make sure traverse the dimshuffles in one branch
std::unordered_set<VarNode*> visited;
std::vector<OperatorNodeBase*> stack(0);
std::vector<uint8_t> idx(0); // input index
stack.push_back(m_internal_graph->output()->owner_opr());
idx.push_back(0);
while (!stack.empty()) {
if (idx.back() < stack.back()->input().size() &&
!visited.count(stack.back()->input(idx.back()))) {
visited.insert(stack.back()->input(idx.back()));
stack.push_back(stack.back()->input(idx.back())->owner_opr());
if (stack.back()->same_type<jit::JITPlaceholder>()) {
auto jitph = gopt::try_cast_as_op<JITPlaceholder>(stack.back());
size_t input_id = jitph->input_id();
auto&& input = m_args.inputs[input_id];
for (int i = stack.size() - 1; i >= 0; --i) {
if (stack[i]->same_type<opr::Dimshuffle>()) {
auto param =
stack[i]->cast_final_safe<opr::Dimshuffle>()
.param();
mgb_assert(input.layout.ndim == param.ndim,
"input ndim mismatch for Dimshuffle: "
"expect=%u "
"actual=%zu",
param.ndim, input.layout.ndim);
auto dimshuffled_layout = get_dimshuffled_layout(
input.layout, param.pattern, param.pattern_len);
input.layout = dimshuffled_layout;
}
}
stack.pop_back();
++idx.back();
} else {
idx.push_back(0);
}
} else {
stack.pop_back();
idx.pop_back();
if (!stack.empty())
++idx.back();
}
for (auto&& i : m_internal_graph->placeholders()) {
auto&& input = m_args.inputs[i->input_id()];
auto&& iter = m_jitph2dimshuffle.find(i);
if (iter == m_jitph2dimshuffle.end()) continue;
auto&& param = iter->second;
mgb_assert(input.layout.ndim == param.second,
"input ndim mismatch for Dimshuffle: "
"expect=%u "
"actual=%zu",
param.second, input.layout.ndim);
auto dimshuffled_layout = get_dimshuffled_layout(
input.layout, param.first);
input.layout = dimshuffled_layout;
}
}
}
}
void JITExecutor::update_args() {
void JITExecutor::update_args() {
@@ -259,7 +245,9 @@ void JITExecutor::update_args() {
}
}
//! dimshuffle opr need to change the input.
//! dimshuffle opr need to change the input.
do_dimshuffle();
if (has_dimshuffle()) {
do_dimshuffle();
}
if (m_compiler->property().contain_flag(CPFlag::NEED_INPUT_COLLAPSE)) {
if (m_compiler->property().contain_flag(CPFlag::NEED_INPUT_COLLAPSE)) {
// collective collapse datum layout, try to reduce the output ndim
// collective collapse datum layout, try to reduce the output ndim
@@ -304,6 +292,82 @@ void JITExecutor::update_args() {
m_args.need_update = false;
m_args.need_update = false;
}
}
void JITExecutor::prepare_dimshuffle() {
std::unordered_set<OperatorNodeBase*> visited;
std::vector<OperatorNodeBase*> stack(0);
std::vector<uint8_t> idx(0); // input index
using Param = DimshuffleParam;
std::vector<Param> dimshuffle_stack;
auto merge_dimshuffle = [&](const opr::Dimshuffle::Param& p) {
if (dimshuffle_stack.empty()) {
dimshuffle_stack.emplace_back();
auto&& param = dimshuffle_stack.back();
param.first.insert(param.first.end(), p.pattern, p.pattern + p.pattern_len);
param.second = p.ndim;
} else {
// merge(p, src) -> param and it has performing dimshuffle(dimshuffle(x, p), src)
// is equivalent to dimshuffle(x, param)
dimshuffle_stack.emplace_back();
auto&& param = dimshuffle_stack.back();
auto&& src = dimshuffle_stack[dimshuffle_stack.size() - 2];
mgb_assert(p.pattern_len == src.second);
param.first.resize(src.first.size());
for (size_t i = 0; i < src.first.size(); ++ i) {
if (src.first[i] == -1) {
param.first[i] = -1;
} else {
param.first[i] = p.pattern[src.first[i]];
}
}
param.second = p.ndim;
}
};
auto push_back = [&](cg::OperatorNodeBase* op) {
mgb_assert(!op->same_type<jit::JITPlaceholder>());
if (auto o = op->try_cast_final<opr::Dimshuffle>()) {
merge_dimshuffle(o->param());
}
stack.push_back(op);
idx.push_back(0);
};
auto pop_back = [&]() {
auto&& op = stack.back();
if (op->same_type<opr::Dimshuffle>()) {
dimshuffle_stack.pop_back();
}
stack.pop_back();
idx.pop_back();
};
push_back(m_internal_graph->output()->owner_opr());
while (!stack.empty()) {
if (idx.back() < stack.back()->input().size()) {
auto cur_opr = stack.back()->input(idx.back())->owner_opr();
if (visited.insert(cur_opr).second) {
if (auto jitph = cur_opr->try_cast_final<jit::JITPlaceholder>()) {
if (!dimshuffle_stack.empty()) {
mgb_assert(
m_jitph2dimshuffle.emplace(jitph, dimshuffle_stack.back()).second,
"already visited JITPlaceholder %s",
jitph->cname());
}
++ idx.back();
} else {
push_back(cur_opr);
}
} else {
++ idx.back();
}
} else {
pop_back();
if (!stack.empty())
++ idx.back();
}
}
}
const JITExecutor::Args& JITExecutor::args() const {
const JITExecutor::Args& JITExecutor::args() const {
if (m_args.need_update) {
if (m_args.need_update) {
const_cast<JITExecutor*>(this)->update_args();
const_cast<JITExecutor*>(this)->update_args();
@@ -383,6 +447,56 @@ megdnn::TensorShape JITExecutor::broadcasted_input_shape() const {
#if MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
namespace {
class InternalGraphRewriter {
ThinHashMap<VarNode*, VarNode*> m_var_map;
VarNode* m_dest_var;
VarNodeArray m_new_inp;
VarNode* get_var(VarNode* var) {
auto&& iter = m_var_map.find(var);
if (iter != m_var_map.end()) {
return iter->second;
}
return var;
}
public:
InternalGraphRewriter(VarNode* dest_var)
:m_dest_var{dest_var}{}
void iter(thin_function<void(cg::OperatorNodeBase*)>&& cb) {
m_var_map.clear();
cg::DepOprIter{std::move(cb)}.add(m_dest_var->owner_opr());
m_dest_var = get_var(m_dest_var);
}
VarNode* dest_var() {
return m_dest_var;
}
void replace_var(VarNode* src, VarNode* dst) {
// Note: do not perform var replacing recursively
// when we extract used placeholders from internal graph, we don't
// consider placeholder replacement pair (a to b), (b to c) as a
// var replacing chain (a to b to c) but as a injective function
// from (a, b) to (b, c)
// in other cases, each var node would be passed as \p src or
// \p dst at most once
m_var_map[src] = dst;
}
void auto_replace_outputs(cg::OperatorNodeBase* opr) {
// in JIT internal graph, output size of opr is always 1
mgb_assert(opr->usable_output().size() == 1);
m_new_inp.clear();
bool need_replace = false;
for (auto&& i : opr->input()) {
auto inp = get_var(i);
m_new_inp.push_back(inp);
need_replace |= (inp != i);
}
if (need_replace) {
auto new_op = serialization::copy_opr_shallow(*opr, m_new_inp);
replace_var(opr->output(0), new_op->output(0));
}
}
};
} // anonymous namespace
MGB_IMPL_OPR_GRAD(JITExecutor) {
MGB_IMPL_OPR_GRAD(JITExecutor) {
VarNodeArray grad_inputs;
VarNodeArray grad_inputs;
for (auto input : opr.input())
for (auto input : opr.input())
@@ -404,49 +518,120 @@ MGB_IMPL_OPR_GRAD(JITExecutor) {
if (gx.node()->owner_opr()->same_type<opr::InvalidGrad>()) {
if (gx.node()->owner_opr()->same_type<opr::InvalidGrad>()) {
return opr::InvalidGrad::make(opr, wrt_idx);
return opr::InvalidGrad::make(opr, wrt_idx);
}
}
// early return if grad expression is single node
for (size_t i = 0; i < fwd_igraph_ptr->placeholders().size(); ++i) {
if (gx.node() == fwd_igraph_ptr->placeholders()[i]->output(0)) {
return grad_inputs[i];
}
}
if (gx.node() == og_ph.node()) {
return out_grad[0];
}
if (gx.node() == fwd_igraph_ptr->output()) {
return opr.output(0);
}
if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(gx.node()->owner_opr())) {
HostTensorND hval{grad_inputs[0]->comp_node()};
hval.copy_from(imm->value()).sync();
return opr::ImmutableTensor::make(*imm->owner_graph(), hval).node();
}
// replace output var in internal graph with output placeholder, so
// we could forward opr.output(computeed by forward JITExecutor) into
// placeholder to avoid redundant computation
InternalGraphRewriter rewriter{gx.node()};
rewriter.iter([&rewriter, &fwd_igraph_ptr,
&output_ph](cg::OperatorNodeBase* opr) {
if (opr == fwd_igraph_ptr->output()->owner_opr()) {
rewriter.replace_var(opr->output(0), output_ph.node());
return;
}
rewriter.auto_replace_outputs(opr);
});
static auto expand_into_origin_graph = [](cg::OperatorNodeBase* opr,
InternalGraphRewriter& rewriter, const VarNodeArray& grad_inputs) {
if (auto ph = gopt::try_cast_as_op<JITPlaceholder>(opr)) {
rewriter.replace_var(
opr->output(0), grad_inputs.at(ph->input_id()));
return;
}
if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(opr)) {
HostTensorND hval{grad_inputs[0]->comp_node()};
hval.copy_from(imm->value()).sync();
rewriter.replace_var(opr->output(0),
opr::ImmutableTensor::make(*opr->owner_graph(), hval).node());
return;
}
rewriter.auto_replace_outputs(opr);
};
if (opr.compiler()->property().feature_bits & JITFeatureBits::REDUCE) {
if (opr.compiler()->property().feature_bits & JITFeatureBits::REDUCE) {
// expand the gradient graph into the original graph to handle bcast
// expand the gradient graph into the original graph to handle bcast
// oprs
// oprs
ThinHashMap<VarNode*, VarNode*> old2new;
VarNodeArray new_inp;
auto on_opr = [&old2new, &grad_inputs,
&new_inp](cg::OperatorNodeBase* opr) {
using namespace std::placeholders;
rewriter.iter(std::bind(expand_into_origin_graph, _1,
std::ref(rewriter), std::cref(grad_inputs)));
return rewriter.dest_var();
} else {
VarNodeArray new_grad_inputs;
PlaceholderArray placeholders;
bool all_inp_const = true;
// gx was not depend on all JITPlaceholders so we need to extract used
// placeholders and build a new internal graph
rewriter.iter([&rewriter, &grad_inputs, &new_grad_inputs,
&placeholders, &all_inp_const](cg::OperatorNodeBase* opr) {
if (auto ph = gopt::try_cast_as_op<JITPlaceholder>(opr)) {
if (auto ph = gopt::try_cast_as_op<JITPlaceholder>(opr)) {
old2new[opr->output(0)] = grad_inputs.at(ph->input_id());
return;
}
if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(opr)) {
HostTensorND hval{grad_inputs[0]->comp_node()};
hval.copy_from(imm->value()).sync();
old2new[opr->output(0)] =
opr::ImmutableTensor::make(*opr->owner_graph(), hval)
.node();
new_grad_inputs.push_back(grad_inputs[ph->input_id()]);
auto new_ph = JITPlaceholder::make(
new_grad_inputs.back(), placeholders.size())
.node()->owner_opr();
placeholders.push_back(new_ph->try_cast_final<JITPlaceholder>());
mgb_assert(placeholders.back());
rewriter.replace_var(opr->output(0), new_ph->output(0));
if (!cg::is_const_var_value(new_grad_inputs.back())) {
all_inp_const = false;
}
return;
return;
}
}
new_inp.clear();
for (auto inp : opr->input()) {
new_inp.push_back(old2new.at(inp));
}
auto new_opr = serialization::copy_opr_shallow(*opr, new_inp);
old2new[opr->output(0)] = new_opr->output(0);
};
cg::DepOprIter{on_opr}.add(gx.node());
return old2new.at(gx.node());
} else {
PlaceholderArray placeholders = fwd_igraph_ptr->placeholders();
for (SymbolVar i : {output_ph, og_ph}) {
placeholders.push_back(
&i.node()->owner_opr()->cast_final_safe<JITPlaceholder>());
rewriter.auto_replace_outputs(opr);
});
if (all_inp_const) {
// if all_inp_const, expand grad graph into origin graph by replace
// placeholders with const inputs, so it could benefit from static
// infer and const folding mechanism
using namespace std::placeholders;
rewriter.iter(std::bind(expand_into_origin_graph, _1,
std::ref(rewriter), std::cref(new_grad_inputs)));
return rewriter.dest_var();
}
}
for (size_t i = 0; i < placeholders.size(); ++i) {
if (gx.node() == placeholders[i]->output(0)) {
return grad_inputs[i];
gx = rewriter.dest_var();
auto shape_infer = fwd_igraph_ptr->shape_infer();
if (opr.has_dimshuffle()) {
auto&& iter = opr.dimshuffle_params().find(
fwd_igraph_ptr->placeholders()[wrt_idx]);
if (iter != opr.dimshuffle_params().end()) {
auto&& pattern = iter->second.first;
auto&& ndim = iter->second.second;
std::vector<int> back(ndim, -1);
for (size_t i = 0; i < pattern.size(); i ++) {
// outdim[i] is indim[j]
auto j = pattern[i];
if (j >= 0) {
mgb_assert(back[j] == -1,
"taking grad for Dimshuffle with duplicated "
"input axis unsupported");
back[j] = i;
}
}
shape_infer = opr::Dimshuffle::make(shape_infer, back, pattern.size()).node();
}
}
}
}
auto grad_ig = std::make_shared<InternalGraph>(
auto grad_ig = std::make_shared<InternalGraph>(
gx.node(), fwd_igraph_ptr->shape_infer(), nullptr,
gx.node(), shape_infer, nullptr,
std::move(placeholders));
std::move(placeholders));
auto grad_jit = JITExecutor::make(grad_ig, grad_inputs);
auto grad_jit = JITExecutor::make(grad_ig, new_ grad_inputs);
if (opr.input_broadcastable()[wrt_idx]) {
if (opr.input_broadcastable()[wrt_idx]) {
grad_jit = opr::reduce_sum(
grad_jit = opr::reduce_sum(