@@ -35,12 +35,14 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item( | |||
node_item.NodeName().c_str(), | |||
this->num_pending_shapes_); | |||
for (int i = 0; i < node_item.num_inputs; ++i){ | |||
input_tensor_desc.emplace_back(*node_item.MutableInputDesc(i)); | |||
input_tensor_desc.resize(node_item.num_inputs); | |||
for (int i = 0; i < node_item.num_inputs; ++i) { | |||
node_item.GetInputDesc(i, input_tensor_desc[i]); | |||
} | |||
for (int i = 0; i < node_item.num_outputs; ++i){ | |||
output_tensor_desc.emplace_back(*node_item.MutableOutputDesc(i)); | |||
output_tensor_desc.resize(node_item.num_outputs); | |||
for (int i = 0; i < node_item.num_outputs; ++i) { | |||
node_item.GetOutputDesc(i, output_tensor_desc[i]); | |||
} | |||
} | |||
@@ -297,7 +297,7 @@ void NodeItem::SetToDynamic() { | |||
} | |||
} | |||
GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { | |||
GeTensorDescPtr NodeItem::DoGetInputDesc(int index) const { | |||
if (!has_optional_inputs) { | |||
return op_desc->MutableInputDesc(static_cast<uint32_t>(index)); | |||
} | |||
@@ -314,6 +314,40 @@ GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { | |||
return op_desc->MutableInputDesc(input_desc_indices_[index]); | |||
} | |||
GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
return DoGetInputDesc(index); | |||
} | |||
Status NodeItem::GetInputDesc(int index, GeTensorDesc &tensor_desc) const { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
auto input_desc = DoGetInputDesc(index); | |||
GE_CHECK_NOTNULL(input_desc); | |||
tensor_desc = *input_desc; | |||
return SUCCESS; | |||
} | |||
Status NodeItem::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(index)); | |||
GE_CHECK_NOTNULL(output_desc); | |||
tensor_desc = *output_desc; | |||
return SUCCESS; | |||
} | |||
GeTensorDescPtr NodeItem::MutableOutputDesc(int index) const { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
return op_desc->MutableOutputDesc(static_cast<uint32_t>(index)); | |||
} | |||
Status NodeItem::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
auto input_desc = DoGetInputDesc(index); | |||
GE_CHECK_NOTNULL(input_desc); | |||
*input_desc = tensor_desc; | |||
return SUCCESS; | |||
} | |||
Status NodeItem::GetCanonicalInputIndex(uint32_t index, int &canonical_index) const { | |||
if (!has_optional_inputs) { | |||
canonical_index = index; | |||
@@ -17,6 +17,7 @@ | |||
#ifndef GE_HYBRID_MODEL_NODE_ITEM_H_ | |||
#define GE_HYBRID_MODEL_NODE_ITEM_H_ | |||
#include <mutex> | |||
#include <vector> | |||
#include "external/ge/ge_api_error_codes.h" | |||
#include "graph/node.h" | |||
@@ -57,12 +58,16 @@ struct NodeItem { | |||
bool IsInputShapeStatic(int index) const; | |||
GeTensorDescPtr MutableOutputDesc(int index) const { | |||
return op_desc->MutableOutputDesc(static_cast<uint32_t>(index)); | |||
} | |||
GeTensorDescPtr MutableOutputDesc(int index) const; | |||
Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc); | |||
GeTensorDescPtr MutableInputDesc(int index) const; | |||
Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const; | |||
Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const; | |||
Status GetCanonicalInputIndex(uint32_t index, int &canonical_index) const; | |||
bool IsControlOp() const; | |||
@@ -113,9 +118,11 @@ struct NodeItem { | |||
Status ResolveDynamicState(); | |||
Status ResolveStaticInputsAndOutputs(); | |||
void ResolveUnknownShapeType(); | |||
GeTensorDescPtr DoGetInputDesc(int index) const; | |||
std::vector<bool> is_input_shape_static_; | |||
std::vector<uint32_t> input_desc_indices_; | |||
mutable std::mutex mu_; | |||
}; | |||
} // namespace hybrid | |||
} // namespace ge | |||
@@ -237,8 +237,8 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun | |||
} | |||
bool is_continue = false; | |||
GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), | |||
"[%s] Failed to execute iteration 0.", | |||
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), | |||
"[%s] Failed to execute cond-subgraph", | |||
task_context.GetNodeName()); | |||
if (!is_continue) { | |||
for (int i = 0; i < task_context.NumInputs(); ++i) { | |||
@@ -259,42 +259,28 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun | |||
} | |||
// backup original input tensor desc | |||
std::vector<GeTensorDesc> ori_input_desc; | |||
std::vector<GeTensorDesc> ori_input_desc(task_context.NumInputs()); | |||
for (int i = 0; i < task_context.NumInputs(); ++i) { | |||
auto tensor_desc = task_context.GetInputDesc(i); | |||
GE_CHECK_NOTNULL(tensor_desc); | |||
ori_input_desc.emplace_back(*tensor_desc); | |||
GE_CHK_STATUS_RET_NOLOG(task_context.GetInputDesc(i, ori_input_desc[i])); | |||
} | |||
int iteration = 1; | |||
while (true) { | |||
int iteration = 0; | |||
while (is_continue) { | |||
++iteration; | |||
GELOGD("[%s] Start to execute, iteration = %d", task_context.GetNodeName(), iteration); | |||
GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), | |||
"[%s] Failed to execute iteration %d.", | |||
task_context.GetNodeName(), | |||
iteration); | |||
if (!is_continue) { | |||
GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration); | |||
break; | |||
} | |||
++iteration; | |||
} | |||
for (int i = 0; i < task_context.NumInputs(); ++i) { | |||
auto input_tensor = task_context.GetInput(i); | |||
auto tensor_desc = task_context.MutableInputDesc(i); | |||
GE_CHECK_NOTNULL(input_tensor); | |||
GE_CHECK_NOTNULL(tensor_desc); | |||
// restore original input tensor desc | |||
*tensor_desc = std::move(ori_input_desc[i]); | |||
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_tensor)); | |||
} | |||
GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration); | |||
if (done_callback) { | |||
done_callback(); | |||
} | |||
for (int i = 0; i < task_context.NumInputs(); ++i) { | |||
GE_CHK_STATUS_RET_NOLOG(task_context.UpdateInputDesc(i, ori_input_desc[i])); | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -379,13 +365,6 @@ Status WhileOpNodeTask::MoveOutputs2Inputs(TaskContext &task_context) { | |||
} | |||
Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const { | |||
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), | |||
"[%s] Failed to execute cond-subgraph", | |||
task_context.GetNodeName()); | |||
if (!is_continue) { | |||
return SUCCESS; | |||
} | |||
GELOGD("[%s] Start to execute body-subgraph.", task_context.GetNodeName()); | |||
GE_CHK_STATUS_RET(ExecuteSubgraph(body_, task_context, nullptr), | |||
"[%s] Failed to execute cond-subgraph", task_context.GetNodeName()); | |||
@@ -396,6 +375,17 @@ Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_conti | |||
"[%s] Failed to move outputs to inputs", | |||
task_context.GetNodeName()); | |||
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), | |||
"[%s] Failed to execute cond-subgraph", | |||
task_context.GetNodeName()); | |||
if (!is_continue) { | |||
for (int i = 0; i < task_context.NumInputs(); ++i) { | |||
auto input_desc = task_context.GetInput(i); | |||
GE_CHECK_NOTNULL(input_desc); | |||
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_desc)); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -80,7 +80,6 @@ class WhileOpNodeTask : public ControlOpNodeTask { | |||
Status ExecuteCond(TaskContext &task_context, bool &is_continue) const; | |||
static Status MoveOutputs2Inputs(TaskContext &task_context); | |||
Status ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const; | |||
private: | |||
@@ -554,5 +554,16 @@ NodeState *TaskContext::GetNodeState() const { | |||
return node_state_; | |||
} | |||
Status TaskContext::GetInputDesc(int index, GeTensorDesc &tensor_desc) const { | |||
return node_item_->GetInputDesc(index, tensor_desc); | |||
} | |||
Status TaskContext::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) { | |||
return const_cast<NodeItem *>(node_item_)->UpdateInputDesc(index, tensor_desc); | |||
} | |||
Status TaskContext::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const { | |||
return node_item_->GetOutputDesc(index, tensor_desc); | |||
} | |||
} // namespace hybrid | |||
} // namespace ge |
@@ -50,9 +50,12 @@ class TaskContext { | |||
const char *GetNodeName() const; | |||
TensorValue *MutableInput(int index); | |||
ConstGeTensorDescPtr GetInputDesc(int index) const; | |||
Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const; | |||
ConstGeTensorDescPtr GetOutputDesc(int index) const; | |||
Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const; | |||
GeTensorDescPtr MutableInputDesc(int index) const; | |||
GeTensorDescPtr MutableOutputDesc(int index) const; | |||
Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc); | |||
void ReleaseInputsAndOutputs(); | |||
bool NeedCallback(); | |||
void ReleaseInput(int index); | |||