Browse Source

!1928 cherry-pick fix for dynamic shape V1

Merge pull request !1928 from 张晓昆/r1.5.0
tags/v1.3.0
i-robot Gitee 4 years ago
parent
commit
cc7175217c
43 changed files with 403 additions and 295 deletions
  1. +4
    -7
      ge/ge_local_engine/engine/host_cpu_engine.cc
  2. +8
    -4
      ge/graph/load/model_manager/model_manager.cc
  3. +13
    -12
      ge/graph/partition/dynamic_shape_partition.cc
  4. +3
    -1
      ge/graph/partition/dynamic_shape_partition.h
  5. +55
    -42
      ge/graph/passes/mark_force_unknown_for_cond_pass.cc
  6. +11
    -0
      ge/graph/passes/mark_force_unknown_for_cond_pass.h
  7. +21
    -12
      ge/graph/passes/next_iteration_pass.cc
  8. +3
    -2
      ge/graph/passes/switch_to_stream_switch_pass.cc
  9. +32
    -11
      ge/hybrid/executor/node_state.cc
  10. +1
    -0
      ge/hybrid/executor/node_state.h
  11. +1
    -0
      ge/hybrid/executor/worker/execution_engine.cc
  12. +0
    -4
      ge/hybrid/model/hybrid_model_builder.cc
  13. +34
    -14
      ge/hybrid/model/node_item.cc
  14. +3
    -3
      ge/hybrid/model/node_item.h
  15. +1
    -0
      ge/hybrid/node_executor/hccl/hccl_node_executor.cc
  16. +2
    -6
      ge/hybrid/node_executor/rts/rts_node_executor.cc
  17. +0
    -29
      ge/hybrid/node_executor/rts/rts_node_task.cc
  18. +0
    -5
      ge/hybrid/node_executor/rts/rts_node_task.h
  19. +1
    -4
      ge/hybrid/node_executor/task_context.cc
  20. +4
    -0
      tests/depends/mmpa/src/mmpa_stub.cc
  21. +1
    -2
      tests/ut/ge/CMakeLists.txt
  22. +5
    -5
      tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc
  23. +1
    -1
      tests/ut/ge/graph/build/stream_allocator_unittest.cc
  24. +18
    -0
      tests/ut/ge/graph/load/model_manager_unittest.cc
  25. +3
    -3
      tests/ut/ge/graph/passes/assert_pass_unittest.cc
  26. +7
    -7
      tests/ut/ge/graph/passes/base_pass_unittest.cc
  27. +3
    -3
      tests/ut/ge/graph/passes/cond_branch_v1_unittest.cc
  28. +19
    -19
      tests/ut/ge/graph/passes/constant_folding_pass_unittest.cc
  29. +4
    -4
      tests/ut/ge/graph/passes/dimension_compute_pass_unittest.cc
  30. +1
    -1
      tests/ut/ge/graph/passes/folding_kernel/ssd_prior_box_kernel_unittest.cc
  31. +1
    -1
      tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc
  32. +88
    -38
      tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc
  33. +14
    -14
      tests/ut/ge/graph/passes/merge_pass_unittest.cc
  34. +6
    -6
      tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc
  35. +3
    -3
      tests/ut/ge/graph/passes/reshape_recovery_pass_unittest.cc
  36. +8
    -8
      tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc
  37. +1
    -1
      tests/ut/ge/graph/passes/resource_pair_control_pass_unittest.cc
  38. +6
    -6
      tests/ut/ge/graph/passes/switch_logic_remove_pass_unittest.cc
  39. +2
    -2
      tests/ut/ge/graph/passes/trans_op_breadth_fusion_pass_unittest.cc
  40. +7
    -7
      tests/ut/ge/graph/passes/trans_op_depth_fusion_pass_unittest.cc
  41. +2
    -2
      tests/ut/ge/graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc
  42. +1
    -1
      tests/ut/ge/graph/passes/variable_op_pass_unittest.cc
  43. +5
    -5
      tests/ut/ge/graph/variable_accelerate_ctrl_unittest.cc

+ 4
- 7
ge/ge_local_engine/engine/host_cpu_engine.cc View File

@@ -13,15 +13,15 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "host_cpu_engine.h"
#include "graph/common/omg_util.h"
#include "ge_local_engine/engine/host_cpu_engine.h"
#include "graph/utils/op_desc_utils.h" #include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_adapter.h" #include "graph/utils/tensor_adapter.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/type_utils.h"
#include "register/op_kernel_registry.h" #include "register/op_kernel_registry.h"
#include "register/host_cpu_context.h" #include "register/host_cpu_context.h"
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "common/ge/plugin_manager.h" #include "common/ge/plugin_manager.h"
#include "graph/utils/type_utils.h"
#include "common/fp16_t.h" #include "common/fp16_t.h"
#include "common/math/math_util.h" #include "common/math/math_util.h"


@@ -123,10 +123,7 @@ bool HostCpuEngine::CheckSupported(const string &op_type) {
} }


Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) { Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) {
std::string op_type;
auto status = GetOriginalType(node, op_type);
GE_CHK_BOOL_EXEC_NOLOG(status == SUCCESS, return status);

const std::string op_type = NodeUtils::GetNodeType(node);
auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type); auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type);
if (kernel == nullptr) { if (kernel == nullptr) {
GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str()); GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str());


+ 8
- 4
ge/graph/load/model_manager/model_manager.cc View File

@@ -1378,7 +1378,9 @@ Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_
Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) {
GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str()); GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str());
std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); std::lock_guard<std::mutex> lock(cust_aicpu_mutex_);
if (cust_aicpu_so_.size() == 0) return SUCCESS;
if (cust_aicpu_so_.empty()) {
return SUCCESS;
}
// get current context // get current context
rtContext_t rt_cur_ctx = nullptr; rtContext_t rt_cur_ctx = nullptr;
auto rt_error = rtCtxGetCurrent(&rt_cur_ctx); auto rt_error = rtCtxGetCurrent(&rt_cur_ctx);
@@ -1394,17 +1396,19 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) {
return SUCCESS; return SUCCESS;
} }


vector<void *> allocated_mem;
rtError_t status;
rtStream_t stream = nullptr; rtStream_t stream = nullptr;
vector<void *> allocated_mem;
std::function<void()> callback = [&]() { std::function<void()> callback = [&]() {
for (auto mem : allocated_mem) { for (auto mem : allocated_mem) {
GE_CHK_RT(rtFree(mem)); GE_CHK_RT(rtFree(mem));
} }
GE_CHK_RT(rtStreamDestroy(stream));
if (stream != nullptr) {
GE_CHK_RT(rtStreamDestroy(stream));
}
}; };
GE_MAKE_GUARD(release, callback); GE_MAKE_GUARD(release, callback);


rtError_t status;
vector<CustAicpuSoBuf> v_cust_so; vector<CustAicpuSoBuf> v_cust_so;
void *args = nullptr; void *args = nullptr;




+ 13
- 12
ge/graph/partition/dynamic_shape_partition.cc View File

@@ -284,9 +284,6 @@ Status DynamicShapePartitioner::InitClusters() {
auto cluster = MakeShared<Cluster>(rank++, type, node, this); auto cluster = MakeShared<Cluster>(rank++, type, node, this);
REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed."); REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed.");
node_2_cluster_[node] = cluster; node_2_cluster_[node] = cluster;
if (cluster->IsUnknownShape()) {
ordered_cluster_.push_back(cluster);
}


int64_t group_index = -1; int64_t group_index = -1;
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) {
@@ -306,7 +303,7 @@ Status DynamicShapePartitioner::InitClusters() {
return SUCCESS; return SUCCESS;
} }


Status DynamicShapePartitioner::TopologicalSortClusters() {
Status DynamicShapePartitioner::TopologicalSortClusters(const OrderedFilter &ordered_filter) {
ordered_cluster_.clear(); ordered_cluster_.clear();
// BFS topological sort clusters for known shape cluster // BFS topological sort clusters for known shape cluster
std::queue<ClusterPtr> ready_clusters; std::queue<ClusterPtr> ready_clusters;
@@ -331,7 +328,7 @@ Status DynamicShapePartitioner::TopologicalSortClusters() {
auto cluster = ready_clusters.front(); auto cluster = ready_clusters.front();
ready_clusters.pop(); ready_clusters.pop();
cluster->UpdateRank(rank++); cluster->UpdateRank(rank++);
if (cluster->IsKnownShape() || cluster->IsInputNode()) {
if (ordered_filter == nullptr || ordered_filter(cluster)) {
ordered_cluster_.push_back(cluster); ordered_cluster_.push_back(cluster);
} }
for (const auto &out_cluster : cluster->Outputs()) { for (const auto &out_cluster : cluster->Outputs()) {
@@ -378,7 +375,6 @@ void DynamicShapePartitioner::MergeClustersControlFlow() {
continue; continue;
} }


bool is_unknown_cluster = cluster->IsUnknownShape();
for (++rit; rit != control_cluster.rend(); ++rit) { for (++rit; rit != control_cluster.rend(); ++rit) {
const auto &cluster_from = *rit; const auto &cluster_from = *rit;
if (all_merged_clusters.count(cluster_from) > 0) { if (all_merged_clusters.count(cluster_from) > 0) {
@@ -395,11 +391,6 @@ void DynamicShapePartitioner::MergeClustersControlFlow() {
} }
} }
} }

if (!is_unknown_cluster && cluster->IsUnknownShape()) {
GELOGD("Add to ordered cluster: %s", cluster->DebugString().c_str());
ordered_cluster_.push_back(cluster);
}
} }
} }


@@ -475,9 +466,19 @@ void DynamicShapePartitioner::MergeClustersInputData() {
} }


Status DynamicShapePartitioner::MergeClusters() { Status DynamicShapePartitioner::MergeClusters() {
const auto filter_known = [](const ClusterPtr &cluster) {
return cluster->IsKnownShape() || cluster->IsInputNode();
};
const auto filter_unknown = [](const ClusterPtr &cluster) {
return cluster->IsUnknownShape();
};

MergeClustersControlFlow(); MergeClustersControlFlow();
REQUIRE_SUCCESS(TopologicalSortClusters(filter_unknown),
"[TopologicalSort][Clusters] after merge control flow clusters failed.");
MergeClustersUnknownShape(); MergeClustersUnknownShape();
REQUIRE_SUCCESS(TopologicalSortClusters(), "[TopologicalSort][Clusters] after merge unknown shape clusters failed.");
REQUIRE_SUCCESS(TopologicalSortClusters(filter_known),
"[TopologicalSort][Clusters] after merge unknown shape clusters failed.");
MergeClustersKnownShape(); MergeClustersKnownShape();
MergeClustersInputData(); MergeClustersInputData();
return SUCCESS; return SUCCESS;


+ 3
- 1
ge/graph/partition/dynamic_shape_partition.h View File

@@ -111,6 +111,8 @@ class DynamicShapePartitioner {


Status Partition(); Status Partition();


using OrderedFilter = std::function<bool(const std::shared_ptr<Cluster> &cluster)>;

private: private:
Status PartitionImpl(); Status PartitionImpl();
// Collect nodes that satisfy the unknowshape rules: // Collect nodes that satisfy the unknowshape rules:
@@ -138,7 +140,7 @@ class DynamicShapePartitioner {
// Merge clusters step3 // Merge clusters step3
void MergeClustersInputData(); void MergeClustersInputData();
// Topological sort clusters after merge unknown shape clusters. // Topological sort clusters after merge unknown shape clusters.
Status TopologicalSortClusters();
Status TopologicalSortClusters(const OrderedFilter &ordered_filter);
// Deduplicate merged clusters // Deduplicate merged clusters
void PruneUniqueClusters(); void PruneUniqueClusters();
// Establish the input-output anchors for each partition of the cluster and record links to other clusters // Establish the input-output anchors for each partition of the cluster and record links to other clusters


+ 55
- 42
ge/graph/passes/mark_force_unknown_for_cond_pass.cc View File

@@ -16,8 +16,6 @@


#include "mark_force_unknown_for_cond_pass.h" #include "mark_force_unknown_for_cond_pass.h"


#include <queue>

#include "graph/utils/node_utils.h" #include "graph/utils/node_utils.h"
#include "graph/common/omg_util.h" #include "graph/common/omg_util.h"


@@ -26,17 +24,7 @@ namespace {
inline bool IsMergeInLoop(const NodePtr &node) { inline bool IsMergeInLoop(const NodePtr &node) {
const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION };


std::string node_type;
(void)GetOriginalType(node, node_type);
return kLoopMergeInputs.count(node_type) > 0;
}

inline bool IsSwitchInLoop(const NodePtr &node) {
const static std::set<std::string> kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND };

std::string node_type;
(void)GetOriginalType(node, node_type);
return kLoopSwitchInputs.count(node_type) > 0;
return kLoopMergeInputs.count(NodeUtils::GetNodeType(node)) > 0;
} }
} }


@@ -44,10 +32,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) {
GELOGD("MarkForceUnknownForCondPass Enter"); GELOGD("MarkForceUnknownForCondPass Enter");
std::map<NodePtr, std::vector<NodePtr>> switch_groups; std::map<NodePtr, std::vector<NodePtr>> switch_groups;
for (const auto &node : graph->GetDirectNode()) { for (const auto &node : graph->GetDirectNode()) {
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type),
"[Get][OriginalType] of node in graph:%s failed.", graph->GetName().c_str());
if (kMergeOpTypes.count(node_type) == 0) {
if (kMergeOpTypes.count(NodeUtils::GetNodeType(node)) == 0) {
continue; continue;
} }


@@ -65,6 +50,51 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) {
} }


/// ///
/// @brief Deal with Switch node for LoopCond
/// @param [in] Switch node
/// @param [in] dest span
/// @param [out] Search queue
/// @return true: Switch In while loop / false: Not in while Loop.
///
bool MarkForceUnknownForCondPass::DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span,
std::queue<std::pair<NodePtr, uint32_t>> &search_queue) {
/// LoopCond --->\.
/// \.
/// Enter-----------+ \.
/// +--> Merge --> Switch --> Exit
/// NextIteration---+
const auto is_loop_op = [](const NodePtr &n) {
return NodeUtils::GetNodeType(n) == LOOPCOND;
};
const auto is_exit_op = [](const NodePtr &n) {
return kExitOpTypes.count(NodeUtils::GetNodeType(n)) > 0;
};

const auto src_nodes = node->GetInAllNodes();
const auto dst_nodes = node->GetOutAllNodes();
if (std::none_of(src_nodes.begin(), src_nodes.end(), is_loop_op) &&
std::none_of(dst_nodes.begin(), dst_nodes.end(), is_exit_op)) {
return false;
}

for (const auto &m : src_nodes) {
if (kMergeOpTypes.count(NodeUtils::GetNodeType(m)) > 0) {
for (const auto &n : m->GetInAllNodes()) {
if (kNextIterationOpTypes.count(NodeUtils::GetNodeType(n)) > 0) {
continue;
}

search_queue.push({n, dst_span});
GELOGD("Travel in Loop: %s <-- %s <-- %s, span is: %u", node->GetName().c_str(), m->GetName().c_str(),
n->GetName().c_str(), dst_span);
}
}
}

return true;
}

///
/// @brief Mark force unknown shape for Switch node /// @brief Mark force unknown shape for Switch node
/// @param [in] merge node /// @param [in] merge node
/// @param [out] switch group /// @param [out] switch group
@@ -72,6 +102,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) {
/// ///
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) { void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) {
// Switch --> {Switch --> Merge} --> Merge // Switch --> {Switch --> Merge} --> Merge
GELOGD("Search Switch node for Merge: %s", node->GetName().c_str());
std::unordered_set<NodePtr> nodes_seen; std::unordered_set<NodePtr> nodes_seen;
std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}}); std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}});
while (!search_queue.empty()) { while (!search_queue.empty()) {
@@ -79,43 +110,25 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std:
const auto dst_span = search_queue.front().second; const auto dst_span = search_queue.front().second;
search_queue.pop(); search_queue.pop();


// Switch --> Identity --> Constant
for (const auto &in_node : dst_node->GetInControlNodes()) {
if (nodes_seen.count(in_node) > 0) {
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str());
continue;
}
nodes_seen.insert(in_node);

if (in_node->GetType() == IDENTITY) {
GELOGD("Travel node: %s, In control: %s, span is: %u", dst_node->GetName().c_str(),
in_node->GetName().c_str(), dst_span);
search_queue.push({in_node, dst_span});
}
}

for (const auto &in_node : dst_node->GetInDataNodes()) {
for (const auto &in_node : dst_node->GetInAllNodes()) {
if (nodes_seen.count(in_node) > 0) { if (nodes_seen.count(in_node) > 0) {
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str());
continue; continue;
} }
nodes_seen.insert(in_node); nodes_seen.insert(in_node);


std::string node_type;
(void)GetOriginalType(in_node, node_type);
const std::string node_type = NodeUtils::GetNodeType(in_node);
GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(),
in_node->GetName().c_str(), dst_span); in_node->GetName().c_str(), dst_span);
if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node.
if (DealAsLoopSwitch(in_node, dst_span, search_queue)) {
continue;
}

if (dst_span > 0) { if (dst_span > 0) {
search_queue.push({in_node, dst_span - 1}); search_queue.push({in_node, dst_span - 1});
} else { } else {
const auto &all_in_nodes = in_node->GetInDataNodes();
if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) {
GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(),
in_node->GetName().c_str());
} else {
switch_group.emplace_back(in_node);
}
switch_group.emplace_back(in_node);
} }
} else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node.
search_queue.push({in_node, dst_span + 1}); search_queue.push({in_node, dst_span + 1});


+ 11
- 0
ge/graph/passes/mark_force_unknown_for_cond_pass.h View File

@@ -19,6 +19,8 @@


#include "inc/graph_pass.h" #include "inc/graph_pass.h"


#include <queue>

namespace ge { namespace ge {
class MarkForceUnknownForCondPass : public GraphPass { class MarkForceUnknownForCondPass : public GraphPass {
public: public:
@@ -26,6 +28,15 @@ class MarkForceUnknownForCondPass : public GraphPass {


private: private:
/// ///
/// @brief Deal with Switch node for LoopCond
/// @param [in] Switch node
/// @param [in] dest span
/// @param [out] Search queue
/// @return true: Switch In while loop / false: Not in while Loop.
///
bool DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, std::queue<std::pair<NodePtr, uint32_t>> &search_queue);

///
/// @brief Mark force unknown shape for Switch node /// @brief Mark force unknown shape for Switch node
/// @param [in] merge node /// @param [in] merge node
/// @param [out] switch group /// @param [out] switch group


+ 21
- 12
ge/graph/passes/next_iteration_pass.cc View File

@@ -24,7 +24,9 @@ using std::string;


namespace ge { namespace ge {
namespace { namespace {
const int64_t kLoopType = 1;
constexpr int64_t kLoopType = 1;
constexpr uint8_t kMaxTransOp = 3;
constexpr uint8_t kTransOpIoSize = 1;
} }


Status NextIterationPass::Run(ComputeGraphPtr graph) { Status NextIterationPass::Run(ComputeGraphPtr graph) {
@@ -287,18 +289,25 @@ void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, i
std::string node_type; std::string node_type;
for (const auto &switch_node : loop_group.switch_nodes) { for (const auto &switch_node : loop_group.switch_nodes) {
SetControlFlowGroup(switch_node, group_index); SetControlFlowGroup(switch_node, group_index);
for (const auto &node : switch_node->GetOutDataNodes()) {
(void)GetOriginalType(node, node_type);
if (kExitOpTypes.count(node_type) > 0) {
SetControlFlowGroup(node, group_index);
} else {
// For: Switch -> Cast -> Exit
for (const auto &n : node->GetOutDataNodes()) {
(void)GetOriginalType(n, node_type);
if (kExitOpTypes.count(node_type) > 0) {
SetControlFlowGroup(n, group_index);
}
for (auto node : switch_node->GetOutDataNodes()) {
// Switch --> Exit
// Switch --> Cast --> Exit
// Switch --> TransData --> Cast --> Exit
for (uint8_t i = 0; i < kMaxTransOp; ++i) {
if (node->GetInDataNodes().size() != kTransOpIoSize || node->GetAllOutDataAnchorsSize() != kTransOpIoSize) {
break;
} }

if (kExitOpTypes.count(NodeUtils::GetNodeType(node)) > 0) {
SetControlFlowGroup(node, group_index);
break;
}

const auto &all_nodes = node->GetOutAllNodes();
if (all_nodes.size() != kTransOpIoSize) {
break;
}
node = all_nodes.at(0);
} }
} }
} }


+ 3
- 2
ge/graph/passes/switch_to_stream_switch_pass.cc View File

@@ -395,8 +395,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr &
peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str());


int64_t group_index = -1; int64_t group_index = -1;
(void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
SetControlFlowGroup(stream_switch, group_index);
if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) {
SetControlFlowGroup(stream_switch, group_index);
}
return stream_switch; return stream_switch;
} }




+ 32
- 11
ge/hybrid/executor/node_state.cc View File

@@ -326,17 +326,45 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() {
} }


void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) {
if (node_item_->root_data_.count(input_idx) > 0) {
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx);
root_tensor_values_[input_idx] = tensor;
const auto is_persist_tensor = [](const std::map<const NodeItem *, std::set<int>> &items, int idx) {
const auto is_exist = [&idx](const std::pair<const NodeItem *, std::set<int>> &items) {
return items.second.count(idx) > 0;
};
return std::any_of(items.begin(), items.end(), is_exist);
};

if (root_tensor_values_.count(input_idx) > 0) {
return;
} }


if (node_item_->enter_data_.count(input_idx) > 0) {
if (is_persist_tensor(node_item_->root_data_, input_idx)) {
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx);
root_tensor_values_[input_idx] = tensor;
} else if (is_persist_tensor(node_item_->enter_data_, input_idx)) {
GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx);
root_tensor_values_[input_idx] = tensor; root_tensor_values_[input_idx] = tensor;
} }
} }


void NodeState::UpdatePersistTensor() {
const auto update_tensor = [&](const std::map<const NodeItem *, std::set<int>> &items) {
for (const auto &item : items) {
for (const auto idx : item.second) {
UpdatePersistTensor(idx);
}
}
};

if (root_tensor_values_.empty()) {
return;
}

update_tensor(node_item_->root_data_);
if (iteration_count_ > 0) {
update_tensor(node_item_->enter_data_);
}
}

void NodeState::UpdatePersistTensor(int input_idx) { void NodeState::UpdatePersistTensor(int input_idx) {
const auto it = root_tensor_values_.find(input_idx); const auto it = root_tensor_values_.find(input_idx);
if (it == root_tensor_values_.end()) { if (it == root_tensor_values_.end()) {
@@ -363,16 +391,9 @@ void NodeState::ResetContext(uint64_t iteration) {


data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
for (auto item : node_item_->root_data_) {
UpdatePersistTensor(item.first);
}

if (iteration > 0) { if (iteration > 0) {
data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size());
ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size());
for (auto item : node_item_->enter_data_) {
UpdatePersistTensor(item.first);
}
} }


iteration_count_ = iteration; iteration_count_ = iteration;


+ 1
- 0
ge/hybrid/executor/node_state.h View File

@@ -132,6 +132,7 @@ struct NodeState {
void RunNextIteration(); void RunNextIteration();


void SavePersistTensor(int input_idx, const TensorValue &tensor); void SavePersistTensor(int input_idx, const TensorValue &tensor);
void UpdatePersistTensor();


Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const;




+ 1
- 0
ge/hybrid/executor/worker/execution_engine.cc View File

@@ -373,6 +373,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state,
auto executor = node_item.node_executor; auto executor = node_item.node_executor;
GE_CHECK_NOTNULL(executor); GE_CHECK_NOTNULL(executor);
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start");
node_state.UpdatePersistTensor();
GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.",
node_state.GetName().c_str()); node_state.GetName().c_str());
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End");


+ 0
- 4
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -288,10 +288,6 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n
return SUCCESS; return SUCCESS;
} }


if (node->GetType() == MEMCPYASYNC) { // Convert MemcpyAsync to Identity.
node->GetOpDesc()->SetType(IDENTITY);
}

std::unique_ptr<NodeItem> new_node; std::unique_ptr<NodeItem> new_node;
GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName());
GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor));


+ 34
- 14
ge/hybrid/model/node_item.cc View File

@@ -14,10 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */


#include "node_item.h"
#include <sstream>
#include "common/debug/log.h"
#include "graph/common/omg_util.h"
#include "hybrid/model/node_item.h"

#include "graph/compute_graph.h" #include "graph/compute_graph.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "hybrid/executor/worker/shape_inference_engine.h" #include "hybrid/executor/worker/shape_inference_engine.h"
@@ -26,6 +24,8 @@
namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
namespace { namespace {
const uint8_t kMaxTransCount = 3;
const uint32_t kTransOpIoSize = 1;
const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph";
const char *const kNodeTypeRetVal = "_RetVal"; const char *const kNodeTypeRetVal = "_RetVal";
const std::set<std::string> kControlOpTypes{ const std::set<std::string> kControlOpTypes{
@@ -41,6 +41,25 @@ const std::set<std::string> kMergeOpTypes{
MERGE, REFMERGE, STREAMMERGE MERGE, REFMERGE, STREAMMERGE
}; };


bool IsEnterFeedNode(NodePtr node) {
// For: Enter -> node
// For: Enter -> Cast -> node
// For: Enter -> TransData -> Cast -> node
for (uint8_t i = 0; i < kMaxTransCount; ++i) {
if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) {
GELOGD("Node[%s] is Enter feed node.", node->GetName().c_str());
return true;
}

const auto all_nodes = node->GetInDataNodes();
if (all_nodes.size() != kTransOpIoSize || node->GetAllInDataAnchorsSize() != kTransOpIoSize) {
return false;
}
node = all_nodes.at(0);
}
return false;
}

Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) {
uint32_t parent_index = 0; uint32_t parent_index = 0;
if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
@@ -98,8 +117,7 @@ Status ParseFusedSubgraph(NodeItem &node_item) {
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
auto op_desc = node->GetOpDesc(); auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type));
const std::string node_type = NodeUtils::GetNodeType(node);
if (node_type == DATA) { if (node_type == DATA) {
GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph)); GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph));
} else if (node_type == kNodeTypeRetVal) { } else if (node_type == kNodeTypeRetVal) {
@@ -398,19 +416,21 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) {
data_send_.emplace(node_item); data_send_.emplace(node_item);
node_item->data_recv_[this] = anchor_index; node_item->data_recv_[this] = anchor_index;
if (is_root_node_) { if (is_root_node_) {
node_item->root_data_[anchor_index] = this;
auto &data_anchors = node_item->root_data_[this];
data_anchors.emplace(anchor_index);
} }
// If Enter feed Not Merge, take as root Node. // If Enter feed Not Merge, take as root Node.
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) {
node_item->enter_data_[anchor_index] = this;
if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE)) {
auto &data_anchors = node_item->enter_data_[this];
data_anchors.emplace(anchor_index);
} }
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());
} }


void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) {
if (switch_index < switch_groups_.size()) { if (switch_index < switch_groups_.size()) {
std::vector<const NodeItem *> &switch_group = switch_groups_[switch_index];
switch_group.emplace_back(node_item);
auto &switch_group = switch_groups_[switch_index];
switch_group.emplace(node_item);
} else { } else {
ctrl_send_.insert(node_item); ctrl_send_.insert(node_item);
} }
@@ -420,7 +440,7 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) {
node_item->root_ctrl_.emplace(this); node_item->root_ctrl_.emplace(this);
} }
// If Enter feed control signal, take as root Node. // If Enter feed control signal, take as root Node.
if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) {
if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) {
node_item->enter_ctrl_.emplace(this); node_item->enter_ctrl_.emplace(this);
} }
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());
@@ -433,8 +453,8 @@ void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) {
} }


// this is StreamMerge node, node_item is StreamActive node. // this is StreamMerge node, node_item is StreamActive node.
std::vector<const NodeItem *> &switch_group = switch_groups_[merge_index];
switch_group.emplace_back(node_item);
auto &switch_group = switch_groups_[merge_index];
switch_group.emplace(node_item);


node_item->ctrl_send_.emplace(this); node_item->ctrl_send_.emplace(this);
GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str());


+ 3
- 3
ge/hybrid/model/node_item.h View File

@@ -148,14 +148,14 @@ struct NodeItem {
int64_t frame_index_ = -1; int64_t frame_index_ = -1;
int64_t parent_frame_ = -1; int64_t parent_frame_ = -1;
std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node
std::map<int, const NodeItem *> root_data_; // Recv data from root node
std::map<const NodeItem *, std::set<int>> root_data_; // Recv data from root node
std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node
std::map<int, const NodeItem *> enter_data_; // Recv data from Enter node
std::map<const NodeItem *, std::set<int>> enter_data_; // Recv data from Enter node
std::set<const NodeItem *> data_send_; // Send data notify to std::set<const NodeItem *> data_send_; // Send data notify to
std::map<const NodeItem *, int> data_recv_; // Recv data notify from std::map<const NodeItem *, int> data_recv_; // Recv data notify from
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to
std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from
std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to
std::vector<std::set<const NodeItem *>> switch_groups_; // Send ctrl notify to


std::shared_ptr<NodeTask> kernel_task; std::shared_ptr<NodeTask> kernel_task;
std::unique_ptr<FusedSubgraph> fused_subgraph; std::unique_ptr<FusedSubgraph> fused_subgraph;


+ 1
- 0
ge/hybrid/node_executor/hccl/hccl_node_executor.cc View File

@@ -342,6 +342,7 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do
GE_CHK_RT_RET(rtEventDestroy(evt)); GE_CHK_RT_RET(rtEventDestroy(evt));
} }
GELOGI("rdma callback success."); GELOGI("rdma callback success.");
return SUCCESS;
}; };


HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback);


+ 2
- 6
ge/hybrid/node_executor/rts/rts_node_executor.cc View File

@@ -17,13 +17,9 @@
#include "hybrid/node_executor/rts/rts_node_executor.h" #include "hybrid/node_executor/rts/rts_node_executor.h"
#include "hybrid/node_executor/rts/rts_task_factory.h" #include "hybrid/node_executor/rts/rts_task_factory.h"


#include "common/debug/log.h"
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "common/types.h"
#include "graph/common/omg_util.h"
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"
#include "hybrid/model/hybrid_model.h" #include "hybrid/model/hybrid_model.h"
#include "runtime/rt.h"


namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
@@ -33,6 +29,7 @@ REGISTER_RTS_TASK_CREATOR(IDENTITY, IdentityNodeTask);
REGISTER_RTS_TASK_CREATOR(IDENTITYN, IdentityNNodeTask); REGISTER_RTS_TASK_CREATOR(IDENTITYN, IdentityNNodeTask);
REGISTER_RTS_TASK_CREATOR(READVARIABLEOP, ReadVariableOpNodeTask); REGISTER_RTS_TASK_CREATOR(READVARIABLEOP, ReadVariableOpNodeTask);
REGISTER_RTS_TASK_CREATOR(PROFILINGTRAININGTRACE, ProfilingTraceNodeTask); REGISTER_RTS_TASK_CREATOR(PROFILINGTRAININGTRACE, ProfilingTraceNodeTask);
REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, IdentityNodeTask);


Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) {
auto input_desc = context.MutableInputDesc(index); auto input_desc = context.MutableInputDesc(index);
@@ -133,8 +130,7 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function<
Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const {
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
GELOGD("[%s] Load for local task.", node->GetName().c_str()); GELOGD("[%s] Load for local task.", node->GetName().c_str());
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed.");
const std::string node_type = NodeUtils::GetNodeType(node);
RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type);
if (rts_task == nullptr) { if (rts_task == nullptr) {
GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str()); GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str());


+ 0
- 29
ge/hybrid/node_executor/rts/rts_node_task.cc View File

@@ -43,7 +43,6 @@ namespace hybrid {
REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask); REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask);
REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask); REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask);
REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask); REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask);
REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, MemcpyAsyncNodeTask);


REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask); REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask);
REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask); REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask);
@@ -168,34 +167,6 @@ Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::functio
return SUCCESS; return SUCCESS;
} }


Status MemcpyAsyncNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", task_context.GetNodeName());
auto input_desc = task_context.MutableInputDesc(0);
GE_CHECK_NOTNULL(input_desc);
int64_t copy_size = 0;
GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*input_desc, copy_size));
// copy_size would not be negative since GetTensorSizeInBytes returned successfully.
if (copy_size > 0) {
const auto in_v = task_context.MutableInput(0);
const auto out_v = task_context.MutableOutput(0);
GE_CHECK_NOTNULL(in_v);
GE_CHECK_NOTNULL(out_v);
GELOGD("[%s] input size: %zu, output size: %zu, copy size: %ld", task_context.GetNodeName(),
in_v->GetSize(), out_v->GetSize(), copy_size);
GE_CHK_RT_RET(rtMemcpyAsync(out_v->MutableData(), out_v->GetSize(), in_v->GetData(), copy_size,
RT_MEMCPY_DEVICE_TO_DEVICE, task_context.GetStream()));
} else {
GELOGW("[%s] invalid copy size: %ld", task_context.GetNodeName(), copy_size);
}

if (done_callback) {
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
}

GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
return SUCCESS;
}

Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", task_context.GetNodeName()); GELOGD("[%s] Start to execute.", task_context.GetNodeName());
const auto in_x = task_context.GetInput(0); // x const auto in_x = task_context.GetInput(0); // x


+ 0
- 5
ge/hybrid/node_executor/rts/rts_node_task.h View File

@@ -60,11 +60,6 @@ class StreamMergeNodeTask : public RtsNodeTask {
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override;
}; };


class MemcpyAsyncNodeTask : public RtsNodeTask {
public:
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override;
};

class PassThroughNodeTask : public RtsNodeTask { class PassThroughNodeTask : public RtsNodeTask {
public: public:
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override;


+ 1
- 4
ge/hybrid/node_executor/task_context.cc View File

@@ -458,10 +458,6 @@ Status TaskContext::PropagateOutputs() {
subgraph_context_->all_inputs_[input_offset].SetName( subgraph_context_->all_inputs_[input_offset].SetName(
node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx));
} }

auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item);
GE_CHECK_NOTNULL(dst_node_state);
dst_node_state->SavePersistTensor(dst_input_idx, *tensor);
} }
} }
(void)guard; (void)guard;
@@ -493,6 +489,7 @@ void TaskContext::ReleaseInputsAndOutputs() {
void TaskContext::ReleaseInput(int index) { void TaskContext::ReleaseInput(int index) {
auto input_tensor = MutableInput(index); auto input_tensor = MutableInput(index);
if (input_tensor != nullptr) { if (input_tensor != nullptr) {
node_state_->SavePersistTensor(index, *input_tensor);
input_tensor->Destroy(); input_tensor->Destroy();
GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index);
} }


+ 4
- 0
tests/depends/mmpa/src/mmpa_stub.cc View File

@@ -345,6 +345,10 @@ INT32 mmIsDir(const CHAR *fileName)


INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len) INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len)
{ {
const char *env = getenv(name);
if (env != nullptr) {
strcpy(value, env);
}
return 0; return 0;
} }




+ 1
- 2
tests/ut/ge/CMakeLists.txt View File

@@ -726,7 +726,6 @@ set(PASS_TEST_FILES
"graph/passes/memcpy_addr_async_unittest.cc" "graph/passes/memcpy_addr_async_unittest.cc"
"graph/passes/hccl_continuous_pass_unittest.cc" "graph/passes/hccl_continuous_pass_unittest.cc"
"graph/passes/hccl_memcpy_pass_unittest.cc" "graph/passes/hccl_memcpy_pass_unittest.cc"
) )


set(KERNEL_TEST_FILES set(KERNEL_TEST_FILES
@@ -859,7 +858,6 @@ set(HYBRID_TEST_FILES
"hybrid/executor/hybrid_model_async_executor_unittest.cc" "hybrid/executor/hybrid_model_async_executor_unittest.cc"
"hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc"
"hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc"

) )


set(OTHERS_TEST_FILES set(OTHERS_TEST_FILES
@@ -887,6 +885,7 @@ add_library(ge_ut_graph STATIC


target_compile_definitions(ge_ut_graph PRIVATE target_compile_definitions(ge_ut_graph PRIVATE
google=ascend_private google=ascend_private
FMK_SUPPORT_DUMP
) )


target_compile_options(ge_ut_graph PRIVATE target_compile_options(ge_ut_graph PRIVATE


+ 5
- 5
tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc View File

@@ -349,7 +349,7 @@ class UtestLogicalStreamAllocator : public testing::Test {
/// B --> C(AllReduce) --- D /// B --> C(AllReduce) --- D
/// / /// /
/// stream id: 0 A /// stream id: 0 A
/// \
/// \.
/// E --> F(AllReduce) --- G /// E --> F(AllReduce) --- G
/// stream id: 2 2 2 /// stream id: 2 2 2
/// ///
@@ -599,7 +599,7 @@ TEST_F(UtestLogicalStreamAllocator, test_label_not_reusable2) {


/// case of multi-output, then unuse stream /// case of multi-output, then unuse stream
/// sub1 /// sub1
/// / | \
/// / | \.
/// sub2 sub3 sub4 /// sub2 sub3 sub4
TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) {
SubGraphInfoPtr data = CreateDataSubgraph(); SubGraphInfoPtr data = CreateDataSubgraph();
@@ -624,7 +624,7 @@ TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) {


/// if paralle id 1, then use stream /// if paralle id 1, then use stream
/// sub1 /// sub1
/// / | | \
/// / | | \.
/// sub2 sub3 sub4 sub5 /// sub2 sub3 sub4 sub5
TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { TEST_F(UtestLogicalStreamAllocator, test_parallel_one) {
SubGraphInfoPtr data = CreateDataSubgraph(); SubGraphInfoPtr data = CreateDataSubgraph();
@@ -653,7 +653,7 @@ TEST_F(UtestLogicalStreamAllocator, test_parallel_one) {


/// if the param of engine independent is true, then set independent stream /// if the param of engine independent is true, then set independent stream
/// sub1 /// sub1
/// / | | \
/// / | | \.
/// sub2 sub3 sub4 sub5 /// sub2 sub3 sub4 sub5
TEST_F(UtestLogicalStreamAllocator, test_independent) { TEST_F(UtestLogicalStreamAllocator, test_independent) {
SubGraphInfoPtr data = CreateDataSubgraph(); SubGraphInfoPtr data = CreateDataSubgraph();
@@ -692,7 +692,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent) {


/// set stream based on stream label, and then based on independent /// set stream based on stream label, and then based on independent
/// sub1 /// sub1
/// / | | \
/// / | | \.
/// sub2 sub3 sub4 sub5 /// sub2 sub3 sub4 sub5
TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) {
SubGraphInfoPtr data = CreateDataSubgraph(); SubGraphInfoPtr data = CreateDataSubgraph();


+ 1
- 1
tests/ut/ge/graph/build/stream_allocator_unittest.cc View File

@@ -36,7 +36,7 @@ class UtestStreamAllocator : public testing::Test {


/// ///
/// A /// A
/// / \
/// / \.
/// B C /// B C
/// | | /// | |
/// D 400 /// D 400


+ 18
- 0
tests/ut/ge/graph/load/model_manager_unittest.cc View File

@@ -438,4 +438,22 @@ TEST_F(UtestModelManagerModelManager, test_data_input_tensor) {
auto ret = mm.DataInputTensor(model_id,inputs); auto ret = mm.DataInputTensor(model_id,inputs);
EXPECT_EQ(PARAM_INVALID, ret); // HybridDavinciModel::impl_ is null. EXPECT_EQ(PARAM_INVALID, ret); // HybridDavinciModel::impl_ is null.
} }

TEST_F(UtestModelManagerModelManager, test_launch_kernel_cust_aicpu) {
ModelManager mm;

// cust_aicpu_so_ is empty.
EXPECT_EQ(mm.LaunchKernelCustAicpuSo("empty_cust_aicpu"), SUCCESS);

// deleteCustOp after Launch will deleted.
uintptr_t resource_id = 1; // for rtCtxGetCurrent stub
std::vector<char> kernel_bin(256);
auto &cust_resource_001 = mm.cust_aicpu_so_[resource_id];
auto tbe_kernel = std::shared_ptr<OpKernelBin>(new OpKernelBin("deleteCustOp", std::move(kernel_bin)));
auto &cust_opkernel_001 = cust_resource_001["deleteCustOp"] = tbe_kernel;

EXPECT_FALSE(mm.cust_aicpu_so_.empty());
EXPECT_EQ(mm.LaunchKernelCustAicpuSo("deleteCustOp"), SUCCESS);
EXPECT_TRUE(mm.cust_aicpu_so_.empty());
}
} // namespace ge } // namespace ge

+ 3
- 3
tests/ut/ge/graph/passes/assert_pass_unittest.cc View File

@@ -55,7 +55,7 @@ class UtestGraphPassesAssertPass : public Test {
}; };


/// D E /// D E
/// | \ | \
/// | \ | \.
/// F C G /// F C G
/// : | : /// : | :
/// H A I /// H A I
@@ -134,8 +134,8 @@ TEST_F(UtestGraphPassesAssertPass, assert_pass_test2) {
EXPECT_EQ(graph->FindNode("D"), nullptr); EXPECT_EQ(graph->FindNode("D"), nullptr);
} }


/// E F
/// | \ | \
/// E F
/// | \ | \.
/// H C -> D G /// H C -> D G
/// \ | : /// \ | :
/// A I /// A I


+ 7
- 7
tests/ut/ge/graph/passes/base_pass_unittest.cc View File

@@ -130,7 +130,7 @@ class UTESTGraphPassesBasePass : public testing::Test {
/// reshape1 /// reshape1
/// | /// |
/// add1 /// add1
/// / \
/// / \.
/// | | /// | |
/// data1 const1 /// data1 const1
ComputeGraphPtr BuildGraph1() { ComputeGraphPtr BuildGraph1() {
@@ -148,9 +148,9 @@ ComputeGraphPtr BuildGraph1() {
} }


/// sum1 /// sum1
/// / \
/// / \
/// / \
/// / \.
/// / \.
/// / \.
/// reshape1 addn1 /// reshape1 addn1
/// | c | /// | c |
/// add1 <--- shape1 /// add1 <--- shape1
@@ -217,7 +217,7 @@ void CheckIterOrder(UtestTestPass *pass, std::vector<std::unordered_set<std::str
/// Op1 /// Op1
/// | /// |
/// Merge /// Merge
/// / \
/// / \.
/// Op2 Op3 /// Op2 Op3
TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) {
auto builder = ut::GraphBuilder("g1"); auto builder = ut::GraphBuilder("g1");
@@ -245,7 +245,7 @@ TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) {
/// Op1 /// Op1
/// | /// |
/// Merge /// Merge
/// / \
/// / \.
/// Op2 Op3 /// Op2 Op3
TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { TEST_F(UTESTGraphPassesBasePass, del_isolate_success) {
auto builder = ut::GraphBuilder("g1"); auto builder = ut::GraphBuilder("g1");
@@ -459,7 +459,7 @@ TEST_F(UTESTGraphPassesBasePass, while_loop) {
/// data1 const /// data1 const
/// \ / /// \ /
/// while /// while
/// / \
/// / \.
/// | | /// | |
/// cast1 cast2 /// cast1 cast2
ComputeGraphPtr BuildWhileGraph1() { ComputeGraphPtr BuildWhileGraph1() {


+ 3
- 3
tests/ut/ge/graph/passes/cond_branch_v1_unittest.cc View File

@@ -34,11 +34,11 @@ namespace {
/// net_output /// net_output
/// | /// |
/// merge /// merge
/// / \
/// / \.
/// square add /// square add
/// F| T/ T\
/// F| T/ T\.
/// switch1 switch2 /// switch1 switch2
/// / \ / \
/// / \ / \.
/// var1 var2 var3 /// var1 var2 var3
/// ///
ComputeGraphPtr BuildGraph1() { ComputeGraphPtr BuildGraph1() {


+ 19
- 19
tests/ut/ge/graph/passes/constant_folding_pass_unittest.cc View File

@@ -173,8 +173,8 @@ namespace {
/// shapeNo1 /// shapeNo1
/// | /// |
/// addnYes1 /// addnYes1
/// / \
/// / \
/// / \.
/// / \.
/// const1 const2 /// const1 const2
ComputeGraphPtr BuildGraph1() { ComputeGraphPtr BuildGraph1() {
auto builder = ut::GraphBuilder("test"); auto builder = ut::GraphBuilder("test");
@@ -223,8 +223,8 @@ ComputeGraphPtr BuildGraph2() {
/// shapeNo1 /// shapeNo1
/// | c /// | c
/// addnYes1 <----- dataNo1 /// addnYes1 <----- dataNo1
/// / \
/// / \
/// / \.
/// / \.
/// const1 const2 /// const1 const2
ComputeGraphPtr BuildGraph3() { ComputeGraphPtr BuildGraph3() {
auto builder = ut::GraphBuilder("test"); auto builder = ut::GraphBuilder("test");
@@ -249,8 +249,8 @@ ComputeGraphPtr BuildGraph3() {
/// shapeNo1 /// shapeNo1
/// | c /// | c
/// addnYes1 <--------- /// addnYes1 <---------
/// / \ \
/// / \ c \
/// / \ \.
/// / \ c \.
/// const1 const2 <----- dataNo1 /// const1 const2 <----- dataNo1
ComputeGraphPtr BuildGraph4() { ComputeGraphPtr BuildGraph4() {
auto builder = ut::GraphBuilder("test"); auto builder = ut::GraphBuilder("test");
@@ -276,7 +276,7 @@ ComputeGraphPtr BuildGraph4() {
/// shapeNo1 /// shapeNo1
/// | c /// | c
/// addnYes1 <----- dataNo1 /// addnYes1 <----- dataNo1
/// / \
/// / \.
/// / \ c /// / \ c
/// const1 const2 <----- dataNo2 /// const1 const2 <----- dataNo2
ComputeGraphPtr BuildGraph5() { ComputeGraphPtr BuildGraph5() {
@@ -306,8 +306,8 @@ ComputeGraphPtr BuildGraph5() {
/// addYes1 <---- const3 /// addYes1 <---- const3
/// | /// |
/// addnYes1 <- /// addnYes1 <-
/// / \ \
/// / \ \
/// / \ \.
/// / \ \.
/// const1 const2 const4 /// const1 const2 const4
ComputeGraphPtr BuildGraph6() { ComputeGraphPtr BuildGraph6() {
auto builder = ut::GraphBuilder("test"); auto builder = ut::GraphBuilder("test");
@@ -332,12 +332,12 @@ ComputeGraphPtr BuildGraph6() {
} }


/// netoutput1 /// netoutput1
/// / \
/// / \.
/// shapeNo1 ShpaeNo2 /// shapeNo1 ShpaeNo2
/// \ / /// \ /
/// huberLoss1 /// huberLoss1
/// / | \
/// / | \
/// / | \.
/// / | \.
/// const1 const2 const3 /// const1 const2 const3
ComputeGraphPtr BuildGraph7() { ComputeGraphPtr BuildGraph7() {
auto builder = ut::GraphBuilder("test"); auto builder = ut::GraphBuilder("test");
@@ -365,8 +365,8 @@ ComputeGraphPtr BuildGraph7() {
/// shapeNo1 /// shapeNo1
/// | /// |
/// addnNo1 /// addnNo1
/// / \
/// / \
/// / \.
/// / \.
/// const1 const2 /// const1 const2
ComputeGraphPtr BuildGraph8() { ComputeGraphPtr BuildGraph8() {
auto builder = ut::GraphBuilder("test"); auto builder = ut::GraphBuilder("test");
@@ -389,8 +389,8 @@ ComputeGraphPtr BuildGraph8() {
/// shapeNo1 /// shapeNo1
/// | /// |
/// addnYes1 /// addnYes1
/// / \
/// / \
/// / \.
/// / \.
/// const1 data1 /// const1 data1
ComputeGraphPtr BuildGraph9() { ComputeGraphPtr BuildGraph9() {
auto builder = ut::GraphBuilder("test"); auto builder = ut::GraphBuilder("test");
@@ -409,12 +409,12 @@ ComputeGraphPtr BuildGraph9() {
} }


/// netoutput1 /// netoutput1
/// / \
/// / \.
/// addDim sqrt1 /// addDim sqrt1
/// \ / /// \ /
/// switch1 /// switch1
/// / \
/// / \
/// / \.
/// / \.
/// const1 const2 /// const1 const2
ComputeGraphPtr BuildGraph10() { ComputeGraphPtr BuildGraph10() {
auto builder = ut::GraphBuilder("test"); auto builder = ut::GraphBuilder("test");


+ 4
- 4
tests/ut/ge/graph/passes/dimension_compute_pass_unittest.cc View File

@@ -63,8 +63,8 @@ namespace {
/// shapeNo1 /// shapeNo1
/// | /// |
/// addnNo1 /// addnNo1
/// / \
/// / \
/// / \.
/// / \.
/// const1 const2 /// const1 const2
ComputeGraphPtr BuildGraph8() { ComputeGraphPtr BuildGraph8() {
auto builder = ut::GraphBuilder("test"); auto builder = ut::GraphBuilder("test");
@@ -87,8 +87,8 @@ ComputeGraphPtr BuildGraph8() {
/// shapeNo1 /// shapeNo1
/// | /// |
/// addnYes1 /// addnYes1
/// / \
/// / \
/// / \.
/// / \.
///const1 data1 ///const1 data1
ComputeGraphPtr BuildGraph9() { ComputeGraphPtr BuildGraph9() {
auto builder = ut::GraphBuilder("test"); auto builder = ut::GraphBuilder("test");


+ 1
- 1
tests/ut/ge/graph/passes/folding_kernel/ssd_prior_box_kernel_unittest.cc View File

@@ -46,7 +46,7 @@ class UtestGraphPassesFoldingKernelSsdPriorboxKernel : public testing::Test {
/// convolution data /// convolution data
/// | / /// | /
/// ssdpriorbox /// ssdpriorbox
/// \
/// \.
/// reshape /// reshape
class NodeBuilder { class NodeBuilder {
public: public:


+ 1
- 1
tests/ut/ge/graph/passes/fuse_data_nodes_with_common_input_pass_unittest.cc View File

@@ -120,7 +120,7 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) {


/// graph with subgraph /// graph with subgraph
/// const /// const
/// / \
/// / \.
/// cast1 cast1 /// cast1 cast1
/// \ / /// \ /
/// case /// case


+ 88
- 38
tests/ut/ge/graph/passes/mark_force_unknown_for_cond_pass_unittest.cc View File

@@ -69,62 +69,100 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string
return graph.AddNode(op_desc); return graph.AddNode(op_desc);
} }


static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge) {
static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge, vector<NodePtr> &loop, vector<NodePtr> &cond) {
/******************************************************************************* /*******************************************************************************
* Exit Identify
* \ / \.
* \ / \.
* Switch Add
* / | |
* / | |
* / | |
* LoopCond | |
* \ | |
* \ | |
* \ | |
* Less | |
* \ | NextIteration
* \ | |
* \ | |
* Merge <---------|
* |
* |
* Enter
* |
* +--------------------- Merge ----------------------+
* / |
* / |
* / |
* / |
* Exit Identify |
* \ / \. |
* \ / \. |
* Switch Add Add
* / | | |
* / | | |
* / | | |
* LoopCond | | |
* \ | | |
* \ | | |
* \ | | |
* Less | | |
* \ | NextIteration |
* \ | | |
* \ | | |
* Merge <---------| |
* | |
* | |
* Enter |
* \ |
* \ |
* Switch Switch
* | |
* +-----------------Equal----------------------+
* |
******************************************************************************/ ******************************************************************************/
auto data1 = CreateNode(*graph, "data", DATA, 1, 1);
auto data1 = CreateNode(*graph, "data1", DATA, 1, 1);
auto data2 = CreateNode(*graph, "data2", DATA, 1, 1);

auto equal1 = CreateNode(*graph, "equal1", EQUAL, 2, 1);
auto switch1 = CreateNode(*graph, "switch1", SWITCH, 2, 2);
auto switch2 = CreateNode(*graph, "switch2", SWITCH, 2, 2);

auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1);
auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2);
auto less1 = CreateNode(*graph, "less", LESS, 2, 1);
auto merge1 = CreateNode(*graph, "merge1", MERGE, 2, 2);
auto less1 = CreateNode(*graph, "less1", LESS, 2, 1);
auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1);
auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2);
auto switch3 = CreateNode(*graph, "switch3", SWITCH, 2, 2);
auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1);
auto add1 = CreateNode(*graph, "add", ADD, 2, 1);
auto add1 = CreateNode(*graph, "add1", ADD, 2, 1);
auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1);
auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1);
auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto value1 = CreateNode(*graph, "const1", CONSTANT, 0, 1);

auto value2 = CreateNode(*graph, "const2", CONSTANT, 0, 1);
auto add2 = CreateNode(*graph, "add2", ADD, 2, 1);
auto merge2 = CreateNode(*graph, "merge2", MERGE, 2, 2);
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1);


GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0));
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), equal1->GetInDataAnchor(0));
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), equal1->GetInDataAnchor(1));
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0));
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), switch2->GetInDataAnchor(0));
GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1));
GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch2->GetInDataAnchor(1));
cond.emplace_back(switch1);
cond.emplace_back(switch2);

GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); // false
GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0));


GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1));
GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch3->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch3->GetInDataAnchor(1));
loop.emplace_back(merge1);


GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0));
GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0));
GraphUtils::AddEdge(switch3->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); // false
GraphUtils::AddEdge(switch3->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); // true
loop.emplace_back(switch3);


GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1));
GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0));

GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1));
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0));


merge = merge1;
GraphUtils::AddEdge(switch2->GetOutDataAnchor(1), add2->GetInDataAnchor(1)); // true
GraphUtils::AddEdge(value2->GetOutDataAnchor(0), add2->GetInDataAnchor(0));

GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), merge2->GetInDataAnchor(0));
GraphUtils::AddEdge(add2->GetOutDataAnchor(0), merge2->GetInDataAnchor(1));
GraphUtils::AddEdge(merge2->GetOutDataAnchor(0), output1->GetInDataAnchor(0));

cond.emplace_back(merge2);
merge = merge2;
} }


static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) {
@@ -197,12 +235,24 @@ static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) {
TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) {
auto graph = std::make_shared<ComputeGraph>("test_graph"); auto graph = std::make_shared<ComputeGraph>("test_graph");
NodePtr merge; NodePtr merge;
CreateLoopGraph(graph, merge);
AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true);
vector<NodePtr> loop;
vector<NodePtr> cond;
CreateLoopGraph(graph, merge, loop, cond);


MarkForceUnknownForCondPass mark_force_unknown_pass; MarkForceUnknownForCondPass mark_force_unknown_pass;
EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond

EXPECT_EQ(loop.size(), 2);
for (const auto &node : loop) {
EXPECT_FALSE(node->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP));
}

EXPECT_EQ(cond.size(), 3);
for (const auto &node : cond) {
int64_t group_index = -1;
EXPECT_TRUE(AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index));
EXPECT_EQ(group_index, merge->GetOpDesc()->GetId());
}
} }


TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) { TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) {


+ 14
- 14
tests/ut/ge/graph/passes/merge_pass_unittest.cc View File

@@ -110,8 +110,8 @@ TEST_F(UtestGraphPassesMergePass, multiple_inputs) {
} }


/// Merge /// Merge
/// | \
/// | \
/// | \.
/// | \.
/// Op1 Op2 Merge2 /// Op1 Op2 Merge2
/// \ | | /// \ | |
/// \ | Op3 /// \ | Op3
@@ -137,10 +137,10 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_da
} }


/// Merge /// Merge
/// | \
/// | \
/// | \.
/// | \.
/// Op1 Op2 Merge2 /// Op1 Op2 Merge2
/// \ | | \
/// \ | | \.
/// \ | Op3 /// \ | Op3
/// \ | : /// \ | :
/// NetOutput /// NetOutput
@@ -165,8 +165,8 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_co


TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) {
/// Merge /// Merge
/// | \
/// | \
/// | \.
/// | \.
/// Op1 Op2 Merge2 /// Op1 Op2 Merge2
/// \ | | /// \ | |
/// \ | Op3 /// \ | Op3
@@ -210,7 +210,7 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) {
/// Op1 Op2 Merge2 /// Op1 Op2 Merge2
/// \ | /// \ |
/// \ Op3 /// \ Op3
/// \
/// \.
/// Merge3 /// Merge3


ret = pass_.Run(merge_node2); ret = pass_.Run(merge_node2);
@@ -224,7 +224,7 @@ TEST_F(UtestGraphPassesMergePass, single_non_const_input) {
/// Op1 /// Op1
/// | /// |
/// Merge /// Merge
/// / \
/// / \.
/// Op2 Op3 /// Op2 Op3
auto merge_node = NewNode("Merge", MERGE, 1, 2); auto merge_node = NewNode("Merge", MERGE, 1, 2);
auto node1 = NewNode("Op1", RELU, 1, 1); auto node1 = NewNode("Op1", RELU, 1, 1);
@@ -253,7 +253,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input) {
/// Const /// Const
/// | /// |
/// Merge Pass Const /// Merge Pass Const
/// / \ ===> / \
/// / \ ===> / \.
/// Op1 Op2 Op1 Op2 /// Op1 Op2 Op1 Op2
auto merge_node = NewNode("Merge", MERGE, 1, 2); auto merge_node = NewNode("Merge", MERGE, 1, 2);
auto const_node = NewNode("Const", CONSTANT, 1, 1); auto const_node = NewNode("Const", CONSTANT, 1, 1);
@@ -284,7 +284,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes)
/// / | ===> / \(control anchor) /// / | ===> / \(control anchor)
/// Op1 | \ Op1 Constant /// Op1 | \ Op1 Constant
/// Op2 Op3 | /// Op2 Op3 |
/// / \
/// / \.
/// Op2 Op3 /// Op2 Op3
auto merge_node = NewNode("Merge", MERGE, 1, 2); auto merge_node = NewNode("Merge", MERGE, 1, 2);
auto const_node = NewNode("Const", CONSTANT, 1, 1); auto const_node = NewNode("Const", CONSTANT, 1, 1);
@@ -329,7 +329,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes1)
/// / | ===> / \(control anchor) /// / | ===> / \(control anchor)
/// Op1 | \ Op1 Constant /// Op1 | \ Op1 Constant
/// Op2 Op3 | /// Op2 Op3 |
/// / \
/// / \.
/// Op2 Op3 /// Op2 Op3
auto merge_node = NewNode("Merge", MERGE, 1, 2); auto merge_node = NewNode("Merge", MERGE, 1, 2);
auto const_node = NewNode("Const", CONSTANT, 1, 1); auto const_node = NewNode("Const", CONSTANT, 1, 1);
@@ -357,7 +357,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) {
/// C /// C
/// | /// |
/// Merge /// Merge
/// / \
/// / \.
/// Op1 Op2 /// Op1 Op2
auto switch_node = NewNode("Switch", SWITCH, 1, 2); auto switch_node = NewNode("Switch", SWITCH, 1, 2);
auto identity_node = NewNode("Identity", SWITCH, 1, 1); auto identity_node = NewNode("Identity", SWITCH, 1, 1);
@@ -381,7 +381,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) {
/// . /// .
/// . /// .
/// C /// C
/// / \
/// / \.
/// Op1 Op2 /// Op1 Op2
auto ret = pass_.Run(merge_node); auto ret = pass_.Run(merge_node);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);


+ 6
- 6
tests/ut/ge/graph/passes/parallel_group_pass_unittest.cc View File

@@ -67,11 +67,11 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test {


void BuildDefaultGraph() { void BuildDefaultGraph() {
/// input /// input
/// \
/// \.
/// sqrt pred /// sqrt pred
/// \ / /// \ /
/// cast /// cast
/// / \
/// / \.
/// switch_t switch_f /// switch_t switch_f
/// | | /// | |
/// F T /// F T
@@ -119,13 +119,13 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test {


void BuildDefaultGraph1() { void BuildDefaultGraph1() {
/// input /// input
/// \
/// \.
/// sqrt pred /// sqrt pred
/// \ / /// \ /
/// Switch /// Switch
/// | | /// | |
/// ----F T---- /// ----F T----
/// \ | / \
/// \ | / \.
/// \ Merge1 Merge2 /// \ Merge1 Merge2
/// \_________| /// \_________|
input_node_ = NewNode("input", RELU, 0, 1); input_node_ = NewNode("input", RELU, 0, 1);
@@ -165,14 +165,14 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test {


void BuildDefaultGraph2() { void BuildDefaultGraph2() {
/// input input1 /// input input1
/// \ \
/// \ \.
/// sqrt pred sqrt1 pred1 /// sqrt pred sqrt1 pred1
/// \ / \ / /// \ / \ /
/// Switch Switch1 /// Switch Switch1
/// | | _______| /// | | _______|
/// | | / /// | | /
/// ____F T____ /// ____F T____
/// \ | / \
/// \ | / \.
/// \ Merge1 Merge2 /// \ Merge1 Merge2
/// \__________| /// \__________|
input_node_ = NewNode("input", RELU, 0, 2); input_node_ = NewNode("input", RELU, 0, 2);


+ 3
- 3
tests/ut/ge/graph/passes/reshape_recovery_pass_unittest.cc View File

@@ -31,9 +31,9 @@ class UtestReshapeRecoveryPass : public testing::Test {


namespace { namespace {
/// netoutput1 /// netoutput1
/// | \
///transdata1 \
/// | \
/// | \.
///transdata1 \.
/// | \.
/// | transdata2 /// | transdata2
/// | / /// | /
/// var1 const1 /// var1 const1


+ 8
- 8
tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc View File

@@ -35,7 +35,7 @@ namespace {
/// transdata1 /// transdata1
/// | /// |
/// reshape1 /// reshape1
/// | \
/// | \.
/// var1 const1 /// var1 const1
ut::GraphBuilder Graph1Builder() { ut::GraphBuilder Graph1Builder() {
ut::GraphBuilder builder = ut::GraphBuilder("g1"); ut::GraphBuilder builder = ut::GraphBuilder("g1");
@@ -55,11 +55,11 @@ ut::GraphBuilder Graph1Builder() {
} }


/// netoutput1 /// netoutput1
/// | \
///transdata1 \
/// | \
/// | \.
///transdata1 \.
/// | \.
/// reshape1 reshape2 /// reshape1 reshape2
/// | \ / \
/// | \ / \.
/// var1 const1 var2 /// var1 const1 var2
ut::GraphBuilder Graph2Builder() { ut::GraphBuilder Graph2Builder() {
ut::GraphBuilder builder = ut::GraphBuilder("g2"); ut::GraphBuilder builder = ut::GraphBuilder("g2");
@@ -83,9 +83,9 @@ ut::GraphBuilder Graph2Builder() {
} }


/// netoutput1 /// netoutput1
/// | \
///transdata1 \
/// | \
/// | \.
///transdata1 \.
/// | \.
/// reshape1 transdata2 /// reshape1 transdata2
/// | \ / /// | \ /
/// var1 const1 /// var1 const1


+ 1
- 1
tests/ut/ge/graph/passes/resource_pair_control_pass_unittest.cc View File

@@ -34,7 +34,7 @@ class UtestResourcePairControlPass : public testing::Test {


namespace { namespace {
/// netoutput1 /// netoutput1
/// | \
/// | \.
/// StackPush StackPop /// StackPush StackPop
/// | | /// | |
/// var1 const1 /// var1 const1


+ 6
- 6
tests/ut/ge/graph/passes/switch_logic_remove_pass_unittest.cc View File

@@ -63,9 +63,9 @@ ComputeGraphPtr BuildGraph1() {
/// netoutput1 /// netoutput1
/// | /// |
/// merge1 /// merge1
/// / \
/// / \.
/// / add1 /// / add1
/// / F| \
/// / F| \.
/// addn1 swtich2 var3 /// addn1 swtich2 var3
/// \F T/ | /// \F T/ |
/// switch1 | /// switch1 |
@@ -101,9 +101,9 @@ ComputeGraphPtr BuildGraph2() {
/// add1 /// add1
/// / \T /// / \T
/// var3 swtich2 /// var3 swtich2
/// T/ \
/// switch1 \
/// / \ \
/// T/ \.
/// switch1 \.
/// / \ \.
/// var1 var2 var4 /// var1 var2 var4
ComputeGraphPtr BuildGraph3() { ComputeGraphPtr BuildGraph3() {
auto builder = ut::GraphBuilder("g3"); auto builder = ut::GraphBuilder("g3");
@@ -129,7 +129,7 @@ ComputeGraphPtr BuildGraph3() {
/// netoutput1 /// netoutput1
/// | /// |
/// merge1 /// merge1
/// / \
/// / \.
/// add1 addn1 /// add1 addn1
/// / \T F/ /// / \T F/
/// var3 swtich2 /// var3 swtich2


+ 2
- 2
tests/ut/ge/graph/passes/trans_op_breadth_fusion_pass_unittest.cc View File

@@ -402,7 +402,7 @@ TEST_F(UtestGraphPassesTransOpBreadthFusionPass, test_multi_anchor_case) {
} }


/// ----> netoutput1 /// ----> netoutput1
/// / | \
/// / | \.
/// transdata1 transdata2 transdata3 /// transdata1 transdata2 transdata3
/// \ / | /// \ / |
/// var1-------------- /// var1--------------
@@ -432,7 +432,7 @@ static ComputeGraphPtr BuildGraph1() {
} }


/// ---------> netoutput1 /// ---------> netoutput1
/// / | \
/// / | \.
/// transdata1 transdata2(l1) transdata3(l1) /// transdata1 transdata2(l1) transdata3(l1)
/// \ / | /// \ / |
/// var1------------------ /// var1------------------


+ 7
- 7
tests/ut/ge/graph/passes/trans_op_depth_fusion_pass_unittest.cc View File

@@ -456,19 +456,19 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge)
/// -->transpose1 -->transpose3-->sinh2 /// -->transpose1 -->transpose3-->sinh2
/// | \ / /// | \ /
/// | -->transpose2 /// | -->transpose2
/// | \
/// | \.
/// / -->cast3-->cast4-->sinh3 /// / -->cast3-->cast4-->sinh3
/// / /// /
/// / -->transpose4-->transpose5-->sinh4 /// / -->transpose4-->transpose5-->sinh4
/// / / /// / /
/// Node4D-->Cast1-->Cast2-->Cast5 -->reshape2-->sinh5 /// Node4D-->Cast1-->Cast2-->Cast5 -->reshape2-->sinh5
/// \ \
/// \ \.
/// \ -->sinh6 /// \ -->sinh6
/// \
/// \.
/// \ -->transpose6-->transpose7-->sinh9 /// \ -->transpose6-->transpose7-->sinh9
/// \ / /// \ /
/// -->reshape-->cast6-->cast7-->sinh8 /// -->reshape-->cast6-->cast7-->sinh8
/// \
/// \.
/// -->sinh7 /// -->sinh7


/// after optimized graph /// after optimized graph
@@ -479,15 +479,15 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge)
/// / /-->transpose3-->sinh2 /// / /-->transpose3-->sinh2
/// -->Cast1 /// -->Cast1
/// / \-->sinh7 /// / \-->sinh7
/// / \
/// / \.
/// / -->sinh9 /// / -->sinh9
/// Node4D /// Node4D
/// \ -->sinh4 /// \ -->sinh4
/// \ / /// \ /
/// -->Cast5-->sinh5 /// -->Cast5-->sinh5
/// \ \
/// \ \.
/// \ -->sinh6 /// \ -->sinh6
/// \
/// \.
/// -->Cast7-->sinh8 /// -->Cast7-->sinh8
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");




+ 2
- 2
tests/ut/ge/graph/passes/transop_nearby_allreduce_fusion_pass_unittest.cc View File

@@ -180,7 +180,7 @@ ComputeGraphPtr GetGraph7(size_t symmetric_transdata_num, size_t asymmetric_tran
/// TransData TransData ... MatMul ... /// TransData TransData ... MatMul ...
/// \ | / / / /// \ | / / /
/// HcomAllReduce /// HcomAllReduce
/// / | \ \ \
/// / | \ \ \.
/// TransData TransData ... RealDiv ... /// TransData TransData ... RealDiv ...
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
NodePtr allreduce = NodePtr allreduce =
@@ -340,7 +340,7 @@ TEST(UtestTransopNearbyAllreduceFusionPass, test7_all_reduce_with_multiple_trans
/// TransData TransData ... MatMul ... /// TransData TransData ... MatMul ...
/// \ | / / / /// \ | / / /
/// HcomAllReduce /// HcomAllReduce
/// / | \ \ \
/// / | \ \ \.
/// TransData TransData ... RealDiv ... /// TransData TransData ... RealDiv ...
size_t symmetric_transdata_num = 20; size_t symmetric_transdata_num = 20;
size_t asymmetric_transdata_num = 20; size_t asymmetric_transdata_num = 20;


+ 1
- 1
tests/ut/ge/graph/passes/variable_op_pass_unittest.cc View File

@@ -66,7 +66,7 @@ namespace {
/// transdata2 /// transdata2
/// | /// |
/// assign1 /// assign1
/// / \
/// / \.
/// transdata1 | /// transdata1 |
/// | | /// | |
/// var1 const1 /// var1 const1


+ 5
- 5
tests/ut/ge/graph/variable_accelerate_ctrl_unittest.cc View File

@@ -35,8 +35,8 @@ namespace {
/// shapeNo1 /// shapeNo1
/// | /// |
/// addnYes1 /// addnYes1
/// / \
/// / \
/// / \.
/// / \.
/// const1 const2 /// const1 const2


ComputeGraphPtr BuildGraph1() { ComputeGraphPtr BuildGraph1() {
@@ -57,9 +57,9 @@ ComputeGraphPtr BuildGraph1() {


/// ///
/// netoutput1 /// netoutput1
/// / \ \
/// add1 assign1 \
/// / \ / \ \
/// / \ \.
/// add1 assign1 \.
/// / \ / \ \.
/// var1 var2 const1 var3 /// var1 var2 const1 var3


ComputeGraphPtr BuildGraph2() { ComputeGraphPtr BuildGraph2() {


Loading…
Cancel
Save