From fca7991db263a9754d09e7d56ca8fe05af0d3abc Mon Sep 17 00:00:00 2001 From: y00500818 Date: Tue, 15 Dec 2020 11:12:20 +0800 Subject: [PATCH] add validation of fmk type for plugin load. --- ge/common/ge/tbe_plugin_manager.cc | 14 +++++++++++--- parser | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/ge/common/ge/tbe_plugin_manager.cc b/ge/common/ge/tbe_plugin_manager.cc index 8a594cb9..36e7646a 100644 --- a/ge/common/ge/tbe_plugin_manager.cc +++ b/ge/common/ge/tbe_plugin_manager.cc @@ -184,12 +184,20 @@ void TBEPluginManager::GetCustomOpPath(std::string &customop_path) { void TBEPluginManager::LoadCustomOpLib() { LoadPluginSo(options_); + std::string fmk_type = std::to_string(domi::TENSORFLOW); + auto it = options_.find(ge::FRAMEWORK_TYPE); + if (it != options_.end()) { + fmk_type = it->second; + } + GELOGD("frameworkType is %s", fmk_type.c_str()); std::vector registration_datas = domi::OpRegistry::Instance()->registrationDatas; GELOGI("The size of registration_datas is: %zu", registration_datas.size()); for (OpRegistrationData reg_data : registration_datas) { - GELOGD("Begin to register optype: %s, imply_type: %s", reg_data.GetOmOptype().c_str(), - TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str()); - domi::OpRegistry::Instance()->Register(reg_data); + if (std::to_string(reg_data.GetFrameworkType()) == fmk_type) { + GELOGD("Begin to register optype: %s, imply_type: %s", reg_data.GetOmOptype().c_str(), + TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str()); + domi::OpRegistry::Instance()->Register(reg_data); + } } } diff --git a/parser b/parser index c8414582..befc2aac 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit c841458262316866e7bfa7783f7ee3205e12e2c9 +Subproject commit befc2aac08de4f1f1c38e476c4d3fd53174653ff