Merge pull request !1938 from mindspore_ding/code_sync_0705tags/v1.3.0
@@ -161,6 +161,7 @@ Status ExceptionDumper::DumpExceptionInfo(const std::vector<rtExceptionInfo> &ex | |||||
uint64_t proto_size = dump_data.ByteSizeLong(); | uint64_t proto_size = dump_data.ByteSizeLong(); | ||||
std::unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]); | std::unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]); | ||||
GE_CHECK_NOTNULL(proto_msg); | |||||
bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); | bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); | ||||
if (!ret || proto_size == 0) { | if (!ret || proto_size == 0) { | ||||
REPORT_INNER_ERROR("E19999", "Serialize proto to string fail"); | REPORT_INNER_ERROR("E19999", "Serialize proto to string fail"); | ||||
@@ -22,6 +22,7 @@ | |||||
#include "graph/load/graph_loader.h" | #include "graph/load/graph_loader.h" | ||||
#include "init/gelib.h" | #include "init/gelib.h" | ||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "model/ge_model.h" | |||||
namespace { | namespace { | ||||
const uint32_t kDeviceListIndex = 3; | const uint32_t kDeviceListIndex = 3; | ||||
@@ -42,6 +43,10 @@ const std::map<ProfCommandHandleType, std::string> kProfCommandTypeMap = { | |||||
{kProfCommandhandleFinalize, kProfilingFinalize}, | {kProfCommandhandleFinalize, kProfilingFinalize}, | ||||
{kProfCommandhandleModelSubscribe, kProfModelSubscribe}, | {kProfCommandhandleModelSubscribe, kProfModelSubscribe}, | ||||
{kProfCommandhandleModelUnsubscribe, kProfModelUnsubscribe}}; | {kProfCommandhandleModelUnsubscribe, kProfModelUnsubscribe}}; | ||||
const uint64_t kModelId = ge::INVALID_MODEL_ID; | |||||
const uint16_t kStepStart = 0; | |||||
const uint16_t kStepEnd = 1; | |||||
} // namespace | } // namespace | ||||
bool TransProfConfigToParam(const ProfCommandHandleData &profCommand, vector<string> &prof_config_params) { | bool TransProfConfigToParam(const ProfCommandHandleData &profCommand, vector<string> &prof_config_params) { | ||||
@@ -216,6 +221,36 @@ ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t le | |||||
return ge::SUCCESS; | return ge::SUCCESS; | ||||
} | } | ||||
GE_FUNC_VISIBILITY ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream) { | |||||
return ge::SUCCESS; | |||||
ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream) { | |||||
static bool is_first_run = true; | |||||
int32_t device_id = 0; | |||||
rtError_t rt_ret = rtGetDevice(&device_id); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(rt_ret, "[Get][LogicDeviceId]Failed, ret 0x%X", rt_ret); | |||||
REPORT_CALL_ERROR("E19999", "Get logic device id failed, ret 0x%X", rt_ret); | |||||
return ge::FAILED; | |||||
} | |||||
if (is_first_run && tag_id == kStepStart) { | |||||
GE_CHK_STATUS_RET_NOLOG(ge::ProfilingManager::Instance().ProfileStepInfo(index_id, | |||||
kModelId, | |||||
tag_id, | |||||
stream, | |||||
device_id)); | |||||
is_first_run = false; | |||||
return ge::SUCCESS; | |||||
} | |||||
if (!is_first_run && tag_id == kStepEnd) { | |||||
GE_CHK_STATUS_RET_NOLOG(ge::ProfilingManager::Instance().ProfileStepInfo(index_id, | |||||
kModelId, | |||||
tag_id, | |||||
stream, | |||||
device_id)); | |||||
is_first_run = true; | |||||
return ge::SUCCESS; | |||||
} | |||||
GELOGE(ge::FAILED, "Param tag_id:%u invalid when is_first_run is %d", tag_id, is_first_run); | |||||
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"value", "parameter", "reason"}), | |||||
std::vector<std::string>({std::to_string(tag_id), "tag_id", | |||||
"tag id must be 0 when first run, must be 1 when second run"})); | |||||
return ge::FAILED; | |||||
} | } |
@@ -13,15 +13,15 @@ | |||||
* See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#include "host_cpu_engine.h" | |||||
#include "graph/common/omg_util.h" | |||||
#include "ge_local_engine/engine/host_cpu_engine.h" | |||||
#include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
#include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
#include "graph/utils/node_utils.h" | |||||
#include "graph/utils/type_utils.h" | |||||
#include "register/op_kernel_registry.h" | #include "register/op_kernel_registry.h" | ||||
#include "register/host_cpu_context.h" | #include "register/host_cpu_context.h" | ||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "common/ge/plugin_manager.h" | #include "common/ge/plugin_manager.h" | ||||
#include "graph/utils/type_utils.h" | |||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
#include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
@@ -123,10 +123,7 @@ bool HostCpuEngine::CheckSupported(const string &op_type) { | |||||
} | } | ||||
Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) { | Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) { | ||||
std::string op_type; | |||||
auto status = GetOriginalType(node, op_type); | |||||
GE_CHK_BOOL_EXEC_NOLOG(status == SUCCESS, return status); | |||||
const std::string op_type = NodeUtils::GetNodeType(node); | |||||
auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type); | auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type); | ||||
if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str()); | GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str()); | ||||
@@ -85,7 +85,7 @@ bool LabelGotoTask::Distribute() { | |||||
return false; | return false; | ||||
} | } | ||||
rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size); | |||||
rt_ret = rtLabelListCpy(reinterpret_cast<void**>(label_list.data()), label_list.size(), label_info_, label_info_size); | |||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | ||||
return false; | return false; | ||||
@@ -707,7 +707,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { | |||||
if (!kernel_name.empty() && (kernel_buffer.GetSize() > 0)) { | if (!kernel_name.empty() && (kernel_buffer.GetSize() > 0)) { | ||||
GE_CHECK_NOTNULL(kernel_buffer.GetData()); | GE_CHECK_NOTNULL(kernel_buffer.GetData()); | ||||
std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); | std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); | ||||
tbe_kernel = std::make_shared<OpKernelBin>(kernel_name, std::move(data)); | |||||
tbe_kernel = MakeShared<OpKernelBin>(kernel_name, std::move(data)); | |||||
GE_CHECK_NOTNULL(tbe_kernel); | GE_CHECK_NOTNULL(tbe_kernel); | ||||
GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(), | GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(), | ||||
node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); | node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); | ||||
@@ -793,7 +793,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP | |||||
GELOGI("Start AutoFindBpOpIndex"); | GELOGI("Start AutoFindBpOpIndex"); | ||||
NodePtr bp_node = nullptr; | NodePtr bp_node = nullptr; | ||||
uint32_t current_idx = 0; | uint32_t current_idx = 0; | ||||
uint32_t netoutput_idx = 0; | |||||
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { | for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { | ||||
OpDescPtr op_desc = node->GetOpDesc(); | OpDescPtr op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
@@ -811,7 +810,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP | |||||
if (op_desc->GetName() == NODE_NAME_NET_OUTPUT) { | if (op_desc->GetName() == NODE_NAME_NET_OUTPUT) { | ||||
if (bp_node == nullptr) { | if (bp_node == nullptr) { | ||||
bp_node = node; | bp_node = node; | ||||
netoutput_idx = current_idx - 1; | |||||
} | } | ||||
} | } | ||||
if (graph->GetNeedIteration()) { | if (graph->GetNeedIteration()) { | ||||
@@ -836,34 +834,30 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP | |||||
if (bp_node == nullptr) { | if (bp_node == nullptr) { | ||||
GELOGW("not find bp_node."); | GELOGW("not find bp_node."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} else if (bp_node->GetName() == NODE_NAME_NET_OUTPUT) { | |||||
profiling_point.bp_index = netoutput_idx; | |||||
GELOGI("First bp name %s, idx %u", bp_node->GetName().c_str(), netoutput_idx); | |||||
} else { | |||||
profiling_point.bp_index = FindLastBpFromBpNode(graph, bp_node); | |||||
} | } | ||||
return SUCCESS; | |||||
return FindLastBpFromBpNode(graph, bp_node, profiling_point.bp_index); | |||||
} | } | ||||
uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const { | |||||
uint32_t last_bp = 0; | |||||
Status TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &target_node, | |||||
uint32_t &bp_index) const { | |||||
bp_index = 0; | |||||
auto target_desc = target_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(target_desc); | |||||
OpDescPtr bp_op_desc = nullptr; | OpDescPtr bp_op_desc = nullptr; | ||||
for (auto &in_anchor : bp_node->GetAllInDataAnchors()) { | |||||
auto out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) { | |||||
continue; | |||||
} | |||||
auto out_node_desc = out_anchor->GetOwnerNode()->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(out_node_desc); | |||||
if (bp_op_desc == nullptr || ((out_node_desc->GetId()) > (bp_op_desc->GetId()))) { | |||||
bp_op_desc = out_node_desc; | |||||
for (auto &in_node : target_node->GetInAllNodes()) { | |||||
GE_CHECK_NOTNULL(in_node); | |||||
auto in_node_desc = in_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(in_node_desc); | |||||
if ((bp_op_desc == nullptr || (in_node_desc->GetId() > bp_op_desc->GetId())) && | |||||
(in_node_desc->GetStreamId() == target_desc->GetStreamId())){ | |||||
bp_op_desc = in_node_desc; | |||||
} | } | ||||
GELOGI("bp_op_desc is %s, id is %ld", bp_op_desc->GetName().c_str(), bp_op_desc->GetId()); | |||||
} | } | ||||
if (bp_op_desc == nullptr) { | if (bp_op_desc == nullptr) { | ||||
return last_bp; | |||||
GELOGI("Did not find bp node."); | |||||
return SUCCESS; | |||||
} | } | ||||
uint32_t current_idx = 0; | uint32_t current_idx = 0; | ||||
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { | for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { | ||||
@@ -871,12 +865,14 @@ uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const | |||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
current_idx++; | current_idx++; | ||||
if (op_desc->GetName() == bp_op_desc->GetName()) { | if (op_desc->GetName() == bp_op_desc->GetName()) { | ||||
last_bp = current_idx; | |||||
GELOGI("First bp name %s, idx %u", op_desc->GetName().c_str(), last_bp); | |||||
bp_index = current_idx; | |||||
GELOGI("Find bp name %s, idx %u", op_desc->GetName().c_str(), bp_index); | |||||
break; | break; | ||||
} | } | ||||
} | } | ||||
return last_bp; | |||||
GELOGI("Last bp node[%s], type[%s], index[%u], stream id[%ld]", bp_op_desc->GetName().c_str(), | |||||
bp_op_desc->GetType().c_str(), bp_index, bp_op_desc->GetStreamId()); | |||||
return SUCCESS; | |||||
} | } | ||||
Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, | Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, | ||||
@@ -116,7 +116,7 @@ class TaskGenerator { | |||||
Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const; | Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const; | ||||
Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, | Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, | ||||
vector<uint32_t> &all_reduce_nodes) const; | vector<uint32_t> &all_reduce_nodes) const; | ||||
uint32_t FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const; | |||||
Status FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node, uint32_t &bp_index) const; | |||||
Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, | Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, | ||||
ProfilingPoint &profiling_point) const; | ProfilingPoint &profiling_point) const; | ||||
@@ -1378,7 +1378,9 @@ Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_ | |||||
Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | ||||
GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str()); | GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str()); | ||||
std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); | std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); | ||||
if (cust_aicpu_so_.size() == 0) return SUCCESS; | |||||
if (cust_aicpu_so_.empty()) { | |||||
return SUCCESS; | |||||
} | |||||
// get current context | // get current context | ||||
rtContext_t rt_cur_ctx = nullptr; | rtContext_t rt_cur_ctx = nullptr; | ||||
auto rt_error = rtCtxGetCurrent(&rt_cur_ctx); | auto rt_error = rtCtxGetCurrent(&rt_cur_ctx); | ||||
@@ -1394,9 +1396,19 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
rtStream_t stream = nullptr; | |||||
vector<void *> allocated_mem; | vector<void *> allocated_mem; | ||||
std::function<void()> callback = [&]() { | |||||
for (auto mem : allocated_mem) { | |||||
GE_CHK_RT(rtFree(mem)); | |||||
} | |||||
if (stream != nullptr) { | |||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
} | |||||
}; | |||||
GE_MAKE_GUARD(release, callback); | |||||
rtError_t status; | rtError_t status; | ||||
rtStream_t stream = nullptr; | |||||
vector<CustAicpuSoBuf> v_cust_so; | vector<CustAicpuSoBuf> v_cust_so; | ||||
void *args = nullptr; | void *args = nullptr; | ||||
@@ -1471,13 +1483,6 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | |||||
GELOGE(RT_FAILED, "[Call][RtStreamSynchronize] fail, ret = 0x%X", status); | GELOGE(RT_FAILED, "[Call][RtStreamSynchronize] fail, ret = 0x%X", status); | ||||
return RT_ERROR_TO_GE_STATUS(status); | return RT_ERROR_TO_GE_STATUS(status); | ||||
} | } | ||||
std::function<void()> callback = [&]() { | |||||
for (auto mem : allocated_mem) { | |||||
GE_CHK_RT(rtFree(mem)); | |||||
} | |||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
}; | |||||
GE_MAKE_GUARD(release, callback); | |||||
GELOGI("Cpu kernel launch task success."); | GELOGI("Cpu kernel launch task success."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -645,6 +645,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne | |||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]); | args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]); | ||||
GE_CHECK_NOTNULL(args_addr); | |||||
errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); | errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); | ||||
if (sec_ret != EOK) { | if (sec_ret != EOK) { | ||||
REPORT_CALL_ERROR("E19999", "Call memcpy_s fail, size:%u, ret:0x%X", args_size_, sec_ret); | REPORT_CALL_ERROR("E19999", "Call memcpy_s fail, size:%u, ret:0x%X", args_size_, sec_ret); | ||||
@@ -1000,6 +1001,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||||
// copy args to new host memory | // copy args to new host memory | ||||
args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]); | args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]); | ||||
GE_CHECK_NOTNULL(args_addr); | |||||
GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) | GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) | ||||
errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); | errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); | ||||
if (sec_ret != EOK) { | if (sec_ret != EOK) { | ||||
@@ -3139,10 +3139,10 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||||
} | } | ||||
// Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency | // Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency | ||||
if (count > 1 && graph_node->GetBuildFlag()) { | if (count > 1 && graph_node->GetBuildFlag()) { | ||||
graph_node->Lock(); | |||||
GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id); | GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id); | ||||
// In online inference concurrency senario, graph_node is allowed to be locked for 'count' times | // In online inference concurrency senario, graph_node is allowed to be locked for 'count' times | ||||
graph_node->SetSemSize(count); | graph_node->SetSemSize(count); | ||||
graph_node->Lock(); | |||||
graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, | graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, | ||||
args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback })); | args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback })); | ||||
GELOGI("[PreRunThread] Loop end. Start to run with cached build model."); | GELOGI("[PreRunThread] Loop end. Start to run with cached build model."); | ||||
@@ -284,9 +284,6 @@ Status DynamicShapePartitioner::InitClusters() { | |||||
auto cluster = MakeShared<Cluster>(rank++, type, node, this); | auto cluster = MakeShared<Cluster>(rank++, type, node, this); | ||||
REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed."); | REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed."); | ||||
node_2_cluster_[node] = cluster; | node_2_cluster_[node] = cluster; | ||||
if (cluster->IsUnknownShape()) { | |||||
ordered_cluster_.push_back(cluster); | |||||
} | |||||
int64_t group_index = -1; | int64_t group_index = -1; | ||||
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | ||||
@@ -306,7 +303,7 @@ Status DynamicShapePartitioner::InitClusters() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status DynamicShapePartitioner::TopologicalSortClusters() { | |||||
Status DynamicShapePartitioner::TopologicalSortClusters(const OrderedFilter &ordered_filter) { | |||||
ordered_cluster_.clear(); | ordered_cluster_.clear(); | ||||
// BFS topological sort clusters for known shape cluster | // BFS topological sort clusters for known shape cluster | ||||
std::queue<ClusterPtr> ready_clusters; | std::queue<ClusterPtr> ready_clusters; | ||||
@@ -331,7 +328,7 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { | |||||
auto cluster = ready_clusters.front(); | auto cluster = ready_clusters.front(); | ||||
ready_clusters.pop(); | ready_clusters.pop(); | ||||
cluster->UpdateRank(rank++); | cluster->UpdateRank(rank++); | ||||
if (cluster->IsKnownShape() || cluster->IsInputNode()) { | |||||
if (ordered_filter == nullptr || ordered_filter(cluster)) { | |||||
ordered_cluster_.push_back(cluster); | ordered_cluster_.push_back(cluster); | ||||
} | } | ||||
for (const auto &out_cluster : cluster->Outputs()) { | for (const auto &out_cluster : cluster->Outputs()) { | ||||
@@ -378,7 +375,6 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { | |||||
continue; | continue; | ||||
} | } | ||||
bool is_unknown_cluster = cluster->IsUnknownShape(); | |||||
for (++rit; rit != control_cluster.rend(); ++rit) { | for (++rit; rit != control_cluster.rend(); ++rit) { | ||||
const auto &cluster_from = *rit; | const auto &cluster_from = *rit; | ||||
if (all_merged_clusters.count(cluster_from) > 0) { | if (all_merged_clusters.count(cluster_from) > 0) { | ||||
@@ -395,11 +391,6 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
if (!is_unknown_cluster && cluster->IsUnknownShape()) { | |||||
GELOGD("Add to ordered cluster: %s", cluster->DebugString().c_str()); | |||||
ordered_cluster_.push_back(cluster); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -475,9 +466,19 @@ void DynamicShapePartitioner::MergeClustersInputData() { | |||||
} | } | ||||
Status DynamicShapePartitioner::MergeClusters() { | Status DynamicShapePartitioner::MergeClusters() { | ||||
const auto filter_known = [](const ClusterPtr &cluster) { | |||||
return cluster->IsKnownShape() || cluster->IsInputNode(); | |||||
}; | |||||
const auto filter_unknown = [](const ClusterPtr &cluster) { | |||||
return cluster->IsUnknownShape(); | |||||
}; | |||||
MergeClustersControlFlow(); | MergeClustersControlFlow(); | ||||
REQUIRE_SUCCESS(TopologicalSortClusters(filter_unknown), | |||||
"[TopologicalSort][Clusters] after merge control flow clusters failed."); | |||||
MergeClustersUnknownShape(); | MergeClustersUnknownShape(); | ||||
REQUIRE_SUCCESS(TopologicalSortClusters(), "[TopologicalSort][Clusters] after merge unknown shape clusters failed."); | |||||
REQUIRE_SUCCESS(TopologicalSortClusters(filter_known), | |||||
"[TopologicalSort][Clusters] after merge unknown shape clusters failed."); | |||||
MergeClustersKnownShape(); | MergeClustersKnownShape(); | ||||
MergeClustersInputData(); | MergeClustersInputData(); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -111,6 +111,8 @@ class DynamicShapePartitioner { | |||||
Status Partition(); | Status Partition(); | ||||
using OrderedFilter = std::function<bool(const std::shared_ptr<Cluster> &cluster)>; | |||||
private: | private: | ||||
Status PartitionImpl(); | Status PartitionImpl(); | ||||
// Collect nodes that satisfy the unknowshape rules: | // Collect nodes that satisfy the unknowshape rules: | ||||
@@ -138,7 +140,7 @@ class DynamicShapePartitioner { | |||||
// Merge clusters step3 | // Merge clusters step3 | ||||
void MergeClustersInputData(); | void MergeClustersInputData(); | ||||
// Topological sort clusters after merge unknown shape clusters. | // Topological sort clusters after merge unknown shape clusters. | ||||
Status TopologicalSortClusters(); | |||||
Status TopologicalSortClusters(const OrderedFilter &ordered_filter); | |||||
// Deduplicate merged clusters | // Deduplicate merged clusters | ||||
void PruneUniqueClusters(); | void PruneUniqueClusters(); | ||||
// Establish the input-output anchors for each partition of the cluster and record links to other clusters | // Establish the input-output anchors for each partition of the cluster and record links to other clusters | ||||
@@ -16,8 +16,6 @@ | |||||
#include "mark_force_unknown_for_cond_pass.h" | #include "mark_force_unknown_for_cond_pass.h" | ||||
#include <queue> | |||||
#include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
#include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
@@ -26,17 +24,7 @@ namespace { | |||||
inline bool IsMergeInLoop(const NodePtr &node) { | inline bool IsMergeInLoop(const NodePtr &node) { | ||||
const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | ||||
std::string node_type; | |||||
(void)GetOriginalType(node, node_type); | |||||
return kLoopMergeInputs.count(node_type) > 0; | |||||
} | |||||
inline bool IsSwitchInLoop(const NodePtr &node) { | |||||
const static std::set<std::string> kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND }; | |||||
std::string node_type; | |||||
(void)GetOriginalType(node, node_type); | |||||
return kLoopSwitchInputs.count(node_type) > 0; | |||||
return kLoopMergeInputs.count(NodeUtils::GetNodeType(node)) > 0; | |||||
} | } | ||||
} | } | ||||
@@ -44,10 +32,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||||
GELOGD("MarkForceUnknownForCondPass Enter"); | GELOGD("MarkForceUnknownForCondPass Enter"); | ||||
std::map<NodePtr, std::vector<NodePtr>> switch_groups; | std::map<NodePtr, std::vector<NodePtr>> switch_groups; | ||||
for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
std::string node_type; | |||||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), | |||||
"[Get][OriginalType] of node in graph:%s failed.", graph->GetName().c_str()); | |||||
if (kMergeOpTypes.count(node_type) == 0) { | |||||
if (kMergeOpTypes.count(NodeUtils::GetNodeType(node)) == 0) { | |||||
continue; | continue; | ||||
} | } | ||||
@@ -65,6 +50,51 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||||
} | } | ||||
/// | /// | ||||
/// @brief Deal with Switch node for LoopCond | |||||
/// @param [in] Switch node | |||||
/// @param [in] dest span | |||||
/// @param [out] Search queue | |||||
/// @return true: Switch In while loop / false: Not in while Loop. | |||||
/// | |||||
bool MarkForceUnknownForCondPass::DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, | |||||
std::queue<std::pair<NodePtr, uint32_t>> &search_queue) { | |||||
/// LoopCond --->\. | |||||
/// \. | |||||
/// Enter-----------+ \. | |||||
/// +--> Merge --> Switch --> Exit | |||||
/// NextIteration---+ | |||||
const auto is_loop_op = [](const NodePtr &n) { | |||||
return NodeUtils::GetNodeType(n) == LOOPCOND; | |||||
}; | |||||
const auto is_exit_op = [](const NodePtr &n) { | |||||
return kExitOpTypes.count(NodeUtils::GetNodeType(n)) > 0; | |||||
}; | |||||
const auto src_nodes = node->GetInAllNodes(); | |||||
const auto dst_nodes = node->GetOutAllNodes(); | |||||
if (std::none_of(src_nodes.begin(), src_nodes.end(), is_loop_op) && | |||||
std::none_of(dst_nodes.begin(), dst_nodes.end(), is_exit_op)) { | |||||
return false; | |||||
} | |||||
for (const auto &m : src_nodes) { | |||||
if (kMergeOpTypes.count(NodeUtils::GetNodeType(m)) > 0) { | |||||
for (const auto &n : m->GetInAllNodes()) { | |||||
if (kNextIterationOpTypes.count(NodeUtils::GetNodeType(n)) > 0) { | |||||
continue; | |||||
} | |||||
search_queue.push({n, dst_span}); | |||||
GELOGD("Travel in Loop: %s <-- %s <-- %s, span is: %u", node->GetName().c_str(), m->GetName().c_str(), | |||||
n->GetName().c_str(), dst_span); | |||||
} | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
/// | |||||
/// @brief Mark force unknown shape for Switch node | /// @brief Mark force unknown shape for Switch node | ||||
/// @param [in] merge node | /// @param [in] merge node | ||||
/// @param [out] switch group | /// @param [out] switch group | ||||
@@ -72,6 +102,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||||
/// | /// | ||||
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) { | void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) { | ||||
// Switch --> {Switch --> Merge} --> Merge | // Switch --> {Switch --> Merge} --> Merge | ||||
GELOGD("Search Switch node for Merge: %s", node->GetName().c_str()); | |||||
std::unordered_set<NodePtr> nodes_seen; | std::unordered_set<NodePtr> nodes_seen; | ||||
std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}}); | std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}}); | ||||
while (!search_queue.empty()) { | while (!search_queue.empty()) { | ||||
@@ -79,43 +110,25 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||||
const auto dst_span = search_queue.front().second; | const auto dst_span = search_queue.front().second; | ||||
search_queue.pop(); | search_queue.pop(); | ||||
// Switch --> Identity --> Constant | |||||
for (const auto &in_node : dst_node->GetInControlNodes()) { | |||||
if (nodes_seen.count(in_node) > 0) { | |||||
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
nodes_seen.insert(in_node); | |||||
if (in_node->GetType() == IDENTITY) { | |||||
GELOGD("Travel node: %s, In control: %s, span is: %u", dst_node->GetName().c_str(), | |||||
in_node->GetName().c_str(), dst_span); | |||||
search_queue.push({in_node, dst_span}); | |||||
} | |||||
} | |||||
for (const auto &in_node : dst_node->GetInDataNodes()) { | |||||
for (const auto &in_node : dst_node->GetInAllNodes()) { | |||||
if (nodes_seen.count(in_node) > 0) { | if (nodes_seen.count(in_node) > 0) { | ||||
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | ||||
continue; | continue; | ||||
} | } | ||||
nodes_seen.insert(in_node); | nodes_seen.insert(in_node); | ||||
std::string node_type; | |||||
(void)GetOriginalType(in_node, node_type); | |||||
const std::string node_type = NodeUtils::GetNodeType(in_node); | |||||
GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), | GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), | ||||
in_node->GetName().c_str(), dst_span); | in_node->GetName().c_str(), dst_span); | ||||
if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. | if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. | ||||
if (DealAsLoopSwitch(in_node, dst_span, search_queue)) { | |||||
continue; | |||||
} | |||||
if (dst_span > 0) { | if (dst_span > 0) { | ||||
search_queue.push({in_node, dst_span - 1}); | search_queue.push({in_node, dst_span - 1}); | ||||
} else { | } else { | ||||
const auto &all_in_nodes = in_node->GetInDataNodes(); | |||||
if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) { | |||||
GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(), | |||||
in_node->GetName().c_str()); | |||||
} else { | |||||
switch_group.emplace_back(in_node); | |||||
} | |||||
switch_group.emplace_back(in_node); | |||||
} | } | ||||
} else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. | } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. | ||||
search_queue.push({in_node, dst_span + 1}); | search_queue.push({in_node, dst_span + 1}); | ||||
@@ -19,6 +19,8 @@ | |||||
#include "inc/graph_pass.h" | #include "inc/graph_pass.h" | ||||
#include <queue> | |||||
namespace ge { | namespace ge { | ||||
class MarkForceUnknownForCondPass : public GraphPass { | class MarkForceUnknownForCondPass : public GraphPass { | ||||
public: | public: | ||||
@@ -26,6 +28,15 @@ class MarkForceUnknownForCondPass : public GraphPass { | |||||
private: | private: | ||||
/// | /// | ||||
/// @brief Deal with Switch node for LoopCond | |||||
/// @param [in] Switch node | |||||
/// @param [in] dest span | |||||
/// @param [out] Search queue | |||||
/// @return true: Switch In while loop / false: Not in while Loop. | |||||
/// | |||||
bool DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, std::queue<std::pair<NodePtr, uint32_t>> &search_queue); | |||||
/// | |||||
/// @brief Mark force unknown shape for Switch node | /// @brief Mark force unknown shape for Switch node | ||||
/// @param [in] merge node | /// @param [in] merge node | ||||
/// @param [out] switch group | /// @param [out] switch group | ||||
@@ -24,7 +24,9 @@ using std::string; | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
const int64_t kLoopType = 1; | |||||
constexpr int64_t kLoopType = 1; | |||||
constexpr uint8_t kMaxTransOp = 3; | |||||
constexpr uint8_t kTransOpIoSize = 1; | |||||
} | } | ||||
Status NextIterationPass::Run(ComputeGraphPtr graph) { | Status NextIterationPass::Run(ComputeGraphPtr graph) { | ||||
@@ -287,18 +289,25 @@ void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, i | |||||
std::string node_type; | std::string node_type; | ||||
for (const auto &switch_node : loop_group.switch_nodes) { | for (const auto &switch_node : loop_group.switch_nodes) { | ||||
SetControlFlowGroup(switch_node, group_index); | SetControlFlowGroup(switch_node, group_index); | ||||
for (const auto &node : switch_node->GetOutDataNodes()) { | |||||
(void)GetOriginalType(node, node_type); | |||||
if (kExitOpTypes.count(node_type) > 0) { | |||||
SetControlFlowGroup(node, group_index); | |||||
} else { | |||||
// For: Switch -> Cast -> Exit | |||||
for (const auto &n : node->GetOutDataNodes()) { | |||||
(void)GetOriginalType(n, node_type); | |||||
if (kExitOpTypes.count(node_type) > 0) { | |||||
SetControlFlowGroup(n, group_index); | |||||
} | |||||
for (auto node : switch_node->GetOutDataNodes()) { | |||||
// Switch --> Exit | |||||
// Switch --> Cast --> Exit | |||||
// Switch --> TransData --> Cast --> Exit | |||||
for (uint8_t i = 0; i < kMaxTransOp; ++i) { | |||||
if (node->GetInDataNodes().size() != kTransOpIoSize || node->GetAllOutDataAnchorsSize() != kTransOpIoSize) { | |||||
break; | |||||
} | } | ||||
if (kExitOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { | |||||
SetControlFlowGroup(node, group_index); | |||||
break; | |||||
} | |||||
const auto &all_nodes = node->GetOutAllNodes(); | |||||
if (all_nodes.size() != kTransOpIoSize) { | |||||
break; | |||||
} | |||||
node = all_nodes.at(0); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -15,7 +15,7 @@ | |||||
*/ | */ | ||||
#include "graph/passes/parallel_group_pass.h" | #include "graph/passes/parallel_group_pass.h" | ||||
#include <queue> | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
@@ -299,24 +299,19 @@ Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cu | |||||
for (const auto &switch_node : cur_itr->second.first) { | for (const auto &switch_node : cur_itr->second.first) { | ||||
int64_t pre_id = pre_node->GetOpDesc()->GetId(); | int64_t pre_id = pre_node->GetOpDesc()->GetId(); | ||||
int64_t switch_id = switch_node->GetOpDesc()->GetId(); | int64_t switch_id = switch_node->GetOpDesc()->GetId(); | ||||
// avoid ring | |||||
if (pre_id > switch_id) { | |||||
auto merge_node = cur_itr->second.second; | |||||
if (AddCtrlEdge(merge_node, pre_node) != SUCCESS) { | |||||
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||||
pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||||
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||||
pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
} else { | |||||
if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) { | |||||
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||||
pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||||
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||||
pre_node->GetName().c_str(), switch_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
NodePtr first_node = pre_node; | |||||
NodePtr second_node = switch_node; | |||||
if (pre_id > switch_id && IsIndirectConnect(switch_node, pre_node)) { | |||||
// avoid ring, merge->pre_node | |||||
first_node = cur_itr->second.second; | |||||
second_node = pre_node; | |||||
} | |||||
if (AddCtrlEdge(first_node, second_node) != SUCCESS) { | |||||
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||||
first_node->GetName().c_str(), second_node->GetName().c_str()); | |||||
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", | |||||
first_node->GetName().c_str(), second_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | } | ||||
} | } | ||||
} else { | } else { | ||||
@@ -345,4 +340,29 @@ bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) { | |||||
return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) && | return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) && | ||||
stream_switch_type == kLoopType); | stream_switch_type == kLoopType); | ||||
} | } | ||||
bool ParallelGroupPass::IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b) { | |||||
if (node_a == nullptr || node_b == nullptr) { | |||||
GELOGW("node_a or node_b is nullptr."); | |||||
return false; | |||||
} | |||||
int64_t end_id = node_b->GetOpDesc()->GetId(); | |||||
std::queue<NodePtr> nodes; | |||||
nodes.push(node_a); | |||||
while (!nodes.empty()) { | |||||
NodePtr tmp_node = nodes.front(); | |||||
nodes.pop(); | |||||
if (tmp_node == nullptr || tmp_node->GetOpDesc() == nullptr || | |||||
tmp_node->GetOpDesc()->GetId() > end_id) { | |||||
continue; | |||||
} | |||||
if (tmp_node == node_b) { | |||||
return true; | |||||
} | |||||
for (const auto &out_node : tmp_node->GetOutAllNodes()) { | |||||
nodes.push(out_node); | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -48,6 +48,7 @@ class ParallelGroupPass : public GraphPass { | |||||
bool IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc); | bool IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc); | ||||
bool IsWhileStreamSwitch(OpDescPtr switch_op_desc); | bool IsWhileStreamSwitch(OpDescPtr switch_op_desc); | ||||
bool IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H | #endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H |
@@ -395,8 +395,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||||
peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | ||||
int64_t group_index = -1; | int64_t group_index = -1; | ||||
(void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
SetControlFlowGroup(stream_switch, group_index); | |||||
if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||||
SetControlFlowGroup(stream_switch, group_index); | |||||
} | |||||
return stream_switch; | return stream_switch; | ||||
} | } | ||||
@@ -568,6 +568,7 @@ Status InsertNewOpUtil::GetDataRelatedNode(NodePtr &node, std::map<NodePtr, std: | |||||
} | } | ||||
std::unique_ptr<domi::AippOpParams> aipp_params(new (std::nothrow) domi::AippOpParams()); | std::unique_ptr<domi::AippOpParams> aipp_params(new (std::nothrow) domi::AippOpParams()); | ||||
GE_CHECK_NOTNULL(aipp_params); | |||||
ge::GeAttrValue::NAMED_ATTRS aipp_attr; | ge::GeAttrValue::NAMED_ATTRS aipp_attr; | ||||
GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(data_op, ATTR_NAME_AIPP, aipp_attr), ACL_ERROR_GE_AIPP_NOT_EXIST, | GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(data_op, ATTR_NAME_AIPP, aipp_attr), ACL_ERROR_GE_AIPP_NOT_EXIST, | ||||
"[Get][Attr] %s from op:%s failed", ATTR_NAME_AIPP.c_str(), data_op->GetName().c_str()); | "[Get][Attr] %s from op:%s failed", ATTR_NAME_AIPP.c_str(), data_op->GetName().c_str()); | ||||
@@ -1206,7 +1206,7 @@ Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector<NodePtr> &start_ | |||||
auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims(); | auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims(); | ||||
if (!IsAllDimsPositive(dims)) { | if (!IsAllDimsPositive(dims)) { | ||||
REPORT_CALL_ERROR("E19999", "Failed to copy multi batch graph, the node %s still has unknown shape %s", | REPORT_CALL_ERROR("E19999", "Failed to copy multi batch graph, the node %s still has unknown shape %s", | ||||
node->GetName().c_str(), formats::ShapeToString(dims).c_str()); | |||||
node->GetName().c_str(), formats::ShapeToString(dims).c_str()); | |||||
GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to copy multi batch graph, the node %s still has unknown shape %s", | GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to copy multi batch graph, the node %s still has unknown shape %s", | ||||
node->GetName().c_str(), formats::ShapeToString(dims).c_str()); | node->GetName().c_str(), formats::ShapeToString(dims).c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
@@ -45,6 +45,7 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge | |||||
GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr is nullptr."); | GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr is nullptr."); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
GELOGD("FillKernel in, name: %s.", op_desc_ptr->GetName().c_str()); | |||||
GE_CHECK_NOTNULL(input.at(kFillDimsInputIndex)); | GE_CHECK_NOTNULL(input.at(kFillDimsInputIndex)); | ||||
GE_CHECK_NOTNULL(input.at(kFillDataInputIndex)); | GE_CHECK_NOTNULL(input.at(kFillDataInputIndex)); | ||||
@@ -57,6 +58,13 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge | |||||
return NOT_CHANGED; | return NOT_CHANGED; | ||||
} | } | ||||
auto output_desc = op_desc_ptr->GetOutputDescPtr(0); | |||||
GE_CHECK_NOTNULL(output_desc); | |||||
if (output_desc->GetShape().IsUnknownShape()) { | |||||
GELOGD("Output is unknown shape, [%s] skip FillKernel.", op_desc_ptr->GetName().c_str()); | |||||
return NOT_CHANGED; | |||||
} | |||||
GeTensorPtr output_ptr; | GeTensorPtr output_ptr; | ||||
output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0)); | output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0)); | ||||
if (output_ptr == nullptr) { | if (output_ptr == nullptr) { | ||||
@@ -297,13 +297,15 @@ Status HybridModelAsyncExecutor::PrepareInputs(const InputData ¤t_data, Hy | |||||
} | } | ||||
} | } | ||||
tensor_desc->SetShape(shape); | tensor_desc->SetShape(shape); | ||||
args.input_desc[input_index] = tensor_desc; | |||||
GELOGD("Update shape of input[%zu] to [%s]", input_index, tensor_desc->MutableShape().ToString().c_str()); | |||||
GELOGD("Update shape[%s] of input[%zu] to [%s]", | |||||
shape.ToString().c_str(), input_index, tensor_desc->MutableShape().ToString().c_str()); | |||||
GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size), | GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size), | ||||
"[Invoke][GetTensorMemorySizeInBytes]Failed to calc tensor size," | "[Invoke][GetTensorMemorySizeInBytes]Failed to calc tensor size," | ||||
"index = %zu, shape = [%s], model_id = %u.", | "index = %zu, shape = [%s], model_id = %u.", | ||||
input_index, tensor_desc->GetShape().ToString().c_str(), model_id_); | input_index, tensor_desc->GetShape().ToString().c_str(), model_id_); | ||||
GELOGD("Input tensor[%zu] size = %zu", input_index, tensor_size); | |||||
GELOGD("Input tensor[%zu] size = %ld", input_index, tensor_size); | |||||
TensorUtils::SetSize(*tensor_desc, tensor_size); | |||||
args.input_desc[input_index] = tensor_desc; | |||||
} | } | ||||
GE_CHECK_GE(tensor_size, 0); | GE_CHECK_GE(tensor_size, 0); | ||||
@@ -326,17 +326,45 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||||
} | } | ||||
void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { | void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { | ||||
if (node_item_->root_data_.count(input_idx) > 0) { | |||||
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | |||||
root_tensor_values_[input_idx] = tensor; | |||||
const auto is_persist_tensor = [](const std::map<const NodeItem *, std::set<int>> &items, int idx) { | |||||
const auto is_exist = [&idx](const std::pair<const NodeItem *, std::set<int>> &items) { | |||||
return items.second.count(idx) > 0; | |||||
}; | |||||
return std::any_of(items.begin(), items.end(), is_exist); | |||||
}; | |||||
if (root_tensor_values_.count(input_idx) > 0) { | |||||
return; | |||||
} | } | ||||
if (node_item_->enter_data_.count(input_idx) > 0) { | |||||
if (is_persist_tensor(node_item_->root_data_, input_idx)) { | |||||
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | |||||
root_tensor_values_[input_idx] = tensor; | |||||
} else if (is_persist_tensor(node_item_->enter_data_, input_idx)) { | |||||
GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); | GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); | ||||
root_tensor_values_[input_idx] = tensor; | root_tensor_values_[input_idx] = tensor; | ||||
} | } | ||||
} | } | ||||
void NodeState::UpdatePersistTensor() { | |||||
const auto update_tensor = [&](const std::map<const NodeItem *, std::set<int>> &items) { | |||||
for (const auto &item : items) { | |||||
for (const auto idx : item.second) { | |||||
UpdatePersistTensor(idx); | |||||
} | |||||
} | |||||
}; | |||||
if (root_tensor_values_.empty()) { | |||||
return; | |||||
} | |||||
update_tensor(node_item_->root_data_); | |||||
if (iteration_count_ > 0) { | |||||
update_tensor(node_item_->enter_data_); | |||||
} | |||||
} | |||||
void NodeState::UpdatePersistTensor(int input_idx) { | void NodeState::UpdatePersistTensor(int input_idx) { | ||||
const auto it = root_tensor_values_.find(input_idx); | const auto it = root_tensor_values_.find(input_idx); | ||||
if (it == root_tensor_values_.end()) { | if (it == root_tensor_values_.end()) { | ||||
@@ -363,16 +391,9 @@ void NodeState::ResetContext(uint64_t iteration) { | |||||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | ||||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | ||||
for (auto item : node_item_->root_data_) { | |||||
UpdatePersistTensor(item.first); | |||||
} | |||||
if (iteration > 0) { | if (iteration > 0) { | ||||
data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | ||||
ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | ||||
for (auto item : node_item_->enter_data_) { | |||||
UpdatePersistTensor(item.first); | |||||
} | |||||
} | } | ||||
iteration_count_ = iteration; | iteration_count_ = iteration; | ||||
@@ -132,6 +132,7 @@ struct NodeState { | |||||
void RunNextIteration(); | void RunNextIteration(); | ||||
void SavePersistTensor(int input_idx, const TensorValue &tensor); | void SavePersistTensor(int input_idx, const TensorValue &tensor); | ||||
void UpdatePersistTensor(); | |||||
Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | ||||
@@ -373,6 +373,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, | |||||
auto executor = node_item.node_executor; | auto executor = node_item.node_executor; | ||||
GE_CHECK_NOTNULL(executor); | GE_CHECK_NOTNULL(executor); | ||||
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); | RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); | ||||
node_state.UpdatePersistTensor(); | |||||
GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", | GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", | ||||
node_state.GetName().c_str()); | node_state.GetName().c_str()); | ||||
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); | RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); | ||||
@@ -147,6 +147,7 @@ class HybridModel { | |||||
GeRootModelPtr ge_root_model_; | GeRootModelPtr ge_root_model_; | ||||
std::map<uint32_t, NodeItem *> input_nodes_; | std::map<uint32_t, NodeItem *> input_nodes_; | ||||
ComputeGraphPtr root_graph_; | ComputeGraphPtr root_graph_; | ||||
ComputeGraphPtr orig_root_graph_; | |||||
std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148 | std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148 | ||||
std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148 | std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148 | ||||
std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_; | std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_; | ||||
@@ -147,6 +147,7 @@ Status HybridModelBuilder::Build() { | |||||
GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); | ||||
hybrid_model_.model_name_ = ge_root_model_->GetModelName(); | hybrid_model_.model_name_ = ge_root_model_->GetModelName(); | ||||
GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | ||||
GE_CHK_STATUS_RET(CopyGraph(), "[Invoke][CopyGraph] failed, model_name_:[%s]", GetGraphName()); | |||||
GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName()); | ||||
GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), | GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), | ||||
"[Invoke][RecoverGraphUnknownFlag] failed, model_name_:[%s]", GetGraphName()); | "[Invoke][RecoverGraphUnknownFlag] failed, model_name_:[%s]", GetGraphName()); | ||||
@@ -171,11 +172,12 @@ Status HybridModelBuilder::Build() { | |||||
Status HybridModelBuilder::BuildForSingleOp() { | Status HybridModelBuilder::BuildForSingleOp() { | ||||
GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); | ||||
hybrid_model_.root_graph_ = ge_root_model_->GetRootGraph(); | |||||
hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); | hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); | ||||
GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | ||||
auto ret = ge_root_model_->GetSubgraphInstanceNameToModel(); | auto ret = ge_root_model_->GetSubgraphInstanceNameToModel(); | ||||
const GeModelPtr ge_model = ret[ge_root_model_->GetRootGraph()->GetName()]; | |||||
GE_CHK_STATUS_RET(IndexTaskDefs(ge_root_model_->GetRootGraph(), ge_model), | |||||
const GeModelPtr ge_model = ret[hybrid_model_.root_graph_->GetName()]; | |||||
GE_CHK_STATUS_RET(IndexTaskDefs(hybrid_model_.root_graph_, ge_model), | |||||
"[Invoke][IndexTaskDefs] failed, model_name_:[%s]", GetGraphName()); | "[Invoke][IndexTaskDefs] failed, model_name_:[%s]", GetGraphName()); | ||||
GE_CHK_STATUS_RET(LoadGraph(), "[Invoke][LoadGraph] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(LoadGraph(), "[Invoke][LoadGraph] failed, model_name_:[%s]", GetGraphName()); | ||||
GE_CHK_STATUS_RET(InitWeights(), "[Invoke][InitWeights] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(InitWeights(), "[Invoke][InitWeights] failed, model_name_:[%s]", GetGraphName()); | ||||
@@ -190,6 +192,27 @@ Status HybridModelBuilder::ValidateParams() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HybridModelBuilder::CopyGraph() { | |||||
GELOGD("Copy compute graph begin."); | |||||
auto root_graph = ge_root_model_->GetRootGraph(); | |||||
std::string new_graph_name = ge_root_model_->GetRootGraph()->GetName(); | |||||
ComputeGraphPtr new_root_graph = MakeShared<ComputeGraph>(new_graph_name); | |||||
GE_CHECK_NOTNULL(new_root_graph); | |||||
int32_t depth = 0; | |||||
std::map<ConstNodePtr, NodePtr> node_old_2_new; | |||||
std::map<ConstOpDescPtr, OpDescPtr> op_desc_old_2_new; | |||||
graphStatus ret = GraphUtils::CopyComputeGraph(root_graph, new_root_graph, node_old_2_new, op_desc_old_2_new, depth); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Copy compute graph failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
hybrid_model_.root_graph_ = new_root_graph; | |||||
GELOGD("Copy compute graph[%s] success.", new_graph_name.c_str()); | |||||
return SUCCESS; | |||||
} | |||||
Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { | Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item), | GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item), | ||||
@@ -265,10 +288,6 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
if (node->GetType() == MEMCPYASYNC) { // Convert MemcpyAsync to Identity. | |||||
node->GetOpDesc()->SetType(IDENTITY); | |||||
} | |||||
std::unique_ptr<NodeItem> new_node; | std::unique_ptr<NodeItem> new_node; | ||||
GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); | ||||
GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); | GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); | ||||
@@ -814,12 +833,13 @@ Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, | |||||
} | } | ||||
Status HybridModelBuilder::LoadGraph() { | Status HybridModelBuilder::LoadGraph() { | ||||
auto root_graph = ge_root_model_->GetRootGraph(); | |||||
auto root_graph = hybrid_model_.root_graph_; | |||||
if (!GetContext().GetHostExecFlag()) { | if (!GetContext().GetHostExecFlag()) { | ||||
std::shared_ptr<ComputeGraph> merged_graph; | std::shared_ptr<ComputeGraph> merged_graph; | ||||
GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", | GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", | ||||
root_graph->GetDirectNodesSize(), | root_graph->GetDirectNodesSize(), | ||||
root_graph->GetAllNodesSize()); | root_graph->GetAllNodesSize()); | ||||
hybrid_model_.orig_root_graph_ = root_graph; | |||||
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), | GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), | ||||
"[Invoke][UnfoldSubgraphs]Failed to unfold subgraphs, model_name_:%s.", GetGraphName()); | "[Invoke][UnfoldSubgraphs]Failed to unfold subgraphs, model_name_:%s.", GetGraphName()); | ||||
root_graph = std::move(merged_graph); | root_graph = std::move(merged_graph); | ||||
@@ -877,6 +897,7 @@ Status HybridModelBuilder::LoadGraph() { | |||||
} | } | ||||
for (auto &it : hybrid_model_.known_shape_sub_models_) { | for (auto &it : hybrid_model_.known_shape_sub_models_) { | ||||
auto node_item = MutableNodeItem(it.first); | auto node_item = MutableNodeItem(it.first); | ||||
GE_CHECK_NOTNULL(node_item); | |||||
AscendString graph_name; | AscendString graph_name; | ||||
GE_CHK_GRAPH_STATUS_RET(it.second->GetGraph().GetName(graph_name), "Failed to get subgraph name"); | GE_CHK_GRAPH_STATUS_RET(it.second->GetGraph().GetName(graph_name), "Failed to get subgraph name"); | ||||
auto subgraph = hybrid_model_.GetRootGraph()->GetSubgraph(graph_name.GetString()); | auto subgraph = hybrid_model_.GetRootGraph()->GetSubgraph(graph_name.GetString()); | ||||
@@ -1023,6 +1044,7 @@ Status HybridModelBuilder::InitConstantOps() { | |||||
} else { | } else { | ||||
var_tensor.reset(new(std::nothrow)TensorValue(nullptr, 0)); | var_tensor.reset(new(std::nothrow)TensorValue(nullptr, 0)); | ||||
} | } | ||||
GE_CHECK_NOTNULL(var_tensor); | |||||
} else { | } else { | ||||
GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); | GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); | ||||
GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize()); | GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize()); | ||||
@@ -1125,7 +1147,9 @@ Status HybridModelBuilder::InitWeights() { | |||||
sub_weight_buffer->GetSize()); | sub_weight_buffer->GetSize()); | ||||
auto subgraph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); | auto subgraph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); | ||||
if (subgraph != ge_root_model_->GetRootGraph()) { | if (subgraph != ge_root_model_->GetRootGraph()) { | ||||
subgraph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first); | |||||
subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_model.first); | |||||
} else { | |||||
subgraph = hybrid_model_.root_graph_; | |||||
} | } | ||||
GE_CHECK_NOTNULL(subgraph); | GE_CHECK_NOTNULL(subgraph); | ||||
hybrid_model_.weight_buffer_map_.emplace(subgraph->GetName(), std::move(sub_weight_buffer)); | hybrid_model_.weight_buffer_map_.emplace(subgraph->GetName(), std::move(sub_weight_buffer)); | ||||
@@ -1304,7 +1328,7 @@ Status HybridModelBuilder::IndexTaskDefs(const ComputeGraphPtr &sub_graph, const | |||||
} | } | ||||
Status HybridModelBuilder::IndexTaskDefs() { | Status HybridModelBuilder::IndexTaskDefs() { | ||||
const auto root_graph = ge_root_model_->GetRootGraph(); | |||||
const auto &root_graph = hybrid_model_.root_graph_; | |||||
const auto &root_graph_name = root_graph->GetName(); | const auto &root_graph_name = root_graph->GetName(); | ||||
if (SetOutputNameAttr(*root_graph) != SUCCESS) { | if (SetOutputNameAttr(*root_graph) != SUCCESS) { | ||||
GELOGW("Set output name attr failed."); | GELOGW("Set output name attr failed."); | ||||
@@ -1338,7 +1362,7 @@ Status HybridModelBuilder::IndexTaskDefs() { | |||||
Status HybridModelBuilder::IndexSpecialNodes() { | Status HybridModelBuilder::IndexSpecialNodes() { | ||||
GELOGD("Start to index special nodes"); | GELOGD("Start to index special nodes"); | ||||
const auto &root_graph = ge_root_model_->GetRootGraph(); | |||||
const auto &root_graph = hybrid_model_.root_graph_; | |||||
for (auto &node : root_graph->GetAllNodes()) { | for (auto &node : root_graph->GetAllNodes()) { | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
@@ -1493,7 +1517,7 @@ Status HybridModelBuilder::InitRuntimeParams() { | |||||
runtime_param_.session_id = ret ? static_cast<uint64_t>(value) : 0; | runtime_param_.session_id = ret ? static_cast<uint64_t>(value) : 0; | ||||
ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_TASK_GEN_VAR_ADDR, value); | ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_TASK_GEN_VAR_ADDR, value); | ||||
runtime_param_.logic_var_base = ret ? static_cast<uint64_t>(value) : 0; | runtime_param_.logic_var_base = ret ? static_cast<uint64_t>(value) : 0; | ||||
runtime_param_.graph_id = ge_root_model_->GetRootGraph()->GetGraphID(); | |||||
runtime_param_.graph_id = hybrid_model_.root_graph_->GetGraphID(); | |||||
value = 0; | value = 0; | ||||
for (auto &it : ge_root_model_->GetSubgraphInstanceNameToModel()) { | for (auto &it : ge_root_model_->GetSubgraphInstanceNameToModel()) { | ||||
(void) ge::AttrUtils::GetInt(it.second, ATTR_MODEL_VAR_SIZE, value); | (void) ge::AttrUtils::GetInt(it.second, ATTR_MODEL_VAR_SIZE, value); | ||||
@@ -1630,7 +1654,7 @@ Status HybridModelBuilder::TransAllVarData() { | |||||
} | } | ||||
Status HybridModelBuilder::CopyVarData() { | Status HybridModelBuilder::CopyVarData() { | ||||
GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(ge_root_model_->GetRootGraph(), | |||||
GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(hybrid_model_.root_graph_, | |||||
runtime_param_.session_id, | runtime_param_.session_id, | ||||
hybrid_model_.device_id_), | hybrid_model_.device_id_), | ||||
"[Invoke][CopyVarData] failed."); | "[Invoke][CopyVarData] failed."); | ||||
@@ -1713,7 +1737,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem | |||||
} | } | ||||
Status HybridModelBuilder::RecoverGraphUnknownFlag() { | Status HybridModelBuilder::RecoverGraphUnknownFlag() { | ||||
const auto &root_graph = ge_root_model_->GetRootGraph(); | |||||
const auto &root_graph = hybrid_model_.root_graph_; | |||||
for (auto &sub_graph : root_graph->GetAllSubgraphs()) { | for (auto &sub_graph : root_graph->GetAllSubgraphs()) { | ||||
GE_CHECK_NOTNULL(sub_graph); | GE_CHECK_NOTNULL(sub_graph); | ||||
for (const auto &node : sub_graph->GetDirectNode()) { | for (const auto &node : sub_graph->GetDirectNode()) { | ||||
@@ -56,6 +56,7 @@ class HybridModelBuilder { | |||||
Status BuildOutputMapping(GraphItem &partitioned_call, const NodeItem &node_item, bool is_root_graph); | Status BuildOutputMapping(GraphItem &partitioned_call, const NodeItem &node_item, bool is_root_graph); | ||||
Status ValidateParams(); | Status ValidateParams(); | ||||
Status LoadGraph(); | Status LoadGraph(); | ||||
Status CopyGraph(); | |||||
Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | ||||
static Status InitHcclExecutorOnDemand(const GeModelPtr &ge_model); | static Status InitHcclExecutorOnDemand(const GeModelPtr &ge_model); | ||||
Status LoadTask(NodeItem &node_item); | Status LoadTask(NodeItem &node_item); | ||||
@@ -14,10 +14,8 @@ | |||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#include "node_item.h" | |||||
#include <sstream> | |||||
#include "common/debug/log.h" | |||||
#include "graph/common/omg_util.h" | |||||
#include "hybrid/model/node_item.h" | |||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "hybrid/executor/worker/shape_inference_engine.h" | #include "hybrid/executor/worker/shape_inference_engine.h" | ||||
@@ -26,6 +24,8 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
namespace { | namespace { | ||||
const uint8_t kMaxTransCount = 3; | |||||
const uint32_t kTransOpIoSize = 1; | |||||
const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | ||||
const char *const kNodeTypeRetVal = "_RetVal"; | const char *const kNodeTypeRetVal = "_RetVal"; | ||||
const std::set<std::string> kControlOpTypes{ | const std::set<std::string> kControlOpTypes{ | ||||
@@ -41,6 +41,25 @@ const std::set<std::string> kMergeOpTypes{ | |||||
MERGE, REFMERGE, STREAMMERGE | MERGE, REFMERGE, STREAMMERGE | ||||
}; | }; | ||||
bool IsEnterFeedNode(NodePtr node) { | |||||
// For: Enter -> node | |||||
// For: Enter -> Cast -> node | |||||
// For: Enter -> TransData -> Cast -> node | |||||
for (uint8_t i = 0; i < kMaxTransCount; ++i) { | |||||
if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { | |||||
GELOGD("Node[%s] is Enter feed node.", node->GetName().c_str()); | |||||
return true; | |||||
} | |||||
const auto all_nodes = node->GetInDataNodes(); | |||||
if (all_nodes.size() != kTransOpIoSize || node->GetAllInDataAnchorsSize() != kTransOpIoSize) { | |||||
return false; | |||||
} | |||||
node = all_nodes.at(0); | |||||
} | |||||
return false; | |||||
} | |||||
Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | ||||
uint32_t parent_index = 0; | uint32_t parent_index = 0; | ||||
if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | ||||
@@ -98,8 +117,7 @@ Status ParseFusedSubgraph(NodeItem &node_item) { | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
std::string node_type; | |||||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type)); | |||||
const std::string node_type = NodeUtils::GetNodeType(node); | |||||
if (node_type == DATA) { | if (node_type == DATA) { | ||||
GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph)); | GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph)); | ||||
} else if (node_type == kNodeTypeRetVal) { | } else if (node_type == kNodeTypeRetVal) { | ||||
@@ -398,19 +416,21 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||||
data_send_.emplace(node_item); | data_send_.emplace(node_item); | ||||
node_item->data_recv_[this] = anchor_index; | node_item->data_recv_[this] = anchor_index; | ||||
if (is_root_node_) { | if (is_root_node_) { | ||||
node_item->root_data_[anchor_index] = this; | |||||
auto &data_anchors = node_item->root_data_[this]; | |||||
data_anchors.emplace(anchor_index); | |||||
} | } | ||||
// If Enter feed Not Merge, take as root Node. | // If Enter feed Not Merge, take as root Node. | ||||
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | |||||
node_item->enter_data_[anchor_index] = this; | |||||
if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE)) { | |||||
auto &data_anchors = node_item->enter_data_[this]; | |||||
data_anchors.emplace(anchor_index); | |||||
} | } | ||||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
} | } | ||||
void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | ||||
if (switch_index < switch_groups_.size()) { | if (switch_index < switch_groups_.size()) { | ||||
std::vector<const NodeItem *> &switch_group = switch_groups_[switch_index]; | |||||
switch_group.emplace_back(node_item); | |||||
auto &switch_group = switch_groups_[switch_index]; | |||||
switch_group.emplace(node_item); | |||||
} else { | } else { | ||||
ctrl_send_.insert(node_item); | ctrl_send_.insert(node_item); | ||||
} | } | ||||
@@ -420,7 +440,7 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | |||||
node_item->root_ctrl_.emplace(this); | node_item->root_ctrl_.emplace(this); | ||||
} | } | ||||
// If Enter feed control signal, take as root Node. | // If Enter feed control signal, take as root Node. | ||||
if (IsEnterOp() && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { | |||||
if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) { | |||||
node_item->enter_ctrl_.emplace(this); | node_item->enter_ctrl_.emplace(this); | ||||
} | } | ||||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
@@ -433,8 +453,8 @@ void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) { | |||||
} | } | ||||
// this is StreamMerge node, node_item is StreamActive node. | // this is StreamMerge node, node_item is StreamActive node. | ||||
std::vector<const NodeItem *> &switch_group = switch_groups_[merge_index]; | |||||
switch_group.emplace_back(node_item); | |||||
auto &switch_group = switch_groups_[merge_index]; | |||||
switch_group.emplace(node_item); | |||||
node_item->ctrl_send_.emplace(this); | node_item->ctrl_send_.emplace(this); | ||||
GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); | ||||
@@ -148,14 +148,14 @@ struct NodeItem { | |||||
int64_t frame_index_ = -1; | int64_t frame_index_ = -1; | ||||
int64_t parent_frame_ = -1; | int64_t parent_frame_ = -1; | ||||
std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | ||||
std::map<int, const NodeItem *> root_data_; // Recv data from root node | |||||
std::map<const NodeItem *, std::set<int>> root_data_; // Recv data from root node | |||||
std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | ||||
std::map<int, const NodeItem *> enter_data_; // Recv data from Enter node | |||||
std::map<const NodeItem *, std::set<int>> enter_data_; // Recv data from Enter node | |||||
std::set<const NodeItem *> data_send_; // Send data notify to | std::set<const NodeItem *> data_send_; // Send data notify to | ||||
std::map<const NodeItem *, int> data_recv_; // Recv data notify from | std::map<const NodeItem *, int> data_recv_; // Recv data notify from | ||||
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | ||||
std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | ||||
std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | |||||
std::vector<std::set<const NodeItem *>> switch_groups_; // Send ctrl notify to | |||||
std::shared_ptr<NodeTask> kernel_task; | std::shared_ptr<NodeTask> kernel_task; | ||||
std::unique_ptr<FusedSubgraph> fused_subgraph; | std::unique_ptr<FusedSubgraph> fused_subgraph; | ||||
@@ -342,6 +342,7 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
GE_CHK_RT_RET(rtEventDestroy(evt)); | GE_CHK_RT_RET(rtEventDestroy(evt)); | ||||
} | } | ||||
GELOGI("rdma callback success."); | GELOGI("rdma callback success."); | ||||
return SUCCESS; | |||||
}; | }; | ||||
HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); | HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); | ||||
@@ -17,13 +17,9 @@ | |||||
#include "hybrid/node_executor/rts/rts_node_executor.h" | #include "hybrid/node_executor/rts/rts_node_executor.h" | ||||
#include "hybrid/node_executor/rts/rts_task_factory.h" | #include "hybrid/node_executor/rts/rts_task_factory.h" | ||||
#include "common/debug/log.h" | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "common/types.h" | |||||
#include "graph/common/omg_util.h" | |||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "hybrid/model/hybrid_model.h" | #include "hybrid/model/hybrid_model.h" | ||||
#include "runtime/rt.h" | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
@@ -33,6 +29,7 @@ REGISTER_RTS_TASK_CREATOR(IDENTITY, IdentityNodeTask); | |||||
REGISTER_RTS_TASK_CREATOR(IDENTITYN, IdentityNNodeTask); | REGISTER_RTS_TASK_CREATOR(IDENTITYN, IdentityNNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(READVARIABLEOP, ReadVariableOpNodeTask); | REGISTER_RTS_TASK_CREATOR(READVARIABLEOP, ReadVariableOpNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(PROFILINGTRAININGTRACE, ProfilingTraceNodeTask); | REGISTER_RTS_TASK_CREATOR(PROFILINGTRAININGTRACE, ProfilingTraceNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, IdentityNodeTask); | |||||
Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { | Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { | ||||
auto input_desc = context.MutableInputDesc(index); | auto input_desc = context.MutableInputDesc(index); | ||||
@@ -133,8 +130,7 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function< | |||||
Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
GELOGD("[%s] Load for local task.", node->GetName().c_str()); | GELOGD("[%s] Load for local task.", node->GetName().c_str()); | ||||
std::string node_type; | |||||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); | |||||
const std::string node_type = NodeUtils::GetNodeType(node); | |||||
RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); | RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); | ||||
if (rts_task == nullptr) { | if (rts_task == nullptr) { | ||||
GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str()); | GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str()); | ||||
@@ -43,7 +43,6 @@ namespace hybrid { | |||||
REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask); | REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask); | REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask); | REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, MemcpyAsyncNodeTask); | |||||
REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask); | REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask); | REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask); | ||||
@@ -168,34 +167,6 @@ Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::functio | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MemcpyAsyncNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | |||||
GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | |||||
auto input_desc = task_context.MutableInputDesc(0); | |||||
GE_CHECK_NOTNULL(input_desc); | |||||
int64_t copy_size = 0; | |||||
GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*input_desc, copy_size)); | |||||
// copy_size would not be negative since GetTensorSizeInBytes returned successfully. | |||||
if (copy_size > 0) { | |||||
const auto in_v = task_context.MutableInput(0); | |||||
const auto out_v = task_context.MutableOutput(0); | |||||
GE_CHECK_NOTNULL(in_v); | |||||
GE_CHECK_NOTNULL(out_v); | |||||
GELOGD("[%s] input size: %zu, output size: %zu, copy size: %ld", task_context.GetNodeName(), | |||||
in_v->GetSize(), out_v->GetSize(), copy_size); | |||||
GE_CHK_RT_RET(rtMemcpyAsync(out_v->MutableData(), out_v->GetSize(), in_v->GetData(), copy_size, | |||||
RT_MEMCPY_DEVICE_TO_DEVICE, task_context.GetStream())); | |||||
} else { | |||||
GELOGW("[%s] invalid copy size: %ld", task_context.GetNodeName(), copy_size); | |||||
} | |||||
if (done_callback) { | |||||
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | |||||
} | |||||
GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); | |||||
return SUCCESS; | |||||
} | |||||
Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | ||||
GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | ||||
const auto in_x = task_context.GetInput(0); // x | const auto in_x = task_context.GetInput(0); // x | ||||
@@ -60,11 +60,6 @@ class StreamMergeNodeTask : public RtsNodeTask { | |||||
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | ||||
}; | }; | ||||
class MemcpyAsyncNodeTask : public RtsNodeTask { | |||||
public: | |||||
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | |||||
}; | |||||
class PassThroughNodeTask : public RtsNodeTask { | class PassThroughNodeTask : public RtsNodeTask { | ||||
public: | public: | ||||
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | ||||
@@ -458,10 +458,6 @@ Status TaskContext::PropagateOutputs() { | |||||
subgraph_context_->all_inputs_[input_offset].SetName( | subgraph_context_->all_inputs_[input_offset].SetName( | ||||
node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); | node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); | ||||
} | } | ||||
auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); | |||||
GE_CHECK_NOTNULL(dst_node_state); | |||||
dst_node_state->SavePersistTensor(dst_input_idx, *tensor); | |||||
} | } | ||||
} | } | ||||
(void)guard; | (void)guard; | ||||
@@ -493,6 +489,7 @@ void TaskContext::ReleaseInputsAndOutputs() { | |||||
void TaskContext::ReleaseInput(int index) { | void TaskContext::ReleaseInput(int index) { | ||||
auto input_tensor = MutableInput(index); | auto input_tensor = MutableInput(index); | ||||
if (input_tensor != nullptr) { | if (input_tensor != nullptr) { | ||||
node_state_->SavePersistTensor(index, *input_tensor); | |||||
input_tensor->Destroy(); | input_tensor->Destroy(); | ||||
GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); | GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); | ||||
} | } | ||||
@@ -33,6 +33,10 @@ | |||||
#include "register/op_tiling.h" | #include "register/op_tiling.h" | ||||
namespace ge { | namespace ge { | ||||
namespace { | |||||
const int kAddressNum = 2; | |||||
} // namespace | |||||
class StreamResource; | class StreamResource; | ||||
struct SingleOpModelParam; | struct SingleOpModelParam; | ||||
class OpTask { | class OpTask { | ||||
@@ -256,7 +260,7 @@ class MemcpyAsyncTask : public OpTask { | |||||
friend class SingleOpModel; | friend class SingleOpModel; | ||||
friend class RtsKernelTaskBuilder; | friend class RtsKernelTaskBuilder; | ||||
uintptr_t addresses_[2]; | |||||
uintptr_t addresses_[kAddressNum]; | |||||
size_t dst_max_; | size_t dst_max_; | ||||
size_t count_; | size_t count_; | ||||
rtMemcpyKind_t kind_; | rtMemcpyKind_t kind_; | ||||
@@ -26,9 +26,9 @@ extern "C" { | |||||
#endif | #endif | ||||
// Current version is 1.0.0 | // Current version is 1.0.0 | ||||
#define ACL_MAJOR_VERSION 1 | |||||
#define ACL_MINOR_VERSION 0 | |||||
#define ACL_PATCH_VERSION 0 | |||||
#define ACL_MAJOR_VERSION 1 | |||||
#define ACL_MINOR_VERSION 0 | |||||
#define ACL_PATCH_VERSION 0 | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -72,11 +72,11 @@ ACL_FUNC_VISIBILITY aclError aclrtGetVersion(int32_t *majorVersion, int32_t *min | |||||
* | * | ||||
* @retval null for failed | * @retval null for failed | ||||
* @retval OtherValues success | * @retval OtherValues success | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY const char *aclGetRecentErrMsg(); | ACL_FUNC_VISIBILITY const char *aclGetRecentErrMsg(); | ||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif | #endif | ||||
#endif // INC_EXTERNAL_ACL_ACL_H_ | |||||
#endif // INC_EXTERNAL_ACL_ACL_H_ |
@@ -136,50 +136,49 @@ static const int ACL_ERROR_PROFILING_FAILURE = 500005; | |||||
#define ACL_UNKNOWN_RANK 0xFFFFFFFFFFFFFFFE | #define ACL_UNKNOWN_RANK 0xFFFFFFFFFFFFFFFE | ||||
typedef enum { | typedef enum { | ||||
ACL_DT_UNDEFINED = -1, | |||||
ACL_FLOAT = 0, | |||||
ACL_FLOAT16 = 1, | |||||
ACL_INT8 = 2, | |||||
ACL_INT32 = 3, | |||||
ACL_UINT8 = 4, | |||||
ACL_INT16 = 6, | |||||
ACL_UINT16 = 7, | |||||
ACL_UINT32 = 8, | |||||
ACL_INT64 = 9, | |||||
ACL_UINT64 = 10, | |||||
ACL_DOUBLE = 11, | |||||
ACL_BOOL = 12, | |||||
ACL_STRING = 13, | |||||
ACL_DT_UNDEFINED = -1, | |||||
ACL_FLOAT = 0, | |||||
ACL_FLOAT16 = 1, | |||||
ACL_INT8 = 2, | |||||
ACL_INT32 = 3, | |||||
ACL_UINT8 = 4, | |||||
ACL_INT16 = 6, | |||||
ACL_UINT16 = 7, | |||||
ACL_UINT32 = 8, | |||||
ACL_INT64 = 9, | |||||
ACL_UINT64 = 10, | |||||
ACL_DOUBLE = 11, | |||||
ACL_BOOL = 12, | |||||
ACL_STRING = 13, | |||||
} aclDataType; | } aclDataType; | ||||
typedef enum { | typedef enum { | ||||
ACL_FORMAT_UNDEFINED = -1, | |||||
ACL_FORMAT_NCHW = 0, | |||||
ACL_FORMAT_NHWC = 1, | |||||
ACL_FORMAT_ND = 2, | |||||
ACL_FORMAT_NC1HWC0 = 3, | |||||
ACL_FORMAT_FRACTAL_Z = 4, | |||||
ACL_FORMAT_NC1HWC0_C04 = 12, | |||||
ACL_FORMAT_NDHWC = 27, | |||||
ACL_FORMAT_FRACTAL_NZ = 29, | |||||
ACL_FORMAT_NCDHW = 30, | |||||
ACL_FORMAT_NDC1HWC0 = 32, | |||||
ACL_FRACTAL_Z_3D = 33 | |||||
ACL_FORMAT_UNDEFINED = -1, | |||||
ACL_FORMAT_NCHW = 0, | |||||
ACL_FORMAT_NHWC = 1, | |||||
ACL_FORMAT_ND = 2, | |||||
ACL_FORMAT_NC1HWC0 = 3, | |||||
ACL_FORMAT_FRACTAL_Z = 4, | |||||
ACL_FORMAT_NC1HWC0_C04 = 12, | |||||
ACL_FORMAT_NDHWC = 27, | |||||
ACL_FORMAT_FRACTAL_NZ = 29, | |||||
ACL_FORMAT_NCDHW = 30, | |||||
ACL_FORMAT_NDC1HWC0 = 32, | |||||
ACL_FRACTAL_Z_3D = 33 | |||||
} aclFormat; | } aclFormat; | ||||
typedef enum { | typedef enum { | ||||
ACL_DEBUG = 0, | |||||
ACL_INFO = 1, | |||||
ACL_WARNING = 2, | |||||
ACL_ERROR = 3, | |||||
ACL_DEBUG = 0, | |||||
ACL_INFO = 1, | |||||
ACL_WARNING = 2, | |||||
ACL_ERROR = 3, | |||||
} aclLogLevel; | } aclLogLevel; | ||||
typedef enum { | typedef enum { | ||||
ACL_MEMTYPE_DEVICE = 0, | |||||
ACL_MEMTYPE_HOST = 1, | |||||
ACL_MEMTYPE_DEVICE = 0, | |||||
ACL_MEMTYPE_HOST = 1, | |||||
} aclMemType; | } aclMemType; | ||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
* @brief Converts data of type aclFloat16 to data of type float | * @brief Converts data of type aclFloat16 to data of type float | ||||
@@ -312,9 +311,7 @@ ACL_FUNC_VISIBILITY size_t aclDataTypeSize(aclDataType dataType); | |||||
* @retval aclTensorDesc pointer. | * @retval aclTensorDesc pointer. | ||||
* @retval nullptr if param is invalid or run out of memory | * @retval nullptr if param is invalid or run out of memory | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclTensorDesc *aclCreateTensorDesc(aclDataType dataType, | |||||
int numDims, | |||||
const int64_t *dims, | |||||
ACL_FUNC_VISIBILITY aclTensorDesc *aclCreateTensorDesc(aclDataType dataType, int numDims, const int64_t *dims, | |||||
aclFormat format); | aclFormat format); | ||||
/** | /** | ||||
@@ -336,8 +333,7 @@ ACL_FUNC_VISIBILITY void aclDestroyTensorDesc(const aclTensorDesc *desc); | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclSetTensorShapeRange(aclTensorDesc* desc, | |||||
size_t dimsCount, | |||||
ACL_FUNC_VISIBILITY aclError aclSetTensorShapeRange(aclTensorDesc *desc, size_t dimsCount, | |||||
int64_t dimsRange[][ACL_TENSOR_SHAPE_RANGE_NUM]); | int64_t dimsRange[][ACL_TENSOR_SHAPE_RANGE_NUM]); | ||||
/** | /** | ||||
@@ -434,9 +430,7 @@ ACL_FUNC_VISIBILITY aclError aclGetTensorDescDimV2(const aclTensorDesc *desc, si | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclGetTensorDescDimRange(const aclTensorDesc *desc, | |||||
size_t index, | |||||
size_t dimRangeNum, | |||||
ACL_FUNC_VISIBILITY aclError aclGetTensorDescDimRange(const aclTensorDesc *desc, size_t index, size_t dimRangeNum, | |||||
int64_t *dimRange); | int64_t *dimRange); | ||||
/** | /** | ||||
@@ -473,7 +467,7 @@ ACL_FUNC_VISIBILITY const char *aclGetTensorDescName(aclTensorDesc *desc); | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclTransTensorDescFormat(const aclTensorDesc *srcDesc, aclFormat dstFormat, | ACL_FUNC_VISIBILITY aclError aclTransTensorDescFormat(const aclTensorDesc *srcDesc, aclFormat dstFormat, | ||||
aclTensorDesc **dstDesc); | |||||
aclTensorDesc **dstDesc); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -561,7 +555,7 @@ ACL_FUNC_VISIBILITY aclError aclSetTensorOriginShape(aclTensorDesc *desc, int nu | |||||
* | * | ||||
* @retval null for failed. | * @retval null for failed. | ||||
* @retval OtherValues success. | * @retval OtherValues success. | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclTensorDesc *aclGetTensorDescByIndex(aclTensorDesc *desc, size_t index); | ACL_FUNC_VISIBILITY aclTensorDesc *aclGetTensorDescByIndex(aclTensorDesc *desc, size_t index); | ||||
/** | /** | ||||
@@ -572,7 +566,7 @@ ACL_FUNC_VISIBILITY aclTensorDesc *aclGetTensorDescByIndex(aclTensorDesc *desc, | |||||
* | * | ||||
* @retval null for failed | * @retval null for failed | ||||
* @retval OtherValues success | * @retval OtherValues success | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY void *aclGetTensorDescAddress(const aclTensorDesc *desc); | ACL_FUNC_VISIBILITY void *aclGetTensorDescAddress(const aclTensorDesc *desc); | ||||
/** | /** | ||||
@@ -624,7 +618,7 @@ ACL_FUNC_VISIBILITY aclError aclSetTensorPlaceMent(aclTensorDesc *desc, aclMemTy | |||||
* @param ... [IN] the value of current log | * @param ... [IN] the value of current log | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY void aclAppLog(aclLogLevel logLevel, const char *func, const char *file, uint32_t line, | ACL_FUNC_VISIBILITY void aclAppLog(aclLogLevel logLevel, const char *func, const char *file, uint32_t line, | ||||
const char *fmt, ...); | |||||
const char *fmt, ...); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -632,14 +626,13 @@ ACL_FUNC_VISIBILITY void aclAppLog(aclLogLevel logLevel, const char *func, const | |||||
* | * | ||||
* @retval null for failed | * @retval null for failed | ||||
* @retval OtherValues success | * @retval OtherValues success | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY const char *aclrtGetSocName(); | ACL_FUNC_VISIBILITY const char *aclrtGetSocName(); | ||||
#define ACL_APP_LOG(level, fmt, ...) \ | |||||
aclAppLog(level, __FUNCTION__, __FILE__, __LINE__, fmt, ##__VA_ARGS__) | |||||
#define ACL_APP_LOG(level, fmt, ...) aclAppLog(level, __FUNCTION__, __FILE__, __LINE__, fmt, ##__VA_ARGS__) | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif | #endif | ||||
#endif // INC_EXTERNAL_ACL_ACL_BASE_H_ | |||||
#endif // INC_EXTERNAL_ACL_ACL_BASE_H_ |
@@ -27,19 +27,19 @@ | |||||
extern "C" { | extern "C" { | ||||
#endif | #endif | ||||
#define ACL_MAX_DIM_CNT 128 | |||||
#define ACL_MAX_TENSOR_NAME_LEN 128 | |||||
#define ACL_MAX_BATCH_NUM 128 | |||||
#define ACL_MAX_HW_NUM 128 | |||||
#define ACL_MAX_SHAPE_COUNT 128 | |||||
#define ACL_INVALID_NODE_INDEX 0xFFFFFFFF | |||||
#define ACL_MDL_LOAD_FROM_FILE 1 | |||||
#define ACL_MDL_LOAD_FROM_FILE_WITH_MEM 2 | |||||
#define ACL_MDL_LOAD_FROM_MEM 3 | |||||
#define ACL_MDL_LOAD_FROM_MEM_WITH_MEM 4 | |||||
#define ACL_MDL_LOAD_FROM_FILE_WITH_Q 5 | |||||
#define ACL_MDL_LOAD_FROM_MEM_WITH_Q 6 | |||||
#define ACL_MAX_DIM_CNT 128 | |||||
#define ACL_MAX_TENSOR_NAME_LEN 128 | |||||
#define ACL_MAX_BATCH_NUM 128 | |||||
#define ACL_MAX_HW_NUM 128 | |||||
#define ACL_MAX_SHAPE_COUNT 128 | |||||
#define ACL_INVALID_NODE_INDEX 0xFFFFFFFF | |||||
#define ACL_MDL_LOAD_FROM_FILE 1 | |||||
#define ACL_MDL_LOAD_FROM_FILE_WITH_MEM 2 | |||||
#define ACL_MDL_LOAD_FROM_MEM 3 | |||||
#define ACL_MDL_LOAD_FROM_MEM_WITH_MEM 4 | |||||
#define ACL_MDL_LOAD_FROM_FILE_WITH_Q 5 | |||||
#define ACL_MDL_LOAD_FROM_MEM_WITH_Q 6 | |||||
#define ACL_DYNAMIC_TENSOR_NAME "ascend_mbatch_shape_data" | #define ACL_DYNAMIC_TENSOR_NAME "ascend_mbatch_shape_data" | ||||
#define ACL_DYNAMIC_AIPP_NAME "ascend_dynamic_aipp_data" | #define ACL_DYNAMIC_AIPP_NAME "ascend_dynamic_aipp_data" | ||||
@@ -52,123 +52,123 @@ typedef struct aclAippExtendInfo aclAippExtendInfo; | |||||
typedef struct aclmdlConfigHandle aclmdlConfigHandle; | typedef struct aclmdlConfigHandle aclmdlConfigHandle; | ||||
typedef enum { | typedef enum { | ||||
ACL_YUV420SP_U8 = 1, | |||||
ACL_XRGB8888_U8, | |||||
ACL_RGB888_U8, | |||||
ACL_YUV400_U8, | |||||
ACL_NC1HWC0DI_FP16, | |||||
ACL_NC1HWC0DI_S8, | |||||
ACL_ARGB8888_U8, | |||||
ACL_YUYV_U8, | |||||
ACL_YUV422SP_U8, | |||||
ACL_AYUV444_U8, | |||||
ACL_RAW10, | |||||
ACL_RAW12, | |||||
ACL_RAW16, | |||||
ACL_RAW24, | |||||
ACL_AIPP_RESERVED = 0xffff, | |||||
ACL_YUV420SP_U8 = 1, | |||||
ACL_XRGB8888_U8, | |||||
ACL_RGB888_U8, | |||||
ACL_YUV400_U8, | |||||
ACL_NC1HWC0DI_FP16, | |||||
ACL_NC1HWC0DI_S8, | |||||
ACL_ARGB8888_U8, | |||||
ACL_YUYV_U8, | |||||
ACL_YUV422SP_U8, | |||||
ACL_AYUV444_U8, | |||||
ACL_RAW10, | |||||
ACL_RAW12, | |||||
ACL_RAW16, | |||||
ACL_RAW24, | |||||
ACL_AIPP_RESERVED = 0xffff, | |||||
} aclAippInputFormat; | } aclAippInputFormat; | ||||
typedef enum { | typedef enum { | ||||
ACL_MDL_PRIORITY_INT32 = 0, | |||||
ACL_MDL_LOAD_TYPE_SIZET, | |||||
ACL_MDL_PATH_PTR, /**< pointer to model load path with deep copy */ | |||||
ACL_MDL_MEM_ADDR_PTR, /**< pointer to model memory with shallow copy */ | |||||
ACL_MDL_MEM_SIZET, | |||||
ACL_MDL_WEIGHT_ADDR_PTR, /**< pointer to weight memory of model with shallow copy */ | |||||
ACL_MDL_WEIGHT_SIZET, | |||||
ACL_MDL_WORKSPACE_ADDR_PTR, /**< pointer to worksapce memory of model with shallow copy */ | |||||
ACL_MDL_WORKSPACE_SIZET, | |||||
ACL_MDL_INPUTQ_NUM_SIZET, | |||||
ACL_MDL_INPUTQ_ADDR_PTR, /**< pointer to inputQ with shallow copy */ | |||||
ACL_MDL_OUTPUTQ_NUM_SIZET, | |||||
ACL_MDL_OUTPUTQ_ADDR_PTR /**< pointer to outputQ with shallow copy */ | |||||
ACL_MDL_PRIORITY_INT32 = 0, | |||||
ACL_MDL_LOAD_TYPE_SIZET, | |||||
ACL_MDL_PATH_PTR, /**< pointer to model load path with deep copy */ | |||||
ACL_MDL_MEM_ADDR_PTR, /**< pointer to model memory with shallow copy */ | |||||
ACL_MDL_MEM_SIZET, | |||||
ACL_MDL_WEIGHT_ADDR_PTR, /**< pointer to weight memory of model with shallow copy */ | |||||
ACL_MDL_WEIGHT_SIZET, | |||||
ACL_MDL_WORKSPACE_ADDR_PTR, /**< pointer to worksapce memory of model with shallow copy */ | |||||
ACL_MDL_WORKSPACE_SIZET, | |||||
ACL_MDL_INPUTQ_NUM_SIZET, | |||||
ACL_MDL_INPUTQ_ADDR_PTR, /**< pointer to inputQ with shallow copy */ | |||||
ACL_MDL_OUTPUTQ_NUM_SIZET, | |||||
ACL_MDL_OUTPUTQ_ADDR_PTR /**< pointer to outputQ with shallow copy */ | |||||
} aclmdlConfigAttr; | } aclmdlConfigAttr; | ||||
typedef enum { | typedef enum { | ||||
ACL_DATA_WITHOUT_AIPP = 0, | |||||
ACL_DATA_WITH_STATIC_AIPP, | |||||
ACL_DATA_WITH_DYNAMIC_AIPP, | |||||
ACL_DYNAMIC_AIPP_NODE | |||||
ACL_DATA_WITHOUT_AIPP = 0, | |||||
ACL_DATA_WITH_STATIC_AIPP, | |||||
ACL_DATA_WITH_DYNAMIC_AIPP, | |||||
ACL_DYNAMIC_AIPP_NODE | |||||
} aclmdlInputAippType; | } aclmdlInputAippType; | ||||
typedef struct aclmdlIODims { | typedef struct aclmdlIODims { | ||||
char name[ACL_MAX_TENSOR_NAME_LEN]; /**< tensor name */ | |||||
size_t dimCount; /**< dim array count */ | |||||
int64_t dims[ACL_MAX_DIM_CNT]; /**< dim data array */ | |||||
char name[ACL_MAX_TENSOR_NAME_LEN]; /**< tensor name */ | |||||
size_t dimCount; /**< dim array count */ | |||||
int64_t dims[ACL_MAX_DIM_CNT]; /**< dim data array */ | |||||
} aclmdlIODims; | } aclmdlIODims; | ||||
typedef struct aclAippDims { | typedef struct aclAippDims { | ||||
aclmdlIODims srcDims; /**< input dims before model transform */ | |||||
size_t srcSize; /**< input size before model transform */ | |||||
aclmdlIODims aippOutdims; /**< aipp output dims */ | |||||
size_t aippOutSize; /**< aipp output size */ | |||||
aclmdlIODims srcDims; /**< input dims before model transform */ | |||||
size_t srcSize; /**< input size before model transform */ | |||||
aclmdlIODims aippOutdims; /**< aipp output dims */ | |||||
size_t aippOutSize; /**< aipp output size */ | |||||
} aclAippDims; | } aclAippDims; | ||||
typedef struct aclmdlBatch { | typedef struct aclmdlBatch { | ||||
size_t batchCount; /**< batch array count */ | |||||
uint64_t batch[ACL_MAX_BATCH_NUM]; /**< batch data array */ | |||||
size_t batchCount; /**< batch array count */ | |||||
uint64_t batch[ACL_MAX_BATCH_NUM]; /**< batch data array */ | |||||
} aclmdlBatch; | } aclmdlBatch; | ||||
typedef struct aclmdlHW { | typedef struct aclmdlHW { | ||||
size_t hwCount; /**< height&width array count */ | |||||
uint64_t hw[ACL_MAX_HW_NUM][2]; /**< height&width data array */ | |||||
size_t hwCount; /**< height&width array count */ | |||||
uint64_t hw[ACL_MAX_HW_NUM][2]; /**< height&width data array */ | |||||
} aclmdlHW; | } aclmdlHW; | ||||
typedef struct aclAippInfo { | typedef struct aclAippInfo { | ||||
aclAippInputFormat inputFormat; | |||||
int32_t srcImageSizeW; | |||||
int32_t srcImageSizeH; | |||||
int8_t cropSwitch; | |||||
int32_t loadStartPosW; | |||||
int32_t loadStartPosH; | |||||
int32_t cropSizeW; | |||||
int32_t cropSizeH; | |||||
int8_t resizeSwitch; | |||||
int32_t resizeOutputW; | |||||
int32_t resizeOutputH; | |||||
int8_t paddingSwitch; | |||||
int32_t leftPaddingSize; | |||||
int32_t rightPaddingSize; | |||||
int32_t topPaddingSize; | |||||
int32_t bottomPaddingSize; | |||||
int8_t cscSwitch; | |||||
int8_t rbuvSwapSwitch; | |||||
int8_t axSwapSwitch; | |||||
int8_t singleLineMode; | |||||
int32_t matrixR0C0; | |||||
int32_t matrixR0C1; | |||||
int32_t matrixR0C2; | |||||
int32_t matrixR1C0; | |||||
int32_t matrixR1C1; | |||||
int32_t matrixR1C2; | |||||
int32_t matrixR2C0; | |||||
int32_t matrixR2C1; | |||||
int32_t matrixR2C2; | |||||
int32_t outputBias0; | |||||
int32_t outputBias1; | |||||
int32_t outputBias2; | |||||
int32_t inputBias0; | |||||
int32_t inputBias1; | |||||
int32_t inputBias2; | |||||
int32_t meanChn0; | |||||
int32_t meanChn1; | |||||
int32_t meanChn2; | |||||
int32_t meanChn3; | |||||
float minChn0; | |||||
float minChn1; | |||||
float minChn2; | |||||
float minChn3; | |||||
float varReciChn0; | |||||
float varReciChn1; | |||||
float varReciChn2; | |||||
float varReciChn3; | |||||
aclFormat srcFormat; | |||||
aclDataType srcDatatype; | |||||
size_t srcDimNum; | |||||
size_t shapeCount; | |||||
aclAippDims outDims[ACL_MAX_SHAPE_COUNT]; | |||||
aclAippExtendInfo *aippExtend; /**< reserved parameters, current version needs to be null */ | |||||
aclAippInputFormat inputFormat; | |||||
int32_t srcImageSizeW; | |||||
int32_t srcImageSizeH; | |||||
int8_t cropSwitch; | |||||
int32_t loadStartPosW; | |||||
int32_t loadStartPosH; | |||||
int32_t cropSizeW; | |||||
int32_t cropSizeH; | |||||
int8_t resizeSwitch; | |||||
int32_t resizeOutputW; | |||||
int32_t resizeOutputH; | |||||
int8_t paddingSwitch; | |||||
int32_t leftPaddingSize; | |||||
int32_t rightPaddingSize; | |||||
int32_t topPaddingSize; | |||||
int32_t bottomPaddingSize; | |||||
int8_t cscSwitch; | |||||
int8_t rbuvSwapSwitch; | |||||
int8_t axSwapSwitch; | |||||
int8_t singleLineMode; | |||||
int32_t matrixR0C0; | |||||
int32_t matrixR0C1; | |||||
int32_t matrixR0C2; | |||||
int32_t matrixR1C0; | |||||
int32_t matrixR1C1; | |||||
int32_t matrixR1C2; | |||||
int32_t matrixR2C0; | |||||
int32_t matrixR2C1; | |||||
int32_t matrixR2C2; | |||||
int32_t outputBias0; | |||||
int32_t outputBias1; | |||||
int32_t outputBias2; | |||||
int32_t inputBias0; | |||||
int32_t inputBias1; | |||||
int32_t inputBias2; | |||||
int32_t meanChn0; | |||||
int32_t meanChn1; | |||||
int32_t meanChn2; | |||||
int32_t meanChn3; | |||||
float minChn0; | |||||
float minChn1; | |||||
float minChn2; | |||||
float minChn3; | |||||
float varReciChn0; | |||||
float varReciChn1; | |||||
float varReciChn2; | |||||
float varReciChn3; | |||||
aclFormat srcFormat; | |||||
aclDataType srcDatatype; | |||||
size_t srcDimNum; | |||||
size_t shapeCount; | |||||
aclAippDims outDims[ACL_MAX_SHAPE_COUNT]; | |||||
aclAippExtendInfo *aippExtend; /**< reserved parameters, current version needs to be null */ | |||||
} aclAippInfo; | } aclAippInfo; | ||||
/** | /** | ||||
@@ -292,8 +292,7 @@ ACL_FUNC_VISIBILITY aclError aclmdlAddDatasetBuffer(aclmdlDataset *dataset, aclD | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclmdlSetDatasetTensorDesc(aclmdlDataset *dataset, | |||||
aclTensorDesc *tensorDesc, | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetDatasetTensorDesc(aclmdlDataset *dataset, aclTensorDesc *tensorDesc, | |||||
size_t index); | size_t index); | ||||
/** | /** | ||||
@@ -355,8 +354,7 @@ ACL_FUNC_VISIBILITY aclError aclmdlLoadFromFile(const char *modelPath, uint32_t | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclmdlLoadFromMem(const void *model, size_t modelSize, | |||||
uint32_t *modelId); | |||||
ACL_FUNC_VISIBILITY aclError aclmdlLoadFromMem(const void *model, size_t modelSize, uint32_t *modelId); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -378,9 +376,8 @@ ACL_FUNC_VISIBILITY aclError aclmdlLoadFromMem(const void *model, size_t modelS | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclmdlLoadFromFileWithMem(const char *modelPath, | |||||
uint32_t *modelId, void *workPtr, size_t workSize, | |||||
void *weightPtr, size_t weightSize); | |||||
ACL_FUNC_VISIBILITY aclError aclmdlLoadFromFileWithMem(const char *modelPath, uint32_t *modelId, void *workPtr, | |||||
size_t workSize, void *weightPtr, size_t weightSize); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -403,9 +400,9 @@ ACL_FUNC_VISIBILITY aclError aclmdlLoadFromFileWithMem(const char *modelPath, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclmdlLoadFromMemWithMem(const void *model, size_t modelSize, | |||||
uint32_t *modelId, void *workPtr, size_t workSize, | |||||
void *weightPtr, size_t weightSize); | |||||
ACL_FUNC_VISIBILITY aclError aclmdlLoadFromMemWithMem(const void *model, size_t modelSize, uint32_t *modelId, | |||||
void *workPtr, size_t workSize, void *weightPtr, | |||||
size_t weightSize); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -440,8 +437,8 @@ ACL_FUNC_VISIBILITY aclError aclmdlLoadFromFileWithQ(const char *modelPath, uint | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclmdlLoadFromMemWithQ(const void *model, size_t modelSize, uint32_t *modelId, | ACL_FUNC_VISIBILITY aclError aclmdlLoadFromMemWithQ(const void *model, size_t modelSize, uint32_t *modelId, | ||||
const uint32_t *inputQ, size_t inputQNum, | |||||
const uint32_t *outputQ, size_t outputQNum); | |||||
const uint32_t *inputQ, size_t inputQNum, const uint32_t *outputQ, | |||||
size_t outputQNum); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -471,8 +468,8 @@ ACL_FUNC_VISIBILITY aclError aclmdlExecute(uint32_t modelId, const aclmdlDataset | |||||
* @see aclmdlLoadFromFile | aclmdlLoadFromMem | aclmdlLoadFromFileWithMem | | * @see aclmdlLoadFromFile | aclmdlLoadFromMem | aclmdlLoadFromFileWithMem | | ||||
* aclmdlLoadFromMemWithMem | * aclmdlLoadFromMemWithMem | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclmdlExecuteAsync(uint32_t modelId, const aclmdlDataset *input, | |||||
aclmdlDataset *output, aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError aclmdlExecuteAsync(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output, | |||||
aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -647,7 +644,7 @@ ACL_FUNC_VISIBILITY aclError aclmdlGetCurOutputDims(const aclmdlDesc *modelDesc, | |||||
* @param modelDesc [IN] model description | * @param modelDesc [IN] model description | ||||
* @param opName [IN] op name | * @param opName [IN] op name | ||||
* @param attr [IN] attr name | * @param attr [IN] attr name | ||||
* | |||||
* | |||||
* @retval the attr value | * @retval the attr value | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY const char *aclmdlGetOpAttr(aclmdlDesc *modelDesc, const char *opName, const char *attr); | ACL_FUNC_VISIBILITY const char *aclmdlGetOpAttr(aclmdlDesc *modelDesc, const char *opName, const char *attr); | ||||
@@ -859,11 +856,11 @@ ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPInputFormat(aclmdlAIPP *aippParmsSet, | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
* | * | ||||
* @see aclmdlCreateAIPP | * @see aclmdlCreateAIPP | ||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPCscParams(aclmdlAIPP *aippParmsSet, int8_t csc_switch, | |||||
int16_t cscMatrixR0C0, int16_t cscMatrixR0C1, int16_t cscMatrixR0C2, | |||||
int16_t cscMatrixR1C0, int16_t cscMatrixR1C1, int16_t cscMatrixR1C2, | |||||
int16_t cscMatrixR2C0, int16_t cscMatrixR2C1, int16_t cscMatrixR2C2, | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPCscParams(aclmdlAIPP *aippParmsSet, int8_t csc_switch, int16_t cscMatrixR0C0, | |||||
int16_t cscMatrixR0C1, int16_t cscMatrixR0C2, int16_t cscMatrixR1C0, | |||||
int16_t cscMatrixR1C1, int16_t cscMatrixR1C2, int16_t cscMatrixR2C0, | |||||
int16_t cscMatrixR2C1, int16_t cscMatrixR2C2, | |||||
uint8_t cscOutputBiasR0, uint8_t cscOutputBiasR1, | uint8_t cscOutputBiasR0, uint8_t cscOutputBiasR1, | ||||
uint8_t cscOutputBiasR2, uint8_t cscInputBiasR0, | uint8_t cscOutputBiasR2, uint8_t cscInputBiasR0, | ||||
uint8_t cscInputBiasR1, uint8_t cscInputBiasR2); | uint8_t cscInputBiasR1, uint8_t cscInputBiasR2); | ||||
@@ -879,7 +876,7 @@ ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPCscParams(aclmdlAIPP *aippParmsSet, in | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
* | * | ||||
* @see aclmdlCreateAIPP | * @see aclmdlCreateAIPP | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPRbuvSwapSwitch(aclmdlAIPP *aippParmsSet, int8_t rbuvSwapSwitch); | ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPRbuvSwapSwitch(aclmdlAIPP *aippParmsSet, int8_t rbuvSwapSwitch); | ||||
/** | /** | ||||
@@ -893,7 +890,7 @@ ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPRbuvSwapSwitch(aclmdlAIPP *aippParmsSe | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
* | * | ||||
* @see aclmdlCreateAIPP | * @see aclmdlCreateAIPP | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPAxSwapSwitch(aclmdlAIPP *aippParmsSet, int8_t axSwapSwitch); | ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPAxSwapSwitch(aclmdlAIPP *aippParmsSet, int8_t axSwapSwitch); | ||||
/** | /** | ||||
@@ -908,7 +905,7 @@ ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPAxSwapSwitch(aclmdlAIPP *aippParmsSet, | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
* | * | ||||
* @see aclmdlCreateAIPP | * @see aclmdlCreateAIPP | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPSrcImageSize(aclmdlAIPP *aippParmsSet, int32_t srcImageSizeW, | ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPSrcImageSize(aclmdlAIPP *aippParmsSet, int32_t srcImageSizeW, | ||||
int32_t srcImageSizeH); | int32_t srcImageSizeH); | ||||
@@ -928,14 +925,10 @@ ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPSrcImageSize(aclmdlAIPP *aippParmsSet, | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
* | * | ||||
* @see aclmdlCreateAIPP | * @see aclmdlCreateAIPP | ||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPScfParams(aclmdlAIPP *aippParmsSet, | |||||
int8_t scfSwitch, | |||||
int32_t scfInputSizeW, | |||||
int32_t scfInputSizeH, | |||||
int32_t scfOutputSizeW, | |||||
int32_t scfOutputSizeH, | |||||
uint64_t batchIndex); | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPScfParams(aclmdlAIPP *aippParmsSet, int8_t scfSwitch, int32_t scfInputSizeW, | |||||
int32_t scfInputSizeH, int32_t scfOutputSizeW, | |||||
int32_t scfOutputSizeH, uint64_t batchIndex); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -953,13 +946,9 @@ ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPScfParams(aclmdlAIPP *aippParmsSet, | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
* | * | ||||
* @see aclmdlCreateAIPP | * @see aclmdlCreateAIPP | ||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPCropParams(aclmdlAIPP *aippParmsSet, | |||||
int8_t cropSwitch, | |||||
int32_t cropStartPosW, | |||||
int32_t cropStartPosH, | |||||
int32_t cropSizeW, | |||||
int32_t cropSizeH, | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPCropParams(aclmdlAIPP *aippParmsSet, int8_t cropSwitch, int32_t cropStartPosW, | |||||
int32_t cropStartPosH, int32_t cropSizeW, int32_t cropSizeH, | |||||
uint64_t batchIndex); | uint64_t batchIndex); | ||||
/** | /** | ||||
@@ -978,7 +967,7 @@ ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPCropParams(aclmdlAIPP *aippParmsSet, | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
* | * | ||||
* @see aclmdlCreateAIPP | * @see aclmdlCreateAIPP | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPPaddingParams(aclmdlAIPP *aippParmsSet, int8_t paddingSwitch, | ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPPaddingParams(aclmdlAIPP *aippParmsSet, int8_t paddingSwitch, | ||||
int32_t paddingSizeTop, int32_t paddingSizeBottom, | int32_t paddingSizeTop, int32_t paddingSizeBottom, | ||||
int32_t paddingSizeLeft, int32_t paddingSizeRight, | int32_t paddingSizeLeft, int32_t paddingSizeRight, | ||||
@@ -999,13 +988,10 @@ ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPPaddingParams(aclmdlAIPP *aippParmsSet | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
* | * | ||||
* @see aclmdlCreateAIPP | * @see aclmdlCreateAIPP | ||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPDtcPixelMean(aclmdlAIPP *aippParmsSet, | |||||
int16_t dtcPixelMeanChn0, | |||||
int16_t dtcPixelMeanChn1, | |||||
int16_t dtcPixelMeanChn2, | |||||
int16_t dtcPixelMeanChn3, | |||||
uint64_t batchIndex); | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPDtcPixelMean(aclmdlAIPP *aippParmsSet, int16_t dtcPixelMeanChn0, | |||||
int16_t dtcPixelMeanChn1, int16_t dtcPixelMeanChn2, | |||||
int16_t dtcPixelMeanChn3, uint64_t batchIndex); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1022,13 +1008,10 @@ ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPDtcPixelMean(aclmdlAIPP *aippParmsSet, | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
* | * | ||||
* @see aclmdlCreateAIPP | * @see aclmdlCreateAIPP | ||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPDtcPixelMin(aclmdlAIPP *aippParmsSet, | |||||
float dtcPixelMinChn0, | |||||
float dtcPixelMinChn1, | |||||
float dtcPixelMinChn2, | |||||
float dtcPixelMinChn3, | |||||
uint64_t batchIndex); | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPDtcPixelMin(aclmdlAIPP *aippParmsSet, float dtcPixelMinChn0, | |||||
float dtcPixelMinChn1, float dtcPixelMinChn2, | |||||
float dtcPixelMinChn3, uint64_t batchIndex); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1045,13 +1028,10 @@ ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPDtcPixelMin(aclmdlAIPP *aippParmsSet, | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
* | * | ||||
* @see aclmdlCreateAIPP | * @see aclmdlCreateAIPP | ||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPPixelVarReci(aclmdlAIPP *aippParmsSet, | |||||
float dtcPixelVarReciChn0, | |||||
float dtcPixelVarReciChn1, | |||||
float dtcPixelVarReciChn2, | |||||
float dtcPixelVarReciChn3, | |||||
uint64_t batchIndex); | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPPixelVarReci(aclmdlAIPP *aippParmsSet, float dtcPixelVarReciChn0, | |||||
float dtcPixelVarReciChn1, float dtcPixelVarReciChn2, | |||||
float dtcPixelVarReciChn3, uint64_t batchIndex); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1067,10 +1047,8 @@ ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPPixelVarReci(aclmdlAIPP *aippParmsSet, | |||||
* | * | ||||
* @see aclmdlLoadFromFile | aclmdlLoadFromMem | aclmdlLoadFromFileWithMem | | * @see aclmdlLoadFromFile | aclmdlLoadFromMem | aclmdlLoadFromFileWithMem | | ||||
* aclmdlLoadFromMemWithMem | aclmdlGetInputIndexByName | aclmdlCreateAIPP | * aclmdlLoadFromMemWithMem | aclmdlGetInputIndexByName | aclmdlCreateAIPP | ||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetInputAIPP(uint32_t modelId, | |||||
aclmdlDataset *dataset, | |||||
size_t index, | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetInputAIPP(uint32_t modelId, aclmdlDataset *dataset, size_t index, | |||||
const aclmdlAIPP *aippParmsSet); | const aclmdlAIPP *aippParmsSet); | ||||
/** | /** | ||||
@@ -1087,10 +1065,8 @@ ACL_FUNC_VISIBILITY aclError aclmdlSetInputAIPP(uint32_t modelId, | |||||
* | * | ||||
* @see aclmdlLoadFromFile | aclmdlLoadFromMem | aclmdlLoadFromFileWithMem | | * @see aclmdlLoadFromFile | aclmdlLoadFromMem | aclmdlLoadFromFileWithMem | | ||||
* aclmdlLoadFromMemWithMem | aclmdlGetInputIndexByName | aclmdlCreateAIPP | * aclmdlLoadFromMemWithMem | aclmdlGetInputIndexByName | aclmdlCreateAIPP | ||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPByInputIndex(uint32_t modelId, | |||||
aclmdlDataset *dataset, | |||||
size_t index, | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPByInputIndex(uint32_t modelId, aclmdlDataset *dataset, size_t index, | |||||
const aclmdlAIPP *aippParmsSet); | const aclmdlAIPP *aippParmsSet); | ||||
/** | /** | ||||
@@ -1108,10 +1084,8 @@ ACL_FUNC_VISIBILITY aclError aclmdlSetAIPPByInputIndex(uint32_t modelId, | |||||
* | * | ||||
* @see aclmdlLoadFromFile | aclmdlLoadFromMem | aclmdlLoadFromFileWithMem | | * @see aclmdlLoadFromFile | aclmdlLoadFromMem | aclmdlLoadFromFileWithMem | | ||||
* aclmdlLoadFromMemWithMem | aclmdlGetInputIndexByName | aclmdlCreateAIPP | * aclmdlLoadFromMemWithMem | aclmdlGetInputIndexByName | aclmdlCreateAIPP | ||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlGetAippType(uint32_t modelId, | |||||
size_t index, | |||||
aclmdlInputAippType *type, | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlGetAippType(uint32_t modelId, size_t index, aclmdlInputAippType *type, | |||||
size_t *dynamicAttachedDataIndex); | size_t *dynamicAttachedDataIndex); | ||||
/** | /** | ||||
@@ -1128,7 +1102,7 @@ ACL_FUNC_VISIBILITY aclError aclmdlGetAippType(uint32_t modelId, | |||||
* | * | ||||
* @see aclmdlLoadFromFile | aclmdlLoadFromMem | aclmdlLoadFromFileWithMem | | * @see aclmdlLoadFromFile | aclmdlLoadFromMem | aclmdlLoadFromFileWithMem | | ||||
* aclmdlLoadFromMemWithMem | aclmdlGetInputIndexByName | * aclmdlLoadFromMemWithMem | aclmdlGetInputIndexByName | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlGetFirstAippInfo(uint32_t modelId, size_t index, aclAippInfo *aippinfo); | ACL_FUNC_VISIBILITY aclError aclmdlGetFirstAippInfo(uint32_t modelId, size_t index, aclAippInfo *aippinfo); | ||||
/** | /** | ||||
@@ -1147,10 +1121,11 @@ ACL_FUNC_VISIBILITY aclError aclmdlGetFirstAippInfo(uint32_t modelId, size_t ind | |||||
* | * | ||||
* @retval ACL_SUCCESS The function is successfully executed | * @retval ACL_SUCCESS The function is successfully executed | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlCreateAndGetOpDesc(uint32_t deviceId, uint32_t streamId, | |||||
uint32_t taskId, char *opName, size_t opNameLen, aclTensorDesc **inputDesc, size_t *numInputs, | |||||
aclTensorDesc **outputDesc, size_t *numOutputs); | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlCreateAndGetOpDesc(uint32_t deviceId, uint32_t streamId, uint32_t taskId, | |||||
char *opName, size_t opNameLen, aclTensorDesc **inputDesc, | |||||
size_t *numInputs, aclTensorDesc **outputDesc, | |||||
size_t *numOutputs); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1158,7 +1133,7 @@ ACL_FUNC_VISIBILITY aclError aclmdlCreateAndGetOpDesc(uint32_t deviceId, uint32_ | |||||
* | * | ||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlInitDump(); | ACL_FUNC_VISIBILITY aclError aclmdlInitDump(); | ||||
/** | /** | ||||
@@ -1169,7 +1144,7 @@ ACL_FUNC_VISIBILITY aclError aclmdlInitDump(); | |||||
* | * | ||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlSetDump(const char *dumpCfgPath); | ACL_FUNC_VISIBILITY aclError aclmdlSetDump(const char *dumpCfgPath); | ||||
/** | /** | ||||
@@ -1178,7 +1153,7 @@ ACL_FUNC_VISIBILITY aclError aclmdlSetDump(const char *dumpCfgPath); | |||||
* | * | ||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlFinalizeDump(); | ACL_FUNC_VISIBILITY aclError aclmdlFinalizeDump(); | ||||
/** | /** | ||||
@@ -1190,7 +1165,7 @@ ACL_FUNC_VISIBILITY aclError aclmdlFinalizeDump(); | |||||
* | * | ||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclmdlLoadWithConfig(const aclmdlConfigHandle *handle, uint32_t *modelId); | ACL_FUNC_VISIBILITY aclError aclmdlLoadWithConfig(const aclmdlConfigHandle *handle, uint32_t *modelId); | ||||
/** | /** | ||||
@@ -1200,7 +1175,7 @@ ACL_FUNC_VISIBILITY aclError aclmdlLoadWithConfig(const aclmdlConfigHandle *hand | |||||
* @retval the aclmdlConfigHandle pointer | * @retval the aclmdlConfigHandle pointer | ||||
* | * | ||||
* @see aclmdlDestroyConfigHandle | * @see aclmdlDestroyConfigHandle | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclmdlConfigHandle *aclmdlCreateConfigHandle(); | ACL_FUNC_VISIBILITY aclmdlConfigHandle *aclmdlCreateConfigHandle(); | ||||
/** | /** | ||||
@@ -1229,7 +1204,7 @@ ACL_FUNC_VISIBILITY aclError aclmdlDestroyConfigHandle(aclmdlConfigHandle *handl | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclmdlSetConfigOpt(aclmdlConfigHandle *handle, aclmdlConfigAttr attr, | ACL_FUNC_VISIBILITY aclError aclmdlSetConfigOpt(aclmdlConfigHandle *handle, aclmdlConfigAttr attr, | ||||
const void *attrValue, size_t valueSize); | |||||
const void *attrValue, size_t valueSize); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1247,4 +1222,4 @@ ACL_FUNC_VISIBILITY const char *aclmdlGetTensorRealName(const aclmdlDesc *modelD | |||||
} | } | ||||
#endif | #endif | ||||
#endif // INC_EXTERNAL_ACL_ACL_MODEL_H_ | |||||
#endif // INC_EXTERNAL_ACL_ACL_MODEL_H_ |
@@ -33,9 +33,9 @@ typedef void (*aclDataDeallocator)(void *data, size_t length); | |||||
static const int ACL_COMPILE_FLAG_BIN_SELECTOR = 1; | static const int ACL_COMPILE_FLAG_BIN_SELECTOR = 1; | ||||
typedef enum aclEngineType { | typedef enum aclEngineType { | ||||
ACL_ENGINE_SYS, | |||||
ACL_ENGINE_AICORE, | |||||
ACL_ENGINE_VECTOR, | |||||
ACL_ENGINE_SYS, | |||||
ACL_ENGINE_AICORE, | |||||
ACL_ENGINE_VECTOR, | |||||
} aclopEngineType; | } aclopEngineType; | ||||
/** | /** | ||||
@@ -148,7 +148,7 @@ ACL_FUNC_VISIBILITY aclError aclopSetAttrString(aclopAttr *attr, const char *att | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopSetAttrListBool(aclopAttr *attr, const char *attrName, int numValues, | ACL_FUNC_VISIBILITY aclError aclopSetAttrListBool(aclopAttr *attr, const char *attrName, int numValues, | ||||
const uint8_t *values); | |||||
const uint8_t *values); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -163,7 +163,7 @@ ACL_FUNC_VISIBILITY aclError aclopSetAttrListBool(aclopAttr *attr, const char *a | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopSetAttrListInt(aclopAttr *attr, const char *attrName, int numValues, | ACL_FUNC_VISIBILITY aclError aclopSetAttrListInt(aclopAttr *attr, const char *attrName, int numValues, | ||||
const int64_t *values); | |||||
const int64_t *values); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -178,7 +178,7 @@ ACL_FUNC_VISIBILITY aclError aclopSetAttrListInt(aclopAttr *attr, const char *at | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopSetAttrListFloat(aclopAttr *attr, const char *attrName, int numValues, | ACL_FUNC_VISIBILITY aclError aclopSetAttrListFloat(aclopAttr *attr, const char *attrName, int numValues, | ||||
const float *values); | |||||
const float *values); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -193,7 +193,7 @@ ACL_FUNC_VISIBILITY aclError aclopSetAttrListFloat(aclopAttr *attr, const char * | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopSetAttrListString(aclopAttr *attr, const char *attrName, int numValues, | ACL_FUNC_VISIBILITY aclError aclopSetAttrListString(aclopAttr *attr, const char *attrName, int numValues, | ||||
const char **values); | |||||
const char **values); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -208,11 +208,8 @@ ACL_FUNC_VISIBILITY aclError aclopSetAttrListString(aclopAttr *attr, const char | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopSetAttrListListInt(aclopAttr *attr, | |||||
const char *attrName, | |||||
int numLists, | |||||
const int *numValues, | |||||
const int64_t *const values[]); | |||||
ACL_FUNC_VISIBILITY aclError aclopSetAttrListListInt(aclopAttr *attr, const char *attrName, int numLists, | |||||
const int *numValues, const int64_t *const values[]); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -242,15 +239,10 @@ ACL_FUNC_VISIBILITY aclError aclopSetAttrListListInt(aclopAttr *attr, | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_DEPRECATED_MESSAGE("aclopExecute is deprecated, use aclopExecuteV2 instead") | ACL_DEPRECATED_MESSAGE("aclopExecute is deprecated, use aclopExecuteV2 instead") | ||||
ACL_FUNC_VISIBILITY aclError aclopExecute(const char *opType, | |||||
int numInputs, | |||||
const aclTensorDesc *const inputDesc[], | |||||
const aclDataBuffer *const inputs[], | |||||
int numOutputs, | |||||
const aclTensorDesc *const outputDesc[], | |||||
aclDataBuffer *const outputs[], | |||||
const aclopAttr *attr, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError aclopExecute(const char *opType, int numInputs, const aclTensorDesc *const inputDesc[], | |||||
const aclDataBuffer *const inputs[], int numOutputs, | |||||
const aclTensorDesc *const outputDesc[], aclDataBuffer *const outputs[], | |||||
const aclopAttr *attr, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -280,15 +272,9 @@ ACL_FUNC_VISIBILITY aclError aclopExecute(const char *opType, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopExecuteV2(const char *opType, | |||||
int numInputs, | |||||
aclTensorDesc *inputDesc[], | |||||
aclDataBuffer *inputs[], | |||||
int numOutputs, | |||||
aclTensorDesc *outputDesc[], | |||||
aclDataBuffer *outputs[], | |||||
aclopAttr *attr, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError aclopExecuteV2(const char *opType, int numInputs, aclTensorDesc *inputDesc[], | |||||
aclDataBuffer *inputs[], int numOutputs, aclTensorDesc *outputDesc[], | |||||
aclDataBuffer *outputs[], aclopAttr *attr, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -306,12 +292,9 @@ ACL_FUNC_VISIBILITY aclError aclopExecuteV2(const char *opType, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopCreateHandle(const char *opType, | |||||
int numInputs, | |||||
const aclTensorDesc *const inputDesc[], | |||||
int numOutputs, | |||||
const aclTensorDesc *const outputDesc[], | |||||
const aclopAttr *opAttr, | |||||
ACL_FUNC_VISIBILITY aclError aclopCreateHandle(const char *opType, int numInputs, | |||||
const aclTensorDesc *const inputDesc[], int numOutputs, | |||||
const aclTensorDesc *const outputDesc[], const aclopAttr *opAttr, | |||||
aclopHandle **handle); | aclopHandle **handle); | ||||
/** | /** | ||||
@@ -343,12 +326,9 @@ ACL_FUNC_VISIBILITY void aclopDestroyHandle(aclopHandle *handle); | |||||
* | * | ||||
* @see aclopCreateHandle | aclCreateDataBuffer | * @see aclopCreateHandle | aclCreateDataBuffer | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopExecWithHandle(aclopHandle *handle, | |||||
int numInputs, | |||||
const aclDataBuffer *const inputs[], | |||||
int numOutputs, | |||||
aclDataBuffer *const outputs[], | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError aclopExecWithHandle(aclopHandle *handle, int numInputs, | |||||
const aclDataBuffer *const inputs[], int numOutputs, | |||||
aclDataBuffer *const outputs[], aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -364,11 +344,8 @@ ACL_FUNC_VISIBILITY aclError aclopExecWithHandle(aclopHandle *handle, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopCast(const aclTensorDesc *srcDesc, | |||||
const aclDataBuffer *srcBuffer, | |||||
const aclTensorDesc *dstDesc, | |||||
aclDataBuffer *dstBuffer, | |||||
uint8_t truncate, | |||||
ACL_FUNC_VISIBILITY aclError aclopCast(const aclTensorDesc *srcDesc, const aclDataBuffer *srcBuffer, | |||||
const aclTensorDesc *dstDesc, aclDataBuffer *dstBuffer, uint8_t truncate, | |||||
aclrtStream stream); | aclrtStream stream); | ||||
/** | /** | ||||
@@ -383,12 +360,9 @@ ACL_FUNC_VISIBILITY aclError aclopCast(const aclTensorDesc *srcDesc, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopCreateHandleForCast(aclTensorDesc *srcDesc, | |||||
aclTensorDesc *dstDesc, | |||||
uint8_t truncate, | |||||
ACL_FUNC_VISIBILITY aclError aclopCreateHandleForCast(aclTensorDesc *srcDesc, aclTensorDesc *dstDesc, uint8_t truncate, | |||||
aclopHandle **handle); | aclopHandle **handle); | ||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
* @brief create kernel | * @brief create kernel | ||||
@@ -407,15 +381,10 @@ ACL_FUNC_VISIBILITY aclError aclopCreateHandleForCast(aclTensorDesc *srcDesc, | |||||
* | * | ||||
* @see aclopCompile | * @see aclopCompile | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopCreateKernel(const char *opType, | |||||
const char *kernelId, | |||||
const char *kernelName, | |||||
void *binData, | |||||
int binSize, | |||||
aclopEngineType enginetype, | |||||
ACL_FUNC_VISIBILITY aclError aclopCreateKernel(const char *opType, const char *kernelId, const char *kernelName, | |||||
void *binData, int binSize, aclopEngineType enginetype, | |||||
aclDataDeallocator deallocator); | aclDataDeallocator deallocator); | ||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
* @brief create kernel | * @brief create kernel | ||||
@@ -430,11 +399,8 @@ ACL_FUNC_VISIBILITY aclError aclopCreateKernel(const char *opType, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
typedef aclError (*aclopCompileFunc)(int numInputs, | |||||
const aclTensorDesc *const inputDesc[], | |||||
int numOutputs, | |||||
const aclTensorDesc *const outputDesc[], | |||||
const aclopAttr *opAttr, | |||||
typedef aclError (*aclopCompileFunc)(int numInputs, const aclTensorDesc *const inputDesc[], int numOutputs, | |||||
const aclTensorDesc *const outputDesc[], const aclopAttr *opAttr, | |||||
aclopKernelDesc *aclopKernelDesc); | aclopKernelDesc *aclopKernelDesc); | ||||
/** | /** | ||||
@@ -475,11 +441,8 @@ ACL_FUNC_VISIBILITY aclError aclopUnregisterCompileFunc(const char *opType); | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopSetKernelArgs(aclopKernelDesc *kernelDesc, | |||||
const char *kernelId, | |||||
uint32_t blockDim, | |||||
const void *args, | |||||
uint32_t argSize); | |||||
ACL_FUNC_VISIBILITY aclError aclopSetKernelArgs(aclopKernelDesc *kernelDesc, const char *kernelId, uint32_t blockDim, | |||||
const void *args, uint32_t argSize); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -510,12 +473,9 @@ ACL_FUNC_VISIBILITY aclError aclopSetKernelWorkspaceSizes(aclopKernelDesc *kerne | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopUpdateParams(const char *opType, | |||||
int numInputs, | |||||
const aclTensorDesc *const inputDesc[], | |||||
int numOutputs, | |||||
const aclTensorDesc *const outputDesc[], | |||||
const aclopAttr *attr); | |||||
ACL_FUNC_VISIBILITY aclError aclopUpdateParams(const char *opType, int numInputs, | |||||
const aclTensorDesc *const inputDesc[], int numOutputs, | |||||
const aclTensorDesc *const outputDesc[], const aclopAttr *attr); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -533,17 +493,12 @@ ACL_FUNC_VISIBILITY aclError aclopUpdateParams(const char *opType, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopInferShape(const char *opType, | |||||
int numInputs, | |||||
aclTensorDesc *inputDesc[], | |||||
aclDataBuffer *inputs[], | |||||
int numOutputs, | |||||
aclTensorDesc *outputDesc[], | |||||
ACL_FUNC_VISIBILITY aclError aclopInferShape(const char *opType, int numInputs, aclTensorDesc *inputDesc[], | |||||
aclDataBuffer *inputs[], int numOutputs, aclTensorDesc *outputDesc[], | |||||
aclopAttr *attr); | aclopAttr *attr); | ||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif | #endif | ||||
#endif // INC_EXTERNAL_ACL_ACL_OP_H_ | |||||
#endif // INC_EXTERNAL_ACL_ACL_OP_H_ |
@@ -24,28 +24,22 @@ | |||||
extern "C" { | extern "C" { | ||||
#endif | #endif | ||||
typedef enum aclCompileType { | |||||
ACL_COMPILE_SYS, | |||||
ACL_COMPILE_UNREGISTERED | |||||
} aclopCompileType; | |||||
typedef enum aclCompileType { ACL_COMPILE_SYS, ACL_COMPILE_UNREGISTERED } aclopCompileType; | |||||
typedef enum { | typedef enum { | ||||
ACL_PRECISION_MODE, | |||||
ACL_AICORE_NUM, | |||||
ACL_AUTO_TUNE_MODE, | |||||
ACL_OP_SELECT_IMPL_MODE, | |||||
ACL_OPTYPELIST_FOR_IMPLMODE, | |||||
ACL_OP_DEBUG_LEVEL, | |||||
ACL_DEBUG_DIR, | |||||
ACL_OP_COMPILER_CACHE_MODE, | |||||
ACL_OP_COMPILER_CACHE_DIR, | |||||
ACL_OP_PERFORMANCE_MODE | |||||
ACL_PRECISION_MODE, | |||||
ACL_AICORE_NUM, | |||||
ACL_AUTO_TUNE_MODE, | |||||
ACL_OP_SELECT_IMPL_MODE, | |||||
ACL_OPTYPELIST_FOR_IMPLMODE, | |||||
ACL_OP_DEBUG_LEVEL, | |||||
ACL_DEBUG_DIR, | |||||
ACL_OP_COMPILER_CACHE_MODE, | |||||
ACL_OP_COMPILER_CACHE_DIR, | |||||
ACL_OP_PERFORMANCE_MODE | |||||
} aclCompileOpt; | } aclCompileOpt; | ||||
typedef enum aclCompileFlag { | |||||
ACL_OP_COMPILE_DEFAULT, | |||||
ACL_OP_COMPILE_FUZZ | |||||
} aclOpCompileFlag; | |||||
typedef enum aclCompileFlag { ACL_OP_COMPILE_DEFAULT, ACL_OP_COMPILE_FUZZ } aclOpCompileFlag; | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -65,15 +59,10 @@ typedef enum aclCompileFlag { | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopCompile(const char *opType, | |||||
int numInputs, | |||||
const aclTensorDesc *const inputDesc[], | |||||
int numOutputs, | |||||
const aclTensorDesc *const outputDesc[], | |||||
const aclopAttr *attr, | |||||
aclopEngineType engineType, | |||||
aclopCompileType compileFlag, | |||||
const char *opPath); | |||||
ACL_FUNC_VISIBILITY aclError aclopCompile(const char *opType, int numInputs, const aclTensorDesc *const inputDesc[], | |||||
int numOutputs, const aclTensorDesc *const outputDesc[], | |||||
const aclopAttr *attr, aclopEngineType engineType, | |||||
aclopCompileType compileFlag, const char *opPath); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -96,11 +85,10 @@ ACL_FUNC_VISIBILITY aclError aclopCompile(const char *opType, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclopCompileAndExecute(const char *opType, | |||||
int numInputs, const aclTensorDesc *const inputDesc[], const aclDataBuffer *const inputs[], | |||||
int numOutputs, const aclTensorDesc *const outputDesc[], aclDataBuffer *const outputs[], | |||||
const aclopAttr *attr, aclopEngineType engineType, aclopCompileType compileFlag, | |||||
const char *opPath, aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError aclopCompileAndExecute( | |||||
const char *opType, int numInputs, const aclTensorDesc *const inputDesc[], const aclDataBuffer *const inputs[], | |||||
int numOutputs, const aclTensorDesc *const outputDesc[], aclDataBuffer *const outputs[], const aclopAttr *attr, | |||||
aclopEngineType engineType, aclopCompileType compileFlag, const char *opPath, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -130,4 +118,4 @@ ACL_FUNC_VISIBILITY aclError aclopSetCompileFlag(aclOpCompileFlag flag); | |||||
} | } | ||||
#endif | #endif | ||||
#endif // INC_EXTERNAL_ACL_ACL_OP_COMPILER_H_ | |||||
#endif // INC_EXTERNAL_ACL_ACL_OP_COMPILER_H_ |
@@ -23,32 +23,31 @@ | |||||
extern "C" { | extern "C" { | ||||
#endif | #endif | ||||
#define ACL_PROF_ACL_API 0x0001 | |||||
#define ACL_PROF_TASK_TIME 0x0002 | |||||
#define ACL_PROF_AICORE_METRICS 0x0004 | |||||
#define ACL_PROF_AICPU 0x0008 | |||||
#define ACL_PROF_ACL_API 0x0001 | |||||
#define ACL_PROF_TASK_TIME 0x0002 | |||||
#define ACL_PROF_AICORE_METRICS 0x0004 | |||||
#define ACL_PROF_AICPU 0x0008 | |||||
/** | /** | ||||
* @deprecated please use aclprofGetOpTypeLen and aclprofGetOpTNameLen instead | * @deprecated please use aclprofGetOpTypeLen and aclprofGetOpTNameLen instead | ||||
*/ | */ | ||||
#define ACL_PROF_MAX_OP_NAME_LEN 257 | |||||
#define ACL_PROF_MAX_OP_TYPE_LEN 65 | |||||
#define ACL_PROF_MAX_OP_NAME_LEN 257 | |||||
#define ACL_PROF_MAX_OP_TYPE_LEN 65 | |||||
typedef enum { | typedef enum { | ||||
ACL_AICORE_ARITHMETIC_UTILIZATION = 0, | |||||
ACL_AICORE_PIPE_UTILIZATION = 1, | |||||
ACL_AICORE_MEMORY_BANDWIDTH = 2, | |||||
ACL_AICORE_L0B_AND_WIDTH = 3, | |||||
ACL_AICORE_RESOURCE_CONFLICT_RATIO = 4, | |||||
ACL_AICORE_NONE = 0xFF | |||||
ACL_AICORE_ARITHMETIC_UTILIZATION = 0, | |||||
ACL_AICORE_PIPE_UTILIZATION = 1, | |||||
ACL_AICORE_MEMORY_BANDWIDTH = 2, | |||||
ACL_AICORE_L0B_AND_WIDTH = 3, | |||||
ACL_AICORE_RESOURCE_CONFLICT_RATIO = 4, | |||||
ACL_AICORE_NONE = 0xFF | |||||
} aclprofAicoreMetrics; | } aclprofAicoreMetrics; | ||||
typedef enum { | typedef enum { | ||||
ACL_STEP_START = 0, // step start | |||||
ACL_STEP_END = 1 // step end | |||||
ACL_STEP_START = 0, // step start | |||||
ACL_STEP_END = 1 // step end | |||||
} aclprofStepTag; | } aclprofStepTag; | ||||
typedef struct aclprofConfig aclprofConfig; | typedef struct aclprofConfig aclprofConfig; | ||||
typedef struct aclprofStopConfig aclprofStopConfig; | typedef struct aclprofStopConfig aclprofStopConfig; | ||||
typedef struct aclprofAicoreEvents aclprofAicoreEvents; | typedef struct aclprofAicoreEvents aclprofAicoreEvents; | ||||
@@ -108,7 +107,8 @@ ACL_FUNC_VISIBILITY aclError aclprofStart(const aclprofConfig *profilerConfig); | |||||
* @see aclprofDestroyConfig | * @see aclprofDestroyConfig | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclprofConfig *aclprofCreateConfig(uint32_t *deviceIdList, uint32_t deviceNums, | ACL_FUNC_VISIBILITY aclprofConfig *aclprofCreateConfig(uint32_t *deviceIdList, uint32_t deviceNums, | ||||
aclprofAicoreMetrics aicoreMetrics, aclprofAicoreEvents *aicoreEvents, uint64_t dataTypeConfig); | |||||
aclprofAicoreMetrics aicoreMetrics, | |||||
aclprofAicoreEvents *aicoreEvents, uint64_t dataTypeConfig); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -148,8 +148,7 @@ ACL_FUNC_VISIBILITY aclError aclprofStop(const aclprofConfig *profilerConfig); | |||||
* | * | ||||
* @see aclprofModelUnSubscribe | * @see aclprofModelUnSubscribe | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclprofModelSubscribe(uint32_t modelId, | |||||
const aclprofSubscribeConfig *profSubscribeConfig); | |||||
ACL_FUNC_VISIBILITY aclError aclprofModelSubscribe(uint32_t modelId, const aclprofSubscribeConfig *profSubscribeConfig); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -177,7 +176,7 @@ ACL_FUNC_VISIBILITY aclError aclprofModelUnSubscribe(uint32_t modelId); | |||||
* @see aclprofDestroySubscribeConfig | * @see aclprofDestroySubscribeConfig | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclprofSubscribeConfig *aclprofCreateSubscribeConfig(int8_t timeInfoSwitch, | ACL_FUNC_VISIBILITY aclprofSubscribeConfig *aclprofCreateSubscribeConfig(int8_t timeInfoSwitch, | ||||
aclprofAicoreMetrics aicoreMetrics, void *fd); | |||||
aclprofAicoreMetrics aicoreMetrics, void *fd); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -229,7 +228,7 @@ ACL_FUNC_VISIBILITY aclError aclprofGetOpNum(const void *opInfo, size_t opInfoLe | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclprofGetOpTypeLen(const void *opInfo, size_t opInfoLen, uint32_t index, | ACL_FUNC_VISIBILITY aclError aclprofGetOpTypeLen(const void *opInfo, size_t opInfoLen, uint32_t index, | ||||
size_t *opTypeLen); | |||||
size_t *opTypeLen); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -244,8 +243,8 @@ ACL_FUNC_VISIBILITY aclError aclprofGetOpTypeLen(const void *opInfo, size_t opIn | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclprofGetOpType(const void *opInfo, size_t opInfoLen, uint32_t index, | |||||
char *opType, size_t opTypeLen); | |||||
ACL_FUNC_VISIBILITY aclError aclprofGetOpType(const void *opInfo, size_t opInfoLen, uint32_t index, char *opType, | |||||
size_t opTypeLen); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -260,7 +259,7 @@ ACL_FUNC_VISIBILITY aclError aclprofGetOpType(const void *opInfo, size_t opInfoL | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclprofGetOpNameLen(const void *opInfo, size_t opInfoLen, uint32_t index, | ACL_FUNC_VISIBILITY aclError aclprofGetOpNameLen(const void *opInfo, size_t opInfoLen, uint32_t index, | ||||
size_t *opNameLen); | |||||
size_t *opNameLen); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -275,8 +274,8 @@ ACL_FUNC_VISIBILITY aclError aclprofGetOpNameLen(const void *opInfo, size_t opIn | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclprofGetOpName(const void *opInfo, size_t opInfoLen, uint32_t index, | |||||
char *opName, size_t opNameLen); | |||||
ACL_FUNC_VISIBILITY aclError aclprofGetOpName(const void *opInfo, size_t opInfoLen, uint32_t index, char *opName, | |||||
size_t opNameLen); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -339,28 +338,28 @@ ACL_FUNC_VISIBILITY size_t aclprofGetModelId(const void *opInfo, size_t opInfoLe | |||||
* | * | ||||
* @retval 0 for failed | * @retval 0 for failed | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclprofGetStepTimestamp(aclprofStepInfo* stepInfo, aclprofStepTag tag, aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError aclprofGetStepTimestamp(aclprofStepInfo *stepInfo, aclprofStepTag tag, aclrtStream stream); | |||||
/** | |||||
/** | |||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
* @brief create pointer to aclprofStepInfo data | * @brief create pointer to aclprofStepInfo data | ||||
* | * | ||||
* | * | ||||
* @retval aclprofStepInfo pointer | * @retval aclprofStepInfo pointer | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclprofStepInfo* aclprofCreateStepInfo(); | |||||
ACL_FUNC_VISIBILITY aclprofStepInfo *aclprofCreateStepInfo(); | |||||
/** | |||||
/** | |||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
* @brief destroy aclprofStepInfo pointer | * @brief destroy aclprofStepInfo pointer | ||||
* | * | ||||
* | * | ||||
* @retval void | * @retval void | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY void aclprofDestroyStepInfo(aclprofStepInfo* stepinfo); | |||||
ACL_FUNC_VISIBILITY void aclprofDestroyStepInfo(aclprofStepInfo *stepinfo); | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif | #endif | ||||
#endif // INC_EXTERNAL_ACL_PROF_H_ | |||||
#endif // INC_EXTERNAL_ACL_PROF_H_ |
@@ -28,63 +28,63 @@ extern "C" { | |||||
#define ACL_EVENT_TIME_LINE 0x00000008u | #define ACL_EVENT_TIME_LINE 0x00000008u | ||||
typedef enum aclrtRunMode { | typedef enum aclrtRunMode { | ||||
ACL_DEVICE, | |||||
ACL_HOST, | |||||
ACL_DEVICE, | |||||
ACL_HOST, | |||||
} aclrtRunMode; | } aclrtRunMode; | ||||
typedef enum aclrtTsId { | typedef enum aclrtTsId { | ||||
ACL_TS_ID_AICORE = 0, | |||||
ACL_TS_ID_AIVECTOR = 1, | |||||
ACL_TS_ID_RESERVED = 2, | |||||
ACL_TS_ID_AICORE = 0, | |||||
ACL_TS_ID_AIVECTOR = 1, | |||||
ACL_TS_ID_RESERVED = 2, | |||||
} aclrtTsId; | } aclrtTsId; | ||||
typedef enum aclrtEventStatus { | typedef enum aclrtEventStatus { | ||||
ACL_EVENT_STATUS_COMPLETE = 0, | |||||
ACL_EVENT_STATUS_NOT_READY = 1, | |||||
ACL_EVENT_STATUS_RESERVED = 2, | |||||
ACL_EVENT_STATUS_COMPLETE = 0, | |||||
ACL_EVENT_STATUS_NOT_READY = 1, | |||||
ACL_EVENT_STATUS_RESERVED = 2, | |||||
} aclrtEventStatus; | } aclrtEventStatus; | ||||
typedef enum aclrtCallbackBlockType { | typedef enum aclrtCallbackBlockType { | ||||
ACL_CALLBACK_NO_BLOCK, | |||||
ACL_CALLBACK_BLOCK, | |||||
ACL_CALLBACK_NO_BLOCK, | |||||
ACL_CALLBACK_BLOCK, | |||||
} aclrtCallbackBlockType; | } aclrtCallbackBlockType; | ||||
typedef enum aclrtMemcpyKind { | typedef enum aclrtMemcpyKind { | ||||
ACL_MEMCPY_HOST_TO_HOST, | |||||
ACL_MEMCPY_HOST_TO_DEVICE, | |||||
ACL_MEMCPY_DEVICE_TO_HOST, | |||||
ACL_MEMCPY_DEVICE_TO_DEVICE, | |||||
ACL_MEMCPY_HOST_TO_HOST, | |||||
ACL_MEMCPY_HOST_TO_DEVICE, | |||||
ACL_MEMCPY_DEVICE_TO_HOST, | |||||
ACL_MEMCPY_DEVICE_TO_DEVICE, | |||||
} aclrtMemcpyKind; | } aclrtMemcpyKind; | ||||
typedef enum aclrtMemMallocPolicy { | typedef enum aclrtMemMallocPolicy { | ||||
ACL_MEM_MALLOC_HUGE_FIRST, | |||||
ACL_MEM_MALLOC_HUGE_ONLY, | |||||
ACL_MEM_MALLOC_NORMAL_ONLY, | |||||
ACL_MEM_MALLOC_HUGE_FIRST_P2P, | |||||
ACL_MEM_MALLOC_HUGE_ONLY_P2P, | |||||
ACL_MEM_MALLOC_NORMAL_ONLY_P2P, | |||||
ACL_MEM_MALLOC_HUGE_FIRST, | |||||
ACL_MEM_MALLOC_HUGE_ONLY, | |||||
ACL_MEM_MALLOC_NORMAL_ONLY, | |||||
ACL_MEM_MALLOC_HUGE_FIRST_P2P, | |||||
ACL_MEM_MALLOC_HUGE_ONLY_P2P, | |||||
ACL_MEM_MALLOC_NORMAL_ONLY_P2P, | |||||
} aclrtMemMallocPolicy; | } aclrtMemMallocPolicy; | ||||
typedef enum aclrtMemAttr { | typedef enum aclrtMemAttr { | ||||
ACL_DDR_MEM, | |||||
ACL_HBM_MEM, | |||||
ACL_DDR_MEM_HUGE, | |||||
ACL_DDR_MEM_NORMAL, | |||||
ACL_HBM_MEM_HUGE, | |||||
ACL_HBM_MEM_NORMAL, | |||||
ACL_DDR_MEM_P2P_HUGE, | |||||
ACL_DDR_MEM_P2P_NORMAL, | |||||
ACL_HBM_MEM_P2P_HUGE, | |||||
ACL_HBM_MEM_P2P_NORMAL, | |||||
ACL_DDR_MEM, | |||||
ACL_HBM_MEM, | |||||
ACL_DDR_MEM_HUGE, | |||||
ACL_DDR_MEM_NORMAL, | |||||
ACL_HBM_MEM_HUGE, | |||||
ACL_HBM_MEM_NORMAL, | |||||
ACL_DDR_MEM_P2P_HUGE, | |||||
ACL_DDR_MEM_P2P_NORMAL, | |||||
ACL_HBM_MEM_P2P_HUGE, | |||||
ACL_HBM_MEM_P2P_NORMAL, | |||||
} aclrtMemAttr; | } aclrtMemAttr; | ||||
typedef enum aclrtGroupAttr { | typedef enum aclrtGroupAttr { | ||||
ACL_GROUP_AICORE_INT, | |||||
ACL_GROUP_AIV_INT, | |||||
ACL_GROUP_AIC_INT, | |||||
ACL_GROUP_SDMANUM_INT, | |||||
ACL_GROUP_ASQNUM_INT, | |||||
ACL_GROUP_GROUPID_INT | |||||
ACL_GROUP_AICORE_INT, | |||||
ACL_GROUP_AIV_INT, | |||||
ACL_GROUP_AIC_INT, | |||||
ACL_GROUP_SDMANUM_INT, | |||||
ACL_GROUP_ASQNUM_INT, | |||||
ACL_GROUP_GROUPID_INT | |||||
} aclrtGroupAttr; | } aclrtGroupAttr; | ||||
typedef struct tagRtGroupInfo aclrtGroupInfo; | typedef struct tagRtGroupInfo aclrtGroupInfo; | ||||
@@ -487,7 +487,7 @@ ACL_FUNC_VISIBILITY aclError aclrtRecordEvent(aclrtEvent event, aclrtStream stre | |||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclrtResetEvent(aclrtEvent event, aclrtStream stream); | ACL_FUNC_VISIBILITY aclError aclrtResetEvent(aclrtEvent event, aclrtStream stream); | ||||
/** | |||||
/** | |||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
* @brief Queries an event's status | * @brief Queries an event's status | ||||
* | * | ||||
@@ -549,9 +549,7 @@ ACL_FUNC_VISIBILITY aclError aclrtEventElapsedTime(float *ms, aclrtEvent start, | |||||
* | * | ||||
* @see aclrtFree | acldvppMalloc | aclrtMallocCached | * @see aclrtFree | acldvppMalloc | aclrtMallocCached | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclrtMalloc(void **devPtr, | |||||
size_t size, | |||||
aclrtMemMallocPolicy policy); | |||||
ACL_FUNC_VISIBILITY aclError aclrtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -574,9 +572,7 @@ ACL_FUNC_VISIBILITY aclError aclrtMalloc(void **devPtr, | |||||
* | * | ||||
* @see aclrtFree | aclrtMalloc | * @see aclrtFree | aclrtMalloc | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclrtMallocCached(void **devPtr, | |||||
size_t size, | |||||
aclrtMemMallocPolicy policy); | |||||
ACL_FUNC_VISIBILITY aclError aclrtMallocCached(void **devPtr, size_t size, aclrtMemMallocPolicy policy); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -667,10 +663,7 @@ ACL_FUNC_VISIBILITY aclError aclrtFreeHost(void *hostPtr); | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclrtMemcpy(void *dst, | |||||
size_t destMax, | |||||
const void *src, | |||||
size_t count, | |||||
ACL_FUNC_VISIBILITY aclError aclrtMemcpy(void *dst, size_t destMax, const void *src, size_t count, | |||||
aclrtMemcpyKind kind); | aclrtMemcpyKind kind); | ||||
/** | /** | ||||
@@ -717,38 +710,31 @@ ACL_FUNC_VISIBILITY aclError aclrtMemset(void *devPtr, size_t maxCount, int32_t | |||||
* | * | ||||
* @see aclrtSynchronizeStream | * @see aclrtSynchronizeStream | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclrtMemcpyAsync(void *dst, | |||||
size_t destMax, | |||||
const void *src, | |||||
size_t count, | |||||
aclrtMemcpyKind kind, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError aclrtMemcpyAsync(void *dst, size_t destMax, const void *src, size_t count, | |||||
aclrtMemcpyKind kind, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | |||||
* @brief Asynchronous initialize memory | |||||
* and set contents of memory to specified value async | |||||
* | |||||
* @par Function | |||||
* @ingroup AscendCL | |||||
* @brief Asynchronous initialize memory | |||||
* and set contents of memory to specified value async | |||||
* | |||||
* @par Function | |||||
* The memory to be initialized is on the Host or device side, | * The memory to be initialized is on the Host or device side, | ||||
* and the system determines whether | * and the system determines whether | ||||
* it is host or device according to the address | * it is host or device according to the address | ||||
* | * | ||||
* @param devPtr [IN] destination address pointer | |||||
* @param maxCount [IN] Max length of destination address memory | |||||
* @param value [IN] set value | |||||
* @param count [IN] the number of byte to set | |||||
* @param stream [IN] asynchronized task stream | |||||
* | |||||
* @retval ACL_SUCCESS The function is successfully executed. | |||||
* @retval OtherValues Failure | |||||
* | |||||
* @see aclrtSynchronizeStream | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclrtMemsetAsync(void *devPtr, | |||||
size_t maxCount, | |||||
int32_t value, | |||||
size_t count, | |||||
* @param devPtr [IN] destination address pointer | |||||
* @param maxCount [IN] Max length of destination address memory | |||||
* @param value [IN] set value | |||||
* @param count [IN] the number of byte to set | |||||
* @param stream [IN] asynchronized task stream | |||||
* | |||||
* @retval ACL_SUCCESS The function is successfully executed. | |||||
* @retval OtherValues Failure | |||||
* | |||||
* @see aclrtSynchronizeStream | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclrtMemsetAsync(void *devPtr, size_t maxCount, int32_t value, size_t count, | |||||
aclrtStream stream); | aclrtStream stream); | ||||
/** | /** | ||||
@@ -894,11 +880,8 @@ ACL_FUNC_VISIBILITY aclError aclrtGetAllGroupInfo(aclrtGroupInfo *groupInfo); | |||||
* | * | ||||
* @see aclrtGetGroupCount | aclrtGetAllGroupInfo | * @see aclrtGetGroupCount | aclrtGetAllGroupInfo | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclrtGetGroupInfoDetail(const aclrtGroupInfo *groupInfo, | |||||
int32_t groupIndex, | |||||
aclrtGroupAttr attr, | |||||
void *attrValue, | |||||
size_t valueLen, | |||||
ACL_FUNC_VISIBILITY aclError aclrtGetGroupInfoDetail(const aclrtGroupInfo *groupInfo, int32_t groupIndex, | |||||
aclrtGroupAttr attr, void *attrValue, size_t valueLen, | |||||
size_t *paramRetSize); | size_t *paramRetSize); | ||||
/** | /** | ||||
@@ -972,5 +955,4 @@ ACL_FUNC_VISIBILITY aclError aclrtSetOpWaitTimeout(uint32_t timeout); | |||||
} | } | ||||
#endif | #endif | ||||
#endif // INC_EXTERNAL_ACL_ACL_RT_H_ | |||||
#endif // INC_EXTERNAL_ACL_ACL_RT_H_ |
@@ -24,10 +24,10 @@ extern "C" { | |||||
#endif | #endif | ||||
enum acltdtTensorType { | enum acltdtTensorType { | ||||
ACL_TENSOR_DATA_UNDEFINED = -1, | |||||
ACL_TENSOR_DATA_TENSOR, | |||||
ACL_TENSOR_DATA_END_OF_SEQUENCE, | |||||
ACL_TENSOR_DATA_ABNORMAL | |||||
ACL_TENSOR_DATA_UNDEFINED = -1, | |||||
ACL_TENSOR_DATA_TENSOR, | |||||
ACL_TENSOR_DATA_END_OF_SEQUENCE, | |||||
ACL_TENSOR_DATA_ABNORMAL | |||||
}; | }; | ||||
typedef struct acltdtDataItem acltdtDataItem; | typedef struct acltdtDataItem acltdtDataItem; | ||||
@@ -64,7 +64,7 @@ ACL_FUNC_VISIBILITY aclDataType acltdtGetDataTypeFromItem(const acltdtDataItem * | |||||
* | * | ||||
* @retval null for failed | * @retval null for failed | ||||
* @retval OtherValues success | * @retval OtherValues success | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY void *acltdtGetDataAddrFromItem(const acltdtDataItem *dataItem); | ACL_FUNC_VISIBILITY void *acltdtGetDataAddrFromItem(const acltdtDataItem *dataItem); | ||||
/** | /** | ||||
@@ -75,7 +75,7 @@ ACL_FUNC_VISIBILITY void *acltdtGetDataAddrFromItem(const acltdtDataItem *dataIt | |||||
* | * | ||||
* @retval 0 for failed | * @retval 0 for failed | ||||
* @retval OtherValues success | * @retval OtherValues success | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY size_t acltdtGetDataSizeFromItem(const acltdtDataItem *dataItem); | ACL_FUNC_VISIBILITY size_t acltdtGetDataSizeFromItem(const acltdtDataItem *dataItem); | ||||
/** | /** | ||||
@@ -86,7 +86,7 @@ ACL_FUNC_VISIBILITY size_t acltdtGetDataSizeFromItem(const acltdtDataItem *dataI | |||||
* | * | ||||
* @retval 0 for failed | * @retval 0 for failed | ||||
* @retval OtherValues success | * @retval OtherValues success | ||||
*/ | |||||
*/ | |||||
ACL_FUNC_VISIBILITY size_t acltdtGetDimNumFromItem(const acltdtDataItem *dataItem); | ACL_FUNC_VISIBILITY size_t acltdtGetDimNumFromItem(const acltdtDataItem *dataItem); | ||||
/** | /** | ||||
@@ -118,12 +118,8 @@ ACL_FUNC_VISIBILITY aclError acltdtGetDimsFromItem(const acltdtDataItem *dataIte | |||||
* | * | ||||
* @see acltdtDestroyDataItem | * @see acltdtDestroyDataItem | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY acltdtDataItem *acltdtCreateDataItem(acltdtTensorType tdtType, | |||||
const int64_t *dims, | |||||
size_t dimNum, | |||||
aclDataType dataType, | |||||
void *data, | |||||
size_t size); | |||||
ACL_FUNC_VISIBILITY acltdtDataItem *acltdtCreateDataItem(acltdtTensorType tdtType, const int64_t *dims, size_t dimNum, | |||||
aclDataType dataType, void *data, size_t size); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -254,8 +250,7 @@ ACL_FUNC_VISIBILITY aclError acltdtDestroyChannel(acltdtChannelHandle *handle); | |||||
* | * | ||||
* @see acltdtReceiveTensor | * @see acltdtReceiveTensor | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acltdtSendTensor(const acltdtChannelHandle *handle, | |||||
const acltdtDataset *dataset, | |||||
ACL_FUNC_VISIBILITY aclError acltdtSendTensor(const acltdtChannelHandle *handle, const acltdtDataset *dataset, | |||||
int32_t timeout); | int32_t timeout); | ||||
/** | /** | ||||
@@ -271,13 +266,11 @@ ACL_FUNC_VISIBILITY aclError acltdtSendTensor(const acltdtChannelHandle *handle, | |||||
* | * | ||||
* @see acltdtSendTensor | * @see acltdtSendTensor | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acltdtReceiveTensor(const acltdtChannelHandle *handle, | |||||
acltdtDataset *dataset, | |||||
ACL_FUNC_VISIBILITY aclError acltdtReceiveTensor(const acltdtChannelHandle *handle, acltdtDataset *dataset, | |||||
int32_t timeout); | int32_t timeout); | ||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif | #endif | ||||
#endif //INC_EXTERNAL_ACL_ACL_TDT_H_ | |||||
#endif // INC_EXTERNAL_ACL_ACL_TDT_H_ |
@@ -23,87 +23,87 @@ | |||||
extern "C" { | extern "C" { | ||||
#endif | #endif | ||||
static const int32_t ACL_RT_SUCCESS = 0; // success | |||||
static const int32_t ACL_RT_SUCCESS = 0; // success | |||||
static const int32_t ACL_ERROR_RT_PARAM_INVALID = 107000; // param invalid | |||||
static const int32_t ACL_ERROR_RT_INVALID_DEVICEID = 107001; // invalid device id | |||||
static const int32_t ACL_ERROR_RT_CONTEXT_NULL = 107002; // current context null | |||||
static const int32_t ACL_ERROR_RT_STREAM_CONTEXT = 107003; // stream not in current context | |||||
static const int32_t ACL_ERROR_RT_MODEL_CONTEXT = 107004; // model not in current context | |||||
static const int32_t ACL_ERROR_RT_STREAM_MODEL = 107005; // stream not in model | |||||
static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_INVALID = 107006; // event timestamp invalid | |||||
static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_REVERSAL = 107007; // event timestamp reversal | |||||
static const int32_t ACL_ERROR_RT_ADDR_UNALIGNED = 107008; // memory address unaligned | |||||
static const int32_t ACL_ERROR_RT_FILE_OPEN = 107009; // open file failed | |||||
static const int32_t ACL_ERROR_RT_FILE_WRITE = 107010; // write file failed | |||||
static const int32_t ACL_ERROR_RT_STREAM_SUBSCRIBE = 107011; // error subscribe stream | |||||
static const int32_t ACL_ERROR_RT_THREAD_SUBSCRIBE = 107012; // error subscribe thread | |||||
static const int32_t ACL_ERROR_RT_GROUP_NOT_SET = 107013; // group not set | |||||
static const int32_t ACL_ERROR_RT_GROUP_NOT_CREATE = 107014; // group not create | |||||
static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callback not register to stream | |||||
static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type | |||||
static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle | |||||
static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type | |||||
static const int32_t ACL_ERROR_RT_WAIT_TIMEOUT = 107019; // wait timeout | |||||
static const int32_t ACL_ERROR_RT_PARAM_INVALID = 107000; // param invalid | |||||
static const int32_t ACL_ERROR_RT_INVALID_DEVICEID = 107001; // invalid device id | |||||
static const int32_t ACL_ERROR_RT_CONTEXT_NULL = 107002; // current context null | |||||
static const int32_t ACL_ERROR_RT_STREAM_CONTEXT = 107003; // stream not in current context | |||||
static const int32_t ACL_ERROR_RT_MODEL_CONTEXT = 107004; // model not in current context | |||||
static const int32_t ACL_ERROR_RT_STREAM_MODEL = 107005; // stream not in model | |||||
static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_INVALID = 107006; // event timestamp invalid | |||||
static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_REVERSAL = 107007; // event timestamp reversal | |||||
static const int32_t ACL_ERROR_RT_ADDR_UNALIGNED = 107008; // memory address unaligned | |||||
static const int32_t ACL_ERROR_RT_FILE_OPEN = 107009; // open file failed | |||||
static const int32_t ACL_ERROR_RT_FILE_WRITE = 107010; // write file failed | |||||
static const int32_t ACL_ERROR_RT_STREAM_SUBSCRIBE = 107011; // error subscribe stream | |||||
static const int32_t ACL_ERROR_RT_THREAD_SUBSCRIBE = 107012; // error subscribe thread | |||||
static const int32_t ACL_ERROR_RT_GROUP_NOT_SET = 107013; // group not set | |||||
static const int32_t ACL_ERROR_RT_GROUP_NOT_CREATE = 107014; // group not create | |||||
static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callback not register to stream | |||||
static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type | |||||
static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle | |||||
static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type | |||||
static const int32_t ACL_ERROR_RT_WAIT_TIMEOUT = 107019; // wait timeout | |||||
static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support | |||||
static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error | |||||
static const int32_t ACL_ERROR_RT_MEMORY_FREE = 207002; // memory free error | |||||
static const int32_t ACL_ERROR_RT_AICORE_OVER_FLOW = 207003; // aicore over flow | |||||
static const int32_t ACL_ERROR_RT_NO_DEVICE = 207004; // no device | |||||
static const int32_t ACL_ERROR_RT_RESOURCE_ALLOC_FAIL = 207005; // resource alloc fail | |||||
static const int32_t ACL_ERROR_RT_NO_PERMISSION = 207006; // no permission | |||||
static const int32_t ACL_ERROR_RT_NO_EVENT_RESOURCE = 207007; // no event resource | |||||
static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource | |||||
static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource | |||||
static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource | |||||
static const int32_t ACL_ERROR_RT_NO_CDQ_RESOURCE = 207011; // no cdq resource | |||||
static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support | |||||
static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error | |||||
static const int32_t ACL_ERROR_RT_MEMORY_FREE = 207002; // memory free error | |||||
static const int32_t ACL_ERROR_RT_AICORE_OVER_FLOW = 207003; // aicore over flow | |||||
static const int32_t ACL_ERROR_RT_NO_DEVICE = 207004; // no device | |||||
static const int32_t ACL_ERROR_RT_RESOURCE_ALLOC_FAIL = 207005; // resource alloc fail | |||||
static const int32_t ACL_ERROR_RT_NO_PERMISSION = 207006; // no permission | |||||
static const int32_t ACL_ERROR_RT_NO_EVENT_RESOURCE = 207007; // no event resource | |||||
static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource | |||||
static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource | |||||
static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource | |||||
static const int32_t ACL_ERROR_RT_NO_CDQ_RESOURCE = 207011; // no cdq resource | |||||
static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error | |||||
static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error | |||||
static const int32_t ACL_ERROR_RT_STREAM_TASK_FULL = 507002; // task full in stream | |||||
static const int32_t ACL_ERROR_RT_STREAM_TASK_EMPTY = 507003; // task empty in stream | |||||
static const int32_t ACL_ERROR_RT_STREAM_NOT_COMPLETE = 507004; // stream not complete | |||||
static const int32_t ACL_ERROR_RT_END_OF_SEQUENCE = 507005; // end of sequence | |||||
static const int32_t ACL_ERROR_RT_EVENT_NOT_COMPLETE = 507006; // event not complete | |||||
static const int32_t ACL_ERROR_RT_CONTEXT_RELEASE_ERROR = 507007; // context release error | |||||
static const int32_t ACL_ERROR_RT_SOC_VERSION = 507008; // soc version error | |||||
static const int32_t ACL_ERROR_RT_TASK_TYPE_NOT_SUPPORT = 507009; // task type not support | |||||
static const int32_t ACL_ERROR_RT_LOST_HEARTBEAT = 507010; // ts lost heartbeat | |||||
static const int32_t ACL_ERROR_RT_MODEL_EXECUTE = 507011; // model execute failed | |||||
static const int32_t ACL_ERROR_RT_REPORT_TIMEOUT = 507012; // report timeout | |||||
static const int32_t ACL_ERROR_RT_SYS_DMA = 507013; // sys dma error | |||||
static const int32_t ACL_ERROR_RT_AICORE_TIMEOUT = 507014; // aicore timeout | |||||
static const int32_t ACL_ERROR_RT_AICORE_EXCEPTION = 507015; // aicore exception | |||||
static const int32_t ACL_ERROR_RT_AICORE_TRAP_EXCEPTION = 507016; // aicore trap exception | |||||
static const int32_t ACL_ERROR_RT_AICPU_TIMEOUT = 507017; // aicpu timeout | |||||
static const int32_t ACL_ERROR_RT_AICPU_EXCEPTION = 507018; // aicpu exception | |||||
static const int32_t ACL_ERROR_RT_AICPU_DATADUMP_RSP_ERR = 507019; // aicpu datadump response error | |||||
static const int32_t ACL_ERROR_RT_AICPU_MODEL_RSP_ERR = 507020; // aicpu model operate response error | |||||
static const int32_t ACL_ERROR_RT_PROFILING_ERROR = 507021; // profiling error | |||||
static const int32_t ACL_ERROR_RT_IPC_ERROR = 507022; // ipc error | |||||
static const int32_t ACL_ERROR_RT_MODEL_ABORT_NORMAL = 507023; // model abort normal | |||||
static const int32_t ACL_ERROR_RT_KERNEL_UNREGISTERING = 507024; // kernel unregistering | |||||
static const int32_t ACL_ERROR_RT_RINGBUFFER_NOT_INIT = 507025; // ringbuffer not init | |||||
static const int32_t ACL_ERROR_RT_RINGBUFFER_NO_DATA = 507026; // ringbuffer no data | |||||
static const int32_t ACL_ERROR_RT_KERNEL_LOOKUP = 507027; // kernel lookup error | |||||
static const int32_t ACL_ERROR_RT_KERNEL_DUPLICATE = 507028; // kernel register duplicate | |||||
static const int32_t ACL_ERROR_RT_DEBUG_REGISTER_FAIL = 507029; // debug register failed | |||||
static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug unregister failed | |||||
static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context | |||||
static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out | |||||
static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error | |||||
static const int32_t ACL_ERROR_RT_VECTOR_CORE_TIMEOUT = 507034; // vector core timeout | |||||
static const int32_t ACL_ERROR_RT_VECTOR_CORE_EXCEPTION = 507035; // vector core exception | |||||
static const int32_t ACL_ERROR_RT_VECTOR_CORE_TRAP_EXCEPTION = 507036; // vector core trap exception | |||||
static const int32_t ACL_ERROR_RT_CDQ_BATCH_ABNORMAL = 507037; // cdq alloc batch abnormal | |||||
static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error | |||||
static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error | |||||
static const int32_t ACL_ERROR_RT_STREAM_TASK_FULL = 507002; // task full in stream | |||||
static const int32_t ACL_ERROR_RT_STREAM_TASK_EMPTY = 507003; // task empty in stream | |||||
static const int32_t ACL_ERROR_RT_STREAM_NOT_COMPLETE = 507004; // stream not complete | |||||
static const int32_t ACL_ERROR_RT_END_OF_SEQUENCE = 507005; // end of sequence | |||||
static const int32_t ACL_ERROR_RT_EVENT_NOT_COMPLETE = 507006; // event not complete | |||||
static const int32_t ACL_ERROR_RT_CONTEXT_RELEASE_ERROR = 507007; // context release error | |||||
static const int32_t ACL_ERROR_RT_SOC_VERSION = 507008; // soc version error | |||||
static const int32_t ACL_ERROR_RT_TASK_TYPE_NOT_SUPPORT = 507009; // task type not support | |||||
static const int32_t ACL_ERROR_RT_LOST_HEARTBEAT = 507010; // ts lost heartbeat | |||||
static const int32_t ACL_ERROR_RT_MODEL_EXECUTE = 507011; // model execute failed | |||||
static const int32_t ACL_ERROR_RT_REPORT_TIMEOUT = 507012; // report timeout | |||||
static const int32_t ACL_ERROR_RT_SYS_DMA = 507013; // sys dma error | |||||
static const int32_t ACL_ERROR_RT_AICORE_TIMEOUT = 507014; // aicore timeout | |||||
static const int32_t ACL_ERROR_RT_AICORE_EXCEPTION = 507015; // aicore exception | |||||
static const int32_t ACL_ERROR_RT_AICORE_TRAP_EXCEPTION = 507016; // aicore trap exception | |||||
static const int32_t ACL_ERROR_RT_AICPU_TIMEOUT = 507017; // aicpu timeout | |||||
static const int32_t ACL_ERROR_RT_AICPU_EXCEPTION = 507018; // aicpu exception | |||||
static const int32_t ACL_ERROR_RT_AICPU_DATADUMP_RSP_ERR = 507019; // aicpu datadump response error | |||||
static const int32_t ACL_ERROR_RT_AICPU_MODEL_RSP_ERR = 507020; // aicpu model operate response error | |||||
static const int32_t ACL_ERROR_RT_PROFILING_ERROR = 507021; // profiling error | |||||
static const int32_t ACL_ERROR_RT_IPC_ERROR = 507022; // ipc error | |||||
static const int32_t ACL_ERROR_RT_MODEL_ABORT_NORMAL = 507023; // model abort normal | |||||
static const int32_t ACL_ERROR_RT_KERNEL_UNREGISTERING = 507024; // kernel unregistering | |||||
static const int32_t ACL_ERROR_RT_RINGBUFFER_NOT_INIT = 507025; // ringbuffer not init | |||||
static const int32_t ACL_ERROR_RT_RINGBUFFER_NO_DATA = 507026; // ringbuffer no data | |||||
static const int32_t ACL_ERROR_RT_KERNEL_LOOKUP = 507027; // kernel lookup error | |||||
static const int32_t ACL_ERROR_RT_KERNEL_DUPLICATE = 507028; // kernel register duplicate | |||||
static const int32_t ACL_ERROR_RT_DEBUG_REGISTER_FAIL = 507029; // debug register failed | |||||
static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug unregister failed | |||||
static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context | |||||
static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out | |||||
static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error | |||||
static const int32_t ACL_ERROR_RT_VECTOR_CORE_TIMEOUT = 507034; // vector core timeout | |||||
static const int32_t ACL_ERROR_RT_VECTOR_CORE_EXCEPTION = 507035; // vector core exception | |||||
static const int32_t ACL_ERROR_RT_VECTOR_CORE_TRAP_EXCEPTION = 507036; // vector core trap exception | |||||
static const int32_t ACL_ERROR_RT_CDQ_BATCH_ABNORMAL = 507037; // cdq alloc batch abnormal | |||||
static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error | |||||
static const int32_t ACL_ERROR_RT_AICPU_INTERNAL_ERROR = 507900; // aicpu internal error | |||||
static const int32_t ACL_ERROR_RT_SOCKET_CLOSE = 507901; // hdc disconnect | |||||
static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error | |||||
static const int32_t ACL_ERROR_RT_AICPU_INTERNAL_ERROR = 507900; // aicpu internal error | |||||
static const int32_t ACL_ERROR_RT_SOCKET_CLOSE = 507901; // hdc disconnect | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif | #endif | ||||
#endif // __INC_EXTERNEL_RT_ERROR_CODES_H__ | |||||
#endif // __INC_EXTERNEL_RT_ERROR_CODES_H__ |
@@ -23,17 +23,9 @@ | |||||
extern "C" { | extern "C" { | ||||
#endif | #endif | ||||
typedef enum aclTransType { | |||||
ACL_TRANS_N, | |||||
ACL_TRANS_T, | |||||
ACL_TRANS_NZ, | |||||
ACL_TRANS_NZ_T | |||||
} aclTransType; | |||||
typedef enum aclTransType { ACL_TRANS_N, ACL_TRANS_T, ACL_TRANS_NZ, ACL_TRANS_NZ_T } aclTransType; | |||||
typedef enum aclComputeType { | |||||
ACL_COMPUTE_HIGH_PRECISION, | |||||
ACL_COMPUTE_LOW_PRECISION | |||||
} aclComputeType; | |||||
typedef enum aclComputeType { ACL_COMPUTE_HIGH_PRECISION, ACL_COMPUTE_LOW_PRECISION } aclComputeType; | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -61,12 +53,11 @@ typedef enum aclComputeType { | |||||
* | * | ||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclblasGemvEx(aclTransType transA, int m, int n, | |||||
const void *alpha, const void *a, int lda, aclDataType dataTypeA, | |||||
const void *x, int incx, aclDataType dataTypeX, | |||||
const void *beta, void *y, int incy, aclDataType dataTypeY, | |||||
aclComputeType type, aclrtStream stream); | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclblasGemvEx(aclTransType transA, int m, int n, const void *alpha, const void *a, int lda, | |||||
aclDataType dataTypeA, const void *x, int incx, aclDataType dataTypeX, | |||||
const void *beta, void *y, int incy, aclDataType dataTypeY, | |||||
aclComputeType type, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -83,15 +74,10 @@ ACL_FUNC_VISIBILITY aclError aclblasGemvEx(aclTransType transA, int m, int n, | |||||
* | * | ||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForGemvEx(aclTransType transA, | |||||
int m, | |||||
int n, | |||||
aclDataType dataTypeA, | |||||
aclDataType dataTypeX, | |||||
aclDataType dataTypeY, | |||||
aclComputeType type, | |||||
aclopHandle **handle); | |||||
*/ | |||||
ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForGemvEx(aclTransType transA, int m, int n, aclDataType dataTypeA, | |||||
aclDataType dataTypeX, aclDataType dataTypeY, | |||||
aclComputeType type, aclopHandle **handle); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -115,18 +101,9 @@ ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForGemvEx(aclTransType transA, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclblasHgemv(aclTransType transA, | |||||
int m, | |||||
int n, | |||||
const aclFloat16 *alpha, | |||||
const aclFloat16 *a, | |||||
int lda, | |||||
const aclFloat16 *x, | |||||
int incx, | |||||
const aclFloat16 *beta, | |||||
aclFloat16 *y, | |||||
int incy, | |||||
aclComputeType type, | |||||
ACL_FUNC_VISIBILITY aclError aclblasHgemv(aclTransType transA, int m, int n, const aclFloat16 *alpha, | |||||
const aclFloat16 *a, int lda, const aclFloat16 *x, int incx, | |||||
const aclFloat16 *beta, aclFloat16 *y, int incy, aclComputeType type, | |||||
aclrtStream stream); | aclrtStream stream); | ||||
/** | /** | ||||
@@ -142,10 +119,7 @@ ACL_FUNC_VISIBILITY aclError aclblasHgemv(aclTransType transA, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForHgemv(aclTransType transA, | |||||
int m, | |||||
int n, | |||||
aclComputeType type, | |||||
ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForHgemv(aclTransType transA, int m, int n, aclComputeType type, | |||||
aclopHandle **handle); | aclopHandle **handle); | ||||
/** | /** | ||||
@@ -171,19 +145,9 @@ ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForHgemv(aclTransType transA, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclblasS8gemv(aclTransType transA, | |||||
int m, | |||||
int n, | |||||
const int32_t *alpha, | |||||
const int8_t *a, | |||||
int lda, | |||||
const int8_t *x, | |||||
int incx, | |||||
const int32_t *beta, | |||||
int32_t *y, | |||||
int incy, | |||||
aclComputeType type, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError aclblasS8gemv(aclTransType transA, int m, int n, const int32_t *alpha, const int8_t *a, | |||||
int lda, const int8_t *x, int incx, const int32_t *beta, int32_t *y, | |||||
int incy, aclComputeType type, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -198,10 +162,7 @@ ACL_FUNC_VISIBILITY aclError aclblasS8gemv(aclTransType transA, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForS8gemv(aclTransType transA, | |||||
int m, | |||||
int n, | |||||
aclComputeType type, | |||||
ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForS8gemv(aclTransType transA, int m, int n, aclComputeType type, | |||||
aclopHandle **handle); | aclopHandle **handle); | ||||
/** | /** | ||||
@@ -233,26 +194,11 @@ ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForS8gemv(aclTransType transA, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclblasGemmEx(aclTransType transA, | |||||
aclTransType transB, | |||||
aclTransType transC, | |||||
int m, | |||||
int n, | |||||
int k, | |||||
const void *alpha, | |||||
const void *matrixA, | |||||
int lda, | |||||
aclDataType dataTypeA, | |||||
const void *matrixB, | |||||
int ldb, | |||||
aclDataType dataTypeB, | |||||
const void *beta, | |||||
void *matrixC, | |||||
int ldc, | |||||
aclDataType dataTypeC, | |||||
aclComputeType type, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError aclblasGemmEx(aclTransType transA, aclTransType transB, aclTransType transC, int m, int n, | |||||
int k, const void *alpha, const void *matrixA, int lda, | |||||
aclDataType dataTypeA, const void *matrixB, int ldb, aclDataType dataTypeB, | |||||
const void *beta, void *matrixC, int ldc, aclDataType dataTypeC, | |||||
aclComputeType type, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -274,18 +220,10 @@ ACL_FUNC_VISIBILITY aclError aclblasGemmEx(aclTransType transA, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForGemmEx(aclTransType transA, | |||||
aclTransType transB, | |||||
aclTransType transC, | |||||
int m, | |||||
int n, | |||||
int k, | |||||
aclDataType dataTypeA, | |||||
aclDataType dataTypeB, | |||||
aclDataType dataTypeC, | |||||
aclComputeType type, | |||||
aclopHandle **handle); | |||||
ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForGemmEx(aclTransType transA, aclTransType transB, aclTransType transC, | |||||
int m, int n, int k, aclDataType dataTypeA, | |||||
aclDataType dataTypeB, aclDataType dataTypeC, | |||||
aclComputeType type, aclopHandle **handle); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -313,22 +251,10 @@ ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForGemmEx(aclTransType transA, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclblasHgemm(aclTransType transA, | |||||
aclTransType transB, | |||||
aclTransType transC, | |||||
int m, | |||||
int n, | |||||
int k, | |||||
const aclFloat16 *alpha, | |||||
const aclFloat16 *matrixA, | |||||
int lda, | |||||
const aclFloat16 *matrixB, | |||||
int ldb, | |||||
const aclFloat16 *beta, | |||||
aclFloat16 *matrixC, | |||||
int ldc, | |||||
aclComputeType type, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError aclblasHgemm(aclTransType transA, aclTransType transB, aclTransType transC, int m, int n, | |||||
int k, const aclFloat16 *alpha, const aclFloat16 *matrixA, int lda, | |||||
const aclFloat16 *matrixB, int ldb, const aclFloat16 *beta, | |||||
aclFloat16 *matrixC, int ldc, aclComputeType type, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -346,13 +272,8 @@ ACL_FUNC_VISIBILITY aclError aclblasHgemm(aclTransType transA, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForHgemm(aclTransType transA, | |||||
aclTransType transB, | |||||
aclTransType transC, | |||||
int m, | |||||
int n, | |||||
int k, | |||||
aclComputeType type, | |||||
ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForHgemm(aclTransType transA, aclTransType transB, aclTransType transC, | |||||
int m, int n, int k, aclComputeType type, | |||||
aclopHandle **handle); | aclopHandle **handle); | ||||
/** | /** | ||||
@@ -381,23 +302,10 @@ ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForHgemm(aclTransType transA, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclblasS8gemm(aclTransType transA, | |||||
aclTransType transB, | |||||
aclTransType transC, | |||||
int m, | |||||
int n, | |||||
int k, | |||||
const int32_t *alpha, | |||||
const int8_t *matrixA, | |||||
int lda, | |||||
const int8_t *matrixB, | |||||
int ldb, | |||||
const int32_t *beta, | |||||
int32_t *matrixC, | |||||
int ldc, | |||||
aclComputeType type, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError aclblasS8gemm(aclTransType transA, aclTransType transB, aclTransType transC, int m, int n, | |||||
int k, const int32_t *alpha, const int8_t *matrixA, int lda, | |||||
const int8_t *matrixB, int ldb, const int32_t *beta, int32_t *matrixC, | |||||
int ldc, aclComputeType type, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -415,17 +323,12 @@ ACL_FUNC_VISIBILITY aclError aclblasS8gemm(aclTransType transA, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForS8gemm(aclTransType transA, | |||||
aclTransType transB, | |||||
aclTransType transC, | |||||
int m, | |||||
int n, | |||||
int k, | |||||
aclComputeType type, | |||||
ACL_FUNC_VISIBILITY aclError aclblasCreateHandleForS8gemm(aclTransType transA, aclTransType transB, aclTransType transC, | |||||
int m, int n, int k, aclComputeType type, | |||||
aclopHandle **handle); | aclopHandle **handle); | ||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif | #endif | ||||
#endif // INC_EXTERNAL_ACL_OPS_ACL_CBLAS_H_ | |||||
#endif // INC_EXTERNAL_ACL_OPS_ACL_CBLAS_H_ |
@@ -53,123 +53,109 @@ typedef void (*aclvencCallback)(acldvppPicDesc *input, acldvppStreamDesc *output | |||||
// Supported Pixel Format | // Supported Pixel Format | ||||
enum acldvppPixelFormat { | enum acldvppPixelFormat { | ||||
PIXEL_FORMAT_YUV_400 = 0, // 0 | |||||
PIXEL_FORMAT_YUV_SEMIPLANAR_420 = 1, // 1 | |||||
PIXEL_FORMAT_YVU_SEMIPLANAR_420 = 2, // 2 | |||||
PIXEL_FORMAT_YUV_SEMIPLANAR_422 = 3, // 3 | |||||
PIXEL_FORMAT_YVU_SEMIPLANAR_422 = 4, // 4 | |||||
PIXEL_FORMAT_YUV_SEMIPLANAR_444 = 5, // 5 | |||||
PIXEL_FORMAT_YVU_SEMIPLANAR_444 = 6, // 6 | |||||
PIXEL_FORMAT_YUYV_PACKED_422 = 7, // 7 | |||||
PIXEL_FORMAT_UYVY_PACKED_422 = 8, // 8 | |||||
PIXEL_FORMAT_YVYU_PACKED_422 = 9, // 9 | |||||
PIXEL_FORMAT_VYUY_PACKED_422 = 10, // 10 | |||||
PIXEL_FORMAT_YUV_PACKED_444 = 11, // 11 | |||||
PIXEL_FORMAT_RGB_888 = 12, // 12 | |||||
PIXEL_FORMAT_BGR_888 = 13, // 13 | |||||
PIXEL_FORMAT_ARGB_8888 = 14, // 14 | |||||
PIXEL_FORMAT_ABGR_8888 = 15, // 15 | |||||
PIXEL_FORMAT_RGBA_8888 = 16, // 16 | |||||
PIXEL_FORMAT_BGRA_8888 = 17, // 17 | |||||
PIXEL_FORMAT_YUV_SEMI_PLANNER_420_10BIT = 18, // 18 | |||||
PIXEL_FORMAT_YVU_SEMI_PLANNER_420_10BIT = 19, // 19 | |||||
PIXEL_FORMAT_YVU_PLANAR_420 = 20, // 20 | |||||
PIXEL_FORMAT_YVU_PLANAR_422, | |||||
PIXEL_FORMAT_YVU_PLANAR_444, | |||||
PIXEL_FORMAT_RGB_444 = 23, | |||||
PIXEL_FORMAT_BGR_444, | |||||
PIXEL_FORMAT_ARGB_4444, | |||||
PIXEL_FORMAT_ABGR_4444, | |||||
PIXEL_FORMAT_RGBA_4444, | |||||
PIXEL_FORMAT_BGRA_4444, | |||||
PIXEL_FORMAT_RGB_555, | |||||
PIXEL_FORMAT_BGR_555, | |||||
PIXEL_FORMAT_RGB_565, | |||||
PIXEL_FORMAT_BGR_565, | |||||
PIXEL_FORMAT_ARGB_1555, | |||||
PIXEL_FORMAT_ABGR_1555, | |||||
PIXEL_FORMAT_RGBA_1555, | |||||
PIXEL_FORMAT_BGRA_1555, | |||||
PIXEL_FORMAT_ARGB_8565, | |||||
PIXEL_FORMAT_ABGR_8565, | |||||
PIXEL_FORMAT_RGBA_8565, | |||||
PIXEL_FORMAT_BGRA_8565, | |||||
PIXEL_FORMAT_RGB_BAYER_8BPP = 50, | |||||
PIXEL_FORMAT_RGB_BAYER_10BPP, | |||||
PIXEL_FORMAT_RGB_BAYER_12BPP, | |||||
PIXEL_FORMAT_RGB_BAYER_14BPP, | |||||
PIXEL_FORMAT_RGB_BAYER_16BPP, | |||||
PIXEL_FORMAT_BGR_888_PLANAR = 70, | |||||
PIXEL_FORMAT_HSV_888_PACKAGE, | |||||
PIXEL_FORMAT_HSV_888_PLANAR, | |||||
PIXEL_FORMAT_LAB_888_PACKAGE, | |||||
PIXEL_FORMAT_LAB_888_PLANAR, | |||||
PIXEL_FORMAT_S8C1, | |||||
PIXEL_FORMAT_S8C2_PACKAGE, | |||||
PIXEL_FORMAT_S8C2_PLANAR, | |||||
PIXEL_FORMAT_S16C1, | |||||
PIXEL_FORMAT_U8C1, | |||||
PIXEL_FORMAT_U16C1, | |||||
PIXEL_FORMAT_S32C1, | |||||
PIXEL_FORMAT_U32C1, | |||||
PIXEL_FORMAT_U64C1, | |||||
PIXEL_FORMAT_S64C1, | |||||
PIXEL_FORMAT_YUV_SEMIPLANAR_440 = 1000, | |||||
PIXEL_FORMAT_YVU_SEMIPLANAR_440, | |||||
PIXEL_FORMAT_FLOAT32, | |||||
PIXEL_FORMAT_BUTT, | |||||
PIXEL_FORMAT_UNKNOWN = 10000 | |||||
PIXEL_FORMAT_YUV_400 = 0, // 0 | |||||
PIXEL_FORMAT_YUV_SEMIPLANAR_420 = 1, // 1 | |||||
PIXEL_FORMAT_YVU_SEMIPLANAR_420 = 2, // 2 | |||||
PIXEL_FORMAT_YUV_SEMIPLANAR_422 = 3, // 3 | |||||
PIXEL_FORMAT_YVU_SEMIPLANAR_422 = 4, // 4 | |||||
PIXEL_FORMAT_YUV_SEMIPLANAR_444 = 5, // 5 | |||||
PIXEL_FORMAT_YVU_SEMIPLANAR_444 = 6, // 6 | |||||
PIXEL_FORMAT_YUYV_PACKED_422 = 7, // 7 | |||||
PIXEL_FORMAT_UYVY_PACKED_422 = 8, // 8 | |||||
PIXEL_FORMAT_YVYU_PACKED_422 = 9, // 9 | |||||
PIXEL_FORMAT_VYUY_PACKED_422 = 10, // 10 | |||||
PIXEL_FORMAT_YUV_PACKED_444 = 11, // 11 | |||||
PIXEL_FORMAT_RGB_888 = 12, // 12 | |||||
PIXEL_FORMAT_BGR_888 = 13, // 13 | |||||
PIXEL_FORMAT_ARGB_8888 = 14, // 14 | |||||
PIXEL_FORMAT_ABGR_8888 = 15, // 15 | |||||
PIXEL_FORMAT_RGBA_8888 = 16, // 16 | |||||
PIXEL_FORMAT_BGRA_8888 = 17, // 17 | |||||
PIXEL_FORMAT_YUV_SEMI_PLANNER_420_10BIT = 18, // 18 | |||||
PIXEL_FORMAT_YVU_SEMI_PLANNER_420_10BIT = 19, // 19 | |||||
PIXEL_FORMAT_YVU_PLANAR_420 = 20, // 20 | |||||
PIXEL_FORMAT_YVU_PLANAR_422, | |||||
PIXEL_FORMAT_YVU_PLANAR_444, | |||||
PIXEL_FORMAT_RGB_444 = 23, | |||||
PIXEL_FORMAT_BGR_444, | |||||
PIXEL_FORMAT_ARGB_4444, | |||||
PIXEL_FORMAT_ABGR_4444, | |||||
PIXEL_FORMAT_RGBA_4444, | |||||
PIXEL_FORMAT_BGRA_4444, | |||||
PIXEL_FORMAT_RGB_555, | |||||
PIXEL_FORMAT_BGR_555, | |||||
PIXEL_FORMAT_RGB_565, | |||||
PIXEL_FORMAT_BGR_565, | |||||
PIXEL_FORMAT_ARGB_1555, | |||||
PIXEL_FORMAT_ABGR_1555, | |||||
PIXEL_FORMAT_RGBA_1555, | |||||
PIXEL_FORMAT_BGRA_1555, | |||||
PIXEL_FORMAT_ARGB_8565, | |||||
PIXEL_FORMAT_ABGR_8565, | |||||
PIXEL_FORMAT_RGBA_8565, | |||||
PIXEL_FORMAT_BGRA_8565, | |||||
PIXEL_FORMAT_RGB_BAYER_8BPP = 50, | |||||
PIXEL_FORMAT_RGB_BAYER_10BPP, | |||||
PIXEL_FORMAT_RGB_BAYER_12BPP, | |||||
PIXEL_FORMAT_RGB_BAYER_14BPP, | |||||
PIXEL_FORMAT_RGB_BAYER_16BPP, | |||||
PIXEL_FORMAT_BGR_888_PLANAR = 70, | |||||
PIXEL_FORMAT_HSV_888_PACKAGE, | |||||
PIXEL_FORMAT_HSV_888_PLANAR, | |||||
PIXEL_FORMAT_LAB_888_PACKAGE, | |||||
PIXEL_FORMAT_LAB_888_PLANAR, | |||||
PIXEL_FORMAT_S8C1, | |||||
PIXEL_FORMAT_S8C2_PACKAGE, | |||||
PIXEL_FORMAT_S8C2_PLANAR, | |||||
PIXEL_FORMAT_S16C1, | |||||
PIXEL_FORMAT_U8C1, | |||||
PIXEL_FORMAT_U16C1, | |||||
PIXEL_FORMAT_S32C1, | |||||
PIXEL_FORMAT_U32C1, | |||||
PIXEL_FORMAT_U64C1, | |||||
PIXEL_FORMAT_S64C1, | |||||
PIXEL_FORMAT_YUV_SEMIPLANAR_440 = 1000, | |||||
PIXEL_FORMAT_YVU_SEMIPLANAR_440, | |||||
PIXEL_FORMAT_FLOAT32, | |||||
PIXEL_FORMAT_BUTT, | |||||
PIXEL_FORMAT_UNKNOWN = 10000 | |||||
}; | }; | ||||
// Stream Format | // Stream Format | ||||
enum acldvppStreamFormat { | |||||
H265_MAIN_LEVEL = 0, | |||||
H264_BASELINE_LEVEL, | |||||
H264_MAIN_LEVEL, | |||||
H264_HIGH_LEVEL | |||||
}; | |||||
enum acldvppStreamFormat { H265_MAIN_LEVEL = 0, H264_BASELINE_LEVEL, H264_MAIN_LEVEL, H264_HIGH_LEVEL }; | |||||
// Supported Channel Mode | // Supported Channel Mode | ||||
enum acldvppChannelMode { | |||||
DVPP_CHNMODE_VPC = 1, | |||||
DVPP_CHNMODE_JPEGD = 2, | |||||
DVPP_CHNMODE_JPEGE = 4 | |||||
}; | |||||
enum acldvppChannelMode { DVPP_CHNMODE_VPC = 1, DVPP_CHNMODE_JPEGD = 2, DVPP_CHNMODE_JPEGE = 4 }; | |||||
// Supported Border Type | // Supported Border Type | ||||
enum acldvppBorderType { | |||||
BORDER_CONSTANT = 0, | |||||
BORDER_REPLICATE, | |||||
BORDER_REFLECT, | |||||
BORDER_REFLECT_101 | |||||
}; | |||||
enum acldvppBorderType { BORDER_CONSTANT = 0, BORDER_REPLICATE, BORDER_REFLECT, BORDER_REFLECT_101 }; | |||||
// Venc parameter type | // Venc parameter type | ||||
enum aclvencChannelDescParamType { | enum aclvencChannelDescParamType { | ||||
ACL_VENC_THREAD_ID_UINT64 = 0, | |||||
ACL_VENC_CALLBACK_PTR, | |||||
ACL_VENC_PIXEL_FORMAT_UINT32, | |||||
ACL_VENC_ENCODE_TYPE_UINT32, | |||||
ACL_VENC_PIC_WIDTH_UINT32, | |||||
ACL_VENC_PIC_HEIGHT_UINT32, | |||||
ACL_VENC_KEY_FRAME_INTERVAL_UINT32, | |||||
ACL_VENC_BUF_ADDR_PTR, | |||||
ACL_VENC_BUF_SIZE_UINT32, | |||||
ACL_VENC_RC_MODE_UINT32, | |||||
ACL_VENC_SRC_RATE_UINT32, | |||||
ACL_VENC_MAX_BITRATE_UINT32, | |||||
ACL_VENC_MAX_IP_PROP_UINT32 | |||||
ACL_VENC_THREAD_ID_UINT64 = 0, | |||||
ACL_VENC_CALLBACK_PTR, | |||||
ACL_VENC_PIXEL_FORMAT_UINT32, | |||||
ACL_VENC_ENCODE_TYPE_UINT32, | |||||
ACL_VENC_PIC_WIDTH_UINT32, | |||||
ACL_VENC_PIC_HEIGHT_UINT32, | |||||
ACL_VENC_KEY_FRAME_INTERVAL_UINT32, | |||||
ACL_VENC_BUF_ADDR_PTR, | |||||
ACL_VENC_BUF_SIZE_UINT32, | |||||
ACL_VENC_RC_MODE_UINT32, | |||||
ACL_VENC_SRC_RATE_UINT32, | |||||
ACL_VENC_MAX_BITRATE_UINT32, | |||||
ACL_VENC_MAX_IP_PROP_UINT32 | |||||
}; | }; | ||||
// Jpeg picture format | // Jpeg picture format | ||||
enum acldvppJpegFormat { | enum acldvppJpegFormat { | ||||
ACL_JPEG_CSS_444 = 0, | |||||
ACL_JPEG_CSS_422, | |||||
ACL_JPEG_CSS_420, | |||||
ACL_JPEG_CSS_GRAY, | |||||
ACL_JPEG_CSS_440, | |||||
ACL_JPEG_CSS_411, | |||||
ACL_JPEG_CSS_UNKNOWN = 1000 | |||||
ACL_JPEG_CSS_444 = 0, | |||||
ACL_JPEG_CSS_422, | |||||
ACL_JPEG_CSS_420, | |||||
ACL_JPEG_CSS_GRAY, | |||||
ACL_JPEG_CSS_440, | |||||
ACL_JPEG_CSS_411, | |||||
ACL_JPEG_CSS_UNKNOWN = 1000 | |||||
}; | }; | ||||
/** | /** | ||||
@@ -523,9 +509,7 @@ ACL_FUNC_VISIBILITY uint32_t acldvppGetPicDescRetCode(const acldvppPicDesc *picD | |||||
* @retval null for failed. | * @retval null for failed. | ||||
* @retval other success | * @retval other success | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY acldvppRoiConfig *acldvppCreateRoiConfig(uint32_t left, | |||||
uint32_t right, | |||||
uint32_t top, | |||||
ACL_FUNC_VISIBILITY acldvppRoiConfig *acldvppCreateRoiConfig(uint32_t left, uint32_t right, uint32_t top, | |||||
uint32_t bottom); | uint32_t bottom); | ||||
/** | /** | ||||
@@ -604,10 +588,7 @@ ACL_FUNC_VISIBILITY aclError acldvppSetRoiConfigBottom(acldvppRoiConfig *config, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppSetRoiConfig(acldvppRoiConfig *config, | |||||
uint32_t left, | |||||
uint32_t right, | |||||
uint32_t top, | |||||
ACL_FUNC_VISIBILITY aclError acldvppSetRoiConfig(acldvppRoiConfig *config, uint32_t left, uint32_t right, uint32_t top, | |||||
uint32_t bottom); | uint32_t bottom); | ||||
/** | /** | ||||
@@ -1096,7 +1077,8 @@ ACL_FUNC_VISIBILITY aclError aclvencSetChannelDescMaxBitRate(aclvencChannelDesc | |||||
* @retval ACL_SUCCESS for success, other for failure | * @retval ACL_SUCCESS for success, other for failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclvencSetChannelDescParam(aclvencChannelDesc *channelDesc, | ACL_FUNC_VISIBILITY aclError aclvencSetChannelDescParam(aclvencChannelDesc *channelDesc, | ||||
aclvencChannelDescParamType paramType, size_t length, const void *param); | |||||
aclvencChannelDescParamType paramType, size_t length, | |||||
const void *param); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1245,7 +1227,8 @@ ACL_FUNC_VISIBILITY uint32_t aclvencGetChannelDescMaxBitRate(const aclvencChanne | |||||
* @retval ACL_SUCCESS for success, other for failure | * @retval ACL_SUCCESS for success, other for failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclvencGetChannelDescParam(const aclvencChannelDesc *channelDesc, | ACL_FUNC_VISIBILITY aclError aclvencGetChannelDescParam(const aclvencChannelDesc *channelDesc, | ||||
aclvencChannelDescParamType paramType, size_t length, size_t *paramRetSize, void *param); | |||||
aclvencChannelDescParamType paramType, size_t length, | |||||
size_t *paramRetSize, void *param); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1545,10 +1528,7 @@ ACL_FUNC_VISIBILITY aclError aclvdecDestroyFrameConfig(aclvdecFrameConfig *vdecF | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppJpegGetImageInfo(const void *data, | |||||
uint32_t size, | |||||
uint32_t *width, | |||||
uint32_t *height, | |||||
ACL_FUNC_VISIBILITY aclError acldvppJpegGetImageInfo(const void *data, uint32_t size, uint32_t *width, uint32_t *height, | |||||
int32_t *components); | int32_t *components); | ||||
/** | /** | ||||
@@ -1565,11 +1545,8 @@ ACL_FUNC_VISIBILITY aclError acldvppJpegGetImageInfo(const void *data, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppJpegGetImageInfoV2(const void *data, | |||||
uint32_t size, | |||||
uint32_t *width, | |||||
uint32_t *height, | |||||
int32_t *components, | |||||
ACL_FUNC_VISIBILITY aclError acldvppJpegGetImageInfoV2(const void *data, uint32_t size, uint32_t *width, | |||||
uint32_t *height, int32_t *components, | |||||
acldvppJpegFormat *format); | acldvppJpegFormat *format); | ||||
/** | /** | ||||
@@ -1584,8 +1561,7 @@ ACL_FUNC_VISIBILITY aclError acldvppJpegGetImageInfoV2(const void *data, | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppJpegPredictEncSize(const acldvppPicDesc *inputDesc, | ACL_FUNC_VISIBILITY aclError acldvppJpegPredictEncSize(const acldvppPicDesc *inputDesc, | ||||
const acldvppJpegeConfig *config, | |||||
uint32_t *size); | |||||
const acldvppJpegeConfig *config, uint32_t *size); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1599,10 +1575,8 @@ ACL_FUNC_VISIBILITY aclError acldvppJpegPredictEncSize(const acldvppPicDesc *inp | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppJpegPredictDecSize(const void *data, | |||||
uint32_t dataSize, | |||||
acldvppPixelFormat outputPixelFormat, | |||||
uint32_t *decSize); | |||||
ACL_FUNC_VISIBILITY aclError acldvppJpegPredictDecSize(const void *data, uint32_t dataSize, | |||||
acldvppPixelFormat outputPixelFormat, uint32_t *decSize); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1617,11 +1591,8 @@ ACL_FUNC_VISIBILITY aclError acldvppJpegPredictDecSize(const void *data, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppPngGetImageInfo(const void *data, | |||||
uint32_t dataSize, | |||||
uint32_t *width, | |||||
uint32_t *height, | |||||
int32_t *components); | |||||
ACL_FUNC_VISIBILITY aclError acldvppPngGetImageInfo(const void *data, uint32_t dataSize, uint32_t *width, | |||||
uint32_t *height, int32_t *components); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1635,10 +1606,8 @@ ACL_FUNC_VISIBILITY aclError acldvppPngGetImageInfo(const void *data, | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppPngPredictDecSize(const void *data, | |||||
uint32_t dataSize, | |||||
acldvppPixelFormat outputPixelFormat, | |||||
uint32_t *decSize); | |||||
ACL_FUNC_VISIBILITY aclError acldvppPngPredictDecSize(const void *data, uint32_t dataSize, | |||||
acldvppPixelFormat outputPixelFormat, uint32_t *decSize); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1702,10 +1671,8 @@ ACL_FUNC_VISIBILITY aclError acldvppDestroyChannel(acldvppChannelDesc *channelDe | |||||
* @see acldvppCreateChannel | acldvppCreatePicDesc | * @see acldvppCreateChannel | acldvppCreatePicDesc | ||||
* | acldvppCreateResizeConfig | * | acldvppCreateResizeConfig | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppVpcResizeAsync(acldvppChannelDesc *channelDesc, | |||||
acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, | |||||
acldvppResizeConfig *resizeConfig, | |||||
ACL_FUNC_VISIBILITY aclError acldvppVpcResizeAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, acldvppResizeConfig *resizeConfig, | |||||
aclrtStream stream); | aclrtStream stream); | ||||
/** | /** | ||||
@@ -1741,10 +1708,8 @@ ACL_FUNC_VISIBILITY aclError acldvppVpcResizeAsync(acldvppChannelDesc *channelDe | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppVpcCropAsync(acldvppChannelDesc *channelDesc, | |||||
acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, | |||||
acldvppRoiConfig *cropArea, | |||||
ACL_FUNC_VISIBILITY aclError acldvppVpcCropAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, acldvppRoiConfig *cropArea, | |||||
aclrtStream stream); | aclrtStream stream); | ||||
/** | /** | ||||
@@ -1781,13 +1746,9 @@ ACL_FUNC_VISIBILITY aclError acldvppVpcCropAsync(acldvppChannelDesc *channelDesc | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppVpcCropResizeAsync(acldvppChannelDesc *channelDesc, | |||||
acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, | |||||
acldvppRoiConfig *cropArea, | |||||
acldvppResizeConfig *resizeConfig, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError acldvppVpcCropResizeAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, acldvppRoiConfig *cropArea, | |||||
acldvppResizeConfig *resizeConfig, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1811,12 +1772,9 @@ ACL_FUNC_VISIBILITY aclError acldvppVpcCropResizeAsync(acldvppChannelDesc *chann | |||||
* @see acldvppCreateChannel | acldvppCreateBatchPicDesc | acldvppCreateRoiConfig | * @see acldvppCreateChannel | acldvppCreateBatchPicDesc | acldvppCreateRoiConfig | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropAsync(acldvppChannelDesc *channelDesc, | ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropAsync(acldvppChannelDesc *channelDesc, | ||||
acldvppBatchPicDesc *srcBatchPicDescs, | |||||
uint32_t *roiNums, | |||||
uint32_t size, | |||||
acldvppBatchPicDesc *dstBatchPicDescs, | |||||
acldvppRoiConfig *cropAreas[], | |||||
aclrtStream stream); | |||||
acldvppBatchPicDesc *srcBatchPicDescs, uint32_t *roiNums, | |||||
uint32_t size, acldvppBatchPicDesc *dstBatchPicDescs, | |||||
acldvppRoiConfig *cropAreas[], aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1841,13 +1799,10 @@ ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropAsync(acldvppChannelDesc *channe | |||||
* @see acldvppCreateChannel | acldvppCreateBatchPicDesc | acldvppCreateRoiConfig | acldvppCreateDvppConfig | * @see acldvppCreateChannel | acldvppCreateBatchPicDesc | acldvppCreateRoiConfig | acldvppCreateDvppConfig | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropResizeAsync(acldvppChannelDesc *channelDesc, | ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropResizeAsync(acldvppChannelDesc *channelDesc, | ||||
acldvppBatchPicDesc *srcBatchPicDescs, | |||||
uint32_t *roiNums, | |||||
uint32_t size, | |||||
acldvppBatchPicDesc *dstBatchPicDescs, | |||||
acldvppBatchPicDesc *srcBatchPicDescs, uint32_t *roiNums, | |||||
uint32_t size, acldvppBatchPicDesc *dstBatchPicDescs, | |||||
acldvppRoiConfig *cropAreas[], | acldvppRoiConfig *cropAreas[], | ||||
acldvppResizeConfig *resizeConfig, | |||||
aclrtStream stream); | |||||
acldvppResizeConfig *resizeConfig, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1870,12 +1825,9 @@ ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropResizeAsync(acldvppChannelDesc * | |||||
* | * | ||||
* @see acldvppCreateChannel | acldvppCreatePicDesc | acldvppCreateRoiConfig | * @see acldvppCreateChannel | acldvppCreatePicDesc | acldvppCreateRoiConfig | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppVpcCropAndPasteAsync(acldvppChannelDesc *channelDesc, | |||||
acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, | |||||
acldvppRoiConfig *cropArea, | |||||
acldvppRoiConfig *pasteArea, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError acldvppVpcCropAndPasteAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, acldvppRoiConfig *cropArea, | |||||
acldvppRoiConfig *pasteArea, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1899,13 +1851,10 @@ ACL_FUNC_VISIBILITY aclError acldvppVpcCropAndPasteAsync(acldvppChannelDesc *cha | |||||
* | * | ||||
* @see acldvppCreateChannel | acldvppCreatePicDesc | acldvppCreateRoiConfig | acldvppCreateResizeConfig | * @see acldvppCreateChannel | acldvppCreatePicDesc | acldvppCreateRoiConfig | acldvppCreateResizeConfig | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppVpcCropResizePasteAsync(acldvppChannelDesc *channelDesc, | |||||
acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, | |||||
acldvppRoiConfig *cropArea, | |||||
ACL_FUNC_VISIBILITY aclError acldvppVpcCropResizePasteAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, acldvppRoiConfig *cropArea, | |||||
acldvppRoiConfig *pasteArea, | acldvppRoiConfig *pasteArea, | ||||
acldvppResizeConfig *resizeConfig, | |||||
aclrtStream stream); | |||||
acldvppResizeConfig *resizeConfig, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1930,14 +1879,11 @@ ACL_FUNC_VISIBILITY aclError acldvppVpcCropResizePasteAsync(acldvppChannelDesc * | |||||
* | * | ||||
* @see acldvppCreateChannel | acldvppCreateBatchPicDesc | acldvppCreateRoiConfig | * @see acldvppCreateChannel | acldvppCreateBatchPicDesc | acldvppCreateRoiConfig | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropAndPasteAsync(acldvppChannelDesc *channelDesc, | |||||
acldvppBatchPicDesc *srcBatchPicDescs, | |||||
uint32_t *roiNums, | |||||
uint32_t size, | |||||
acldvppBatchPicDesc *dstBatchPicDescs, | |||||
acldvppRoiConfig *cropAreas[], | |||||
acldvppRoiConfig *pasteAreas[], | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropAndPasteAsync(acldvppChannelDesc *channelDesc, | |||||
acldvppBatchPicDesc *srcBatchPicDescs, uint32_t *roiNums, | |||||
uint32_t size, acldvppBatchPicDesc *dstBatchPicDescs, | |||||
acldvppRoiConfig *cropAreas[], | |||||
acldvppRoiConfig *pasteAreas[], aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -1963,16 +1909,10 @@ ACL_FUNC_VISIBILITY aclError acldvppVpcCropResizePasteAsync(acldvppChannelDesc * | |||||
* | * | ||||
* @see acldvppCreateChannel | acldvppCreateBatchPicDesc | acldvppCreateRoiConfig | acldvppCreateResizeConfig | * @see acldvppCreateChannel | acldvppCreateBatchPicDesc | acldvppCreateRoiConfig | acldvppCreateResizeConfig | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropResizePasteAsync(acldvppChannelDesc *channelDesc, | |||||
acldvppBatchPicDesc *srcBatchPicDescs, | |||||
uint32_t *roiNums, | |||||
uint32_t size, | |||||
acldvppBatchPicDesc *dstBatchPicDescs, | |||||
acldvppRoiConfig *cropAreas[], | |||||
acldvppRoiConfig *pasteAreas[], | |||||
acldvppResizeConfig *resizeConfig, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropResizePasteAsync( | |||||
acldvppChannelDesc *channelDesc, acldvppBatchPicDesc *srcBatchPicDescs, uint32_t *roiNums, uint32_t size, | |||||
acldvppBatchPicDesc *dstBatchPicDescs, acldvppRoiConfig *cropAreas[], acldvppRoiConfig *pasteAreas[], | |||||
acldvppResizeConfig *resizeConfig, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -2000,11 +1940,8 @@ ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropResizePasteAsync(acldvppChannelD | |||||
* | * | ||||
* @see acldvppCreateChannel | acldvppCreatePicDesc | * @see acldvppCreateChannel | acldvppCreatePicDesc | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppJpegDecodeAsync(acldvppChannelDesc *channelDesc, | |||||
const void *data, | |||||
uint32_t size, | |||||
acldvppPicDesc *outputDesc, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError acldvppJpegDecodeAsync(acldvppChannelDesc *channelDesc, const void *data, uint32_t size, | |||||
acldvppPicDesc *outputDesc, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -2022,11 +1959,8 @@ ACL_FUNC_VISIBILITY aclError acldvppJpegDecodeAsync(acldvppChannelDesc *channelD | |||||
* | * | ||||
* @see acldvppCreateChannel | acldvppCreateJpegeConfig | * @see acldvppCreateChannel | acldvppCreateJpegeConfig | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppJpegEncodeAsync(acldvppChannelDesc *channelDesc, | |||||
acldvppPicDesc *inputDesc, | |||||
const void *data, | |||||
uint32_t *size, | |||||
acldvppJpegeConfig *config, | |||||
ACL_FUNC_VISIBILITY aclError acldvppJpegEncodeAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, | |||||
const void *data, uint32_t *size, acldvppJpegeConfig *config, | |||||
aclrtStream stream); | aclrtStream stream); | ||||
/** | /** | ||||
@@ -2044,11 +1978,8 @@ ACL_FUNC_VISIBILITY aclError acldvppJpegEncodeAsync(acldvppChannelDesc *channelD | |||||
* | * | ||||
* @see acldvppCreateChannel | acldvppCreatePicDesc | * @see acldvppCreateChannel | acldvppCreatePicDesc | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppPngDecodeAsync(acldvppChannelDesc *channelDesc, | |||||
const void *data, | |||||
uint32_t size, | |||||
acldvppPicDesc *outputDesc, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError acldvppPngDecodeAsync(acldvppChannelDesc *channelDesc, const void *data, uint32_t size, | |||||
acldvppPicDesc *outputDesc, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -2103,11 +2034,8 @@ ACL_FUNC_VISIBILITY aclError aclvdecDestroyChannel(aclvdecChannelDesc *channelDe | |||||
* | * | ||||
* @see aclvdecCreateChannel | acldvppCreateStreamDesc | acldvppCreatePicDesc | * @see aclvdecCreateChannel | acldvppCreateStreamDesc | acldvppCreatePicDesc | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclvdecSendFrame(aclvdecChannelDesc *channelDesc, | |||||
acldvppStreamDesc *input, | |||||
acldvppPicDesc *output, | |||||
aclvdecFrameConfig *config, | |||||
void *userData); | |||||
ACL_FUNC_VISIBILITY aclError aclvdecSendFrame(aclvdecChannelDesc *channelDesc, acldvppStreamDesc *input, | |||||
acldvppPicDesc *output, aclvdecFrameConfig *config, void *userData); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -2126,10 +2054,8 @@ ACL_FUNC_VISIBILITY aclError aclvdecSendFrame(aclvdecChannelDesc *channelDesc, | |||||
* | * | ||||
* @see aclvdecCreateChannel | acldvppCreateStreamDesc | acldvppCreatePicDesc | aclvdecSendFrame | * @see aclvdecCreateChannel | acldvppCreateStreamDesc | acldvppCreatePicDesc | aclvdecSendFrame | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclvdecSendSkippedFrame(aclvdecChannelDesc *channelDesc, | |||||
acldvppStreamDesc *input, | |||||
aclvdecFrameConfig *config, | |||||
void *userData); | |||||
ACL_FUNC_VISIBILITY aclError aclvdecSendSkippedFrame(aclvdecChannelDesc *channelDesc, acldvppStreamDesc *input, | |||||
aclvdecFrameConfig *config, void *userData); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -2150,10 +2076,8 @@ ACL_FUNC_VISIBILITY aclError aclvdecSendSkippedFrame(aclvdecChannelDesc *channel | |||||
* | * | ||||
* @see acldvppCreateChannel | acldvppCreatePicDesc | * @see acldvppCreateChannel | acldvppCreatePicDesc | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppVpcConvertColorAsync(acldvppChannelDesc *channelDesc, | |||||
acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError acldvppVpcConvertColorAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -2175,11 +2099,8 @@ ACL_FUNC_VISIBILITY aclError acldvppVpcConvertColorAsync(acldvppChannelDesc *cha | |||||
* | * | ||||
* @see acldvppCreateChannel | acldvppCreatePicDesc | * @see acldvppCreateChannel | acldvppCreatePicDesc | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppVpcPyrDownAsync(acldvppChannelDesc *channelDesc, | |||||
acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, | |||||
void *reserve, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError acldvppVpcPyrDownAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, void *reserve, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -2191,8 +2112,7 @@ ACL_FUNC_VISIBILITY aclError acldvppVpcPyrDownAsync(acldvppChannelDesc *channelD | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppSetChannelDescMode(acldvppChannelDesc *channelDesc, | |||||
uint32_t mode); | |||||
ACL_FUNC_VISIBILITY aclError acldvppSetChannelDescMode(acldvppChannelDesc *channelDesc, uint32_t mode); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -2227,8 +2147,7 @@ ACL_FUNC_VISIBILITY uint32_t acldvppGetResizeConfigInterpolation(const acldvppRe | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError aclvdecSetChannelDescOutMode(aclvdecChannelDesc *channelDesc, | |||||
uint32_t outMode); | |||||
ACL_FUNC_VISIBILITY aclError aclvdecSetChannelDescOutMode(aclvdecChannelDesc *channelDesc, uint32_t outMode); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -2325,9 +2244,7 @@ ACL_FUNC_VISIBILITY uint32_t acldvppGetLutMapDims(const acldvppLutMap *lutMap); | |||||
* @retval ACL_SUCCESS The function is successfully executed. | * @retval ACL_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppGetLutMapData(const acldvppLutMap *lutMap, | |||||
uint32_t dim, | |||||
uint8_t **data, | |||||
ACL_FUNC_VISIBILITY aclError acldvppGetLutMapData(const acldvppLutMap *lutMap, uint32_t dim, uint8_t **data, | |||||
uint32_t *len); | uint32_t *len); | ||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -2345,10 +2262,8 @@ ACL_FUNC_VISIBILITY aclError acldvppGetLutMapData(const acldvppLutMap *lutMap, | |||||
* @see acldvppCreateChannel|acldvppCreatePicDesc|acldvppCreateLutMap | * @see acldvppCreateChannel|acldvppCreatePicDesc|acldvppCreateLutMap | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppVpcEqualizeHistAsync(const acldvppChannelDesc *channelDesc, | ACL_FUNC_VISIBILITY aclError acldvppVpcEqualizeHistAsync(const acldvppChannelDesc *channelDesc, | ||||
const acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, | |||||
const acldvppLutMap *lutMap, | |||||
aclrtStream stream); | |||||
const acldvppPicDesc *inputDesc, acldvppPicDesc *outputDesc, | |||||
const acldvppLutMap *lutMap, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -2369,8 +2284,7 @@ ACL_FUNC_VISIBILITY acldvppBorderConfig *acldvppCreateBorderConfig(); | |||||
* | * | ||||
* @retval ACL_SUCCESS for success, other for failure | * @retval ACL_SUCCESS for success, other for failure | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppSetBorderConfigValue(acldvppBorderConfig *borderConfig, | |||||
uint32_t index, | |||||
ACL_FUNC_VISIBILITY aclError acldvppSetBorderConfigValue(acldvppBorderConfig *borderConfig, uint32_t index, | |||||
double value); | double value); | ||||
/** | /** | ||||
@@ -2515,10 +2429,8 @@ ACL_FUNC_VISIBILITY aclError acldvppDestroyBorderConfig(acldvppBorderConfig *bor | |||||
* @see acldvppCreateChannel|acldvppCreatePicDesc|acldvppCreateBorderConfig | * @see acldvppCreateChannel|acldvppCreatePicDesc|acldvppCreateBorderConfig | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppVpcMakeBorderAsync(const acldvppChannelDesc *channelDesc, | ACL_FUNC_VISIBILITY aclError acldvppVpcMakeBorderAsync(const acldvppChannelDesc *channelDesc, | ||||
const acldvppPicDesc *inputDesc, | |||||
acldvppPicDesc *outputDesc, | |||||
const acldvppBorderConfig *borderConfig, | |||||
aclrtStream stream); | |||||
const acldvppPicDesc *inputDesc, acldvppPicDesc *outputDesc, | |||||
const acldvppBorderConfig *borderConfig, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -2535,11 +2447,8 @@ ACL_FUNC_VISIBILITY aclError acldvppVpcMakeBorderAsync(const acldvppChannelDesc | |||||
* | * | ||||
* @see acldvppCreateChannel | acldvppCreatePicDesc | acldvppCreateHist | * @see acldvppCreateChannel | acldvppCreatePicDesc | acldvppCreateHist | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppVpcCalcHistAsync(acldvppChannelDesc *channelDesc, | |||||
acldvppPicDesc *srcPicDesc, | |||||
acldvppHist *hist, | |||||
void *reserve, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError acldvppVpcCalcHistAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *srcPicDesc, | |||||
acldvppHist *hist, void *reserve, aclrtStream stream); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -2548,7 +2457,7 @@ ACL_FUNC_VISIBILITY aclError acldvppVpcCalcHistAsync(acldvppChannelDesc *channel | |||||
* @retval null for failed. | * @retval null for failed. | ||||
* @retval OtherValues success. | * @retval OtherValues success. | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY acldvppHist* acldvppCreateHist(); | |||||
ACL_FUNC_VISIBILITY acldvppHist *acldvppCreateHist(); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -2605,7 +2514,7 @@ ACL_FUNC_VISIBILITY aclError acldvppGetHistData(acldvppHist *hist, uint32_t dim, | |||||
* | * | ||||
* @see acldvppCreateHist | acldvppVpcCalcHistAsync | * @see acldvppCreateHist | acldvppVpcCalcHistAsync | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY uint32_t acldvppGetHistRetCode(acldvppHist* hist); | |||||
ACL_FUNC_VISIBILITY uint32_t acldvppGetHistRetCode(acldvppHist *hist); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -2624,7 +2533,6 @@ ACL_FUNC_VISIBILITY uint32_t acldvppGetHistRetCode(acldvppHist* hist); | |||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppClearHist(acldvppHist *hist); | ACL_FUNC_VISIBILITY aclError acldvppClearHist(acldvppHist *hist); | ||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
* @brief dvpp vpc batch crop, resize config and make border. | * @brief dvpp vpc batch crop, resize config and make border. | ||||
@@ -2648,18 +2556,13 @@ ACL_FUNC_VISIBILITY aclError acldvppClearHist(acldvppHist *hist); | |||||
* | * | ||||
* @see acldvppCreateChannel | acldvppCreateBatchPicDesc | acldvppCreateRoiConfig | acldvppCreateResizeConfig | * @see acldvppCreateChannel | acldvppCreateBatchPicDesc | acldvppCreateRoiConfig | acldvppCreateResizeConfig | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropResizeMakeBorderAsync(acldvppChannelDesc *channelDesc, | |||||
acldvppBatchPicDesc *srcBatchPicDescs, | |||||
uint32_t *roiNums, | |||||
uint32_t size, | |||||
acldvppBatchPicDesc *dstBatchPicDescs, | |||||
acldvppRoiConfig *cropAreas[], | |||||
acldvppBorderConfig *borderCfgs[], | |||||
acldvppResizeConfig *resizeConfig, | |||||
aclrtStream stream); | |||||
ACL_FUNC_VISIBILITY aclError acldvppVpcBatchCropResizeMakeBorderAsync( | |||||
acldvppChannelDesc *channelDesc, acldvppBatchPicDesc *srcBatchPicDescs, uint32_t *roiNums, uint32_t size, | |||||
acldvppBatchPicDesc *dstBatchPicDescs, acldvppRoiConfig *cropAreas[], acldvppBorderConfig *borderCfgs[], | |||||
acldvppResizeConfig *resizeConfig, aclrtStream stream); | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif | #endif | ||||
#endif // INC_EXTERNAL_ACL_OPS_ACL_DVPP_H_ | |||||
#endif // INC_EXTERNAL_ACL_OPS_ACL_DVPP_H_ |
@@ -32,8 +32,8 @@ typedef struct aclfvSearchResult aclfvSearchResult; | |||||
// search operation type | // search operation type | ||||
enum aclfvSearchType { | enum aclfvSearchType { | ||||
SEARCH_1_N, // 1:N operation type | |||||
SEARCH_N_M // N:M operation type | |||||
SEARCH_1_N, // 1:N operation type | |||||
SEARCH_N_M // N:M operation type | |||||
}; | }; | ||||
/** | /** | ||||
@@ -104,7 +104,8 @@ ACL_FUNC_VISIBILITY aclError aclfvSetNMTopNum(aclfvInitPara *initPara, uint32_t | |||||
* @retval OtherValues success. | * @retval OtherValues success. | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclfvFeatureInfo *aclfvCreateFeatureInfo(uint32_t id0, uint32_t id1, uint32_t offset, | ACL_FUNC_VISIBILITY aclfvFeatureInfo *aclfvCreateFeatureInfo(uint32_t id0, uint32_t id1, uint32_t offset, | ||||
uint32_t featureLen, uint32_t featureCount, uint8_t *featureData, uint32_t featureDataLen); | |||||
uint32_t featureLen, uint32_t featureCount, | |||||
uint8_t *featureData, uint32_t featureDataLen); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -233,8 +234,9 @@ ACL_FUNC_VISIBILITY aclError aclfvDestroySearchInput(aclfvSearchInput *searchInp | |||||
* @retval null for failed. OtherValues success | * @retval null for failed. OtherValues success | ||||
*/ | */ | ||||
ACL_FUNC_VISIBILITY aclfvSearchResult *aclfvCreateSearchResult(uint32_t queryCnt, uint32_t *resultNum, | ACL_FUNC_VISIBILITY aclfvSearchResult *aclfvCreateSearchResult(uint32_t queryCnt, uint32_t *resultNum, | ||||
uint32_t resultNumDataLen, uint32_t *id0, uint32_t *id1, uint32_t *resultOffset, float *resultDistance, | |||||
uint32_t dataLen); | |||||
uint32_t resultNumDataLen, uint32_t *id0, uint32_t *id1, | |||||
uint32_t *resultOffset, float *resultDistance, | |||||
uint32_t dataLen); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -343,4 +345,4 @@ ACL_FUNC_VISIBILITY aclError aclfvSearch(aclfvSearchType type, aclfvSearchInput | |||||
} | } | ||||
#endif | #endif | ||||
#endif // INC_EXTERNAL_ACL_OPS_ACL_RETR_H_ | |||||
#endif // INC_EXTERNAL_ACL_OPS_ACL_RETR_H_ |
@@ -142,7 +142,7 @@ class GE_FUNC_VISIBILITY Session { | |||||
/// | /// | ||||
Status BuildGraph(uint32_t graphId, const std::vector<InputTensorInfo> &inputs); | Status BuildGraph(uint32_t graphId, const std::vector<InputTensorInfo> &inputs); | ||||
Status BuildGraph(uint32_t graphId, const std::vector<ge::Tensor> &inputs); /*lint !e148*/ | |||||
Status BuildGraph(uint32_t graphId, const std::vector<ge::Tensor> &inputs); /*lint !e148*/ | |||||
/// | /// | ||||
/// @ingroup ge_graph | /// @ingroup ge_graph | ||||
@@ -27,7 +27,7 @@ | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
extern "C" { | extern "C" { | ||||
#endif // __cplusplus | |||||
#endif // __cplusplus | |||||
/** | /** | ||||
* @brief Initialize HCCL. | * @brief Initialize HCCL. | ||||
@@ -66,14 +66,15 @@ extern HcclResult HcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *root | |||||
* @param sendBuf A pointer identifying the input data address of the operator. | * @param sendBuf A pointer identifying the input data address of the operator. | ||||
* @param recvBuf A pointer identifying the output data address of the operator. | * @param recvBuf A pointer identifying the output data address of the operator. | ||||
* @param count An integer(u64) identifying the number of the output data. | * @param count An integer(u64) identifying the number of the output data. | ||||
* @param dataType The data type of the operator, must be one of the following types: int8, int16, int32, float16, float32. | |||||
* @param dataType The data type of the operator, must be one of the following types: int8, int16, int32, float16, | |||||
* float32. | |||||
* @param op The reduction type of the operator, must be one of the following types: sum, min, max, prod. | * @param op The reduction type of the operator, must be one of the following types: sum, min, max, prod. | ||||
* @param comm A pointer identifying the communication resource based on. | * @param comm A pointer identifying the communication resource based on. | ||||
* @param stream A pointer identifying the stream information. | * @param stream A pointer identifying the stream information. | ||||
* @return HcclResult | |||||
* @return HcclResult | |||||
*/ | */ | ||||
extern HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, | |||||
HcclReduceOp op, HcclComm comm, aclrtStream stream); | |||||
extern HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op, | |||||
HcclComm comm, aclrtStream stream); | |||||
/** | /** | ||||
* @brief Broadcast operator. | * @brief Broadcast operator. | ||||
@@ -84,10 +85,10 @@ HcclReduceOp op, HcclComm comm, aclrtStream stream); | |||||
* @param root An integer(u32) identifying the the root rank in the operator. | * @param root An integer(u32) identifying the the root rank in the operator. | ||||
* @param comm A pointer identifying the communication resource based on | * @param comm A pointer identifying the communication resource based on | ||||
* @param stream A pointer identifying the stream information. | * @param stream A pointer identifying the stream information. | ||||
* @return HcclResult | |||||
* @return HcclResult | |||||
*/ | */ | ||||
extern HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, HcclComm comm, | |||||
aclrtStream stream); | |||||
extern HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, HcclComm comm, | |||||
aclrtStream stream); | |||||
/** | /** | ||||
* @brief ReduceScatter operator. | * @brief ReduceScatter operator. | ||||
@@ -99,10 +100,10 @@ aclrtStream stream); | |||||
* @param op The reduction type of the operator, must be one of the following types: sum, min, max, prod. | * @param op The reduction type of the operator, must be one of the following types: sum, min, max, prod. | ||||
* @param comm A pointer identifying the communication resource based on. | * @param comm A pointer identifying the communication resource based on. | ||||
* @param stream A pointer identifying the stream information. | * @param stream A pointer identifying the stream information. | ||||
* @return HcclResult | |||||
* @return HcclResult | |||||
*/ | */ | ||||
extern HcclResult HcclReduceScatter(void *sendBuf, void *recvBuf, uint64_t recvCount, HcclDataType dataType, | |||||
HcclReduceOp op, HcclComm comm, aclrtStream stream); | |||||
extern HcclResult HcclReduceScatter(void *sendBuf, void *recvBuf, uint64_t recvCount, HcclDataType dataType, | |||||
HcclReduceOp op, HcclComm comm, aclrtStream stream); | |||||
/** | /** | ||||
* @brief AllGather operator. | * @brief AllGather operator. | ||||
@@ -113,16 +114,16 @@ HcclReduceOp op, HcclComm comm, aclrtStream stream); | |||||
* @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32. | * @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32. | ||||
* @param comm A pointer identifying the communication resource based on. | * @param comm A pointer identifying the communication resource based on. | ||||
* @param stream A pointer identifying the stream information. | * @param stream A pointer identifying the stream information. | ||||
* @return HcclResult | |||||
* @return HcclResult | |||||
*/ | */ | ||||
extern HcclResult HcclAllGather(void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, | |||||
HcclComm comm, aclrtStream stream); | |||||
extern HcclResult HcclAllGather(void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, HcclComm comm, | |||||
aclrtStream stream); | |||||
/** | /** | ||||
* @brief Get the rank size of this comm. | * @brief Get the rank size of this comm. | ||||
* | * | ||||
* @param comm A pointer identifying the communication resource based on. | * @param comm A pointer identifying the communication resource based on. | ||||
* @param rankSize A pointer identifying the rank size. | * @param rankSize A pointer identifying the rank size. | ||||
* @return HcclResult | |||||
* @return HcclResult | |||||
*/ | */ | ||||
extern HcclResult HcclGetRankSize(HcclComm comm, uint32_t *rankSize); | extern HcclResult HcclGetRankSize(HcclComm comm, uint32_t *rankSize); | ||||
@@ -131,7 +132,7 @@ extern HcclResult HcclGetRankSize(HcclComm comm, uint32_t *rankSize); | |||||
* | * | ||||
* @param comm A pointer identifying the communication resource based on. | * @param comm A pointer identifying the communication resource based on. | ||||
* @param rankSize A pointer identifying the rank id. | * @param rankSize A pointer identifying the rank id. | ||||
* @return HcclResult | |||||
* @return HcclResult | |||||
*/ | */ | ||||
extern HcclResult HcclGetRankId(HcclComm comm, uint32_t *rank); | extern HcclResult HcclGetRankId(HcclComm comm, uint32_t *rank); | ||||
/** | /** | ||||
@@ -139,7 +140,7 @@ extern HcclResult HcclGetRankId(HcclComm comm, uint32_t *rank); | |||||
* | * | ||||
* @param comm A pointer identifying the communication resource based on. | * @param comm A pointer identifying the communication resource based on. | ||||
* @param stream A pointer identifying the stream information. | * @param stream A pointer identifying the stream information. | ||||
* @return HcclResult | |||||
* @return HcclResult | |||||
*/ | */ | ||||
extern HcclResult HcclBarrier(HcclComm comm, aclrtStream stream); | extern HcclResult HcclBarrier(HcclComm comm, aclrtStream stream); | ||||
@@ -154,5 +155,5 @@ extern HcclResult HcclCommDestroy(HcclComm comm); | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif // __cplusplus | |||||
#endif // HCCL_H_ | |||||
#endif // __cplusplus | |||||
#endif // HCCL_H_ |
@@ -16,10 +16,10 @@ | |||||
/** | /** | ||||
* @file hccl_types.h | * @file hccl_types.h | ||||
* @brief HCCL data type definition | |||||
* | |||||
* @brief HCCL data type definition | |||||
* | |||||
*/ | */ | ||||
#ifndef HCCL_TYPES_H_ | #ifndef HCCL_TYPES_H_ | ||||
#define HCCL_TYPES_H_ | #define HCCL_TYPES_H_ | ||||
@@ -27,33 +27,33 @@ | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
extern "C" { | extern "C" { | ||||
#endif // __cplusplus | |||||
#endif // __cplusplus | |||||
/** | /** | ||||
* @brief HCCL functions return value definition | * @brief HCCL functions return value definition | ||||
*/ | */ | ||||
typedef enum { | typedef enum { | ||||
HCCL_SUCCESS = 0, /**< success */ | |||||
HCCL_E_PARA = 1, /**< parameter error */ | |||||
HCCL_E_PTR = 2, /**< empty pointer */ | |||||
HCCL_E_MEMORY = 3, /**< memory error */ | |||||
HCCL_E_INTERNAL = 4, /**< internal error */ | |||||
HCCL_E_NOT_SUPPORT = 5, /**< not support feature */ | |||||
HCCL_E_NOT_FOUND = 6, /**< not found specific resource */ | |||||
HCCL_E_UNAVAIL = 7, /**< resource unavailable */ | |||||
HCCL_E_SYSCALL = 8, /**< call system interface error */ | |||||
HCCL_E_TIMEOUT = 9, /**< timeout */ | |||||
HCCL_E_OPEN_FILE_FAILURE = 10, /**< open file fail */ | |||||
HCCL_E_TCP_CONNECT = 11, /**< tcp connect fail */ | |||||
HCCL_E_ROCE_CONNECT = 12, /**< roce connect fail */ | |||||
HCCL_E_TCP_TRANSFER = 13, /**< tcp transfer fail */ | |||||
HCCL_E_ROCE_TRANSFER = 14, /**< roce transfer fail */ | |||||
HCCL_E_RUNTIME = 15, /**< call runtime api fail */ | |||||
HCCL_E_DRV = 16, /**< call driver api fail */ | |||||
HCCL_E_PROFILING = 17, /**< call profiling api fail */ | |||||
HCCL_E_CCE = 18, /**< call cce api fail */ | |||||
HCCL_E_NETWORK = 19, /**< call network api fail */ | |||||
HCCL_E_RESERVED /**< reserved */ | |||||
HCCL_SUCCESS = 0, /**< success */ | |||||
HCCL_E_PARA = 1, /**< parameter error */ | |||||
HCCL_E_PTR = 2, /**< empty pointer */ | |||||
HCCL_E_MEMORY = 3, /**< memory error */ | |||||
HCCL_E_INTERNAL = 4, /**< internal error */ | |||||
HCCL_E_NOT_SUPPORT = 5, /**< not support feature */ | |||||
HCCL_E_NOT_FOUND = 6, /**< not found specific resource */ | |||||
HCCL_E_UNAVAIL = 7, /**< resource unavailable */ | |||||
HCCL_E_SYSCALL = 8, /**< call system interface error */ | |||||
HCCL_E_TIMEOUT = 9, /**< timeout */ | |||||
HCCL_E_OPEN_FILE_FAILURE = 10, /**< open file fail */ | |||||
HCCL_E_TCP_CONNECT = 11, /**< tcp connect fail */ | |||||
HCCL_E_ROCE_CONNECT = 12, /**< roce connect fail */ | |||||
HCCL_E_TCP_TRANSFER = 13, /**< tcp transfer fail */ | |||||
HCCL_E_ROCE_TRANSFER = 14, /**< roce transfer fail */ | |||||
HCCL_E_RUNTIME = 15, /**< call runtime api fail */ | |||||
HCCL_E_DRV = 16, /**< call driver api fail */ | |||||
HCCL_E_PROFILING = 17, /**< call profiling api fail */ | |||||
HCCL_E_CCE = 18, /**< call cce api fail */ | |||||
HCCL_E_NETWORK = 19, /**< call network api fail */ | |||||
HCCL_E_RESERVED /**< reserved */ | |||||
} HcclResult; | } HcclResult; | ||||
/** | /** | ||||
@@ -65,37 +65,37 @@ typedef void *HcclComm; | |||||
* @brief HCCL Reduction opperation | * @brief HCCL Reduction opperation | ||||
*/ | */ | ||||
typedef enum { | typedef enum { | ||||
HCCL_REDUCE_SUM = 0, /**< sum */ | |||||
HCCL_REDUCE_PROD = 1, /**< prod */ | |||||
HCCL_REDUCE_MAX = 2, /**< max */ | |||||
HCCL_REDUCE_MIN = 3, /**< min */ | |||||
HCCL_REDUCE_RESERVED /**< reserved */ | |||||
HCCL_REDUCE_SUM = 0, /**< sum */ | |||||
HCCL_REDUCE_PROD = 1, /**< prod */ | |||||
HCCL_REDUCE_MAX = 2, /**< max */ | |||||
HCCL_REDUCE_MIN = 3, /**< min */ | |||||
HCCL_REDUCE_RESERVED /**< reserved */ | |||||
} HcclReduceOp; | } HcclReduceOp; | ||||
/** | /** | ||||
* @brief HCCL data type | * @brief HCCL data type | ||||
*/ | */ | ||||
typedef enum { | typedef enum { | ||||
HCCL_DATA_TYPE_INT8 = 0, /**< int8 */ | |||||
HCCL_DATA_TYPE_INT16 = 1, /**< int16 */ | |||||
HCCL_DATA_TYPE_INT32 = 2, /**< int32 */ | |||||
HCCL_DATA_TYPE_FP16 = 3, /**< fp16 */ | |||||
HCCL_DATA_TYPE_FP32 = 4, /**< fp32 */ | |||||
HCCL_DATA_TYPE_INT64 = 5, /**< int64 */ | |||||
HCCL_DATA_TYPE_UINT64 = 6, /**< uint64 */ | |||||
HCCL_DATA_TYPE_RESERVED /**< reserved */ | |||||
HCCL_DATA_TYPE_INT8 = 0, /**< int8 */ | |||||
HCCL_DATA_TYPE_INT16 = 1, /**< int16 */ | |||||
HCCL_DATA_TYPE_INT32 = 2, /**< int32 */ | |||||
HCCL_DATA_TYPE_FP16 = 3, /**< fp16 */ | |||||
HCCL_DATA_TYPE_FP32 = 4, /**< fp32 */ | |||||
HCCL_DATA_TYPE_INT64 = 5, /**< int64 */ | |||||
HCCL_DATA_TYPE_UINT64 = 6, /**< uint64 */ | |||||
HCCL_DATA_TYPE_RESERVED /**< reserved */ | |||||
} HcclDataType; | } HcclDataType; | ||||
const uint32_t HCCL_ROOT_INFO_BYTES = 4108; // 4108: root info length | |||||
const uint32_t HCCL_ROOT_INFO_BYTES = 4108; // 4108: root info length | |||||
/** | /** | ||||
* @brief HCCL root info | * @brief HCCL root info | ||||
*/ | */ | ||||
typedef struct HcclRootInfoDef { | typedef struct HcclRootInfoDef { | ||||
char internal[HCCL_ROOT_INFO_BYTES]; | |||||
char internal[HCCL_ROOT_INFO_BYTES]; | |||||
} HcclRootInfo; | } HcclRootInfo; | ||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif // __cplusplus | |||||
#endif // HCCL_TYPES_H_ | |||||
#endif // __cplusplus | |||||
#endif // HCCL_TYPES_H_ |
@@ -23,87 +23,87 @@ | |||||
extern "C" { | extern "C" { | ||||
#endif | #endif | ||||
static const int32_t ACL_RT_SUCCESS = 0; // success | |||||
static const int32_t ACL_RT_SUCCESS = 0; // success | |||||
static const int32_t ACL_ERROR_RT_PARAM_INVALID = 107000; // param invalid | |||||
static const int32_t ACL_ERROR_RT_INVALID_DEVICEID = 107001; // invalid device id | |||||
static const int32_t ACL_ERROR_RT_CONTEXT_NULL = 107002; // current context null | |||||
static const int32_t ACL_ERROR_RT_STREAM_CONTEXT = 107003; // stream not in current context | |||||
static const int32_t ACL_ERROR_RT_MODEL_CONTEXT = 107004; // model not in current context | |||||
static const int32_t ACL_ERROR_RT_STREAM_MODEL = 107005; // stream not in model | |||||
static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_INVALID = 107006; // event timestamp invalid | |||||
static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_REVERSAL = 107007; // event timestamp reversal | |||||
static const int32_t ACL_ERROR_RT_ADDR_UNALIGNED = 107008; // memory address unaligned | |||||
static const int32_t ACL_ERROR_RT_FILE_OPEN = 107009; // open file failed | |||||
static const int32_t ACL_ERROR_RT_FILE_WRITE = 107010; // write file failed | |||||
static const int32_t ACL_ERROR_RT_STREAM_SUBSCRIBE = 107011; // error subscribe stream | |||||
static const int32_t ACL_ERROR_RT_THREAD_SUBSCRIBE = 107012; // error subscribe thread | |||||
static const int32_t ACL_ERROR_RT_GROUP_NOT_SET = 107013; // group not set | |||||
static const int32_t ACL_ERROR_RT_GROUP_NOT_CREATE = 107014; // group not create | |||||
static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callback not register to stream | |||||
static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type | |||||
static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle | |||||
static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type | |||||
static const int32_t ACL_ERROR_RT_WAIT_TIMEOUT = 107019; // wait timeout | |||||
static const int32_t ACL_ERROR_RT_PARAM_INVALID = 107000; // param invalid | |||||
static const int32_t ACL_ERROR_RT_INVALID_DEVICEID = 107001; // invalid device id | |||||
static const int32_t ACL_ERROR_RT_CONTEXT_NULL = 107002; // current context null | |||||
static const int32_t ACL_ERROR_RT_STREAM_CONTEXT = 107003; // stream not in current context | |||||
static const int32_t ACL_ERROR_RT_MODEL_CONTEXT = 107004; // model not in current context | |||||
static const int32_t ACL_ERROR_RT_STREAM_MODEL = 107005; // stream not in model | |||||
static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_INVALID = 107006; // event timestamp invalid | |||||
static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_REVERSAL = 107007; // event timestamp reversal | |||||
static const int32_t ACL_ERROR_RT_ADDR_UNALIGNED = 107008; // memory address unaligned | |||||
static const int32_t ACL_ERROR_RT_FILE_OPEN = 107009; // open file failed | |||||
static const int32_t ACL_ERROR_RT_FILE_WRITE = 107010; // write file failed | |||||
static const int32_t ACL_ERROR_RT_STREAM_SUBSCRIBE = 107011; // error subscribe stream | |||||
static const int32_t ACL_ERROR_RT_THREAD_SUBSCRIBE = 107012; // error subscribe thread | |||||
static const int32_t ACL_ERROR_RT_GROUP_NOT_SET = 107013; // group not set | |||||
static const int32_t ACL_ERROR_RT_GROUP_NOT_CREATE = 107014; // group not create | |||||
static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callback not register to stream | |||||
static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type | |||||
static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle | |||||
static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type | |||||
static const int32_t ACL_ERROR_RT_WAIT_TIMEOUT = 107019; // wait timeout | |||||
static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support | |||||
static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error | |||||
static const int32_t ACL_ERROR_RT_MEMORY_FREE = 207002; // memory free error | |||||
static const int32_t ACL_ERROR_RT_AICORE_OVER_FLOW = 207003; // aicore over flow | |||||
static const int32_t ACL_ERROR_RT_NO_DEVICE = 207004; // no device | |||||
static const int32_t ACL_ERROR_RT_RESOURCE_ALLOC_FAIL = 207005; // resource alloc fail | |||||
static const int32_t ACL_ERROR_RT_NO_PERMISSION = 207006; // no permission | |||||
static const int32_t ACL_ERROR_RT_NO_EVENT_RESOURCE = 207007; // no event resource | |||||
static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource | |||||
static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource | |||||
static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource | |||||
static const int32_t ACL_ERROR_RT_NO_CDQ_RESOURCE = 207011; // no cdq resource | |||||
static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support | |||||
static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error | |||||
static const int32_t ACL_ERROR_RT_MEMORY_FREE = 207002; // memory free error | |||||
static const int32_t ACL_ERROR_RT_AICORE_OVER_FLOW = 207003; // aicore over flow | |||||
static const int32_t ACL_ERROR_RT_NO_DEVICE = 207004; // no device | |||||
static const int32_t ACL_ERROR_RT_RESOURCE_ALLOC_FAIL = 207005; // resource alloc fail | |||||
static const int32_t ACL_ERROR_RT_NO_PERMISSION = 207006; // no permission | |||||
static const int32_t ACL_ERROR_RT_NO_EVENT_RESOURCE = 207007; // no event resource | |||||
static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource | |||||
static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource | |||||
static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource | |||||
static const int32_t ACL_ERROR_RT_NO_CDQ_RESOURCE = 207011; // no cdq resource | |||||
static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error | |||||
static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error | |||||
static const int32_t ACL_ERROR_RT_STREAM_TASK_FULL = 507002; // task full in stream | |||||
static const int32_t ACL_ERROR_RT_STREAM_TASK_EMPTY = 507003; // task empty in stream | |||||
static const int32_t ACL_ERROR_RT_STREAM_NOT_COMPLETE = 507004; // stream not complete | |||||
static const int32_t ACL_ERROR_RT_END_OF_SEQUENCE = 507005; // end of sequence | |||||
static const int32_t ACL_ERROR_RT_EVENT_NOT_COMPLETE = 507006; // event not complete | |||||
static const int32_t ACL_ERROR_RT_CONTEXT_RELEASE_ERROR = 507007; // context release error | |||||
static const int32_t ACL_ERROR_RT_SOC_VERSION = 507008; // soc version error | |||||
static const int32_t ACL_ERROR_RT_TASK_TYPE_NOT_SUPPORT = 507009; // task type not support | |||||
static const int32_t ACL_ERROR_RT_LOST_HEARTBEAT = 507010; // ts lost heartbeat | |||||
static const int32_t ACL_ERROR_RT_MODEL_EXECUTE = 507011; // model execute failed | |||||
static const int32_t ACL_ERROR_RT_REPORT_TIMEOUT = 507012; // report timeout | |||||
static const int32_t ACL_ERROR_RT_SYS_DMA = 507013; // sys dma error | |||||
static const int32_t ACL_ERROR_RT_AICORE_TIMEOUT = 507014; // aicore timeout | |||||
static const int32_t ACL_ERROR_RT_AICORE_EXCEPTION = 507015; // aicore exception | |||||
static const int32_t ACL_ERROR_RT_AICORE_TRAP_EXCEPTION = 507016; // aicore trap exception | |||||
static const int32_t ACL_ERROR_RT_AICPU_TIMEOUT = 507017; // aicpu timeout | |||||
static const int32_t ACL_ERROR_RT_AICPU_EXCEPTION = 507018; // aicpu exception | |||||
static const int32_t ACL_ERROR_RT_AICPU_DATADUMP_RSP_ERR = 507019; // aicpu datadump response error | |||||
static const int32_t ACL_ERROR_RT_AICPU_MODEL_RSP_ERR = 507020; // aicpu model operate response error | |||||
static const int32_t ACL_ERROR_RT_PROFILING_ERROR = 507021; // profiling error | |||||
static const int32_t ACL_ERROR_RT_IPC_ERROR = 507022; // ipc error | |||||
static const int32_t ACL_ERROR_RT_MODEL_ABORT_NORMAL = 507023; // model abort normal | |||||
static const int32_t ACL_ERROR_RT_KERNEL_UNREGISTERING = 507024; // kernel unregistering | |||||
static const int32_t ACL_ERROR_RT_RINGBUFFER_NOT_INIT = 507025; // ringbuffer not init | |||||
static const int32_t ACL_ERROR_RT_RINGBUFFER_NO_DATA = 507026; // ringbuffer no data | |||||
static const int32_t ACL_ERROR_RT_KERNEL_LOOKUP = 507027; // kernel lookup error | |||||
static const int32_t ACL_ERROR_RT_KERNEL_DUPLICATE = 507028; // kernel register duplicate | |||||
static const int32_t ACL_ERROR_RT_DEBUG_REGISTER_FAIL = 507029; // debug register failed | |||||
static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug unregister failed | |||||
static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context | |||||
static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out | |||||
static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error | |||||
static const int32_t ACL_ERROR_RT_VECTOR_CORE_TIMEOUT = 507034; // vector core timeout | |||||
static const int32_t ACL_ERROR_RT_VECTOR_CORE_EXCEPTION = 507035; // vector core exception | |||||
static const int32_t ACL_ERROR_RT_VECTOR_CORE_TRAP_EXCEPTION = 507036; // vector core trap exception | |||||
static const int32_t ACL_ERROR_RT_CDQ_BATCH_ABNORMAL = 507037; // cdq alloc batch abnormal | |||||
static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error | |||||
static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error | |||||
static const int32_t ACL_ERROR_RT_STREAM_TASK_FULL = 507002; // task full in stream | |||||
static const int32_t ACL_ERROR_RT_STREAM_TASK_EMPTY = 507003; // task empty in stream | |||||
static const int32_t ACL_ERROR_RT_STREAM_NOT_COMPLETE = 507004; // stream not complete | |||||
static const int32_t ACL_ERROR_RT_END_OF_SEQUENCE = 507005; // end of sequence | |||||
static const int32_t ACL_ERROR_RT_EVENT_NOT_COMPLETE = 507006; // event not complete | |||||
static const int32_t ACL_ERROR_RT_CONTEXT_RELEASE_ERROR = 507007; // context release error | |||||
static const int32_t ACL_ERROR_RT_SOC_VERSION = 507008; // soc version error | |||||
static const int32_t ACL_ERROR_RT_TASK_TYPE_NOT_SUPPORT = 507009; // task type not support | |||||
static const int32_t ACL_ERROR_RT_LOST_HEARTBEAT = 507010; // ts lost heartbeat | |||||
static const int32_t ACL_ERROR_RT_MODEL_EXECUTE = 507011; // model execute failed | |||||
static const int32_t ACL_ERROR_RT_REPORT_TIMEOUT = 507012; // report timeout | |||||
static const int32_t ACL_ERROR_RT_SYS_DMA = 507013; // sys dma error | |||||
static const int32_t ACL_ERROR_RT_AICORE_TIMEOUT = 507014; // aicore timeout | |||||
static const int32_t ACL_ERROR_RT_AICORE_EXCEPTION = 507015; // aicore exception | |||||
static const int32_t ACL_ERROR_RT_AICORE_TRAP_EXCEPTION = 507016; // aicore trap exception | |||||
static const int32_t ACL_ERROR_RT_AICPU_TIMEOUT = 507017; // aicpu timeout | |||||
static const int32_t ACL_ERROR_RT_AICPU_EXCEPTION = 507018; // aicpu exception | |||||
static const int32_t ACL_ERROR_RT_AICPU_DATADUMP_RSP_ERR = 507019; // aicpu datadump response error | |||||
static const int32_t ACL_ERROR_RT_AICPU_MODEL_RSP_ERR = 507020; // aicpu model operate response error | |||||
static const int32_t ACL_ERROR_RT_PROFILING_ERROR = 507021; // profiling error | |||||
static const int32_t ACL_ERROR_RT_IPC_ERROR = 507022; // ipc error | |||||
static const int32_t ACL_ERROR_RT_MODEL_ABORT_NORMAL = 507023; // model abort normal | |||||
static const int32_t ACL_ERROR_RT_KERNEL_UNREGISTERING = 507024; // kernel unregistering | |||||
static const int32_t ACL_ERROR_RT_RINGBUFFER_NOT_INIT = 507025; // ringbuffer not init | |||||
static const int32_t ACL_ERROR_RT_RINGBUFFER_NO_DATA = 507026; // ringbuffer no data | |||||
static const int32_t ACL_ERROR_RT_KERNEL_LOOKUP = 507027; // kernel lookup error | |||||
static const int32_t ACL_ERROR_RT_KERNEL_DUPLICATE = 507028; // kernel register duplicate | |||||
static const int32_t ACL_ERROR_RT_DEBUG_REGISTER_FAIL = 507029; // debug register failed | |||||
static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug unregister failed | |||||
static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context | |||||
static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out | |||||
static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error | |||||
static const int32_t ACL_ERROR_RT_VECTOR_CORE_TIMEOUT = 507034; // vector core timeout | |||||
static const int32_t ACL_ERROR_RT_VECTOR_CORE_EXCEPTION = 507035; // vector core exception | |||||
static const int32_t ACL_ERROR_RT_VECTOR_CORE_TRAP_EXCEPTION = 507036; // vector core trap exception | |||||
static const int32_t ACL_ERROR_RT_CDQ_BATCH_ABNORMAL = 507037; // cdq alloc batch abnormal | |||||
static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error | |||||
static const int32_t ACL_ERROR_RT_AICPU_INTERNAL_ERROR = 507900; // aicpu internal error | |||||
static const int32_t ACL_ERROR_RT_SOCKET_CLOSE = 507901; // hdc disconnect | |||||
static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error | |||||
static const int32_t ACL_ERROR_RT_AICPU_INTERNAL_ERROR = 507900; // aicpu internal error | |||||
static const int32_t ACL_ERROR_RT_SOCKET_CLOSE = 507901; // hdc disconnect | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif | #endif | ||||
#endif // __INC_EXTERNEL_RT_ERROR_CODES_H__ | |||||
#endif // __INC_EXTERNEL_RT_ERROR_CODES_H__ |
@@ -125,13 +125,13 @@ GE_ERRORNO_CLIENT(GE_CLI_GE_ALREADY_INITIALIZED, 10, "GE is already initialized. | |||||
GE_ERRORNO_CLIENT(GE_CLI_GE_NOT_INITIALIZED, 11, "GE is not yet initialized or is finalized."); // 1343229963 | GE_ERRORNO_CLIENT(GE_CLI_GE_NOT_INITIALIZED, 11, "GE is not yet initialized or is finalized."); // 1343229963 | ||||
// Init module error code definition | // Init module error code definition | ||||
GE_ERRORNO_INIT(GE_MULTI_INIT, 0, "Multiple initializations are not supported."); // 1343234048 | |||||
GE_ERRORNO_INIT(GE_FINALIZE_NOT_INIT, 1, "Finalize is not allowed before initialization."); // 1343234049 | |||||
GE_ERRORNO_INIT(GE_MULTI_FINALIZE, 2, "Multiple finalizations are not supported."); // 1343234050 | |||||
GE_ERRORNO_INIT(GE_PROF_MULTI_INIT, 3, "Multiple profiling initializations are not supported."); // 1343234051 | |||||
GE_ERRORNO_INIT(GE_PROF_NOT_INIT, 4, "Profing initializations have not been done."); // 1343234052 | |||||
GE_ERRORNO_INIT(GE_MULTI_INIT, 0, "Multiple initializations are not supported."); // 1343234048 | |||||
GE_ERRORNO_INIT(GE_FINALIZE_NOT_INIT, 1, "Finalize is not allowed before initialization."); // 1343234049 | |||||
GE_ERRORNO_INIT(GE_MULTI_FINALIZE, 2, "Multiple finalizations are not supported."); // 1343234050 | |||||
GE_ERRORNO_INIT(GE_PROF_MULTI_INIT, 3, "Multiple profiling initializations are not supported."); // 1343234051 | |||||
GE_ERRORNO_INIT(GE_PROF_NOT_INIT, 4, "Profing initializations have not been done."); // 1343234052 | |||||
GE_ERRORNO_INIT(GE_PROF_MODE_CONFLICT, 5, | GE_ERRORNO_INIT(GE_PROF_MODE_CONFLICT, 5, | ||||
"Profiling command mode which is preferred is running, the api mode will not work."); // 1343234053 | |||||
"Profiling command mode which is preferred is running, the api mode will not work."); // 1343234053 | |||||
// Session module error code definition | // Session module error code definition | ||||
GE_ERRORNO_SESSION(GE_SESS_INIT_FAILED, 0, "Failed to initialize session."); // 1343238144 | GE_ERRORNO_SESSION(GE_SESS_INIT_FAILED, 0, "Failed to initialize session."); // 1343238144 | ||||
@@ -216,8 +216,8 @@ GE_ERRORNO_ENGINE(GE_ENG_FINALIZE_FAILED, 1, "Engine finalize failed."); | |||||
GE_ERRORNO_ENGINE(GE_ENG_MEMTYPE_ERROR, 2, "Memory type HBM is necessary when engine is in device"); // 1343246338 | GE_ERRORNO_ENGINE(GE_ENG_MEMTYPE_ERROR, 2, "Memory type HBM is necessary when engine is in device"); // 1343246338 | ||||
// Optimize errocode | // Optimize errocode | ||||
GE_ERRORNO_GRAPH(TO_BE_DELETED, 63, "The node of the graph to be deleted."); // 1343242303 | |||||
GE_ERRORNO_GRAPH(NOT_CHANGED, 64, "The node of the graph no changed."); // 1343242304 | |||||
GE_ERRORNO_GRAPH(TO_BE_DELETED, 63, "The node of the graph to be deleted."); // 1343242303 | |||||
GE_ERRORNO_GRAPH(NOT_CHANGED, 64, "The node of the graph no changed."); // 1343242304 | |||||
// Ops module error code definition | // Ops module error code definition | ||||
GE_ERRORNO_OPS(GE_OPS_KERNEL_STORE_INIT_FAILED, 0, "Failed to initialize OpsKernelInfoStore."); // 1343250432 | GE_ERRORNO_OPS(GE_OPS_KERNEL_STORE_INIT_FAILED, 0, "Failed to initialize OpsKernelInfoStore."); // 1343250432 | ||||
@@ -169,6 +169,6 @@ GE_FUNC_VISIBILITY bool GetAttrDefListValue(const std::string &key, int idx, int | |||||
GE_FUNC_VISIBILITY bool GetAttrDefListValue(const std::string &key, int idx, uint32_t *value, const AttrDefMap &attr); | GE_FUNC_VISIBILITY bool GetAttrDefListValue(const std::string &key, int idx, uint32_t *value, const AttrDefMap &attr); | ||||
GE_FUNC_VISIBILITY bool GetAttrDefListValue(const std::string &key, int idx, float *value, const AttrDefMap &attr); | GE_FUNC_VISIBILITY bool GetAttrDefListValue(const std::string &key, int idx, float *value, const AttrDefMap &attr); | ||||
GE_FUNC_VISIBILITY bool GetAttrDefListValue(const std::string &key, int idx, double *value, const AttrDefMap &attr); | GE_FUNC_VISIBILITY bool GetAttrDefListValue(const std::string &key, int idx, double *value, const AttrDefMap &attr); | ||||
} | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | #endif // INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ |
@@ -43,6 +43,11 @@ GE_FUNC_VISIBILITY ge::Status RegProfCtrlCallback(MsprofCtrlCallback func); | |||||
GE_FUNC_VISIBILITY ge::Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func); | GE_FUNC_VISIBILITY ge::Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func); | ||||
GE_FUNC_VISIBILITY ge::Status RegProfReporterCallback(MsprofReporterCallback func); | GE_FUNC_VISIBILITY ge::Status RegProfReporterCallback(MsprofReporterCallback func); | ||||
GE_FUNC_VISIBILITY ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len); | GE_FUNC_VISIBILITY ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len); | ||||
/// | |||||
/// @brief Output the profiling data of single operator in Pytorch, and does not support multithreading | |||||
/// @return Status result | |||||
/// | |||||
GE_FUNC_VISIBILITY ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream); | GE_FUNC_VISIBILITY ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream); | ||||
#endif // INC_FRAMEWORK_COMMON_GE_PROFILING_H_ | #endif // INC_FRAMEWORK_COMMON_GE_PROFILING_H_ |
@@ -42,8 +42,9 @@ class GE_FUNC_VISIBILITY ScopeGuard { | |||||
if (on_exit_scope_ != nullptr) { | if (on_exit_scope_ != nullptr) { | ||||
try { | try { | ||||
on_exit_scope_(); | on_exit_scope_(); | ||||
} catch (std::bad_function_call &e) { } | |||||
catch (...) { } | |||||
} catch (std::bad_function_call &e) { | |||||
} catch (...) { | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -52,7 +52,7 @@ class GE_FUNC_VISIBILITY StringUtils { | |||||
return s; | return s; | ||||
} | } | ||||
// lint -esym(551,*) | // lint -esym(551,*) | ||||
static std::string &Rtrim(std::string &s) { /*lint !e618*/ | |||||
static std::string &Rtrim(std::string &s) { /*lint !e618*/ | |||||
#if __cplusplus >= 201103L | #if __cplusplus >= 201103L | ||||
(void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); })); | (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); })); | ||||
#else | #else | ||||
@@ -28,4 +28,4 @@ GE_FUNC_VISIBILITY Status ParserInitialize(const std::map<std::string, std::stri | |||||
// Finalize parser, release all resources | // Finalize parser, release all resources | ||||
GE_FUNC_VISIBILITY Status ParserFinalize(); | GE_FUNC_VISIBILITY Status ParserFinalize(); | ||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_FRAMEWORK_OMG_PARSER_PARSER_API_H_ | |||||
#endif // INC_FRAMEWORK_OMG_PARSER_PARSER_API_H_ |
@@ -1 +1 @@ | |||||
Subproject commit d9a260e2b42236ffaf514bc6397116e370506068 | |||||
Subproject commit 0a9ebe1c7f7b27554659f39e387110ac30d4a1e6 |
@@ -1 +1 @@ | |||||
Subproject commit c074dfa5960d67f2910122d46d4d264dd6554aad | |||||
Subproject commit b79ef8ad19c8ab4335a97b2c3668d2776b62ce0a |
@@ -345,6 +345,10 @@ INT32 mmIsDir(const CHAR *fileName) | |||||
INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len) | INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len) | ||||
{ | { | ||||
const char *env = getenv(name); | |||||
if (env != nullptr) { | |||||
strcpy(value, env); | |||||
} | |||||
return 0; | return 0; | ||||
} | } | ||||
@@ -16,6 +16,7 @@ | |||||
#include "toolchain/prof_engine.h" | #include "toolchain/prof_engine.h" | ||||
#include "toolchain/prof_mgr_core.h" | #include "toolchain/prof_mgr_core.h" | ||||
#include "runtime/base.h" | |||||
void * ProfMgrStartUp(const ProfMgrCfg *cfg) | void * ProfMgrStartUp(const ProfMgrCfg *cfg) | ||||
{ | { | ||||
@@ -32,3 +33,10 @@ int Msprof::Engine::RegisterEngine(const std::string& module, const Msprof::Engi | |||||
return 0; | return 0; | ||||
} | } | ||||
rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback) { | |||||
return 0; | |||||
} | |||||
rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCallback callback) { | |||||
return 0; | |||||
} |
@@ -158,6 +158,7 @@ set(COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/opskernel_manager/ops_kernel_builder_manager.cc" | "${GE_CODE_DIR}/ge/opskernel_manager/ops_kernel_builder_manager.cc" | ||||
"${GE_CODE_DIR}/ge/graph/load/model_manager/model_manager.cc" | "${GE_CODE_DIR}/ge/graph/load/model_manager/model_manager.cc" | ||||
"${GE_CODE_DIR}/ge/common/profiling/profiling_manager.cc" | "${GE_CODE_DIR}/ge/common/profiling/profiling_manager.cc" | ||||
"${GE_CODE_DIR}/ge/common/profiling/ge_profiling.cc" | |||||
"${GE_CODE_DIR}/ge/graph/manager/host_mem_manager.cc" | "${GE_CODE_DIR}/ge/graph/manager/host_mem_manager.cc" | ||||
"${GE_CODE_DIR}/ge/graph/manager/memory_api.cc" | "${GE_CODE_DIR}/ge/graph/manager/memory_api.cc" | ||||
"${GE_CODE_DIR}/ge/session/inner_session.cc" | "${GE_CODE_DIR}/ge/session/inner_session.cc" | ||||
@@ -725,7 +726,6 @@ set(PASS_TEST_FILES | |||||
"graph/passes/memcpy_addr_async_unittest.cc" | "graph/passes/memcpy_addr_async_unittest.cc" | ||||
"graph/passes/hccl_continuous_pass_unittest.cc" | "graph/passes/hccl_continuous_pass_unittest.cc" | ||||
"graph/passes/hccl_memcpy_pass_unittest.cc" | "graph/passes/hccl_memcpy_pass_unittest.cc" | ||||
) | ) | ||||
set(KERNEL_TEST_FILES | set(KERNEL_TEST_FILES | ||||
@@ -858,7 +858,6 @@ set(HYBRID_TEST_FILES | |||||
"hybrid/executor/hybrid_model_async_executor_unittest.cc" | "hybrid/executor/hybrid_model_async_executor_unittest.cc" | ||||
"hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | ||||
"hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | ||||
) | ) | ||||
set(OTHERS_TEST_FILES | set(OTHERS_TEST_FILES | ||||
@@ -886,6 +885,7 @@ add_library(ge_ut_graph STATIC | |||||
target_compile_definitions(ge_ut_graph PRIVATE | target_compile_definitions(ge_ut_graph PRIVATE | ||||
google=ascend_private | google=ascend_private | ||||
FMK_SUPPORT_DUMP | |||||
) | ) | ||||
target_compile_options(ge_ut_graph PRIVATE | target_compile_options(ge_ut_graph PRIVATE | ||||
@@ -349,7 +349,7 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||||
/// B --> C(AllReduce) --- D | /// B --> C(AllReduce) --- D | ||||
/// / | /// / | ||||
/// stream id: 0 A | /// stream id: 0 A | ||||
/// \ | |||||
/// \. | |||||
/// E --> F(AllReduce) --- G | /// E --> F(AllReduce) --- G | ||||
/// stream id: 2 2 2 | /// stream id: 2 2 2 | ||||
/// | /// | ||||
@@ -599,7 +599,7 @@ TEST_F(UtestLogicalStreamAllocator, test_label_not_reusable2) { | |||||
/// case of multi-output, then unuse stream | /// case of multi-output, then unuse stream | ||||
/// sub1 | /// sub1 | ||||
/// / | \ | |||||
/// / | \. | |||||
/// sub2 sub3 sub4 | /// sub2 sub3 sub4 | ||||
TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | ||||
SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
@@ -624,7 +624,7 @@ TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | |||||
/// if paralle id 1, then use stream | /// if paralle id 1, then use stream | ||||
/// sub1 | /// sub1 | ||||
/// / | | \ | |||||
/// / | | \. | |||||
/// sub2 sub3 sub4 sub5 | /// sub2 sub3 sub4 sub5 | ||||
TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | ||||
SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
@@ -653,7 +653,7 @@ TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | |||||
/// if the param of engine independent is true, then set independent stream | /// if the param of engine independent is true, then set independent stream | ||||
/// sub1 | /// sub1 | ||||
/// / | | \ | |||||
/// / | | \. | |||||
/// sub2 sub3 sub4 sub5 | /// sub2 sub3 sub4 sub5 | ||||
TEST_F(UtestLogicalStreamAllocator, test_independent) { | TEST_F(UtestLogicalStreamAllocator, test_independent) { | ||||
SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
@@ -692,7 +692,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent) { | |||||
/// set stream based on stream label, and then based on independent | /// set stream based on stream label, and then based on independent | ||||
/// sub1 | /// sub1 | ||||
/// / | | \ | |||||
/// / | | \. | |||||
/// sub2 sub3 sub4 sub5 | /// sub2 sub3 sub4 sub5 | ||||
TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { | TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { | ||||
SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
@@ -36,7 +36,7 @@ class UtestStreamAllocator : public testing::Test { | |||||
/// | /// | ||||
/// A | /// A | ||||
/// / \ | |||||
/// / \. | |||||
/// B C | /// B C | ||||
/// | | | /// | | | ||||
/// D 400 | /// D 400 | ||||
@@ -116,7 +116,9 @@ TEST_F(UtestTaskGeneratorTest, FindLastBpFromBpNode) { | |||||
TaskGenerator task_generator(nullptr, 0); | TaskGenerator task_generator(nullptr, 0); | ||||
auto net_output = graph->FindNode("Node_Output"); | auto net_output = graph->FindNode("Node_Output"); | ||||
// netoutput has no data input, return default value 0 | // netoutput has no data input, return default value 0 | ||||
EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output), 0); | |||||
uint32_t bp_index = 0; | |||||
EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output, bp_index), 0); | |||||
EXPECT_EQ(bp_index, 2); | |||||
} | } | ||||
TEST_F(UtestTaskGeneratorTest, UpdateOpIsVarAttr) { | TEST_F(UtestTaskGeneratorTest, UpdateOpIsVarAttr) { | ||||
@@ -438,4 +438,22 @@ TEST_F(UtestModelManagerModelManager, test_data_input_tensor) { | |||||
auto ret = mm.DataInputTensor(model_id,inputs); | auto ret = mm.DataInputTensor(model_id,inputs); | ||||
EXPECT_EQ(PARAM_INVALID, ret); // HybridDavinciModel::impl_ is null. | EXPECT_EQ(PARAM_INVALID, ret); // HybridDavinciModel::impl_ is null. | ||||
} | } | ||||
TEST_F(UtestModelManagerModelManager, test_launch_kernel_cust_aicpu) { | |||||
ModelManager mm; | |||||
// cust_aicpu_so_ is empty. | |||||
EXPECT_EQ(mm.LaunchKernelCustAicpuSo("empty_cust_aicpu"), SUCCESS); | |||||
// deleteCustOp after Launch will deleted. | |||||
uintptr_t resource_id = 1; // for rtCtxGetCurrent stub | |||||
std::vector<char> kernel_bin(256); | |||||
auto &cust_resource_001 = mm.cust_aicpu_so_[resource_id]; | |||||
auto tbe_kernel = std::shared_ptr<OpKernelBin>(new OpKernelBin("deleteCustOp", std::move(kernel_bin))); | |||||
auto &cust_opkernel_001 = cust_resource_001["deleteCustOp"] = tbe_kernel; | |||||
EXPECT_FALSE(mm.cust_aicpu_so_.empty()); | |||||
EXPECT_EQ(mm.LaunchKernelCustAicpuSo("deleteCustOp"), SUCCESS); | |||||
EXPECT_TRUE(mm.cust_aicpu_so_.empty()); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -55,7 +55,7 @@ class UtestGraphPassesAssertPass : public Test { | |||||
}; | }; | ||||
/// D E | /// D E | ||||
/// | \ | \ | |||||
/// | \ | \. | |||||
/// F C G | /// F C G | ||||
/// : | : | /// : | : | ||||
/// H A I | /// H A I | ||||
@@ -134,8 +134,8 @@ TEST_F(UtestGraphPassesAssertPass, assert_pass_test2) { | |||||
EXPECT_EQ(graph->FindNode("D"), nullptr); | EXPECT_EQ(graph->FindNode("D"), nullptr); | ||||
} | } | ||||
/// E F | |||||
/// | \ | \ | |||||
/// E F | |||||
/// | \ | \. | |||||
/// H C -> D G | /// H C -> D G | ||||
/// \ | : | /// \ | : | ||||
/// A I | /// A I | ||||
@@ -130,7 +130,7 @@ class UTESTGraphPassesBasePass : public testing::Test { | |||||
/// reshape1 | /// reshape1 | ||||
/// | | /// | | ||||
/// add1 | /// add1 | ||||
/// / \ | |||||
/// / \. | |||||
/// | | | /// | | | ||||
/// data1 const1 | /// data1 const1 | ||||
ComputeGraphPtr BuildGraph1() { | ComputeGraphPtr BuildGraph1() { | ||||
@@ -148,9 +148,9 @@ ComputeGraphPtr BuildGraph1() { | |||||
} | } | ||||
/// sum1 | /// sum1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// / \. | |||||
/// reshape1 addn1 | /// reshape1 addn1 | ||||
/// | c | | /// | c | | ||||
/// add1 <--- shape1 | /// add1 <--- shape1 | ||||
@@ -217,7 +217,7 @@ void CheckIterOrder(UtestTestPass *pass, std::vector<std::unordered_set<std::str | |||||
/// Op1 | /// Op1 | ||||
/// | | /// | | ||||
/// Merge | /// Merge | ||||
/// / \ | |||||
/// / \. | |||||
/// Op2 Op3 | /// Op2 Op3 | ||||
TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | ||||
auto builder = ut::GraphBuilder("g1"); | auto builder = ut::GraphBuilder("g1"); | ||||
@@ -245,7 +245,7 @@ TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | |||||
/// Op1 | /// Op1 | ||||
/// | | /// | | ||||
/// Merge | /// Merge | ||||
/// / \ | |||||
/// / \. | |||||
/// Op2 Op3 | /// Op2 Op3 | ||||
TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { | TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { | ||||
auto builder = ut::GraphBuilder("g1"); | auto builder = ut::GraphBuilder("g1"); | ||||
@@ -459,7 +459,7 @@ TEST_F(UTESTGraphPassesBasePass, while_loop) { | |||||
/// data1 const | /// data1 const | ||||
/// \ / | /// \ / | ||||
/// while | /// while | ||||
/// / \ | |||||
/// / \. | |||||
/// | | | /// | | | ||||
/// cast1 cast2 | /// cast1 cast2 | ||||
ComputeGraphPtr BuildWhileGraph1() { | ComputeGraphPtr BuildWhileGraph1() { | ||||
@@ -34,11 +34,11 @@ namespace { | |||||
/// net_output | /// net_output | ||||
/// | | /// | | ||||
/// merge | /// merge | ||||
/// / \ | |||||
/// / \. | |||||
/// square add | /// square add | ||||
/// F| T/ T\ | |||||
/// F| T/ T\. | |||||
/// switch1 switch2 | /// switch1 switch2 | ||||
/// / \ / \ | |||||
/// / \ / \. | |||||
/// var1 var2 var3 | /// var1 var2 var3 | ||||
/// | /// | ||||
ComputeGraphPtr BuildGraph1() { | ComputeGraphPtr BuildGraph1() { | ||||
@@ -173,8 +173,8 @@ namespace { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnYes1 | /// addnYes1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph1() { | ComputeGraphPtr BuildGraph1() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -223,8 +223,8 @@ ComputeGraphPtr BuildGraph2() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | c | /// | c | ||||
/// addnYes1 <----- dataNo1 | /// addnYes1 <----- dataNo1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph3() { | ComputeGraphPtr BuildGraph3() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -249,8 +249,8 @@ ComputeGraphPtr BuildGraph3() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | c | /// | c | ||||
/// addnYes1 <--------- | /// addnYes1 <--------- | ||||
/// / \ \ | |||||
/// / \ c \ | |||||
/// / \ \. | |||||
/// / \ c \. | |||||
/// const1 const2 <----- dataNo1 | /// const1 const2 <----- dataNo1 | ||||
ComputeGraphPtr BuildGraph4() { | ComputeGraphPtr BuildGraph4() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -276,7 +276,7 @@ ComputeGraphPtr BuildGraph4() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | c | /// | c | ||||
/// addnYes1 <----- dataNo1 | /// addnYes1 <----- dataNo1 | ||||
/// / \ | |||||
/// / \. | |||||
/// / \ c | /// / \ c | ||||
/// const1 const2 <----- dataNo2 | /// const1 const2 <----- dataNo2 | ||||
ComputeGraphPtr BuildGraph5() { | ComputeGraphPtr BuildGraph5() { | ||||
@@ -306,8 +306,8 @@ ComputeGraphPtr BuildGraph5() { | |||||
/// addYes1 <---- const3 | /// addYes1 <---- const3 | ||||
/// | | /// | | ||||
/// addnYes1 <- | /// addnYes1 <- | ||||
/// / \ \ | |||||
/// / \ \ | |||||
/// / \ \. | |||||
/// / \ \. | |||||
/// const1 const2 const4 | /// const1 const2 const4 | ||||
ComputeGraphPtr BuildGraph6() { | ComputeGraphPtr BuildGraph6() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -332,12 +332,12 @@ ComputeGraphPtr BuildGraph6() { | |||||
} | } | ||||
/// netoutput1 | /// netoutput1 | ||||
/// / \ | |||||
/// / \. | |||||
/// shapeNo1 ShpaeNo2 | /// shapeNo1 ShpaeNo2 | ||||
/// \ / | /// \ / | ||||
/// huberLoss1 | /// huberLoss1 | ||||
/// / | \ | |||||
/// / | \ | |||||
/// / | \. | |||||
/// / | \. | |||||
/// const1 const2 const3 | /// const1 const2 const3 | ||||
ComputeGraphPtr BuildGraph7() { | ComputeGraphPtr BuildGraph7() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -365,8 +365,8 @@ ComputeGraphPtr BuildGraph7() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnNo1 | /// addnNo1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph8() { | ComputeGraphPtr BuildGraph8() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -389,8 +389,8 @@ ComputeGraphPtr BuildGraph8() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnYes1 | /// addnYes1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 data1 | /// const1 data1 | ||||
ComputeGraphPtr BuildGraph9() { | ComputeGraphPtr BuildGraph9() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -409,12 +409,12 @@ ComputeGraphPtr BuildGraph9() { | |||||
} | } | ||||
/// netoutput1 | /// netoutput1 | ||||
/// / \ | |||||
/// / \. | |||||
/// addDim sqrt1 | /// addDim sqrt1 | ||||
/// \ / | /// \ / | ||||
/// switch1 | /// switch1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph10() { | ComputeGraphPtr BuildGraph10() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -63,8 +63,8 @@ namespace { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnNo1 | /// addnNo1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph8() { | ComputeGraphPtr BuildGraph8() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -87,8 +87,8 @@ ComputeGraphPtr BuildGraph8() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnYes1 | /// addnYes1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
///const1 data1 | ///const1 data1 | ||||
ComputeGraphPtr BuildGraph9() { | ComputeGraphPtr BuildGraph9() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -64,6 +64,7 @@ class UtestGraphPassesFoldingKernelFillKernel : public testing::Test { | |||||
op_desc_ptr->AddInputDesc(dims_tensor_desc); | op_desc_ptr->AddInputDesc(dims_tensor_desc); | ||||
op_desc_ptr->AddInputDesc(value_tensor_desc); | op_desc_ptr->AddInputDesc(value_tensor_desc); | ||||
op_desc_ptr->AddOutputDesc(dims_tensor_desc); | |||||
std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; | std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; | ||||
std::vector<GeTensorPtr> outputs; | std::vector<GeTensorPtr> outputs; | ||||
@@ -124,6 +125,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillBoolShape2And3) { | |||||
op_desc_ptr->AddInputDesc(dims_tensor_desc); | op_desc_ptr->AddInputDesc(dims_tensor_desc); | ||||
op_desc_ptr->AddInputDesc(value_tensor_desc); | op_desc_ptr->AddInputDesc(value_tensor_desc); | ||||
op_desc_ptr->AddOutputDesc(dims_tensor_desc); | |||||
std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; | std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; | ||||
std::vector<GeTensorPtr> outputs; | std::vector<GeTensorPtr> outputs; | ||||
@@ -230,6 +232,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsHaveNegativeNumber) { | |||||
op_desc_ptr->AddInputDesc(dims_tensor_desc); | op_desc_ptr->AddInputDesc(dims_tensor_desc); | ||||
op_desc_ptr->AddInputDesc(value_tensor_desc); | op_desc_ptr->AddInputDesc(value_tensor_desc); | ||||
op_desc_ptr->AddOutputDesc(dims_tensor_desc); | |||||
std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; | std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; | ||||
std::vector<GeTensorPtr> outputs; | std::vector<GeTensorPtr> outputs; | ||||
@@ -284,6 +287,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsTypeNotSupport) { | |||||
op_desc_ptr->AddInputDesc(dims_tensor_desc); | op_desc_ptr->AddInputDesc(dims_tensor_desc); | ||||
op_desc_ptr->AddInputDesc(value_tensor_desc); | op_desc_ptr->AddInputDesc(value_tensor_desc); | ||||
op_desc_ptr->AddOutputDesc(dims_tensor_desc); | |||||
std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; | std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; | ||||
std::vector<GeTensorPtr> outputs; | std::vector<GeTensorPtr> outputs; | ||||
@@ -310,6 +314,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsOverflow) { | |||||
op_desc_ptr->AddInputDesc(dims_tensor_desc); | op_desc_ptr->AddInputDesc(dims_tensor_desc); | ||||
op_desc_ptr->AddInputDesc(value_tensor_desc); | op_desc_ptr->AddInputDesc(value_tensor_desc); | ||||
op_desc_ptr->AddOutputDesc(dims_tensor_desc); | |||||
std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; | std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; | ||||
std::vector<GeTensorPtr> outputs; | std::vector<GeTensorPtr> outputs; | ||||
@@ -336,6 +341,7 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsMulDataTypeOverflow) { | |||||
op_desc_ptr->AddInputDesc(dims_tensor_desc); | op_desc_ptr->AddInputDesc(dims_tensor_desc); | ||||
op_desc_ptr->AddInputDesc(value_tensor_desc); | op_desc_ptr->AddInputDesc(value_tensor_desc); | ||||
op_desc_ptr->AddOutputDesc(dims_tensor_desc); | |||||
std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; | std::vector<ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; | ||||
std::vector<GeTensorPtr> outputs; | std::vector<GeTensorPtr> outputs; | ||||
@@ -343,3 +349,33 @@ TEST_F(UtestGraphPassesFoldingKernelFillKernel, FillDimsMulDataTypeOverflow) { | |||||
EXPECT_EQ(PARAM_INVALID, status); | EXPECT_EQ(PARAM_INVALID, status); | ||||
} | } | ||||
TEST_F(UtestGraphPassesFoldingKernelFillKernel, OutputdescUnknown) { | |||||
ge::OpDescPtr op_dims = std::make_shared<ge::OpDesc>(); | |||||
vector <int64_t> dims_vec = {2}; | |||||
vector <int32_t> dims_value_vec = {2, 3}; | |||||
GeTensorDesc dims_tensor_desc(GeShape(dims_vec), FORMAT_NCHW, DT_INT32); | |||||
GeTensorPtr dim_tensor = std::make_shared<GeTensor>(dims_tensor_desc, (uint8_t *) dims_value_vec.data(), | |||||
dims_value_vec.size() * sizeof(int32_t)); | |||||
OpDescUtils::SetWeights(op_dims, dim_tensor); | |||||
ge::OpDescPtr op_value = std::make_shared<ge::OpDesc>(); | |||||
vector <uint8_t> data_vec = {1}; | |||||
GeTensorDesc value_tensor_desc(GeShape(), FORMAT_NCHW, DT_BOOL); | |||||
GeTensorPtr value_tensor = | |||||
std::make_shared<GeTensor>(value_tensor_desc, (uint8_t *) data_vec.data(), data_vec.size() * sizeof(bool)); | |||||
OpDescUtils::SetWeights(op_value, value_tensor); | |||||
op_desc_ptr->AddInputDesc(dims_tensor_desc); | |||||
op_desc_ptr->AddInputDesc(value_tensor_desc); | |||||
vector <int64_t> out_vec = {-1, -1}; | |||||
GeTensorDesc out_tensor_desc(GeShape(out_vec), FORMAT_NCHW, DT_INT32); | |||||
op_desc_ptr->AddOutputDesc(out_tensor_desc); | |||||
std::vector <ge::ConstGeTensorPtr> input = {dim_tensor, value_tensor}; | |||||
std::vector <GeTensorPtr> outputs; | |||||
Status status = kernel->Compute(op_desc_ptr, input, outputs); | |||||
EXPECT_EQ(NOT_CHANGED, status); | |||||
} |
@@ -46,7 +46,7 @@ class UtestGraphPassesFoldingKernelSsdPriorboxKernel : public testing::Test { | |||||
/// convolution data | /// convolution data | ||||
/// | / | /// | / | ||||
/// ssdpriorbox | /// ssdpriorbox | ||||
/// \ | |||||
/// \. | |||||
/// reshape | /// reshape | ||||
class NodeBuilder { | class NodeBuilder { | ||||
public: | public: | ||||
@@ -120,7 +120,7 @@ TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) { | |||||
/// graph with subgraph | /// graph with subgraph | ||||
/// const | /// const | ||||
/// / \ | |||||
/// / \. | |||||
/// cast1 cast1 | /// cast1 cast1 | ||||
/// \ / | /// \ / | ||||
/// case | /// case | ||||
@@ -69,62 +69,100 @@ static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string | |||||
return graph.AddNode(op_desc); | return graph.AddNode(op_desc); | ||||
} | } | ||||
static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge) { | |||||
static void CreateLoopGraph(ComputeGraphPtr &graph, NodePtr &merge, vector<NodePtr> &loop, vector<NodePtr> &cond) { | |||||
/******************************************************************************* | /******************************************************************************* | ||||
* Exit Identify | |||||
* \ / \. | |||||
* \ / \. | |||||
* Switch Add | |||||
* / | | | |||||
* / | | | |||||
* / | | | |||||
* LoopCond | | | |||||
* \ | | | |||||
* \ | | | |||||
* \ | | | |||||
* Less | | | |||||
* \ | NextIteration | |||||
* \ | | | |||||
* \ | | | |||||
* Merge <---------| | |||||
* | | |||||
* | | |||||
* Enter | |||||
* | | |||||
* +--------------------- Merge ----------------------+ | |||||
* / | | |||||
* / | | |||||
* / | | |||||
* / | | |||||
* Exit Identify | | |||||
* \ / \. | | |||||
* \ / \. | | |||||
* Switch Add Add | |||||
* / | | | | |||||
* / | | | | |||||
* / | | | | |||||
* LoopCond | | | | |||||
* \ | | | | |||||
* \ | | | | |||||
* \ | | | | |||||
* Less | | | | |||||
* \ | NextIteration | | |||||
* \ | | | | |||||
* \ | | | | |||||
* Merge <---------| | | |||||
* | | | |||||
* | | | |||||
* Enter | | |||||
* \ | | |||||
* \ | | |||||
* Switch Switch | |||||
* | | | |||||
* +-----------------Equal----------------------+ | |||||
* | | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
auto data1 = CreateNode(*graph, "data", DATA, 1, 1); | |||||
auto data1 = CreateNode(*graph, "data1", DATA, 1, 1); | |||||
auto data2 = CreateNode(*graph, "data2", DATA, 1, 1); | |||||
auto equal1 = CreateNode(*graph, "equal1", EQUAL, 2, 1); | |||||
auto switch1 = CreateNode(*graph, "switch1", SWITCH, 2, 2); | |||||
auto switch2 = CreateNode(*graph, "switch2", SWITCH, 2, 2); | |||||
auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | ||||
auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); | |||||
auto less1 = CreateNode(*graph, "less", LESS, 2, 1); | |||||
auto merge1 = CreateNode(*graph, "merge1", MERGE, 2, 2); | |||||
auto less1 = CreateNode(*graph, "less1", LESS, 2, 1); | |||||
auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); | auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); | ||||
auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); | |||||
auto switch3 = CreateNode(*graph, "switch3", SWITCH, 2, 2); | |||||
auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); | auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); | ||||
auto add1 = CreateNode(*graph, "add", ADD, 2, 1); | |||||
auto add1 = CreateNode(*graph, "add1", ADD, 2, 1); | |||||
auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | ||||
auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | ||||
auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
auto value1 = CreateNode(*graph, "const1", CONSTANT, 0, 1); | |||||
auto value2 = CreateNode(*graph, "const2", CONSTANT, 0, 1); | |||||
auto add2 = CreateNode(*graph, "add2", ADD, 2, 1); | |||||
auto merge2 = CreateNode(*graph, "merge2", MERGE, 2, 2); | |||||
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | ||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), equal1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), equal1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), switch2->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(equal1->GetOutDataAnchor(0), switch2->GetInDataAnchor(1)); | |||||
cond.emplace_back(switch1); | |||||
cond.emplace_back(switch2); | |||||
GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); // false | |||||
GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | ||||
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); | GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch3->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch3->GetInDataAnchor(1)); | |||||
loop.emplace_back(merge1); | |||||
GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch3->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); // false | |||||
GraphUtils::AddEdge(switch3->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); // true | |||||
loop.emplace_back(switch3); | |||||
GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | ||||
GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); | GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | ||||
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
merge = merge1; | |||||
GraphUtils::AddEdge(switch2->GetOutDataAnchor(1), add2->GetInDataAnchor(1)); // true | |||||
GraphUtils::AddEdge(value2->GetOutDataAnchor(0), add2->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), merge2->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(add2->GetOutDataAnchor(0), merge2->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(merge2->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
cond.emplace_back(merge2); | |||||
merge = merge2; | |||||
} | } | ||||
static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { | static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { | ||||
@@ -197,12 +235,24 @@ static void CreateCondGraph(ComputeGraphPtr &graph, NodePtr &merge) { | |||||
TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { | TEST_F(UtestMarkForceUnknownForCondPass, skip_while_loop_merge) { | ||||
auto graph = std::make_shared<ComputeGraph>("test_graph"); | auto graph = std::make_shared<ComputeGraph>("test_graph"); | ||||
NodePtr merge; | NodePtr merge; | ||||
CreateLoopGraph(graph, merge); | |||||
AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||||
vector<NodePtr> loop; | |||||
vector<NodePtr> cond; | |||||
CreateLoopGraph(graph, merge, loop, cond); | |||||
MarkForceUnknownForCondPass mark_force_unknown_pass; | MarkForceUnknownForCondPass mark_force_unknown_pass; | ||||
EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond | EXPECT_EQ(mark_force_unknown_pass.Run(graph), SUCCESS); // skip LoopCond | ||||
EXPECT_EQ(loop.size(), 2); | |||||
for (const auto &node : loop) { | |||||
EXPECT_FALSE(node->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)); | |||||
} | |||||
EXPECT_EQ(cond.size(), 3); | |||||
for (const auto &node : cond) { | |||||
int64_t group_index = -1; | |||||
EXPECT_TRUE(AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)); | |||||
EXPECT_EQ(group_index, merge->GetOpDesc()->GetId()); | |||||
} | |||||
} | } | ||||
TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) { | TEST_F(UtestMarkForceUnknownForCondPass, skip_known_shape_merge) { | ||||
@@ -110,8 +110,8 @@ TEST_F(UtestGraphPassesMergePass, multiple_inputs) { | |||||
} | } | ||||
/// Merge | /// Merge | ||||
/// | \ | |||||
/// | \ | |||||
/// | \. | |||||
/// | \. | |||||
/// Op1 Op2 Merge2 | /// Op1 Op2 Merge2 | ||||
/// \ | | | /// \ | | | ||||
/// \ | Op3 | /// \ | Op3 | ||||
@@ -137,10 +137,10 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_da | |||||
} | } | ||||
/// Merge | /// Merge | ||||
/// | \ | |||||
/// | \ | |||||
/// | \. | |||||
/// | \. | |||||
/// Op1 Op2 Merge2 | /// Op1 Op2 Merge2 | ||||
/// \ | | \ | |||||
/// \ | | \. | |||||
/// \ | Op3 | /// \ | Op3 | ||||
/// \ | : | /// \ | : | ||||
/// NetOutput | /// NetOutput | ||||
@@ -165,8 +165,8 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_co | |||||
TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { | TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { | ||||
/// Merge | /// Merge | ||||
/// | \ | |||||
/// | \ | |||||
/// | \. | |||||
/// | \. | |||||
/// Op1 Op2 Merge2 | /// Op1 Op2 Merge2 | ||||
/// \ | | | /// \ | | | ||||
/// \ | Op3 | /// \ | Op3 | ||||
@@ -210,7 +210,7 @@ TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) { | |||||
/// Op1 Op2 Merge2 | /// Op1 Op2 Merge2 | ||||
/// \ | | /// \ | | ||||
/// \ Op3 | /// \ Op3 | ||||
/// \ | |||||
/// \. | |||||
/// Merge3 | /// Merge3 | ||||
ret = pass_.Run(merge_node2); | ret = pass_.Run(merge_node2); | ||||
@@ -224,7 +224,7 @@ TEST_F(UtestGraphPassesMergePass, single_non_const_input) { | |||||
/// Op1 | /// Op1 | ||||
/// | | /// | | ||||
/// Merge | /// Merge | ||||
/// / \ | |||||
/// / \. | |||||
/// Op2 Op3 | /// Op2 Op3 | ||||
auto merge_node = NewNode("Merge", MERGE, 1, 2); | auto merge_node = NewNode("Merge", MERGE, 1, 2); | ||||
auto node1 = NewNode("Op1", RELU, 1, 1); | auto node1 = NewNode("Op1", RELU, 1, 1); | ||||
@@ -253,7 +253,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input) { | |||||
/// Const | /// Const | ||||
/// | | /// | | ||||
/// Merge Pass Const | /// Merge Pass Const | ||||
/// / \ ===> / \ | |||||
/// / \ ===> / \. | |||||
/// Op1 Op2 Op1 Op2 | /// Op1 Op2 Op1 Op2 | ||||
auto merge_node = NewNode("Merge", MERGE, 1, 2); | auto merge_node = NewNode("Merge", MERGE, 1, 2); | ||||
auto const_node = NewNode("Const", CONSTANT, 1, 1); | auto const_node = NewNode("Const", CONSTANT, 1, 1); | ||||
@@ -284,7 +284,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes) | |||||
/// / | ===> / \(control anchor) | /// / | ===> / \(control anchor) | ||||
/// Op1 | \ Op1 Constant | /// Op1 | \ Op1 Constant | ||||
/// Op2 Op3 | | /// Op2 Op3 | | ||||
/// / \ | |||||
/// / \. | |||||
/// Op2 Op3 | /// Op2 Op3 | ||||
auto merge_node = NewNode("Merge", MERGE, 1, 2); | auto merge_node = NewNode("Merge", MERGE, 1, 2); | ||||
auto const_node = NewNode("Const", CONSTANT, 1, 1); | auto const_node = NewNode("Const", CONSTANT, 1, 1); | ||||
@@ -329,7 +329,7 @@ TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes1) | |||||
/// / | ===> / \(control anchor) | /// / | ===> / \(control anchor) | ||||
/// Op1 | \ Op1 Constant | /// Op1 | \ Op1 Constant | ||||
/// Op2 Op3 | | /// Op2 Op3 | | ||||
/// / \ | |||||
/// / \. | |||||
/// Op2 Op3 | /// Op2 Op3 | ||||
auto merge_node = NewNode("Merge", MERGE, 1, 2); | auto merge_node = NewNode("Merge", MERGE, 1, 2); | ||||
auto const_node = NewNode("Const", CONSTANT, 1, 1); | auto const_node = NewNode("Const", CONSTANT, 1, 1); | ||||
@@ -357,7 +357,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) { | |||||
/// C | /// C | ||||
/// | | /// | | ||||
/// Merge | /// Merge | ||||
/// / \ | |||||
/// / \. | |||||
/// Op1 Op2 | /// Op1 Op2 | ||||
auto switch_node = NewNode("Switch", SWITCH, 1, 2); | auto switch_node = NewNode("Switch", SWITCH, 1, 2); | ||||
auto identity_node = NewNode("Identity", SWITCH, 1, 1); | auto identity_node = NewNode("Identity", SWITCH, 1, 1); | ||||
@@ -381,7 +381,7 @@ TEST_F(UtestGraphPassesMergePass, const_with_control_input) { | |||||
/// . | /// . | ||||
/// . | /// . | ||||
/// C | /// C | ||||
/// / \ | |||||
/// / \. | |||||
/// Op1 Op2 | /// Op1 Op2 | ||||
auto ret = pass_.Run(merge_node); | auto ret = pass_.Run(merge_node); | ||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
@@ -19,7 +19,8 @@ | |||||
#include <string> | #include <string> | ||||
#define private public | #define private public | ||||
#include "inc/graph/ge_local_context.h" | |||||
#include "inc/external/ge/ge_api_types.h" | |||||
#include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
#include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
#include "utils/graph_utils.h" | #include "utils/graph_utils.h" | ||||
@@ -66,11 +67,11 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||||
void BuildDefaultGraph() { | void BuildDefaultGraph() { | ||||
/// input | /// input | ||||
/// \ | |||||
/// \. | |||||
/// sqrt pred | /// sqrt pred | ||||
/// \ / | /// \ / | ||||
/// cast | /// cast | ||||
/// / \ | |||||
/// / \. | |||||
/// switch_t switch_f | /// switch_t switch_f | ||||
/// | | | /// | | | ||||
/// F T | /// F T | ||||
@@ -118,13 +119,13 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||||
void BuildDefaultGraph1() { | void BuildDefaultGraph1() { | ||||
/// input | /// input | ||||
/// \ | |||||
/// \. | |||||
/// sqrt pred | /// sqrt pred | ||||
/// \ / | /// \ / | ||||
/// Switch | /// Switch | ||||
/// | | | /// | | | ||||
/// ----F T---- | /// ----F T---- | ||||
/// \ | / \ | |||||
/// \ | / \. | |||||
/// \ Merge1 Merge2 | /// \ Merge1 Merge2 | ||||
/// \_________| | /// \_________| | ||||
input_node_ = NewNode("input", RELU, 0, 1); | input_node_ = NewNode("input", RELU, 0, 1); | ||||
@@ -164,14 +165,14 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||||
void BuildDefaultGraph2() { | void BuildDefaultGraph2() { | ||||
/// input input1 | /// input input1 | ||||
/// \ \ | |||||
/// \ \. | |||||
/// sqrt pred sqrt1 pred1 | /// sqrt pred sqrt1 pred1 | ||||
/// \ / \ / | /// \ / \ / | ||||
/// Switch Switch1 | /// Switch Switch1 | ||||
/// | | _______| | /// | | _______| | ||||
/// | | / | /// | | / | ||||
/// ____F T____ | /// ____F T____ | ||||
/// \ | / \ | |||||
/// \ | / \. | |||||
/// \ Merge1 Merge2 | /// \ Merge1 Merge2 | ||||
/// \__________| | /// \__________| | ||||
input_node_ = NewNode("input", RELU, 0, 2); | input_node_ = NewNode("input", RELU, 0, 2); | ||||
@@ -225,6 +226,70 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||||
output_true_node_->GetOpDesc()->SetIsInputConst({false}); | output_true_node_->GetOpDesc()->SetIsInputConst({false}); | ||||
} | } | ||||
void BuildDefaultGraph3() { | |||||
/// input | |||||
/// \ | |||||
/// sqrt pred | |||||
/// \ / | |||||
/// Switch | |||||
/// | | | |||||
/// F T ------ | |||||
/// / \_/_ \ | |||||
/// / / \ \ | |||||
/// Merge sqrt2 sqrt3 | |||||
/// / \ \ | |||||
/// sqrt1 \ relu | |||||
/// \ \ | |||||
/// \ sqrt4 | |||||
/// \ / | |||||
/// Merge1 | |||||
input_node_ = NewNode("input", RELU, 0, 1); | |||||
AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
pred_node_ = NewNode("pred", GREATER, 2, 1); | |||||
sqrt_node_ = NewNode("sqrt", SQRT, 1, 1); | |||||
cast_node_ = NewNode("cast", CAST, 2, 2); | |||||
switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1); | |||||
AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); | |||||
switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1); | |||||
AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); | |||||
output_false_node_ = NewNode("false_output", RELU, 1, 2); | |||||
AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
output_true_node_ = NewNode("true_output", RELU, 1, 2); | |||||
AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
merge_node_ = NewNode("merge", STREAMMERGE, 2, 1); | |||||
sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1); | |||||
AttrUtils::SetStr(sqrt_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
sqrt_node2_ = NewNode("sqrt2", SQRT, 1, 1); | |||||
AttrUtils::SetStr(sqrt_node2_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
sqrt_node3_ = NewNode("sqrt3", SQRT, 1, 1); | |||||
relu_node_ = NewNode("relu", RELU, 1, 1); | |||||
sqrt_node4_ = NewNode("sqrt4", SQRT, 1, 1); | |||||
AttrUtils::SetStr(sqrt_node4_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); | |||||
merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1); | |||||
GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), sqrt_node2_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), sqrt_node3_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(sqrt_node3_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node4_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(sqrt_node2_->GetOutDataAnchor(0), merge_node1_->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(sqrt_node4_->GetOutDataAnchor(0), merge_node1_->GetInDataAnchor(1)); | |||||
output_false_node_->GetOpDesc()->SetIsInputConst({false}); | |||||
output_true_node_->GetOpDesc()->SetIsInputConst({false}); | |||||
} | |||||
ComputeGraphPtr graph_; | ComputeGraphPtr graph_; | ||||
ComputeGraphPtr sub_graph_; | ComputeGraphPtr sub_graph_; | ||||
GeTensorDescPtr default_tensor_desc_; | GeTensorDescPtr default_tensor_desc_; | ||||
@@ -235,6 +300,9 @@ class UtestGraphPassesParallelGgroupPass : public testing::Test { | |||||
NodePtr cast_node1_; | NodePtr cast_node1_; | ||||
NodePtr sqrt_node_; | NodePtr sqrt_node_; | ||||
NodePtr sqrt_node1_; | NodePtr sqrt_node1_; | ||||
NodePtr sqrt_node2_; | |||||
NodePtr sqrt_node3_; | |||||
NodePtr sqrt_node4_; | |||||
NodePtr input_node_; | NodePtr input_node_; | ||||
NodePtr input_node1_; | NodePtr input_node1_; | ||||
NodePtr switch_node_t; | NodePtr switch_node_t; | ||||
@@ -278,6 +346,16 @@ TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph2) { | |||||
EXPECT_EQ(true, input_node1_->GetOutControlAnchor()->IsLinkedWith(cast_node1_->GetInControlAnchor())); | EXPECT_EQ(true, input_node1_->GetOutControlAnchor()->IsLinkedWith(cast_node1_->GetInControlAnchor())); | ||||
} | } | ||||
TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph3) { | |||||
std::map<std::string, std::string> options; | |||||
options.emplace(OPTION_GRAPH_RUN_MODE, "1"); | |||||
GetThreadLocalContext().SetGraphOption(options); | |||||
BuildDefaultGraph3(); | |||||
auto ret = pass_.Run(graph_); | |||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
EXPECT_EQ(true, merge_node1_->GetOutControlAnchor()->IsLinkedWith(sqrt_node1_->GetInControlAnchor())); | |||||
} | |||||
TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) { | TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) { | ||||
BuildDefaultGraph1(); | BuildDefaultGraph1(); | ||||
NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true); | NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true); | ||||
@@ -31,9 +31,9 @@ class UtestReshapeRecoveryPass : public testing::Test { | |||||
namespace { | namespace { | ||||
/// netoutput1 | /// netoutput1 | ||||
/// | \ | |||||
///transdata1 \ | |||||
/// | \ | |||||
/// | \. | |||||
///transdata1 \. | |||||
/// | \. | |||||
/// | transdata2 | /// | transdata2 | ||||
/// | / | /// | / | ||||
/// var1 const1 | /// var1 const1 | ||||
@@ -35,7 +35,7 @@ namespace { | |||||
/// transdata1 | /// transdata1 | ||||
/// | | /// | | ||||
/// reshape1 | /// reshape1 | ||||
/// | \ | |||||
/// | \. | |||||
/// var1 const1 | /// var1 const1 | ||||
ut::GraphBuilder Graph1Builder() { | ut::GraphBuilder Graph1Builder() { | ||||
ut::GraphBuilder builder = ut::GraphBuilder("g1"); | ut::GraphBuilder builder = ut::GraphBuilder("g1"); | ||||
@@ -55,11 +55,11 @@ ut::GraphBuilder Graph1Builder() { | |||||
} | } | ||||
/// netoutput1 | /// netoutput1 | ||||
/// | \ | |||||
///transdata1 \ | |||||
/// | \ | |||||
/// | \. | |||||
///transdata1 \. | |||||
/// | \. | |||||
/// reshape1 reshape2 | /// reshape1 reshape2 | ||||
/// | \ / \ | |||||
/// | \ / \. | |||||
/// var1 const1 var2 | /// var1 const1 var2 | ||||
ut::GraphBuilder Graph2Builder() { | ut::GraphBuilder Graph2Builder() { | ||||
ut::GraphBuilder builder = ut::GraphBuilder("g2"); | ut::GraphBuilder builder = ut::GraphBuilder("g2"); | ||||
@@ -83,9 +83,9 @@ ut::GraphBuilder Graph2Builder() { | |||||
} | } | ||||
/// netoutput1 | /// netoutput1 | ||||
/// | \ | |||||
///transdata1 \ | |||||
/// | \ | |||||
/// | \. | |||||
///transdata1 \. | |||||
/// | \. | |||||
/// reshape1 transdata2 | /// reshape1 transdata2 | ||||
/// | \ / | /// | \ / | ||||
/// var1 const1 | /// var1 const1 | ||||
@@ -34,7 +34,7 @@ class UtestResourcePairControlPass : public testing::Test { | |||||
namespace { | namespace { | ||||
/// netoutput1 | /// netoutput1 | ||||
/// | \ | |||||
/// | \. | |||||
/// StackPush StackPop | /// StackPush StackPop | ||||
/// | | | /// | | | ||||
/// var1 const1 | /// var1 const1 | ||||
@@ -63,9 +63,9 @@ ComputeGraphPtr BuildGraph1() { | |||||
/// netoutput1 | /// netoutput1 | ||||
/// | | /// | | ||||
/// merge1 | /// merge1 | ||||
/// / \ | |||||
/// / \. | |||||
/// / add1 | /// / add1 | ||||
/// / F| \ | |||||
/// / F| \. | |||||
/// addn1 swtich2 var3 | /// addn1 swtich2 var3 | ||||
/// \F T/ | | /// \F T/ | | ||||
/// switch1 | | /// switch1 | | ||||
@@ -101,9 +101,9 @@ ComputeGraphPtr BuildGraph2() { | |||||
/// add1 | /// add1 | ||||
/// / \T | /// / \T | ||||
/// var3 swtich2 | /// var3 swtich2 | ||||
/// T/ \ | |||||
/// switch1 \ | |||||
/// / \ \ | |||||
/// T/ \. | |||||
/// switch1 \. | |||||
/// / \ \. | |||||
/// var1 var2 var4 | /// var1 var2 var4 | ||||
ComputeGraphPtr BuildGraph3() { | ComputeGraphPtr BuildGraph3() { | ||||
auto builder = ut::GraphBuilder("g3"); | auto builder = ut::GraphBuilder("g3"); | ||||
@@ -129,7 +129,7 @@ ComputeGraphPtr BuildGraph3() { | |||||
/// netoutput1 | /// netoutput1 | ||||
/// | | /// | | ||||
/// merge1 | /// merge1 | ||||
/// / \ | |||||
/// / \. | |||||
/// add1 addn1 | /// add1 addn1 | ||||
/// / \T F/ | /// / \T F/ | ||||
/// var3 swtich2 | /// var3 swtich2 | ||||
@@ -402,7 +402,7 @@ TEST_F(UtestGraphPassesTransOpBreadthFusionPass, test_multi_anchor_case) { | |||||
} | } | ||||
/// ----> netoutput1 | /// ----> netoutput1 | ||||
/// / | \ | |||||
/// / | \. | |||||
/// transdata1 transdata2 transdata3 | /// transdata1 transdata2 transdata3 | ||||
/// \ / | | /// \ / | | ||||
/// var1-------------- | /// var1-------------- | ||||
@@ -432,7 +432,7 @@ static ComputeGraphPtr BuildGraph1() { | |||||
} | } | ||||
/// ---------> netoutput1 | /// ---------> netoutput1 | ||||
/// / | \ | |||||
/// / | \. | |||||
/// transdata1 transdata2(l1) transdata3(l1) | /// transdata1 transdata2(l1) transdata3(l1) | ||||
/// \ / | | /// \ / | | ||||
/// var1------------------ | /// var1------------------ | ||||
@@ -456,19 +456,19 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge) | |||||
/// -->transpose1 -->transpose3-->sinh2 | /// -->transpose1 -->transpose3-->sinh2 | ||||
/// | \ / | /// | \ / | ||||
/// | -->transpose2 | /// | -->transpose2 | ||||
/// | \ | |||||
/// | \. | |||||
/// / -->cast3-->cast4-->sinh3 | /// / -->cast3-->cast4-->sinh3 | ||||
/// / | /// / | ||||
/// / -->transpose4-->transpose5-->sinh4 | /// / -->transpose4-->transpose5-->sinh4 | ||||
/// / / | /// / / | ||||
/// Node4D-->Cast1-->Cast2-->Cast5 -->reshape2-->sinh5 | /// Node4D-->Cast1-->Cast2-->Cast5 -->reshape2-->sinh5 | ||||
/// \ \ | |||||
/// \ \. | |||||
/// \ -->sinh6 | /// \ -->sinh6 | ||||
/// \ | |||||
/// \. | |||||
/// \ -->transpose6-->transpose7-->sinh9 | /// \ -->transpose6-->transpose7-->sinh9 | ||||
/// \ / | /// \ / | ||||
/// -->reshape-->cast6-->cast7-->sinh8 | /// -->reshape-->cast6-->cast7-->sinh8 | ||||
/// \ | |||||
/// \. | |||||
/// -->sinh7 | /// -->sinh7 | ||||
/// after optimized graph | /// after optimized graph | ||||
@@ -479,15 +479,15 @@ TEST_F(UtestGraphPassesTransOpDepthFusionPass, test_transop_with_multi_out_edge) | |||||
/// / /-->transpose3-->sinh2 | /// / /-->transpose3-->sinh2 | ||||
/// -->Cast1 | /// -->Cast1 | ||||
/// / \-->sinh7 | /// / \-->sinh7 | ||||
/// / \ | |||||
/// / \. | |||||
/// / -->sinh9 | /// / -->sinh9 | ||||
/// Node4D | /// Node4D | ||||
/// \ -->sinh4 | /// \ -->sinh4 | ||||
/// \ / | /// \ / | ||||
/// -->Cast5-->sinh5 | /// -->Cast5-->sinh5 | ||||
/// \ \ | |||||
/// \ \. | |||||
/// \ -->sinh6 | /// \ -->sinh6 | ||||
/// \ | |||||
/// \. | |||||
/// -->Cast7-->sinh8 | /// -->Cast7-->sinh8 | ||||
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ||||
@@ -180,7 +180,7 @@ ComputeGraphPtr GetGraph7(size_t symmetric_transdata_num, size_t asymmetric_tran | |||||
/// TransData TransData ... MatMul ... | /// TransData TransData ... MatMul ... | ||||
/// \ | / / / | /// \ | / / / | ||||
/// HcomAllReduce | /// HcomAllReduce | ||||
/// / | \ \ \ | |||||
/// / | \ \ \. | |||||
/// TransData TransData ... RealDiv ... | /// TransData TransData ... RealDiv ... | ||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | ||||
NodePtr allreduce = | NodePtr allreduce = | ||||
@@ -340,7 +340,7 @@ TEST(UtestTransopNearbyAllreduceFusionPass, test7_all_reduce_with_multiple_trans | |||||
/// TransData TransData ... MatMul ... | /// TransData TransData ... MatMul ... | ||||
/// \ | / / / | /// \ | / / / | ||||
/// HcomAllReduce | /// HcomAllReduce | ||||
/// / | \ \ \ | |||||
/// / | \ \ \. | |||||
/// TransData TransData ... RealDiv ... | /// TransData TransData ... RealDiv ... | ||||
size_t symmetric_transdata_num = 20; | size_t symmetric_transdata_num = 20; | ||||
size_t asymmetric_transdata_num = 20; | size_t asymmetric_transdata_num = 20; | ||||
@@ -66,7 +66,7 @@ namespace { | |||||
/// transdata2 | /// transdata2 | ||||
/// | | /// | | ||||
/// assign1 | /// assign1 | ||||
/// / \ | |||||
/// / \. | |||||
/// transdata1 | | /// transdata1 | | ||||
/// | | | /// | | | ||||
/// var1 const1 | /// var1 const1 | ||||
@@ -35,8 +35,8 @@ namespace { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnYes1 | /// addnYes1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph1() { | ComputeGraphPtr BuildGraph1() { | ||||
@@ -57,9 +57,9 @@ ComputeGraphPtr BuildGraph1() { | |||||
/// | /// | ||||
/// netoutput1 | /// netoutput1 | ||||
/// / \ \ | |||||
/// add1 assign1 \ | |||||
/// / \ / \ \ | |||||
/// / \ \. | |||||
/// add1 assign1 \. | |||||
/// / \ / \ \. | |||||
/// var1 var2 const1 var3 | /// var1 var2 const1 var3 | ||||
ComputeGraphPtr BuildGraph2() { | ComputeGraphPtr BuildGraph2() { | ||||
@@ -103,4 +103,32 @@ TEST_F(UtestHybridModelAsyncExecutor, Test_execute) { | |||||
context.callback_manager->callback_queue_.Push(eof_entry); | context.callback_manager->callback_queue_.Push(eof_entry); | ||||
ASSERT_EQ(executor.Execute(args), SUCCESS); | ASSERT_EQ(executor.Execute(args), SUCCESS); | ||||
} | } | ||||
TEST_F(UtestHybridModelAsyncExecutor, test_PrepareInputs) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||||
ge_root_model->SetModelName("test_name"); | |||||
GeModelPtr ge_sub_model = make_shared<GeModel>(); | |||||
HybridModel hybrid_model(ge_root_model); | |||||
HybridModelAsyncExecutor executor(&hybrid_model); | |||||
GeTensorDescPtr tensor_desc = make_shared<GeTensorDesc>(GeShape({-1, 16, 16, 3})); | |||||
tensor_desc->SetShapeRange({{1, 256}, {16, 16}, {16, 16}, {3, 3}}); | |||||
executor.input_tensor_desc_.insert({0, tensor_desc}); | |||||
executor.device_id_ = 0; | |||||
executor.input_sizes_.insert({0, -1}); | |||||
executor.is_input_dynamic_.push_back(true); | |||||
unique_ptr<uint8_t[]> data_buf(new (std::nothrow)uint8_t[3072]); | |||||
InputData input_data; | |||||
input_data.blobs.push_back(DataBuffer(data_buf.get(), 3072, false)); | |||||
input_data.shapes.push_back({1, 16, 16, 3}); | |||||
HybridModelExecutor::ExecuteArgs args; | |||||
auto ret = executor.PrepareInputs(input_data, args); | |||||
ASSERT_EQ(ret, SUCCESS); | |||||
ASSERT_EQ(args.input_desc[0]->GetShape().ToString(), GeShape({1, 16, 16, 3}).ToString()); | |||||
int64_t tensor_size = 0; | |||||
TensorUtils::GetSize(*(args.input_desc[0]), tensor_size); | |||||
ASSERT_EQ(tensor_size, 3104); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -249,6 +249,9 @@ TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) { | |||||
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | ||||
ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS); | ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS); | ||||
auto root_graph = hybrid_model.root_graph_; | |||||
switch_t = root_graph->FindNode("switch_t"); | |||||
switch_f = root_graph->FindNode("switch_f"); | |||||
const auto node_it_t = hybrid_model.node_items_.find(switch_t); | const auto node_it_t = hybrid_model.node_items_.find(switch_t); | ||||
const auto node_it_f = hybrid_model.node_items_.find(switch_f); | const auto node_it_f = hybrid_model.node_items_.find(switch_f); | ||||
ASSERT_NE(hybrid_model.node_items_.end(), node_it_t); | ASSERT_NE(hybrid_model.node_items_.end(), node_it_t); | ||||
@@ -214,11 +214,17 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||||
ASSERT_EQ(it->second->frame_index_, index); | ASSERT_EQ(it->second->frame_index_, index); | ||||
ASSERT_EQ(it->second->parent_frame_, -1); | ASSERT_EQ(it->second->parent_frame_, -1); | ||||
}; | }; | ||||
TestFrameGroup(enter1, control_group_index); | |||||
TestFrameGroup(active1, control_group_index); | |||||
TestFrameGroup(active2, control_group_index); | |||||
TestFrameGroup(active3, control_group_index); | |||||
TestFrameGroup(output1, -1); | |||||
auto root_graph = hybrid_model.root_graph_; | |||||
auto enter1_node = root_graph->FindNode("enter"); | |||||
auto active1_node = root_graph->FindNode("active1"); | |||||
auto active2_node = root_graph->FindNode("active2"); | |||||
auto active3_node = root_graph->FindNode("active3"); | |||||
auto output1_node = root_graph->FindNode("net_output"); | |||||
TestFrameGroup(enter1_node, control_group_index); | |||||
TestFrameGroup(active1_node, control_group_index); | |||||
TestFrameGroup(active2_node, control_group_index); | |||||
TestFrameGroup(active3_node, control_group_index); | |||||
TestFrameGroup(output1_node, -1); | |||||
engine_mapping.clear(); | engine_mapping.clear(); | ||||
task_executor.clear(); | task_executor.clear(); | ||||
@@ -373,4 +379,14 @@ TEST_F(UtestHybridModelBuilder, TestInitHcclExecutorOnDemand) { | |||||
NodeExecutorManager::GetInstance().builders_.erase(NodeExecutorManager::ExecutorType::HCCL); | NodeExecutorManager::GetInstance().builders_.erase(NodeExecutorManager::ExecutorType::HCCL); | ||||
ASSERT_EQ(HybridModelBuilder::InitHcclExecutorOnDemand(ge_model), SUCCESS); | ASSERT_EQ(HybridModelBuilder::InitHcclExecutorOnDemand(ge_model), SUCCESS); | ||||
} | } | ||||
TEST_F(UtestHybridModelBuilder, copy_graph_success) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||||
HybridModel hybrid_model(ge_root_model); | |||||
HybridModelBuilder hybrid_model_builder(hybrid_model); | |||||
Status st = hybrid_model_builder.CopyGraph(); | |||||
EXPECT_EQ(st, SUCCESS); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -25,6 +25,7 @@ | |||||
#define private public | #define private public | ||||
#include "common/profiling/profiling_manager.h" | #include "common/profiling/profiling_manager.h" | ||||
#include "graph/ge_local_context.h" | #include "graph/ge_local_context.h" | ||||
#include "inc/framework/common/profiling/ge_profiling.h" | |||||
#undef protected | #undef protected | ||||
#undef private | #undef private | ||||
@@ -115,4 +116,20 @@ TEST_F(UtestGeProfilinganager, get_fp_bp_point_empty) { | |||||
ProfilingManager::Instance().GetFpBpPoint(fp_point, bp_point); | ProfilingManager::Instance().GetFpBpPoint(fp_point, bp_point); | ||||
EXPECT_EQ(fp_point, ""); | EXPECT_EQ(fp_point, ""); | ||||
EXPECT_EQ(bp_point, ""); | EXPECT_EQ(bp_point, ""); | ||||
} | |||||
} | |||||
TEST_F(UtestGeProfilinganager, set_step_info_success) { | |||||
uint64_t index_id = 0; | |||||
auto stream = (rtStream_t)0x1; | |||||
Status ret = ProfSetStepInfo(index_id, 0, stream); | |||||
EXPECT_EQ(ret, ge::SUCCESS); | |||||
ret = ProfSetStepInfo(index_id, 1, stream); | |||||
EXPECT_EQ(ret, ge::SUCCESS); | |||||
} | |||||
TEST_F(UtestGeProfilinganager, set_step_info_failed) { | |||||
uint64_t index_id = 0; | |||||
auto stream = (rtStream_t)0x1; | |||||
Status ret = ProfSetStepInfo(index_id, 1, stream); | |||||
EXPECT_EQ(ret, ge::FAILED); | |||||
} |