|
|
@@ -9,21 +9,26 @@ |
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include <queue> |
|
|
|
#include <deque> |
|
|
|
|
|
|
|
#include "../op_trait.h" |
|
|
|
#include "megbrain/imperative/graph_cache.h" |
|
|
|
#include "megbrain/imperative/opr_utility.h" |
|
|
|
#include "megbrain/imperative/ops/autogen.h" |
|
|
|
#include "megbrain/imperative/ops/opr_attr.h" |
|
|
|
#include "megbrain/imperative/ops/utility.h" |
|
|
|
#include "megbrain/imperative/subgraph_detail.h" |
|
|
|
#include "megbrain/jit/executor_opr.h" |
|
|
|
#include "megbrain/opr/io.h" |
|
|
|
#include "megbrain/opr/tensor_gen.h" |
|
|
|
#include "megbrain/opr/tensor_manip.h" |
|
|
|
#include "megbrain/opr/utility.h" |
|
|
|
|
|
|
|
#if MGB_JIT |
|
|
|
#include "megbrain/jit/executor_opr.h" |
|
|
|
#endif |
|
|
|
|
|
|
|
#include "../event_pool.h" |
|
|
|
#include "../op_trait.h" |
|
|
|
|
|
|
|
namespace mgb::imperative { |
|
|
|
|
|
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); |
|
|
@@ -309,7 +314,7 @@ struct ComputingGraphHolder { |
|
|
|
SmallVector<VarNode*> input_vars; |
|
|
|
SmallVector<VarNode*> output_vars; |
|
|
|
std::shared_ptr<DeviceMemoryAllocatorImpl> allocator; |
|
|
|
SmallVector<std::unique_ptr<CompNode::Event>> events; |
|
|
|
SmallVector<std::shared_ptr<CompNode::Event>> events; |
|
|
|
std::unique_ptr<cg::static_infer::StaticInferUpdater> updater; |
|
|
|
|
|
|
|
void initialize( |
|
|
@@ -402,7 +407,7 @@ struct ComputingGraphHolder { |
|
|
|
return true; |
|
|
|
}); |
|
|
|
for (auto&& comp_node : comp_nodes) { |
|
|
|
events.push_back(comp_node.create_event()); |
|
|
|
events.push_back(EventPool::without_timer().alloc_shared(comp_node)); |
|
|
|
events.back()->record(); |
|
|
|
} |
|
|
|
} |
|
|
@@ -510,7 +515,7 @@ ComputingGraphHolder<Kind>& get_computing_graph( |
|
|
|
std::shared_ptr<OpDef> compiled_op, |
|
|
|
const SmallVector<LogicalTensorDesc>& descs) { |
|
|
|
using ComputingGraphHolderCache = |
|
|
|
OpMethResultCache<std::queue<std::unique_ptr<ComputingGraphHolder<Kind>>>>; |
|
|
|
OpMethResultCache<std::deque<std::unique_ptr<ComputingGraphHolder<Kind>>>>; |
|
|
|
thread_local auto cache = std::make_unique<ComputingGraphHolderCache>(); |
|
|
|
thread_local size_t nr_cg_holders = 0; |
|
|
|
typename ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs}; |
|
|
@@ -540,20 +545,28 @@ ComputingGraphHolder<Kind>& get_computing_graph( |
|
|
|
} |
|
|
|
} |
|
|
|
if (holder) { |
|
|
|
cg_holder_queue.pop(); |
|
|
|
cg_holder_queue.pop_front(); |
|
|
|
} |
|
|
|
} |
|
|
|
if (!holder) { |
|
|
|
// create new computing graph |
|
|
|
holder = std::make_unique<ComputingGraphHolder<Kind>>(); |
|
|
|
auto& cg_holder = *holder; |
|
|
|
cg_holder.initialize(compiled_op->cast_final_safe<CompiledOp>(), descs); |
|
|
|
nr_cg_holders++; |
|
|
|
mgb_log_debug( |
|
|
|
"add new computing graph for compiled op, now %zu graphs", |
|
|
|
nr_cg_holders); |
|
|
|
auto create_holder = [&] { |
|
|
|
auto holder = std::make_unique<ComputingGraphHolder<Kind>>(); |
|
|
|
auto& cg_holder = *holder; |
|
|
|
cg_holder.initialize(compiled_op->cast_final_safe<CompiledOp>(), descs); |
|
|
|
nr_cg_holders++; |
|
|
|
mgb_log_debug( |
|
|
|
"add new computing graph for compiled op, now %zu graphs", |
|
|
|
nr_cg_holders); |
|
|
|
return holder; |
|
|
|
}; |
|
|
|
size_t nr_graphs = std::max(cg_holder_queue.size(), (size_t)1); |
|
|
|
for (size_t i = 1; i < nr_graphs; ++i) { |
|
|
|
cg_holder_queue.push_front(create_holder()); |
|
|
|
} |
|
|
|
holder = create_holder(); |
|
|
|
} |
|
|
|
cg_holder_queue.push(std::move(holder)); |
|
|
|
cg_holder_queue.push_back(std::move(holder)); |
|
|
|
return *cg_holder_queue.back(); |
|
|
|
} |
|
|
|
|
|
|
@@ -670,6 +683,7 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { |
|
|
|
// skip for dump (JITExecutor can not be dumped) |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
#if MGB_JIT |
|
|
|
for (auto& output : outputs) { |
|
|
|
jit::InternalGraphGenerator igg{output->owner_opr()}; |
|
|
|
std::vector<cg::OperatorNodeBase*> reverse_order; |
|
|
@@ -686,6 +700,9 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { |
|
|
|
auto ig = igg.generate(); |
|
|
|
output = jit::JITExecutor::make(ig, igg.orig_inps()).node(); |
|
|
|
} |
|
|
|
#else |
|
|
|
mgb_assert(false, "MGB_WITH_JIT was disabled"); |
|
|
|
#endif |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
|
|
|
|