GitOrigin-RevId: 350c90fb86
tags/v1.0.0-rc1
@@ -376,6 +376,7 @@ public: | |||||
ProxyGraphImpl(ProxyGraph* owner) : m_owner(owner) { | ProxyGraphImpl(ProxyGraph* owner) : m_owner(owner) { | ||||
options().imperative_proxy_graph = true; | options().imperative_proxy_graph = true; | ||||
options().no_force_inplace = true; | |||||
options().log_level = 0; | options().log_level = 0; | ||||
m_var_receiver_info.dev_value = 1; | m_var_receiver_info.dev_value = 1; | ||||
m_var_receiver_info.allow_empty_value = 1; | m_var_receiver_info.allow_empty_value = 1; | ||||
@@ -464,6 +464,17 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||||
bool imperative_proxy_graph = false; | bool imperative_proxy_graph = false; | ||||
/*! | |||||
* Request that operators should not force update their inputs. | |||||
* | |||||
* THIS FLAG IS RESERVED FOR INTERNAL USE | |||||
* | |||||
* When this flag is set, operators like AddUpdate and BatchNorm | |||||
* will still attempt to inplace update their inputs, but failing | |||||
* to do so will not be considered as an error. | |||||
*/ | |||||
bool no_force_inplace = false; | |||||
//! add extra deps for the comp seq if a specific var is dependent | //! add extra deps for the comp seq if a specific var is dependent | ||||
ThinHashMap<VarNode*, VarNodeArray> extra_vardeps; | ThinHashMap<VarNode*, VarNodeArray> extra_vardeps; | ||||
@@ -40,7 +40,7 @@ BatchNormForward::BatchNormForward(VarNode *x, | |||||
Super{x->owner_graph(), config, "batch_norm", | Super{x->owner_graph(), config, "batch_norm", | ||||
{x, scale, bias, mean, variance}} | {x, scale, bias, mean, variance}} | ||||
{ | { | ||||
if(owner_graph()->options().imperative_proxy_graph) { | |||||
if(owner_graph()->options().no_force_inplace) { | |||||
m_force_inplace = false; | m_force_inplace = false; | ||||
} | } | ||||