|
|
@@ -657,6 +657,85 @@ OP_TRAIT_REG(CompiledOp, CompiledOp) |
|
|
|
} // namespace compiled_op |
|
|
|
} // namespace |
|
|
|
|
|
|
|
namespace { |
|
|
|
namespace jit_fusion { |
|
|
|
|
|
|
|
static thread_local bool tm_enabled = true; |
|
|
|
|
|
|
|
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { |
|
|
|
auto& op = def.cast_final_safe<JITFusionOp>(); |
|
|
|
op.op->set_scope(op.scope()); |
|
|
|
auto outputs = OpDef::apply_on_var_node(*op.op, inputs); |
|
|
|
if (!tm_enabled) { |
|
|
|
// skip for dump (JITExecutor can not be dumped) |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
for (auto& output : outputs) { |
|
|
|
jit::InternalGraphGenerator igg{output->owner_opr()}; |
|
|
|
std::vector<cg::OperatorNodeBase*> reverse_order; |
|
|
|
cg::DepOprIter iter{ |
|
|
|
[&](cg::OperatorNodeBase* opr) { reverse_order.push_back(opr); }}; |
|
|
|
for (auto&& input : inputs) { |
|
|
|
iter.set_visited(input->owner_opr()); |
|
|
|
} |
|
|
|
iter.add(output->owner_opr()); |
|
|
|
std::reverse(reverse_order.begin(), reverse_order.end()); |
|
|
|
for (auto&& opr : reverse_order) { |
|
|
|
igg.add_opr(opr); |
|
|
|
} |
|
|
|
auto ig = igg.generate(); |
|
|
|
output = jit::JITExecutor::make(ig, igg.orig_inps()).node(); |
|
|
|
} |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
|
|
|
|
auto infer_output_attrs_fallible( |
|
|
|
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { |
|
|
|
return OpDef::infer_output_attrs_fallible( |
|
|
|
*def.cast_final_safe<JITFusionOp>().op, input_descs); |
|
|
|
} |
|
|
|
|
|
|
|
auto props(const OpDef& def) { |
|
|
|
return OpDef::props(*def.cast_final_safe<JITFusionOp>().op); |
|
|
|
} |
|
|
|
|
|
|
|
auto hash(const OpDef& def) { |
|
|
|
return def.cast_final_safe<JITFusionOp>().op->hash(); |
|
|
|
} |
|
|
|
|
|
|
|
auto is_samt_st(const OpDef& def, const OpDef& another) { |
|
|
|
if (!another.same_type<JITFusionOp>()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto& lhs = def.cast_final_safe<JITFusionOp>(); |
|
|
|
auto& rhs = another.cast_final_safe<JITFusionOp>(); |
|
|
|
return lhs.op->is_same(*rhs.op); |
|
|
|
} |
|
|
|
|
|
|
|
EncodedSubgraph make_backward_graph( |
|
|
|
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, |
|
|
|
const SmallVector<bool>& input_requires_grad, |
|
|
|
const SmallVector<bool>& output_has_grad) { |
|
|
|
return {}; |
|
|
|
} |
|
|
|
|
|
|
|
OP_TRAIT_REG(JITFusionOp, JITFusionOp) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
.infer_output_attrs_fallible(infer_output_attrs_fallible) |
|
|
|
.props(props) |
|
|
|
.hash(hash) |
|
|
|
.is_same_st(is_samt_st) |
|
|
|
.make_backward_graph(make_backward_graph) |
|
|
|
.fallback(); |
|
|
|
|
|
|
|
} // namespace jit_fusion |
|
|
|
} // namespace |
|
|
|
|
|
|
|
bool JITFusionOp::set_enabled(bool enabled) { |
|
|
|
std::swap(enabled, jit_fusion::tm_enabled); |
|
|
|
return enabled; |
|
|
|
} |
|
|
|
|
|
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(UniqueKey); |
|
|
|
|
|
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(SubgraphOp); |
|
|
@@ -665,4 +744,6 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardOpKey); |
|
|
|
|
|
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompiledOp); |
|
|
|
|
|
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(JITFusionOp); |
|
|
|
|
|
|
|
} // namespace mgb::imperative |