|
|
@@ -77,16 +77,16 @@ EncodedSubgraph OpDef::make_backward_graph( |
|
|
|
const SmallVector<bool>& output_has_grad) { |
|
|
|
using BackwardGraphCache = |
|
|
|
OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>; |
|
|
|
thread_local BackwardGraphCache cache; |
|
|
|
decltype(cache)::key_t cache_key{ |
|
|
|
thread_local auto cache = std::make_unique<BackwardGraphCache>(); |
|
|
|
BackwardGraphCache::key_t cache_key{ |
|
|
|
const_cast<OpDef&>(def).shared_from_this(), |
|
|
|
inputs, |
|
|
|
{input_requires_grad, output_has_grad}}; |
|
|
|
auto iter = cache.find(cache_key); |
|
|
|
if (iter == cache.end()) { |
|
|
|
iter = cache.insert({cache_key, def.trait()->make_backward_graph( |
|
|
|
def, inputs, input_requires_grad, |
|
|
|
output_has_grad)}) |
|
|
|
auto iter = cache->find(cache_key); |
|
|
|
if (iter == cache->end()) { |
|
|
|
iter = cache->insert({cache_key, def.trait()->make_backward_graph( |
|
|
|
def, inputs, input_requires_grad, |
|
|
|
output_has_grad)}) |
|
|
|
.first; |
|
|
|
} |
|
|
|
return iter->second; |
|
|
@@ -100,12 +100,12 @@ EncodedSubgraph OpDef::make_forward_graph( |
|
|
|
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { |
|
|
|
using ForwardGraphCache = |
|
|
|
OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>; |
|
|
|
thread_local ForwardGraphCache cache; |
|
|
|
decltype(cache)::key_t cache_key{ |
|
|
|
thread_local auto cache = std::make_unique<ForwardGraphCache>(); |
|
|
|
ForwardGraphCache::key_t cache_key{ |
|
|
|
const_cast<OpDef&>(def).shared_from_this(), inputs}; |
|
|
|
auto iter = cache.find(cache_key); |
|
|
|
if (iter == cache.end()) { |
|
|
|
iter = cache.insert({cache_key, def.trait()->make_forward_graph(def, inputs)}) |
|
|
|
auto iter = cache->find(cache_key); |
|
|
|
if (iter == cache->end()) { |
|
|
|
iter = cache->insert({cache_key, def.trait()->make_forward_graph(def, inputs)}) |
|
|
|
.first; |
|
|
|
} |
|
|
|
return iter->second; |
|
|
|