Browse Source

fix(imperative): reduce tls usage

GitOrigin-RevId: a716b2ae98
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
7f03ae9aed
2 changed files with 26 additions and 14 deletions
  1. +12
    -12
      imperative/src/impl/op_def.cpp
  2. +14
    -2
      imperative/src/impl/ops/utility.cpp

+ 12
- 12
imperative/src/impl/op_def.cpp View File

@@ -77,16 +77,16 @@ EncodedSubgraph OpDef::make_backward_graph(
const SmallVector<bool>& output_has_grad) {
using BackwardGraphCache =
OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>;
thread_local BackwardGraphCache cache;
decltype(cache)::key_t cache_key{
thread_local auto cache = std::make_unique<BackwardGraphCache>();
BackwardGraphCache::key_t cache_key{
const_cast<OpDef&>(def).shared_from_this(),
inputs,
{input_requires_grad, output_has_grad}};
auto iter = cache.find(cache_key);
if (iter == cache.end()) {
iter = cache.insert({cache_key, def.trait()->make_backward_graph(
def, inputs, input_requires_grad,
output_has_grad)})
auto iter = cache->find(cache_key);
if (iter == cache->end()) {
iter = cache->insert({cache_key, def.trait()->make_backward_graph(
def, inputs, input_requires_grad,
output_has_grad)})
.first;
}
return iter->second;
@@ -100,12 +100,12 @@ EncodedSubgraph OpDef::make_forward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
using ForwardGraphCache =
OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>;
thread_local ForwardGraphCache cache;
decltype(cache)::key_t cache_key{
thread_local auto cache = std::make_unique<ForwardGraphCache>();
ForwardGraphCache::key_t cache_key{
const_cast<OpDef&>(def).shared_from_this(), inputs};
auto iter = cache.find(cache_key);
if (iter == cache.end()) {
iter = cache.insert({cache_key, def.trait()->make_forward_graph(def, inputs)})
auto iter = cache->find(cache_key);
if (iter == cache->end()) {
iter = cache->insert({cache_key, def.trait()->make_forward_graph(def, inputs)})
.first;
}
return iter->second;


+ 14
- 2
imperative/src/impl/ops/utility.cpp View File

@@ -34,8 +34,20 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
return inputs;
}

auto make_backward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad) {
Subgraph graph;
graph.inputs = {1, 2, 3};
graph.outputs = {3};
graph.exprs = {};
return EncodedSubgraph::make(graph);
}

OP_TRAIT_REG(FastpathCopy, FastpathCopy)
.apply_on_var_node(apply_on_var_node)
.make_backward_graph(make_backward_graph)
.fallback();
} // namespace fastpathcopy
} // namespace
@@ -290,10 +302,10 @@ ComputingGraphHolder& get_computing_graph(
std::shared_ptr<OpDef> compiled_op, SmallVector<LogicalTensorDesc> descs) {
using ComputingGraphHolderCache =
OpMethResultCache<std::queue<std::unique_ptr<ComputingGraphHolder>>>;
thread_local ComputingGraphHolderCache cache;
thread_local auto cache = std::make_unique<ComputingGraphHolderCache>();
thread_local size_t nr_cg_holders = 0;
ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs};
auto& cg_holder_queue = cache[cache_key];
auto& cg_holder_queue = (*cache)[cache_key];
std::unique_ptr<ComputingGraphHolder> holder;
if (!cg_holder_queue.empty()) {
// pick one


Loading…
Cancel
Save