Merge pull request !1928 from 张晓昆/r1.5.0tags/v1.3.0
@@ -13,15 +13,15 @@ | |||||
* See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#include "host_cpu_engine.h" | |||||
#include "graph/common/omg_util.h" | |||||
#include "ge_local_engine/engine/host_cpu_engine.h" | |||||
#include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
#include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
#include "graph/utils/node_utils.h" | |||||
#include "graph/utils/type_utils.h" | |||||
#include "register/op_kernel_registry.h" | #include "register/op_kernel_registry.h" | ||||
#include "register/host_cpu_context.h" | #include "register/host_cpu_context.h" | ||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "common/ge/plugin_manager.h" | #include "common/ge/plugin_manager.h" | ||||
#include "graph/utils/type_utils.h" | |||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
#include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
@@ -123,10 +123,7 @@ bool HostCpuEngine::CheckSupported(const string &op_type) { | |||||
} | } | ||||
Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) { | Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) { | ||||
std::string op_type; | |||||
auto status = GetOriginalType(node, op_type); | |||||
GE_CHK_BOOL_EXEC_NOLOG(status == SUCCESS, return status); | |||||
const std::string op_type = NodeUtils::GetNodeType(node); | |||||
auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type); | auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type); | ||||
if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str()); | GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str()); | ||||
@@ -1378,7 +1378,9 @@ Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_ | |||||
Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | ||||
GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str()); | GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str()); | ||||
std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); | std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); | ||||
if (cust_aicpu_so_.size() == 0) return SUCCESS; | |||||
if (cust_aicpu_so_.empty()) { | |||||
return SUCCESS; | |||||
} | |||||
// get current context | // get current context | ||||
rtContext_t rt_cur_ctx = nullptr; | rtContext_t rt_cur_ctx = nullptr; | ||||
auto rt_error = rtCtxGetCurrent(&rt_cur_ctx); | auto rt_error = rtCtxGetCurrent(&rt_cur_ctx); | ||||
@@ -1394,17 +1396,19 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
vector<void *> allocated_mem; | |||||
rtError_t status; | |||||
rtStream_t stream = nullptr; | rtStream_t stream = nullptr; | ||||
vector<void *> allocated_mem; | |||||
std::function<void()> callback = [&]() { | std::function<void()> callback = [&]() { | ||||
for (auto mem : allocated_mem) { | for (auto mem : allocated_mem) { | ||||
GE_CHK_RT(rtFree(mem)); | GE_CHK_RT(rtFree(mem)); | ||||
} | } | ||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
if (stream != nullptr) { | |||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
} | |||||
}; | }; | ||||
GE_MAKE_GUARD(release, callback); | GE_MAKE_GUARD(release, callback); | ||||
rtError_t status; | |||||
vector<CustAicpuSoBuf> v_cust_so; | vector<CustAicpuSoBuf> v_cust_so; | ||||
void *args = nullptr; | void *args = nullptr; | ||||
@@ -284,9 +284,6 @@ Status DynamicShapePartitioner::InitClusters() { | |||||
auto cluster = MakeShared<Cluster>(rank++, type, node, this); | auto cluster = MakeShared<Cluster>(rank++, type, node, this); | ||||
REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed."); | REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed."); | ||||
node_2_cluster_[node] = cluster; | node_2_cluster_[node] = cluster; | ||||
if (cluster->IsUnknownShape()) { | |||||
ordered_cluster_.push_back(cluster); | |||||
} | |||||
int64_t group_index = -1; | int64_t group_index = -1; | ||||
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | ||||
@@ -306,7 +303,7 @@ Status DynamicShapePartitioner::InitClusters() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status DynamicShapePartitioner::TopologicalSortClusters() { | |||||
Status DynamicShapePartitioner::TopologicalSortClusters(const OrderedFilter &ordered_filter) { | |||||
ordered_cluster_.clear(); | ordered_cluster_.clear(); | ||||
// BFS topological sort clusters for known shape cluster | // BFS topological sort clusters for known shape cluster | ||||
std::queue<ClusterPtr> ready_clusters; | std::queue<ClusterPtr> ready_clusters; | ||||
@@ -331,7 +328,7 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { | |||||
auto cluster = ready_clusters.front(); | auto cluster = ready_clusters.front(); | ||||
ready_clusters.pop(); | ready_clusters.pop(); | ||||
cluster->UpdateRank(rank++); | cluster->UpdateRank(rank++); | ||||
if (cluster->IsKnownShape() || cluster->IsInputNode()) { | |||||
if (ordered_filter == nullptr || ordered_filter(cluster)) { | |||||
ordered_cluster_.push_back(cluster); | ordered_cluster_.push_back(cluster); | ||||
} | } | ||||
for (const auto &out_cluster : cluster->Outputs()) { | for (const auto &out_cluster : cluster->Outputs()) { | ||||
@@ -378,7 +375,6 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { | |||||
continue; | continue; | ||||
} | } | ||||
bool is_unknown_cluster = cluster->IsUnknownShape(); | |||||
for (++rit; rit != control_cluster.rend(); ++rit) { | for (++rit; rit != control_cluster.rend(); ++rit) { | ||||
const auto &cluster_from = *rit; | const auto &cluster_from = *rit; | ||||
if (all_merged_clusters.count(cluster_from) > 0) { | if (all_merged_clusters.count(cluster_from) > 0) { | ||||
@@ -395,11 +391,6 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
if (!is_unknown_cluster && cluster->IsUnknownShape()) { | |||||
GELOGD("Add to ordered cluster: %s", cluster->DebugString().c_str()); | |||||
ordered_cluster_.push_back(cluster); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -475,9 +466,19 @@ void DynamicShapePartitioner::MergeClustersInputData() { | |||||
} | } | ||||
Status DynamicShapePartitioner::MergeClusters() { | Status DynamicShapePartitioner::MergeClusters() { | ||||
const auto filter_known = [](const ClusterPtr &cluster) { | |||||
return cluster->IsKnownShape() || cluster->IsInputNode(); | |||||
}; | |||||
const auto filter_unknown = [](const ClusterPtr &cluster) { | |||||
return cluster->IsUnknownShape(); | |||||
}; | |||||
MergeClustersControlFlow(); | MergeClustersControlFlow(); | ||||
REQUIRE_SUCCESS(TopologicalSortClusters(filter_unknown), | |||||
"[TopologicalSort][Clusters] after merge control flow clusters failed."); | |||||
MergeClustersUnknownShape(); | MergeClustersUnknownShape(); | ||||
REQUIRE_SUCCESS(TopologicalSortClusters(), "[TopologicalSort][Clusters] after merge unknown shape clusters failed."); | |||||
REQUIRE_SUCCESS(TopologicalSortClusters(filter_known), | |||||
"[TopologicalSort][Clusters] after merge unknown shape clusters failed."); | |||||
MergeClustersKnownShape(); | MergeClustersKnownShape(); | ||||
MergeClustersInputData(); | MergeClustersInputData(); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -111,6 +111,8 @@ class DynamicShapePartitioner { | |||||
Status Partition(); | Status Partition(); | ||||
using OrderedFilter = std::function<bool(const std::shared_ptr<Cluster> &cluster)>; | |||||
private: | private: | ||||
Status PartitionImpl(); | Status PartitionImpl(); | ||||
// Collect nodes that satisfy the unknowshape rules: | // Collect nodes that satisfy the unknowshape rules: | ||||
@@ -138,7 +140,7 @@ class DynamicShapePartitioner { | |||||
// Merge clusters step3 | // Merge clusters step3 | ||||
void MergeClustersInputData(); | void MergeClustersInputData(); | ||||
// Topological sort clusters after merge unknown shape clusters. | // Topological sort clusters after merge unknown shape clusters. | ||||
Status TopologicalSortClusters(); | |||||
Status TopologicalSortClusters(const OrderedFilter &ordered_filter); | |||||
// Deduplicate merged clusters | // Deduplicate merged clusters | ||||
void PruneUniqueClusters(); | void PruneUniqueClusters(); | ||||
// Establish the input-output anchors for each partition of the cluster and record links to other clusters | // Establish the input-output anchors for each partition of the cluster and record links to other clusters | ||||
@@ -16,8 +16,6 @@ | |||||
#include "mark_force_unknown_for_cond_pass.h" | #include "mark_force_unknown_for_cond_pass.h" | ||||
#include <queue> | |||||
#include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
#include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
@@ -26,17 +24,7 @@ namespace { | |||||
inline bool IsMergeInLoop(const NodePtr &node) { | inline bool IsMergeInLoop(const NodePtr &node) { | ||||
const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | ||||
std::string node_type; | |||||
(void)GetOriginalType(node, node_type); | |||||
return kLoopMergeInputs.count(node_type) > 0; | |||||
} | |||||
inline bool IsSwitchInLoop(const NodePtr &node) { | |||||
const static std::set<std::string> kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND }; | |||||
std::string node_type; | |||||
(void)GetOriginalType(node, node_type); | |||||
return kLoopSwitchInputs.count(node_type) > 0; | |||||
return kLoopMergeInputs.count(NodeUtils::GetNodeType(node)) > 0; | |||||
} | } | ||||
} | } | ||||
@@ -44,10 +32,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||||
GELOGD("MarkForceUnknownForCondPass Enter"); | GELOGD("MarkForceUnknownForCondPass Enter"); | ||||
std::map<NodePtr, std::vector<NodePtr>> switch_groups; | std::map<NodePtr, std::vector<NodePtr>> switch_groups; | ||||
for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
std::string node_type; | |||||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), | |||||
"[Get][OriginalType] of node in graph:%s failed.", graph->GetName().c_str()); | |||||
if (kMergeOpTypes.count(node_type) == 0) { | |||||
if (kMergeOpTypes.count(NodeUtils::GetNodeType(node)) == 0) { | |||||
continue; | continue; | ||||
} | } | ||||
@@ -65,6 +50,51 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||||
} | } | ||||
/// | /// | ||||
/// @brief Deal with Switch node for LoopCond | |||||
/// @param [in] Switch node | |||||
/// @param [in] dest span | |||||
/// @param [out] Search queue | |||||
/// @return true: Switch In while loop / false: Not in while Loop. | |||||
/// | |||||
bool MarkForceUnknownForCondPass::DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, | |||||
std::queue<std::pair<NodePtr, uint32_t>> &search_queue) { | |||||
/// LoopCond --->\. | |||||
/// \. | |||||
/// Enter-----------+ \. | |||||
/// +--> Merge --> Switch --> Exit | |||||
/// NextIteration---+ | |||||
const auto is_loop_op = [](const NodePtr &n) { | |||||
return NodeUtils::GetNodeType(n) == LOOPCOND; | |||||
}; | |||||
const auto is_exit_op = [](const NodePtr &n) { | |||||
return kExitOpTypes.count(NodeUtils::GetNodeType(n)) > 0; | |||||
}; | |||||
const auto src_nodes = node->GetInAllNodes(); | |||||
const auto dst_nodes = node->GetOutAllNodes(); | |||||
if (std::none_of(src_nodes.begin(), src_nodes.end(), is_loop_op) && | |||||
std::none_of(dst_nodes.begin(), dst_nodes.end(), is_exit_op)) { | |||||
return false; | |||||
} | |||||
for (const auto &m : src_nodes) { | |||||
if (kMergeOpTypes.count(NodeUtils::GetNodeType(m)) > 0) { | |||||
for (const auto &n : m->GetInAllNodes()) { | |||||
if (kNextIterationOpTypes.count(NodeUtils::GetNodeType(n)) > 0) { | |||||
continue; | |||||
} | |||||
search_queue.push({n, dst_span}); | |||||
GELOGD("Travel in Loop: %s <-- %s <-- %s, span is: %u", node->GetName().c_str(), m->GetName().c_str(), | |||||
n->GetName().c_str(), dst_span); | |||||
} | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
/// | |||||
/// @brief Mark force unknown shape for Switch node | /// @brief Mark force unknown shape for Switch node | ||||
/// @param [in] merge node | /// @param [in] merge node | ||||
/// @param [out] switch group | /// @param [out] switch group | ||||
@@ -72,6 +102,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||||
/// | /// | ||||
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) { | void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) { | ||||
// Switch --> {Switch --> Merge} --> Merge | // Switch --> {Switch --> Merge} --> Merge | ||||
GELOGD("Search Switch node for Merge: %s", node->GetName().c_str()); | |||||
std::unordered_set<NodePtr> nodes_seen; | std::unordered_set<NodePtr> nodes_seen; | ||||
std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}}); | std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}}); | ||||
while (!search_queue.empty()) { | while (!search_queue.empty()) { | ||||
@@ -79,43 +110,25 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||||
const auto dst_span = search_queue.front().second; | const auto dst_span = search_queue.front().second; | ||||
search_queue.pop(); | search_queue.pop(); | ||||
// Switch --> Identity --> Constant | |||||
for (const auto &in_node : dst_node->GetInControlNodes()) { | |||||
if (nodes_seen.count(in_node) > 0) { | |||||
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
nodes_seen.insert(in_node); | |||||
if (in_node->GetType() == IDENTITY) { | |||||
GELOGD("Travel node: %s, In control: %s, span is: %u", dst_node->GetName().c_str(), | |||||
in_node->GetName().c_str(), dst_span); | |||||
search_queue.push({in_node, dst_span}); | |||||
} | |||||
} | |||||
for (const auto &in_node : dst_node->GetInDataNodes()) { | |||||
for (const auto &in_node : dst_node->GetInAllNodes()) { | |||||
if (nodes_seen.count(in_node) > 0) { | if (nodes_seen.count(in_node) > 0) { | ||||
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | ||||
continue; | continue; | ||||
} | } | ||||
nodes_seen.insert(in_node); | nodes_seen.insert(in_node); | ||||
std::string node_type; | |||||
(void)GetOriginalType(in_node, node_type); | |||||
const std::string node_type = NodeUtils::GetNodeType(in_node); | |||||
GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), | GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), | ||||
in_node->GetName().c_str(), dst_span); | in_node->GetName().c_str(), dst_span); | ||||
if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. | if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. | ||||
if (DealAsLoopSwitch(in_node, dst_span, search_queue)) { | |||||
continue; | |||||
} | |||||
if (dst_span > 0) { | if (dst_span > 0) { | ||||
search_queue.push({in_node, dst_span - 1}); | search_queue.push({in_node, dst_span - 1}); | ||||
} else { | } else { | ||||
const auto &all_in_nodes = in_node->GetInDataNodes(); | |||||
if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) { | |||||
GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(), | |||||
in_node->GetName().c_str()); | |||||
} else { | |||||
switch_group.emplace_back(in_node); | |||||
} | |||||
switch_group.emplace_back(in_node); | |||||
} | } | ||||
} else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. | } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. | ||||
search_queue.push({in_node, dst_span + 1}); | search_queue.push({in_node, dst_span + 1}); | ||||
@@ -19,6 +19,8 @@ | |||||
#include "inc/graph_pass.h" | #include "inc/graph_pass.h" | ||||
#include <queue> | |||||
namespace ge { | namespace ge { | ||||
class MarkForceUnknownForCondPass : public GraphPass { | class MarkForceUnknownForCondPass : public GraphPass { | ||||
public: | public: | ||||
@@ -26,6 +28,15 @@ class MarkForceUnknownForCondPass : public GraphPass { | |||||
private: | private: | ||||
/// | /// | ||||
/// @brief Deal with Switch node for LoopCond | |||||
/// @param [in] Switch node | |||||
/// @param [in] dest span | |||||
/// @param [out] Search queue | |||||
/// @return true: Switch In while loop / false: Not in while Loop. | |||||
/// | |||||
bool DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, std::queue<std::pair<NodePtr, uint32_t>> &search_queue); | |||||
/// | |||||
/// @brief Mark force unknown shape for Switch node | /// @brief Mark force unknown shape for Switch node | ||||
/// @param [in] merge node | /// @param [in] merge node | ||||
/// @param [out] switch group | /// @param [out] switch group | ||||
@@ -24,7 +24,9 @@ using std::string; | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
const int64_t kLoopType = 1; | |||||
constexpr int64_t kLoopType = 1; | |||||
constexpr uint8_t kMaxTransOp = 3; | |||||
constexpr uint8_t kTransOpIoSize = 1; | |||||
} | } | ||||
Status NextIterationPass::Run(ComputeGraphPtr graph) { | Status NextIterationPass::Run(ComputeGraphPtr graph) { | ||||
@@ -287,18 +289,25 @@ void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, i | |||||
std::string node_type; | std::string node_type; | ||||
for (const auto &switch_node : loop_group.switch_nodes) { | for (const auto &switch_node : loop_group.switch_nodes) { | ||||
SetControlFlowGroup(switch_node, group_index); | SetControlFlowGroup(switch_node, group_index); | ||||
for (const auto &node : switch_node->GetOutDataNodes()) { | |||||
(void)GetOriginalType(node, node_type); | |||||
if (kExitOpTypes.count(node_type) > 0) { | |||||
SetControlFlowGroup(node, group_index); | |||||
} else { | |||||
// For: Switch -> Cast -> Exit | |||||
for (const auto &n : node->GetOutDataNodes()) { | |||||
(void)GetOriginalType(n, node_type); | |||||
if (kExitOpTypes.count(node_type) > 0) { | |||||
SetControlFlowGroup(n, group_index); | |||||
} | |||||
for (auto node : switch_node->GetOutDataNodes()) { | |||||
// Switch --> Exit | |||||
// Switch --> Cast --> Exit | |||||
// Switch --> TransData --> Cast --> Exit | |||||
for (uint8_t i = 0; i < kMaxTransOp; ++i) { | |||||
if (node->GetInDataNodes().size() != kTransOpIoSize || node->GetAllOutDataAnchorsSize() != kTransOpIoSize) { | |||||
break; | |||||
} | } | ||||
if (kExitOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { | |||||
SetControlFlowGroup(node, group_index); | |||||
break; | |||||
} | |||||
const auto &all_nodes = node->GetOutAllNodes(); | |||||
if (all_nodes.size() != kTransOpIoSize) { | |||||
break; | |||||
} | |||||
node = all_nodes.at(0); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -395,8 +395,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||||
peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | ||||
int64_t group_index = -1; | int64_t group_index = -1; | ||||
(void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
SetControlFlowGroup(stream_switch, group_index); | |||||
if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||||
SetControlFlowGroup(stream_switch, group_index); | |||||
} | |||||
return stream_switch; | return stream_switch; | ||||
} | } | ||||
@@ -326,17 +326,45 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||||
} | } | ||||
void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { | void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { | ||||
if (node_item_->root_data_.count(input_idx) > 0) { | |||||
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | |||||
root_tensor_values_[input_idx] = tensor; | |||||
const auto is_persist_tensor = [](const std::map<const NodeItem *, std::set<int>> &items, int idx) { | |||||
const auto is_exist = [&idx](const std::pair<const NodeItem *, std::set<int>> &items) { | |||||
return items.second.count(idx) > 0; | |||||
}; | |||||
return std::any_of(items.begin(), items.end(), is_exist); | |||||
}; | |||||
if (root_tensor_values_.count(input_idx) > 0) { | |||||
return; | |||||
} | } | ||||
if (node_item_->enter_data_.count(input_idx) > 0) { | |||||
if (is_persist_tensor(node_item_->root_data_, input_idx)) { | |||||
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | |||||
root_tensor_values_[input_idx] = tensor; | |||||
} else if (is_persist_tensor(node_item_->enter_data_, input_idx)) { | |||||
GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); | GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); | ||||
root_tensor_values_[input_idx] = tensor; | root_tensor_values_[input_idx] = tensor; | ||||
} | } | ||||
} | } | ||||
void NodeState::UpdatePersistTensor() { | |||||
const auto update_tensor = [&](const std::map<const NodeItem *, std::set<int>> &items) { | |||||
for (const auto &item : items) { | |||||
for (const auto idx : item.second) { | |||||
UpdatePersistTensor(idx); | |||||
} | |||||
} | |||||
}; | |||||
if (root_tensor_values_.empty()) { | |||||
return; | |||||
} | |||||
update_tensor(node_item_->root_data_); | |||||
if (iteration_count_ > 0) { | |||||
update_tensor(node_item_->enter_data_); | |||||
} | |||||
} | |||||
void NodeState::UpdatePersistTensor(int input_idx) { | void NodeState::UpdatePersistTensor(int input_idx) { | ||||
const auto it = root_tensor_values_.find(input_idx); | const auto it = root_tensor_values_.find(input_idx); | ||||
if (it == root_tensor_values_.end()) { | if (it == root_tensor_values_.end()) { | ||||
@@ -363,16 +391,9 @@ void NodeState::ResetContext(uint64_t iteration) { | |||||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | ||||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | ||||
for (auto item : node_item_->root_data_) { | |||||
UpdatePersistTensor(item.first); | |||||
} | |||||
if (iteration > 0) { | if (iteration > 0) { | ||||
data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | ||||
ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | ||||
for (auto item : node_item_->enter_data_) { | |||||
UpdatePersistTensor(item.first); | |||||
} | |||||
} | } | ||||
iteration_count_ = iteration; | iteration_count_ = iteration; | ||||
@@ -132,6 +132,7 @@ struct NodeState { | |||||
void RunNextIteration(); | void RunNextIteration(); | ||||
void SavePersistTensor(int input_idx, const TensorValue &tensor); | void SavePersistTensor(int input_idx, const TensorValue &tensor); | ||||
void UpdatePersistTensor(); | |||||
Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | ||||
@@ -373,6 +373,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, | |||||
auto executor = node_item.node_executor; | auto executor = node_item.node_executor; | ||||
GE_CHECK_NOTNULL(executor); | GE_CHECK_NOTNULL(executor); | ||||
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); | RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); | ||||
node_state.UpdatePersistTensor(); | |||||
GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", | GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", | ||||
node_state.GetName().c_str()); | node_state.GetName().c_str()); | ||||
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); | RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); | ||||
@@ -288,10 +288,6 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
if (node->GetType() == MEMCPYASYNC) { // Convert MemcpyAsync to Identity. | |||||
node->GetOpDesc()->SetType(IDENTITY); | |||||
} | |||||
std::unique_ptr<NodeItem> new_node; | std::unique_ptr<NodeItem> new_node; | ||||
GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); | ||||
GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); | GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); | ||||
@@ -14,10 +14,8 @@ | |||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#include "node_item.h" | |||||
#include <sstream> | |||||
#include "common/debug/log.h" | |||||
#include "graph/common/omg_util.h" | |||||
#include "hybrid/model/node_item.h" | |||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "hybrid/executor/worker/shape_inference_engine.h" | #include "hybrid/executor/worker/shape_inference_engine.h" | ||||
@@ -26,6 +24,8 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
namespace { | namespace { | ||||
const uint8_t kMaxTransCount = 3; | |||||
const uint32_t kTransOpIoSize = 1; | |||||
const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | ||||
const char *const kNodeTypeRetVal = "_RetVal"; | const char *const kNodeTypeRetVal = "_RetVal"; | ||||
const std::set<std::string> kControlOpTypes{ | const std::set<std::string> kControlOpTypes{ | ||||
@@ -41,6 +41,25 @@ const std::set<std::string> kMergeOpTypes{ | |||||
MERGE, REFMERGE, STREAMMERGE | MERGE, REFMERGE, STREAMMERGE | ||||
}; | }; | ||||
bool IsEnterFeedNode(NodePtr node) { | |||||
// For: Enter -> node | |||||
// For: Enter -> Cast -> node | |||||
// For: Enter -> TransData -> Cast -> node | |||||
for (uint8_t i = 0; i < kMaxTransCount; ++i) { | |||||
if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { | |||||
GELOGD("Node[%s] is Enter feed node.", node->GetName().c_str()); | |||||
return true; | |||||
} | |||||
const auto all_nodes = node->GetInDataNodes(); | |||||
if (all_nodes.size() != kTransOpIoSize || node->GetAllInDataAnchorsSize() != kTransOpIoSize) { | |||||
return false; | |||||
} | |||||
node = all_nodes.at(0); | |||||
} | |||||
return false; | |||||
} | |||||
Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | ||||
uint32_t parent_index = 0; | uint32_t parent_index = 0; | ||||
if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | ||||
@@ -98,8 +117,7 @@ Status ParseFusedSubgraph(NodeItem &node_item) { | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
std::string node_type; | |||||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type)); | |||||
const std::string node_type = NodeUtils::GetNodeType(node); | |||||
if (node_type == DATA) { | if (node_type == DATA) { | ||||
GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph)); | GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph)); | ||||
} else if (node_type == kNodeTypeRetVal) { | } else if (node_type == kNodeTypeRetVal) { | ||||
@@ -398,19 +416,21 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||||
data_send_.emplace(node_item); | data_send_.emplace(node_item); | ||||
node_item->data_recv_[this] = anchor_index; | node_item->data_recv_[this] = anchor_index; | ||||
if (is_root_node_) { | if (is_root_node_) { | ||||
node_item->root_data_[anchor_index] = this; | |||||
auto &data_anchors = node_item->root_data_[this]; | |||||
data_anchors.emplace(anchor_index); | |||||
} | } | ||||
// If Enter feed Not Merge, take as root Node. | // If Enter feed Not Merge, take as root Node. | ||||
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | |||||
node_item->enter_data_[anchor_index] = this; | |||||
if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE)) { | |||||
auto &data_anchors = node_item->enter_data_[this]; | |||||
data_anchors.emplace(anchor_index); | |||||
} | } | ||||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
} | } | ||||
void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | ||||
if (switch_index < switch_groups_.size()) { | if (switch_index < switch_groups_.size()) { | ||||
std::vector<const NodeItem *> &switch_group = switch_groups_[switch_index]; | |||||
switch_group.emplace_back(node_item); | |||||
auto &switch_group = switch_groups_[switch_index]; | |||||
switch_group.emplace(node_item); | |||||
} else { | } else { | ||||
ctrl_send_.insert(node_item); | ctrl_send_.insert(node_item); | ||||
} | } | ||||
@@ -420,7 +440,7 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | |||||
node_item->root_ctrl_.emplace(this); | node_item->root_ctrl_.emplace(this); | ||||
} | } | ||||
// If Enter feed control signal, take as root Node. | // If Enter feed control signal, take as root Node. | ||||
if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { | |||||
if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { | |||||
node_item->enter_ctrl_.emplace(this); | node_item->enter_ctrl_.emplace(this); | ||||
} | } | ||||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
@@ -433,8 +453,8 @@ void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) { | |||||
} | } | ||||
// this is StreamMerge node, node_item is StreamActive node. | // this is StreamMerge node, node_item is StreamActive node. | ||||
std::vector<const NodeItem *> &switch_group = switch_groups_[merge_index]; | |||||
switch_group.emplace_back(node_item); | |||||
auto &switch_group = switch_groups_[merge_index]; | |||||
switch_group.emplace(node_item); | |||||
node_item->ctrl_send_.emplace(this); | node_item->ctrl_send_.emplace(this); | ||||
GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); | ||||
@@ -148,14 +148,14 @@ struct NodeItem { | |||||
int64_t frame_index_ = -1; | int64_t frame_index_ = -1; | ||||
int64_t parent_frame_ = -1; | int64_t parent_frame_ = -1; | ||||
std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | ||||
std::map<int, const NodeItem *> root_data_; // Recv data from root node | |||||
std::map<const NodeItem *, std::set<int>> root_data_; // Recv data from root node | |||||
std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | ||||
std::map<int, const NodeItem *> enter_data_; // Recv data from Enter node | |||||
std::map<const NodeItem *, std::set<int>> enter_data_; // Recv data from Enter node | |||||
std::set<const NodeItem *> data_send_; // Send data notify to | std::set<const NodeItem *> data_send_; // Send data notify to | ||||
std::map<const NodeItem *, int> data_recv_; // Recv data notify from | std::map<const NodeItem *, int> data_recv_; // Recv data notify from | ||||
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | ||||
std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | ||||
std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | |||||
std::vector<std::set<const NodeItem *>> switch_groups_; // Send ctrl notify to | |||||
std::shared_ptr<NodeTask> kernel_task; | std::shared_ptr<NodeTask> kernel_task; | ||||
std::unique_ptr<FusedSubgraph> fused_subgraph; | std::unique_ptr<FusedSubgraph> fused_subgraph; | ||||
@@ -342,6 +342,7 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
GE_CHK_RT_RET(rtEventDestroy(evt)); | GE_CHK_RT_RET(rtEventDestroy(evt)); | ||||
} | } | ||||
GELOGI("rdma callback success."); | GELOGI("rdma callback success."); | ||||
return SUCCESS; | |||||
}; | }; | ||||
HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); | HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); | ||||
@@ -17,13 +17,9 @@ | |||||
#include "hybrid/node_executor/rts/rts_node_executor.h" | #include "hybrid/node_executor/rts/rts_node_executor.h" | ||||
#include "hybrid/node_executor/rts/rts_task_factory.h" | #include "hybrid/node_executor/rts/rts_task_factory.h" | ||||
#include "common/debug/log.h" | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "common/types.h" | |||||
#include "graph/common/omg_util.h" | |||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "hybrid/model/hybrid_model.h" | #include "hybrid/model/hybrid_model.h" | ||||
#include "runtime/rt.h" | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
@@ -33,6 +29,7 @@ REGISTER_RTS_TASK_CREATOR(IDENTITY, IdentityNodeTask); | |||||
REGISTER_RTS_TASK_CREATOR(IDENTITYN, IdentityNNodeTask); | REGISTER_RTS_TASK_CREATOR(IDENTITYN, IdentityNNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(READVARIABLEOP, ReadVariableOpNodeTask); | REGISTER_RTS_TASK_CREATOR(READVARIABLEOP, ReadVariableOpNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(PROFILINGTRAININGTRACE, ProfilingTraceNodeTask); | REGISTER_RTS_TASK_CREATOR(PROFILINGTRAININGTRACE, ProfilingTraceNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, IdentityNodeTask); | |||||
Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { | Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { | ||||
auto input_desc = context.MutableInputDesc(index); | auto input_desc = context.MutableInputDesc(index); | ||||
@@ -133,8 +130,7 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function< | |||||
Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
GELOGD("[%s] Load for local task.", node->GetName().c_str()); | GELOGD("[%s] Load for local task.", node->GetName().c_str()); | ||||
std::string node_type; | |||||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); | |||||
const std::string node_type = NodeUtils::GetNodeType(node); | |||||
RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); | RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); | ||||
if (rts_task == nullptr) { | if (rts_task == nullptr) { | ||||
GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str()); | GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str()); | ||||
@@ -43,7 +43,6 @@ namespace hybrid { | |||||
REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask); | REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask); | REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask); | REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, MemcpyAsyncNodeTask); | |||||
REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask); | REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask); | REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask); | ||||
@@ -168,34 +167,6 @@ Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::functio | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MemcpyAsyncNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | |||||
GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | |||||
auto input_desc = task_context.MutableInputDesc(0); | |||||
GE_CHECK_NOTNULL(input_desc); | |||||
int64_t copy_size = 0; | |||||
GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*input_desc, copy_size)); | |||||
// copy_size would not be negative since GetTensorSizeInBytes returned successfully. | |||||
if (copy_size > 0) { | |||||
const auto in_v = task_context.MutableInput(0); | |||||
const auto out_v = task_context.MutableOutput(0); | |||||
GE_CHECK_NOTNULL(in_v); | |||||
GE_CHECK_NOTNULL(out_v); | |||||
GELOGD("[%s] input size: %zu, output size: %zu, copy size: %ld", task_context.GetNodeName(), | |||||
in_v->GetSize(), out_v->GetSize(), copy_size); | |||||
GE_CHK_RT_RET(rtMemcpyAsync(out_v->MutableData(), out_v->GetSize(), in_v->GetData(), copy_size, | |||||
RT_MEMCPY_DEVICE_TO_DEVICE, task_context.GetStream())); | |||||
} else { | |||||
GELOGW("[%s] invalid copy size: %ld", task_context.GetNodeName(), copy_size); | |||||
} | |||||
if (done_callback) { | |||||
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | |||||
} | |||||
GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); | |||||
return SUCCESS; | |||||
} | |||||
Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | ||||
GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | ||||
const auto in_x = task_context.GetInput(0); // x | const auto in_x = task_context.GetInput(0); // x | ||||
@@ -60,11 +60,6 @@ class StreamMergeNodeTask : public RtsNodeTask { | |||||
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | ||||
}; | }; | ||||
class MemcpyAsyncNodeTask : public RtsNodeTask { | |||||
public: | |||||
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | |||||
}; | |||||
class PassThroughNodeTask : public RtsNodeTask { | class PassThroughNodeTask : public RtsNodeTask { | ||||
public: | public: | ||||
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | ||||
@@ -458,10 +458,6 @@ Status TaskContext::PropagateOutputs() { | |||||
subgraph_context_->all_inputs_[input_offset].SetName( | subgraph_context_->all_inputs_[input_offset].SetName( | ||||
node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); | node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); | ||||
} | } | ||||
auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); | |||||
GE_CHECK_NOTNULL(dst_node_state); | |||||
dst_node_state->SavePersistTensor(dst_input_idx, *tensor); | |||||
} | } | ||||
} | } | ||||
(void)guard; | (void)guard; | ||||
@@ -493,6 +489,7 @@ void TaskContext::ReleaseInputsAndOutputs() { | |||||
void TaskContext::ReleaseInput(int index) { | void TaskContext::ReleaseInput(int index) { | ||||
auto input_tensor = MutableInput(index); | auto input_tensor = MutableInput(index); | ||||
if (input_tensor != nullptr) { | if (input_tensor != nullptr) { | ||||
node_state_->SavePersistTensor(index, *input_tensor); | |||||
input_tensor->Destroy(); | input_tensor->Destroy(); | ||||
GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); | GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); | ||||
} | } | ||||
@@ -345,6 +345,10 @@ INT32 mmIsDir(const CHAR *fileName) | |||||
INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len) | INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len) | ||||
{ | { | ||||
const char *env = getenv(name); | |||||
if (env != nullptr) { | |||||
strcpy(value, env); | |||||
} | |||||
return 0; | return 0; | ||||
} | } | ||||
@@ -726,7 +726,6 @@ set(PASS_TEST_FILES | |||||
"graph/passes/memcpy_addr_async_unittest.cc" | "graph/passes/memcpy_addr_async_unittest.cc" | ||||
"graph/passes/hccl_continuous_pass_unittest.cc" | "graph/passes/hccl_continuous_pass_unittest.cc" | ||||
"graph/passes/hccl_memcpy_pass_unittest.cc" | "graph/passes/hccl_memcpy_pass_unittest.cc" | ||||
) | ) | ||||
set(KERNEL_TEST_FILES | set(KERNEL_TEST_FILES | ||||
@@ -859,7 +858,6 @@ set(HYBRID_TEST_FILES | |||||
"hybrid/executor/hybrid_model_async_executor_unittest.cc" | "hybrid/executor/hybrid_model_async_executor_unittest.cc" | ||||
"hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | ||||
"hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | ||||
) | ) | ||||
set(OTHERS_TEST_FILES | set(OTHERS_TEST_FILES | ||||
@@ -887,6 +885,7 @@ add_library(ge_ut_graph STATIC | |||||
target_compile_definitions(ge_ut_graph PRIVATE | target_compile_definitions(ge_ut_graph PRIVATE | ||||
google=ascend_private | google=ascend_private | ||||
FMK_SUPPORT_DUMP | |||||
) | ) | ||||
target_compile_options(ge_ut_graph PRIVATE | target_compile_options(ge_ut_graph PRIVATE | ||||
@@ -349,7 +349,7 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||||
/// B --> C(AllReduce) --- D | /// B --> C(AllReduce) --- D | ||||
/// / | /// / | ||||
/// stream id: 0 A | /// stream id: 0 A | ||||
/// \ | |||||
/// \. | |||||
/// E --> F(AllReduce) --- G | /// E --> F(AllReduce) --- G | ||||
/// stream id: 2 2 2 | /// stream id: 2 2 2 | ||||
/// | /// | ||||
@@ -599,7 +599,7 @@ TEST_F(UtestLogicalStreamAllocator, test_label_not_reusable2) { | |||||
/// case of multi-output, then unuse stream | /// case of multi-output, then unuse stream | ||||
/// sub1 | /// sub1 | ||||
/// / | \ | |||||
/// / | \. | |||||
/// sub2 sub3 sub4 | /// sub2 sub3 sub4 | ||||
TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | ||||
SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
@@ -624,7 +624,7 @@ TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | |||||
/// if paralle id 1, then use stream | /// if paralle id 1, then use stream | ||||
/// sub1 | /// sub1 | ||||
/// / | | \ | |||||
/// / | | \. | |||||
/// sub2 sub3 sub4 sub5 | /// sub2 sub3 sub4 sub5 | ||||
TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | ||||
SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
@@ -653,7 +653,7 @@ TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | |||||
/// if the param of engine independent is true, then set independent stream | /// if the param of engine independent is true, then set independent stream | ||||
/// sub1 | /// sub1 | ||||
/// / | | \ | |||||
/// / | | \. | |||||
/// sub2 sub3 sub4 sub5 | /// sub2 sub3 sub4 sub5 | ||||
TEST_F(UtestLogicalStreamAllocator, test_independent) { | TEST_F(UtestLogicalStreamAllocator, test_independent) { | ||||
SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
@@ -692,7 +692,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent) { | |||||
/// set stream based on stream label, and then based on independent | /// set stream based on stream label, and then based on independent | ||||
/// sub1 | /// sub1 | ||||
/// / | | \ | |||||
/// / | | \. | |||||
/// sub2 sub3 sub4 sub5 | /// sub2 sub3 sub4 sub5 | ||||
TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { | TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { | ||||
SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
@@ -36,7 +36,7 @@ class UtestStreamAllocator : public testing::Test { | |||||
/// | /// | ||||
/// A | /// A | ||||
/// / \ | |||||
/// / \. | |||||
/// B C | /// B C | ||||
/// | | | /// | | | ||||
/// D 400 | /// D 400 | ||||
@@ -438,4 +438,22 @@ TEST_F(UtestModelManagerModelManager, test_data_input_tensor) { | |||||
auto ret = mm.DataInputTensor(model_id,inputs); | auto ret = mm.DataInputTensor(model_id,inputs); | ||||
EXPECT_EQ(PARAM_INVALID, ret); // HybridDavinciModel::impl_ is null. | EXPECT_EQ(PARAM_INVALID, ret); // HybridDavinciModel::impl_ is null. | ||||
} | } | ||||
TEST_F(UtestModelManagerModelManager, test_launch_kernel_cust_aicpu) { | |||||
ModelManager mm; | |||||
// cust_aicpu_so_ is empty. | |||||
EXPECT_EQ(mm.LaunchKernelCustAicpuSo("empty_cust_aicpu"), SUCCESS); | |||||
// deleteCustOp after Launch will deleted. | |||||
uintptr_t resource_id = 1; // for rtCtxGetCurrent stub | |||||
std::vector<char> kernel_bin(256); | |||||
auto &cust_resource_001 = mm.cust_aicpu_so_[resource_id]; | |||||
auto tbe_kernel = std::shared_ptr<OpKernelBin>(new OpKernelBin("deleteCustOp", std::move(kernel_bin))); | |||||
auto &cust_opkernel_001 = cust_resource_001["deleteCustOp"] = tbe_kernel; | |||||
EXPECT_FALSE(mm.cust_aicpu_so_.empty()); | |||||
EXPECT_EQ(mm.LaunchKernelCustAicpuSo("deleteCustOp"), SUCCESS); | |||||
EXPECT_TRUE(mm.cust_aicpu_so_.empty()); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -55,7 +55,7 @@ class UtestGraphPassesAssertPass : public Test { | |||||
}; | }; | ||||
/// D E | /// D E | ||||
/// | \ | \ | |||||
/// | \ | \. | |||||
/// F C G | /// F C G | ||||
/// : | : | /// : | : | ||||
/// H A I | /// H A I | ||||
@@ -134,8 +134,8 @@ TEST_F(UtestGraphPassesAssertPass, assert_pass_test2) { | |||||
EXPECT_EQ(graph->FindNode("D"), nullptr); | EXPECT_EQ(graph->FindNode("D"), nullptr); | ||||
} | } | ||||
/// E F | |||||
/// | \ | \ | |||||
/// E F | |||||
/// | \ | \. | |||||
/// H C -> D G | /// H C -> D G | ||||
/// \ | : | /// \ | : | ||||
/// A I | /// A I | ||||
@@ -130,7 +130,7 @@ class UTESTGraphPassesBasePass : public testing::Test { | |||||
/// reshape1 | /// reshape1 | ||||
/// | | /// | | ||||
/// add1 | /// add1 | ||||
/// / \ | |||||
/// / \. | |||||
/// | | | /// | | | ||||
/// data1 const1 | /// data1 const1 | ||||
ComputeGraphPtr BuildGraph1() { | ComputeGraphPtr BuildGraph1() { | ||||
@@ -148,9 +148,9 @@ ComputeGraphPtr BuildGraph1() { | |||||
} | } | ||||
/// sum1 | /// sum1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// / \. | |||||
/// reshape1 addn1 | /// reshape1 addn1 | ||||
/// | c | | /// | c | | ||||
/// add1 <--- shape1 | /// add1 <--- shape1 | ||||
@@ -217,7 +217,7 @@ void CheckIterOrder(UtestTestPass *pass, std::vector<std::unordered_set<std::str | |||||
/// Op1 | /// Op1 | ||||
/// | | /// | | ||||
/// Merge | /// Merge | ||||
/// / \ | |||||
/// / \. | |||||
/// Op2 Op3 | /// Op2 Op3 | ||||
TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | ||||
auto builder = ut::GraphBuilder("g1"); | auto builder = ut::GraphBuilder("g1"); | ||||
@@ -245,7 +245,7 @@ TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | |||||
/// Op1 | /// Op1 | ||||
/// | | /// | | ||||
/// Merge | /// Merge | ||||
/// / \ | |||||
/// / \. | |||||
/// Op2 Op3 | /// Op2 Op3 | ||||
TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { | TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { | ||||
auto builder = ut::GraphBuilder("g1"); | auto builder = ut::GraphBuilder("g1"); | ||||
@@ -459,7 +459,7 @@ TEST_F(UTESTGraphPassesBasePass, while_loop) { | |||||
/// data1 const | /// data1 const | ||||
/// \ / | /// \ / | ||||
/// while | /// while | ||||
/// / \ | |||||
/// / \. | |||||
/// | | | /// | | | ||||
/// cast1 cast2 | /// cast1 cast2 | ||||
ComputeGraphPtr BuildWhileGraph1() { | ComputeGraphPtr BuildWhileGraph1() { | ||||
@@ -34,11 +34,11 @@ namespace { | |||||
/// net_output | /// net_output | ||||
/// | | /// | | ||||
/// merge | /// merge | ||||
/// / \ | |||||
/// / \. | |||||
/// square add | /// square add | ||||
/// F| T/ T\ | |||||
/// F| T/ T\. | |||||
/// switch1 switch2 | /// switch1 switch2 | ||||
/// / \ / \ | |||||
/// / \ / \. | |||||
/// var1 var2 var3 | /// var1 var2 var3 | ||||
/// | /// | ||||
ComputeGraphPtr BuildGraph1() { | ComputeGraphPtr BuildGraph1() { | ||||
@@ -173,8 +173,8 @@ namespace { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnYes1 | /// addnYes1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph1() { | ComputeGraphPtr BuildGraph1() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -223,8 +223,8 @@ ComputeGraphPtr BuildGraph2() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | c | /// | c | ||||
/// addnYes1 <----- dataNo1 | /// addnYes1 <----- dataNo1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph3() { | ComputeGraphPtr BuildGraph3() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -249,8 +249,8 @@ ComputeGraphPtr BuildGraph3() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | c | /// | c | ||||
/// addnYes1 <--------- | /// addnYes1 <--------- | ||||
/// / \ \ | |||||
/// / \ c \ | |||||
/// / \ \. | |||||
/// / \ c \. | |||||
/// const1 const2 <----- dataNo1 | /// const1 const2 <----- dataNo1 | ||||
ComputeGraphPtr BuildGraph4() { | ComputeGraphPtr BuildGraph4() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -276,7 +276,7 @@ ComputeGraphPtr BuildGraph4() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | c | /// | c | ||||
/// addnYes1 <----- dataNo1 | /// addnYes1 <----- dataNo1 | ||||
/// / \ | |||||
/// / \. | |||||
/// / \ c | /// / \ c | ||||
/// const1 const2 <----- dataNo2 | /// const1 const2 <----- dataNo2 | ||||
ComputeGraphPtr BuildGraph5() { | ComputeGraphPtr BuildGraph5() { | ||||
@@ -306,8 +306,8 @@ ComputeGraphPtr BuildGraph5() { | |||||
/// addYes1 <---- const3 | /// addYes1 <---- const3 | ||||
/// | | /// | | ||||
/// addnYes1 <- | /// addnYes1 <- | ||||
/// / \ \ | |||||
/// / \ \ | |||||
/// / \ \. | |||||
/// / \ \. | |||||
/// const1 const2 const4 | /// const1 const2 const4 | ||||
ComputeGraphPtr BuildGraph6() { | ComputeGraphPtr BuildGraph6() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -332,12 +332,12 @@ ComputeGraphPtr BuildGraph6() { | |||||
} | } | ||||
/// netoutput1 | /// netoutput1 | ||||
/// / \ | |||||
/// / \. | |||||
/// shapeNo1 ShpaeNo2 | /// shapeNo1 ShpaeNo2 | ||||
/// \ / | /// \ / | ||||
/// huberLoss1 | /// huberLoss1 | ||||
/// / | \ | |||||
/// / | \ | |||||
/// / | \. | |||||
/// / | \. | |||||
/// const1 const2 const3 | /// const1 const2 const3 | ||||
ComputeGraphPtr BuildGraph7() { | ComputeGraphPtr BuildGraph7() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -365,8 +365,8 @@ ComputeGraphPtr BuildGraph7() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnNo1 | /// addnNo1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph8() { | ComputeGraphPtr BuildGraph8() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -389,8 +389,8 @@ ComputeGraphPtr BuildGraph8() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnYes1 | /// addnYes1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 data1 | /// const1 data1 | ||||
ComputeGraphPtr BuildGraph9() { | ComputeGraphPtr BuildGraph9() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -409,12 +409,12 @@ ComputeGraphPtr BuildGraph9() { | |||||
} | } | ||||
/// netoutput1 | /// netoutput1 | ||||
/// / \ | |||||
/// / \. | |||||
/// addDim sqrt1 | /// addDim sqrt1 | ||||
/// \ / | /// \ / | ||||
/// switch1 | /// switch1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph10() { | ComputeGraphPtr BuildGraph10() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -63,8 +63,8 @@ namespace { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnNo1 | /// addnNo1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph8() { | ComputeGraphPtr BuildGraph8() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -87,8 +87,8 @@ ComputeGraphPtr BuildGraph8() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnYes1 | /// addnYes1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
///const1 data1 | ///const1 data1 | ||||
ComputeGraphPtr BuildGraph9() { | ComputeGraphPtr BuildGraph9() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -46,7 +46,7 @@ class UtestGraphPassesFoldingKernelSsdPriorboxKernel : public testing::Test { | |||||
/// convolution data | /// convolution data | ||||
/// | / | /// | / | ||||
/// ssdpriorbox | /// ssdpriorbox | ||||
/// \ | |||||
/// \. | |||||
/// reshape | /// reshape | ||||
class NodeBuilder { | class NodeBuilder { | ||||
public: | public: | ||||
@@ -120,7 +120,7 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { | |||||
/// graph with subgraph | /// graph with subgraph | ||||
/// const | /// const | ||||
/// / \ | |||||
/// / \. | |||||
/// cast1 cast1 | /// cast1 cast1 | ||||
/// \ / | /// \ / | ||||
/// case | /// case | ||||
@@ -69,62 +69,100 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string | |||||
return graph.AddNode(op_desc); | return graph.AddNode(op_desc); | ||||
} | } | ||||
static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge) { | |||||
static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge, vector<NodePtr> &loop, vector<NodePtr> &cond) { | |||||
/******************************************************************************* | /******************************************************************************* | ||||
* Exit Identify | |||||
* \ / \. | |||||
* \ / \. | |||||
* Switch Add | |||||
* / | | | |||||
* / | | | |||||
* / | | | |||||
* LoopCond | | | |||||
* \ | | | |||||
* \ | | | |||||
* \ | | | |||||
* Less | | | |||||
* \ | NextIteration | |||||
* \ | | | |||||
* \ | | | |||||
* Merge <---------| | |||||
* | | |||||
* | | |||||
* Enter | |||||
* | | |||||
* +--------------------- Merge ----------------------+ | |||||
* / | | |||||
* / | | |||||
* / | | |||||
* / | | |||||
* Exit Identify | | |||||
* \ / \. | | |||||
* \ / \. | | |||||
* Switch Add Add | |||||
* / | | | | |||||
* / | | | | |||||
* / | | | | |||||
* LoopCond | | | | |||||
* \ | | | | |||||
* \ | | | | |||||
* \ | | | | |||||
* Less | | | | |||||
* \ | NextIteration | | |||||
* \ | | | | |||||
* \ | | | | |||||
* Merge <---------| | | |||||
* | | | |||||
* | | | |||||
* Enter | | |||||
* \ | | |||||
* \ | | |||||
* Switch Switch | |||||
* | | | |||||
* +-----------------Equal----------------------+ | |||||
* | | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
auto data1 = CreateNode(*graph, "data", DATA, 1, 1); | |||||
auto data1 = CreateNode(*graph, "data1", DATA, 1, 1); | |||||
auto data2 = CreateNode(*graph, "data2", DATA, 1, 1); | |||||
auto equal1 = CreateNode(*graph, "equal1", EQUAL, 2, 1); | |||||
auto switch1 = CreateNode(*graph, "switch1", SWITCH, 2, 2); | |||||
auto switch2 = CreateNode(*graph, "switch2", SWITCH, 2, 2); | |||||
auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | ||||
auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); | |||||
auto less1 = CreateNode(*graph, "less", LESS, 2, 1); | |||||
auto merge1 = CreateNode(*graph, "merge1", MERGE, 2, 2); | |||||
auto less1 = CreateNode(*graph, "less1", LESS, 2, 1); | |||||
auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); | auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); | ||||
auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); | |||||
auto switch3 = CreateNode(*graph, "switch3", SWITCH, 2, 2); | |||||
auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); | auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); | ||||
auto add1 = CreateNode(*graph, "add", ADD, 2, 1); | |||||
auto add1 = CreateNode(*graph, "add1", ADD, 2, 1); | |||||
auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | ||||
auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | ||||
auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
auto value1 = CreateNode(*graph, "const1", CONSTANT, 0, 1); | |||||
auto value2 = CreateNode(*graph, "const2", CONSTANT, 0, 1); | |||||
auto add2 = CreateNode(*graph, "add2", ADD, 2, 1); | |||||
auto merge2 = CreateNode(*graph, "merge2", MERGE, 2, 2); | |||||
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | ||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), equal1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), equal1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), switch2->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch2->GetInDataAnchor(1)); | |||||
cond.emplace_back(switch1); | |||||
cond.emplace_back(switch2); | |||||
GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); // false | |||||
GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | ||||
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); | GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch3->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch3->GetInDataAnchor(1)); | |||||
loop.emplace_back(merge1); | |||||
GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch3->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); // false | |||||
GraphUtils::AddEdge(switch3->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); // true | |||||
loop.emplace_back(switch3); | |||||
GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | ||||
GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); | GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | ||||
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
merge = merge1; | |||||
GraphUtils::AddEdge(switch2->GetOutDataAnchor(1), add2->GetInDataAnchor(1)); // true | |||||
GraphUtils::AddEdge(value2->GetOutDataAnchor(0), add2->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), merge2->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(add2->GetOutDataAnchor(0), merge2->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(merge2->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
cond.emplace_back(merge2); | |||||
merge = merge2; | |||||
} | } | ||||
static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { | static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { | ||||
@@ -197,12 +235,24 @@ static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { | |||||
TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { | TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { | ||||
auto graph = std::make_shared<ComputeGraph>("test_graph"); | auto graph = std::make_shared<ComputeGraph>("test_graph"); | ||||
NodePtr merge; | NodePtr merge; | ||||
CreateLoopGraph(graph, merge); | |||||
AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||||
vector<NodePtr> loop; | |||||
vector<NodePtr> cond; | |||||
CreateLoopGraph(graph, merge, loop, cond); | |||||
MarkForceUnknownForCondPass mark_force_unknown_pass; | MarkForceUnknownForCondPass mark_force_unknown_pass; | ||||
EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond | EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond | ||||
EXPECT_EQ(loop.size(), 2); | |||||
for (const auto &node : loop) { | |||||
EXPECT_FALSE(node->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)); | |||||
} | |||||
EXPECT_EQ(cond.size(), 3); | |||||
for (const auto &node : cond) { | |||||
int64_t group_index = -1; | |||||
EXPECT_TRUE(AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)); | |||||
EXPECT_EQ(group_index, merge->GetOpDesc()->GetId()); | |||||
} | |||||
} | } | ||||
TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) { | TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) { | ||||
@@ -110,8 +110,8 @@ TEST_F(UtestGraphPassesMergePass, multiple_inputs) { | |||||
} | } | ||||
/// Merge | /// Merge | ||||
/// | \ | |||||
/// | \ | |||||
/// | \. | |||||
/// | \. | |||||
/// Op1 Op2 Merge2 | /// Op1 Op2 Merge2 | ||||
/// \ | | | /// \ | | | ||||
/// \ | Op3 | /// \ | Op3 | ||||
@@ -137,10 +137,10 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_da | |||||
} | } | ||||
/// Merge | /// Merge | ||||
/// | \ | |||||
/// | \ | |||||
/// | \. | |||||
/// | \. | |||||
/// Op1 Op2 Merge2 | /// Op1 Op2 Merge2 | ||||
/// \ | | \ | |||||
/// \ | | \. | |||||
/// \ | Op3 | /// \ | Op3 | ||||
/// \ | : | /// \ | : | ||||
/// NetOutput | /// NetOutput | ||||
@@ -165,8 +165,8 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_co | |||||
TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { | TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { | ||||
/// Merge | /// Merge | ||||
/// | \ | |||||
/// | \ | |||||
/// | \. | |||||
/// | \. | |||||
/// Op1 Op2 Merge2 | /// Op1 Op2 Merge2 | ||||
/// \ | | | /// \ | | | ||||
/// \ | Op3 | /// \ | Op3 | ||||
@@ -210,7 +210,7 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { | |||||
/// Op1 Op2 Merge2 | /// Op1 Op2 Merge2 | ||||
/// \ | | /// \ | | ||||
/// \ Op3 | /// \ Op3 | ||||
/// \ | |||||
/// \. | |||||
/// Merge3 | /// Merge3 | ||||
ret = pass_.Run(merge_node2); | ret = pass_.Run(merge_node2); | ||||
@@ -224,7 +224,7 @@ TEST_F(UtestGraphPassesMergePass, single_non_const_input) { | |||||
/// Op1 | /// Op1 | ||||
/// | | /// | | ||||
/// Merge | /// Merge | ||||
/// / \ | |||||
/// / \. | |||||
/// Op2 Op3 | /// Op2 Op3 | ||||
auto merge_node = NewNode("Merge", MERGE, 1, 2); | auto merge_node = NewNode("Merge", MERGE, 1, 2); | ||||
auto node1 = NewNode("Op1", RELU, 1, 1); | auto node1 = NewNode("Op1", RELU, 1, 1); | ||||
@@ -253,7 +253,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input) { | |||||
/// Const | /// Const | ||||
/// | | /// | | ||||
/// Merge Pass Const | /// Merge Pass Const | ||||
/// / \ ===> / \ | |||||
/// / \ ===> / \. | |||||
/// Op1 Op2 Op1 Op2 | /// Op1 Op2 Op1 Op2 | ||||
auto merge_node = NewNode("Merge", MERGE, 1, 2); | auto merge_node = NewNode("Merge", MERGE, 1, 2); | ||||
auto const_node = NewNode("Const", CONSTANT, 1, 1); | auto const_node = NewNode("Const", CONSTANT, 1, 1); | ||||
@@ -284,7 +284,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes) | |||||
/// / | ===> / \(control anchor) | /// / | ===> / \(control anchor) | ||||
/// Op1 | \ Op1 Constant | /// Op1 | \ Op1 Constant | ||||
/// Op2 Op3 | | /// Op2 Op3 | | ||||
/// / \ | |||||
/// / \. | |||||
/// Op2 Op3 | /// Op2 Op3 | ||||
auto merge_node = NewNode("Merge", MERGE, 1, 2); | auto merge_node = NewNode("Merge", MERGE, 1, 2); | ||||
auto const_node = NewNode("Const", CONSTANT, 1, 1); | auto const_node = NewNode("Const", CONSTANT, 1, 1); | ||||
@@ -329,7 +329,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes1) | |||||
/// / | ===> / \(control anchor) | /// / | ===> / \(control anchor) | ||||
/// Op1 | \ Op1 Constant | /// Op1 | \ Op1 Constant | ||||
/// Op2 Op3 | | /// Op2 Op3 | | ||||
/// / \ | |||||
/// / \. | |||||
/// Op2 Op3 | /// Op2 Op3 | ||||
auto merge_node = NewNode("Merge", MERGE, 1, 2); | auto merge_node = NewNode("Merge", MERGE, 1, 2); | ||||
auto const_node = NewNode("Const", CONSTANT, 1, 1); | auto const_node = NewNode("Const", CONSTANT, 1, 1); | ||||
@@ -357,7 +357,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) { | |||||
/// C | /// C | ||||
/// | | /// | | ||||
/// Merge | /// Merge | ||||
/// / \ | |||||
/// / \. | |||||
/// Op1 Op2 | /// Op1 Op2 | ||||
auto switch_node = NewNode("Switch", SWITCH, 1, 2); | auto switch_node = NewNode("Switch", SWITCH, 1, 2); | ||||
auto identity_node = NewNode("Identity", SWITCH, 1, 1); | auto identity_node = NewNode("Identity", SWITCH, 1, 1); | ||||
@@ -381,7 +381,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) { | |||||
/// . | /// . | ||||
/// . | /// . | ||||
/// C | /// C | ||||
/// / \ | |||||
/// / \. | |||||
/// Op1 Op2 | /// Op1 Op2 | ||||
auto ret = pass_.Run(merge_node); | auto ret = pass_.Run(merge_node); | ||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
@@ -67,11 +67,11 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||||
void BuildDefaultGraph() { | void BuildDefaultGraph() { | ||||
/// input | /// input | ||||
/// \ | |||||
/// \. | |||||
/// sqrt pred | /// sqrt pred | ||||
/// \ / | /// \ / | ||||
/// cast | /// cast | ||||
/// / \ | |||||
/// / \. | |||||
/// switch_t switch_f | /// switch_t switch_f | ||||
/// | | | /// | | | ||||
/// F T | /// F T | ||||
@@ -119,13 +119,13 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||||
void BuildDefaultGraph1() { | void BuildDefaultGraph1() { | ||||
/// input | /// input | ||||
/// \ | |||||
/// \. | |||||
/// sqrt pred | /// sqrt pred | ||||
/// \ / | /// \ / | ||||
/// Switch | /// Switch | ||||
/// | | | /// | | | ||||
/// ----F T---- | /// ----F T---- | ||||
/// \ | / \ | |||||
/// \ | / \. | |||||
/// \ Merge1 Merge2 | /// \ Merge1 Merge2 | ||||
/// \_________| | /// \_________| | ||||
input_node_ = NewNode("input", RELU, 0, 1); | input_node_ = NewNode("input", RELU, 0, 1); | ||||
@@ -165,14 +165,14 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||||
void BuildDefaultGraph2() { | void BuildDefaultGraph2() { | ||||
/// input input1 | /// input input1 | ||||
/// \ \ | |||||
/// \ \. | |||||
/// sqrt pred sqrt1 pred1 | /// sqrt pred sqrt1 pred1 | ||||
/// \ / \ / | /// \ / \ / | ||||
/// Switch Switch1 | /// Switch Switch1 | ||||
/// | | _______| | /// | | _______| | ||||
/// | | / | /// | | / | ||||
/// ____F T____ | /// ____F T____ | ||||
/// \ | / \ | |||||
/// \ | / \. | |||||
/// \ Merge1 Merge2 | /// \ Merge1 Merge2 | ||||
/// \__________| | /// \__________| | ||||
input_node_ = NewNode("input", RELU, 0, 2); | input_node_ = NewNode("input", RELU, 0, 2); | ||||
@@ -31,9 +31,9 @@ class UtestReshapeRecoveryPass : public testing::Test { | |||||
namespace { | namespace { | ||||
/// netoutput1 | /// netoutput1 | ||||
/// | \ | |||||
///transdata1 \ | |||||
/// | \ | |||||
/// | \. | |||||
///transdata1 \. | |||||
/// | \. | |||||
/// | transdata2 | /// | transdata2 | ||||
/// | / | /// | / | ||||
/// var1 const1 | /// var1 const1 | ||||
@@ -35,7 +35,7 @@ namespace { | |||||
/// transdata1 | /// transdata1 | ||||
/// | | /// | | ||||
/// reshape1 | /// reshape1 | ||||
/// | \ | |||||
/// | \. | |||||
/// var1 const1 | /// var1 const1 | ||||
ut::GraphBuilder Graph1Builder() { | ut::GraphBuilder Graph1Builder() { | ||||
ut::GraphBuilder builder = ut::GraphBuilder("g1"); | ut::GraphBuilder builder = ut::GraphBuilder("g1"); | ||||
@@ -55,11 +55,11 @@ ut::GraphBuilder Graph1Builder() { | |||||
} | } | ||||
/// netoutput1 | /// netoutput1 | ||||
/// | \ | |||||
///transdata1 \ | |||||
/// | \ | |||||
/// | \. | |||||
///transdata1 \. | |||||
/// | \. | |||||
/// reshape1 reshape2 | /// reshape1 reshape2 | ||||
/// | \ / \ | |||||
/// | \ / \. | |||||
/// var1 const1 var2 | /// var1 const1 var2 | ||||
ut::GraphBuilder Graph2Builder() { | ut::GraphBuilder Graph2Builder() { | ||||
ut::GraphBuilder builder = ut::GraphBuilder("g2"); | ut::GraphBuilder builder = ut::GraphBuilder("g2"); | ||||
@@ -83,9 +83,9 @@ ut::GraphBuilder Graph2Builder() { | |||||
} | } | ||||
/// netoutput1 | /// netoutput1 | ||||
/// | \ | |||||
///transdata1 \ | |||||
/// | \ | |||||
/// | \. | |||||
///transdata1 \. | |||||
/// | \. | |||||
/// reshape1 transdata2 | /// reshape1 transdata2 | ||||
/// | \ / | /// | \ / | ||||
/// var1 const1 | /// var1 const1 | ||||
@@ -34,7 +34,7 @@ class UtestResourcePairControlPass : public testing::Test { | |||||
namespace { | namespace { | ||||
/// netoutput1 | /// netoutput1 | ||||
/// | \ | |||||
/// | \. | |||||
/// StackPush StackPop | /// StackPush StackPop | ||||
/// | | | /// | | | ||||
/// var1 const1 | /// var1 const1 | ||||
@@ -63,9 +63,9 @@ ComputeGraphPtr BuildGraph1() { | |||||
/// netoutput1 | /// netoutput1 | ||||
/// | | /// | | ||||
/// merge1 | /// merge1 | ||||
/// / \ | |||||
/// / \. | |||||
/// / add1 | /// / add1 | ||||
/// / F| \ | |||||
/// / F| \. | |||||
/// addn1 swtich2 var3 | /// addn1 swtich2 var3 | ||||
/// \F T/ | | /// \F T/ | | ||||
/// switch1 | | /// switch1 | | ||||
@@ -101,9 +101,9 @@ ComputeGraphPtr BuildGraph2() { | |||||
/// add1 | /// add1 | ||||
/// / \T | /// / \T | ||||
/// var3 swtich2 | /// var3 swtich2 | ||||
/// T/ \ | |||||
/// switch1 \ | |||||
/// / \ \ | |||||
/// T/ \. | |||||
/// switch1 \. | |||||
/// / \ \. | |||||
/// var1 var2 var4 | /// var1 var2 var4 | ||||
ComputeGraphPtr BuildGraph3() { | ComputeGraphPtr BuildGraph3() { | ||||
auto builder = ut::GraphBuilder("g3"); | auto builder = ut::GraphBuilder("g3"); | ||||
@@ -129,7 +129,7 @@ ComputeGraphPtr BuildGraph3() { | |||||
/// netoutput1 | /// netoutput1 | ||||
/// | | /// | | ||||
/// merge1 | /// merge1 | ||||
/// / \ | |||||
/// / \. | |||||
/// add1 addn1 | /// add1 addn1 | ||||
/// / \T F/ | /// / \T F/ | ||||
/// var3 swtich2 | /// var3 swtich2 | ||||
@@ -402,7 +402,7 @@ TEST_F(UtestGraphPassesTransOpBreadthFusionPass, test_multi_anchor_case) { | |||||
} | } | ||||
/// ----> netoutput1 | /// ----> netoutput1 | ||||
/// / | \ | |||||
/// / | \. | |||||
/// transdata1 transdata2 transdata3 | /// transdata1 transdata2 transdata3 | ||||
/// \ / | | /// \ / | | ||||
/// var1-------------- | /// var1-------------- | ||||
@@ -432,7 +432,7 @@ static ComputeGraphPtr BuildGraph1() { | |||||
} | } | ||||
/// ---------> netoutput1 | /// ---------> netoutput1 | ||||
/// / | \ | |||||
/// / | \. | |||||
/// transdata1 transdata2(l1) transdata3(l1) | /// transdata1 transdata2(l1) transdata3(l1) | ||||
/// \ / | | /// \ / | | ||||
/// var1------------------ | /// var1------------------ | ||||
@@ -456,19 +456,19 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge) | |||||
/// -->transpose1 -->transpose3-->sinh2 | /// -->transpose1 -->transpose3-->sinh2 | ||||
/// | \ / | /// | \ / | ||||
/// | -->transpose2 | /// | -->transpose2 | ||||
/// | \ | |||||
/// | \. | |||||
/// / -->cast3-->cast4-->sinh3 | /// / -->cast3-->cast4-->sinh3 | ||||
/// / | /// / | ||||
/// / -->transpose4-->transpose5-->sinh4 | /// / -->transpose4-->transpose5-->sinh4 | ||||
/// / / | /// / / | ||||
/// Node4D-->Cast1-->Cast2-->Cast5 -->reshape2-->sinh5 | /// Node4D-->Cast1-->Cast2-->Cast5 -->reshape2-->sinh5 | ||||
/// \ \ | |||||
/// \ \. | |||||
/// \ -->sinh6 | /// \ -->sinh6 | ||||
/// \ | |||||
/// \. | |||||
/// \ -->transpose6-->transpose7-->sinh9 | /// \ -->transpose6-->transpose7-->sinh9 | ||||
/// \ / | /// \ / | ||||
/// -->reshape-->cast6-->cast7-->sinh8 | /// -->reshape-->cast6-->cast7-->sinh8 | ||||
/// \ | |||||
/// \. | |||||
/// -->sinh7 | /// -->sinh7 | ||||
/// after optimized graph | /// after optimized graph | ||||
@@ -479,15 +479,15 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge) | |||||
/// / /-->transpose3-->sinh2 | /// / /-->transpose3-->sinh2 | ||||
/// -->Cast1 | /// -->Cast1 | ||||
/// / \-->sinh7 | /// / \-->sinh7 | ||||
/// / \ | |||||
/// / \. | |||||
/// / -->sinh9 | /// / -->sinh9 | ||||
/// Node4D | /// Node4D | ||||
/// \ -->sinh4 | /// \ -->sinh4 | ||||
/// \ / | /// \ / | ||||
/// -->Cast5-->sinh5 | /// -->Cast5-->sinh5 | ||||
/// \ \ | |||||
/// \ \. | |||||
/// \ -->sinh6 | /// \ -->sinh6 | ||||
/// \ | |||||
/// \. | |||||
/// -->Cast7-->sinh8 | /// -->Cast7-->sinh8 | ||||
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ||||
@@ -180,7 +180,7 @@ ComputeGraphPtr GetGraph7(size_t symmetric_transdata_num, size_t asymmetric_tran | |||||
/// TransData TransData ... MatMul ... | /// TransData TransData ... MatMul ... | ||||
/// \ | / / / | /// \ | / / / | ||||
/// HcomAllReduce | /// HcomAllReduce | ||||
/// / | \ \ \ | |||||
/// / | \ \ \. | |||||
/// TransData TransData ... RealDiv ... | /// TransData TransData ... RealDiv ... | ||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ||||
NodePtr allreduce = | NodePtr allreduce = | ||||
@@ -340,7 +340,7 @@ TEST(UtestTransopNearbyAllreduceFusionPass, test7_all_reduce_with_multiple_trans | |||||
/// TransData TransData ... MatMul ... | /// TransData TransData ... MatMul ... | ||||
/// \ | / / / | /// \ | / / / | ||||
/// HcomAllReduce | /// HcomAllReduce | ||||
/// / | \ \ \ | |||||
/// / | \ \ \. | |||||
/// TransData TransData ... RealDiv ... | /// TransData TransData ... RealDiv ... | ||||
size_t symmetric_transdata_num = 20; | size_t symmetric_transdata_num = 20; | ||||
size_t asymmetric_transdata_num = 20; | size_t asymmetric_transdata_num = 20; | ||||
@@ -66,7 +66,7 @@ namespace { | |||||
/// transdata2 | /// transdata2 | ||||
/// | | /// | | ||||
/// assign1 | /// assign1 | ||||
/// / \ | |||||
/// / \. | |||||
/// transdata1 | | /// transdata1 | | ||||
/// | | | /// | | | ||||
/// var1 const1 | /// var1 const1 | ||||
@@ -35,8 +35,8 @@ namespace { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnYes1 | /// addnYes1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph1() { | ComputeGraphPtr BuildGraph1() { | ||||
@@ -57,9 +57,9 @@ ComputeGraphPtr BuildGraph1() { | |||||
/// | /// | ||||
/// netoutput1 | /// netoutput1 | ||||
/// / \ \ | |||||
/// add1 assign1 \ | |||||
/// / \ / \ \ | |||||
/// / \ \. | |||||
/// add1 assign1 \. | |||||
/// / \ / \ \. | |||||
/// var1 var2 const1 var3 | /// var1 var2 const1 var3 | ||||
ComputeGraphPtr BuildGraph2() { | ComputeGraphPtr BuildGraph2() { | ||||