@@ -35,12 +35,14 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item( | |||||
node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
this->num_pending_shapes_); | this->num_pending_shapes_); | ||||
for (int i = 0; i < node_item.num_inputs; ++i){ | |||||
input_tensor_desc.emplace_back(*node_item.MutableInputDesc(i)); | |||||
input_tensor_desc.resize(node_item.num_inputs); | |||||
for (int i = 0; i < node_item.num_inputs; ++i) { | |||||
node_item.GetInputDesc(i, input_tensor_desc[i]); | |||||
} | } | ||||
for (int i = 0; i < node_item.num_outputs; ++i){ | |||||
output_tensor_desc.emplace_back(*node_item.MutableOutputDesc(i)); | |||||
output_tensor_desc.resize(node_item.num_outputs); | |||||
for (int i = 0; i < node_item.num_outputs; ++i) { | |||||
node_item.GetOutputDesc(i, output_tensor_desc[i]); | |||||
} | } | ||||
} | } | ||||
@@ -297,7 +297,7 @@ void NodeItem::SetToDynamic() { | |||||
} | } | ||||
} | } | ||||
GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { | |||||
GeTensorDescPtr NodeItem::DoGetInputDesc(int index) const { | |||||
if (!has_optional_inputs) { | if (!has_optional_inputs) { | ||||
return op_desc->MutableInputDesc(static_cast<uint32_t>(index)); | return op_desc->MutableInputDesc(static_cast<uint32_t>(index)); | ||||
} | } | ||||
@@ -314,6 +314,40 @@ GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { | |||||
return op_desc->MutableInputDesc(input_desc_indices_[index]); | return op_desc->MutableInputDesc(input_desc_indices_[index]); | ||||
} | } | ||||
GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { | |||||
std::lock_guard<std::mutex> lk(mu_); | |||||
return DoGetInputDesc(index); | |||||
} | |||||
Status NodeItem::GetInputDesc(int index, GeTensorDesc &tensor_desc) const { | |||||
std::lock_guard<std::mutex> lk(mu_); | |||||
auto input_desc = DoGetInputDesc(index); | |||||
GE_CHECK_NOTNULL(input_desc); | |||||
tensor_desc = *input_desc; | |||||
return SUCCESS; | |||||
} | |||||
Status NodeItem::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const { | |||||
std::lock_guard<std::mutex> lk(mu_); | |||||
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(index)); | |||||
GE_CHECK_NOTNULL(output_desc); | |||||
tensor_desc = *output_desc; | |||||
return SUCCESS; | |||||
} | |||||
GeTensorDescPtr NodeItem::MutableOutputDesc(int index) const { | |||||
std::lock_guard<std::mutex> lk(mu_); | |||||
return op_desc->MutableOutputDesc(static_cast<uint32_t>(index)); | |||||
} | |||||
Status NodeItem::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) { | |||||
std::lock_guard<std::mutex> lk(mu_); | |||||
auto input_desc = DoGetInputDesc(index); | |||||
GE_CHECK_NOTNULL(input_desc); | |||||
*input_desc = tensor_desc; | |||||
return SUCCESS; | |||||
} | |||||
Status NodeItem::GetCanonicalInputIndex(uint32_t index, int &canonical_index) const { | Status NodeItem::GetCanonicalInputIndex(uint32_t index, int &canonical_index) const { | ||||
if (!has_optional_inputs) { | if (!has_optional_inputs) { | ||||
canonical_index = index; | canonical_index = index; | ||||
@@ -17,6 +17,7 @@ | |||||
#ifndef GE_HYBRID_MODEL_NODE_ITEM_H_ | #ifndef GE_HYBRID_MODEL_NODE_ITEM_H_ | ||||
#define GE_HYBRID_MODEL_NODE_ITEM_H_ | #define GE_HYBRID_MODEL_NODE_ITEM_H_ | ||||
#include <mutex> | |||||
#include <vector> | #include <vector> | ||||
#include "external/ge/ge_api_error_codes.h" | #include "external/ge/ge_api_error_codes.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
@@ -57,12 +58,16 @@ struct NodeItem { | |||||
bool IsInputShapeStatic(int index) const; | bool IsInputShapeStatic(int index) const; | ||||
GeTensorDescPtr MutableOutputDesc(int index) const { | |||||
return op_desc->MutableOutputDesc(static_cast<uint32_t>(index)); | |||||
} | |||||
GeTensorDescPtr MutableOutputDesc(int index) const; | |||||
Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc); | |||||
GeTensorDescPtr MutableInputDesc(int index) const; | GeTensorDescPtr MutableInputDesc(int index) const; | ||||
Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const; | |||||
Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const; | |||||
Status GetCanonicalInputIndex(uint32_t index, int &canonical_index) const; | Status GetCanonicalInputIndex(uint32_t index, int &canonical_index) const; | ||||
bool IsControlOp() const; | bool IsControlOp() const; | ||||
@@ -113,9 +118,11 @@ struct NodeItem { | |||||
Status ResolveDynamicState(); | Status ResolveDynamicState(); | ||||
Status ResolveStaticInputsAndOutputs(); | Status ResolveStaticInputsAndOutputs(); | ||||
void ResolveUnknownShapeType(); | void ResolveUnknownShapeType(); | ||||
GeTensorDescPtr DoGetInputDesc(int index) const; | |||||
std::vector<bool> is_input_shape_static_; | std::vector<bool> is_input_shape_static_; | ||||
std::vector<uint32_t> input_desc_indices_; | std::vector<uint32_t> input_desc_indices_; | ||||
mutable std::mutex mu_; | |||||
}; | }; | ||||
} // namespace hybrid | } // namespace hybrid | ||||
} // namespace ge | } // namespace ge | ||||
@@ -237,8 +237,8 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun | |||||
} | } | ||||
bool is_continue = false; | bool is_continue = false; | ||||
GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), | |||||
"[%s] Failed to execute iteration 0.", | |||||
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), | |||||
"[%s] Failed to execute cond-subgraph", | |||||
task_context.GetNodeName()); | task_context.GetNodeName()); | ||||
if (!is_continue) { | if (!is_continue) { | ||||
for (int i = 0; i < task_context.NumInputs(); ++i) { | for (int i = 0; i < task_context.NumInputs(); ++i) { | ||||
@@ -259,42 +259,28 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun | |||||
} | } | ||||
// backup original input tensor desc | // backup original input tensor desc | ||||
std::vector<GeTensorDesc> ori_input_desc; | |||||
std::vector<GeTensorDesc> ori_input_desc(task_context.NumInputs()); | |||||
for (int i = 0; i < task_context.NumInputs(); ++i) { | for (int i = 0; i < task_context.NumInputs(); ++i) { | ||||
auto tensor_desc = task_context.GetInputDesc(i); | |||||
GE_CHECK_NOTNULL(tensor_desc); | |||||
ori_input_desc.emplace_back(*tensor_desc); | |||||
GE_CHK_STATUS_RET_NOLOG(task_context.GetInputDesc(i, ori_input_desc[i])); | |||||
} | } | ||||
int iteration = 1; | |||||
while (true) { | |||||
int iteration = 0; | |||||
while (is_continue) { | |||||
++iteration; | |||||
GELOGD("[%s] Start to execute, iteration = %d", task_context.GetNodeName(), iteration); | GELOGD("[%s] Start to execute, iteration = %d", task_context.GetNodeName(), iteration); | ||||
GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), | GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), | ||||
"[%s] Failed to execute iteration %d.", | "[%s] Failed to execute iteration %d.", | ||||
task_context.GetNodeName(), | task_context.GetNodeName(), | ||||
iteration); | iteration); | ||||
if (!is_continue) { | |||||
GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration); | |||||
break; | |||||
} | |||||
++iteration; | |||||
} | } | ||||
for (int i = 0; i < task_context.NumInputs(); ++i) { | |||||
auto input_tensor = task_context.GetInput(i); | |||||
auto tensor_desc = task_context.MutableInputDesc(i); | |||||
GE_CHECK_NOTNULL(input_tensor); | |||||
GE_CHECK_NOTNULL(tensor_desc); | |||||
// restore original input tensor desc | |||||
*tensor_desc = std::move(ori_input_desc[i]); | |||||
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_tensor)); | |||||
} | |||||
GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration); | |||||
if (done_callback) { | if (done_callback) { | ||||
done_callback(); | done_callback(); | ||||
} | } | ||||
for (int i = 0; i < task_context.NumInputs(); ++i) { | |||||
GE_CHK_STATUS_RET_NOLOG(task_context.UpdateInputDesc(i, ori_input_desc[i])); | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -379,13 +365,6 @@ Status WhileOpNodeTask::MoveOutputs2Inputs(TaskContext &task_context) { | |||||
} | } | ||||
Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const { | Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const { | ||||
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), | |||||
"[%s] Failed to execute cond-subgraph", | |||||
task_context.GetNodeName()); | |||||
if (!is_continue) { | |||||
return SUCCESS; | |||||
} | |||||
GELOGD("[%s] Start to execute body-subgraph.", task_context.GetNodeName()); | GELOGD("[%s] Start to execute body-subgraph.", task_context.GetNodeName()); | ||||
GE_CHK_STATUS_RET(ExecuteSubgraph(body_, task_context, nullptr), | GE_CHK_STATUS_RET(ExecuteSubgraph(body_, task_context, nullptr), | ||||
"[%s] Failed to execute cond-subgraph", task_context.GetNodeName()); | "[%s] Failed to execute cond-subgraph", task_context.GetNodeName()); | ||||
@@ -396,6 +375,17 @@ Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_conti | |||||
"[%s] Failed to move outputs to inputs", | "[%s] Failed to move outputs to inputs", | ||||
task_context.GetNodeName()); | task_context.GetNodeName()); | ||||
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), | |||||
"[%s] Failed to execute cond-subgraph", | |||||
task_context.GetNodeName()); | |||||
if (!is_continue) { | |||||
for (int i = 0; i < task_context.NumInputs(); ++i) { | |||||
auto input_desc = task_context.GetInput(i); | |||||
GE_CHECK_NOTNULL(input_desc); | |||||
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_desc)); | |||||
} | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -80,7 +80,6 @@ class WhileOpNodeTask : public ControlOpNodeTask { | |||||
Status ExecuteCond(TaskContext &task_context, bool &is_continue) const; | Status ExecuteCond(TaskContext &task_context, bool &is_continue) const; | ||||
static Status MoveOutputs2Inputs(TaskContext &task_context); | static Status MoveOutputs2Inputs(TaskContext &task_context); | ||||
Status ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const; | Status ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const; | ||||
private: | private: | ||||
@@ -554,5 +554,16 @@ NodeState *TaskContext::GetNodeState() const { | |||||
return node_state_; | return node_state_; | ||||
} | } | ||||
Status TaskContext::GetInputDesc(int index, GeTensorDesc &tensor_desc) const { | |||||
return node_item_->GetInputDesc(index, tensor_desc); | |||||
} | |||||
Status TaskContext::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) { | |||||
return const_cast<NodeItem *>(node_item_)->UpdateInputDesc(index, tensor_desc); | |||||
} | |||||
Status TaskContext::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const { | |||||
return node_item_->GetOutputDesc(index, tensor_desc); | |||||
} | |||||
} // namespace hybrid | } // namespace hybrid | ||||
} // namespace ge | } // namespace ge |
@@ -50,9 +50,12 @@ class TaskContext { | |||||
const char *GetNodeName() const; | const char *GetNodeName() const; | ||||
TensorValue *MutableInput(int index); | TensorValue *MutableInput(int index); | ||||
ConstGeTensorDescPtr GetInputDesc(int index) const; | ConstGeTensorDescPtr GetInputDesc(int index) const; | ||||
Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const; | |||||
ConstGeTensorDescPtr GetOutputDesc(int index) const; | ConstGeTensorDescPtr GetOutputDesc(int index) const; | ||||
Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const; | |||||
GeTensorDescPtr MutableInputDesc(int index) const; | GeTensorDescPtr MutableInputDesc(int index) const; | ||||
GeTensorDescPtr MutableOutputDesc(int index) const; | GeTensorDescPtr MutableOutputDesc(int index) const; | ||||
Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc); | |||||
void ReleaseInputsAndOutputs(); | void ReleaseInputsAndOutputs(); | ||||
bool NeedCallback(); | bool NeedCallback(); | ||||
void ReleaseInput(int index); | void ReleaseInput(int index); | ||||