@@ -0,0 +1,69 @@ | |||||
/** | |||||
* \file inlude/lite/pack_model.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <string> | |||||
namespace lite { | |||||
struct FeatureBits32 { | |||||
uint32_t is_fast_run_cache : 1; | |||||
//! reserved for new fields | |||||
uint32_t : 31; | |||||
}; | |||||
struct Header { | |||||
std::string name; //! model name | |||||
std::string | |||||
model_decryption_method; //! model encryption method name, this is used to | |||||
//! find the right decryption method. [ | |||||
//! AES_default | RC4_default | | |||||
//! SIMPLE_FAST_RC4_default ], default is NONE. | |||||
std::string info_decryption_method; //! info data encryption method name, this is | |||||
//! used to find the right decryption method. [ | |||||
//! AES_default | RC4_default | | |||||
//! SIMPLE_FAST_RC4_default ], default is NONE. | |||||
std::string info_parse_method = "LITE_default"; //! info parse method name. | |||||
std::string info_cache_parse_method = | |||||
"LITE_parse_cache"; //! fastrun cache parse method name. | |||||
FeatureBits32 fb32; | |||||
}; | |||||
class FbsHelper; | |||||
class ModelPacker { | |||||
public: | |||||
ModelPacker( | |||||
std::string model_path, std::string packed_model_path, | |||||
std::string info_data_path = "", std::string info_algo_policy_path = "", | |||||
std::string info_binary_cache_path = ""); | |||||
void set_header( | |||||
std::string model_decryption_method = "NONE", | |||||
std::string info_decryption_method = "NONE", bool is_fast_run_cache = true); | |||||
void pack_model(); | |||||
private: | |||||
std::string m_packed_model_path; | |||||
std::string m_info_data_path; | |||||
//! fastrun cache / algo policy | |||||
std::string m_info_algo_policy_path; | |||||
//! binary cache | |||||
std::string m_info_binary_cache_path; | |||||
Header m_header; | |||||
friend class FbsHelper; | |||||
FbsHelper* m_fbs_helper; | |||||
}; | |||||
} // namespace lite |
@@ -1,7 +1,7 @@ | |||||
# BUILD the load and run for lite | # BUILD the load and run for lite | ||||
include_directories(PUBLIC | include_directories(PUBLIC | ||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/lite/load_and_run/src>) | $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/lite/load_and_run/src>) | ||||
file(GLOB_RECURSE SOURCES ./*.cpp) | |||||
file(GLOB_RECURSE SOURCES ./*.cpp ${PROJECT_SOURCE_DIR}/lite/src/pack_model/*.cpp) | |||||
add_executable(load_and_run ${SOURCES}) | add_executable(load_and_run ${SOURCES}) | ||||
target_link_libraries(load_and_run lite_static) | target_link_libraries(load_and_run lite_static) | ||||
@@ -43,6 +43,8 @@ public: | |||||
virtual void wait() = 0; | virtual void wait() = 0; | ||||
virtual ~ModelBase() = default; | virtual ~ModelBase() = default; | ||||
virtual const std::string& get_model_path() const = 0; | |||||
}; | }; | ||||
} // namespace lar | } // namespace lar | ||||
@@ -60,6 +60,8 @@ public: | |||||
//! get algo strategy | //! get algo strategy | ||||
Strategy& get_lite_strategy() { return m_strategy; } | Strategy& get_lite_strategy() { return m_strategy; } | ||||
const std::string& get_model_path() const override { return model_path; } | |||||
private: | private: | ||||
bool share_model_mem; | bool share_model_mem; | ||||
bool enable_layout_transform; | bool enable_layout_transform; | ||||
@@ -107,6 +107,8 @@ public: | |||||
std::move(out_file), m_format.val()); | std::move(out_file), m_format.val()); | ||||
} | } | ||||
const std::string& get_model_path() const override { return model_path; } | |||||
private: | private: | ||||
bool share_model_mem; | bool share_model_mem; | ||||
std::string model_path; | std::string model_path; | ||||
@@ -0,0 +1,87 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/model_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "model_options.h" | |||||
#include "device_options.h" | |||||
#include "lite/pack_model.h" | |||||
#include "misc.h" | |||||
#include "models/model_lite.h" | |||||
#include "models/model_mdl.h" | |||||
namespace lar { | |||||
template <typename ModelImpl> | |||||
void PackModelOption::config_model_internel( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelImpl> model) { | |||||
if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||||
lite::ModelPacker packer( | |||||
model->get_model_path(), packed_model_dump, pack_info_json, pack_cache, | |||||
pack_binary_cache); | |||||
packer.set_header(pack_info_cryption, pack_model_cryption, is_fast_run_cache); | |||||
packer.pack_model(); | |||||
} | |||||
} | |||||
} // namespace lar | |||||
using namespace lar; | |||||
////////////////////// PackModel options //////////////////////// | |||||
PackModelOption::PackModelOption() { | |||||
m_option_name = "pack_model"; | |||||
if (!FLAGS_packed_model_dump.empty()) | |||||
packed_model_dump = FLAGS_packed_model_dump; | |||||
if (!FLAGS_pack_info_json.empty()) | |||||
pack_info_json = FLAGS_pack_info_json; | |||||
if (!FLAGS_pack_cache.empty()) | |||||
pack_cache = FLAGS_pack_cache; | |||||
if (!FLAGS_pack_info_cryption.empty()) | |||||
pack_info_cryption = FLAGS_pack_info_cryption; | |||||
if (!FLAGS_pack_model_cryption.empty()) | |||||
pack_model_cryption = FLAGS_pack_model_cryption; | |||||
} | |||||
bool PackModelOption::is_valid() { | |||||
return !FLAGS_packed_model_dump.empty(); | |||||
} | |||||
std::shared_ptr<OptionBase> PackModelOption::create_option() { | |||||
static std::shared_ptr<PackModelOption> option(new PackModelOption); | |||||
if (PackModelOption::is_valid()) { | |||||
return std::static_pointer_cast<OptionBase>(option); | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
void PackModelOption::config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||||
CONFIG_MODEL_FUN; | |||||
} | |||||
////////////////////// PackModel gflags //////////////////////// | |||||
DEFINE_string(packed_model_dump, "", "The output file path of packed model."); | |||||
DEFINE_string( | |||||
pack_info_json, "", | |||||
"An encrypted or not encrypted json format file to pack into the model."); | |||||
DEFINE_string(pack_cache, "", "Pack the fastrun cache or algo policy into the model."); | |||||
DEFINE_string( | |||||
pack_info_cryption, "NONE", | |||||
"The info data encryption method name, this is used to find the right " | |||||
"decryption method. --pack-info-cryption [ AES_default | RC4_default | " | |||||
"SIMPLE_FAST_RC4_default ], default is NONE. See " | |||||
"https://megengine.megvii-inc.com/user-guide/deployment/lite/advance/" | |||||
"pack-lite-model.html for more details."); | |||||
DEFINE_string( | |||||
pack_model_cryption, "NONE", | |||||
"The model encryption method name, this is used to find the right decryption " | |||||
"method. --pack-model-cryption [ AES_default | RC4_default | " | |||||
"SIMPLE_FAST_RC4_default ], default is NONE. See " | |||||
"https://megengine.megvii-inc.com/user-guide/deployment/lite/advance/" | |||||
"pack-lite-model.html for more details."); | |||||
REGIST_OPTION_CREATOR(pack_model, lar::PackModelOption::create_option); |
@@ -0,0 +1,45 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/model_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include <gflags/gflags.h> | |||||
#include "models/model.h" | |||||
#include "option_base.h" | |||||
DECLARE_string(packed_model_dump); | |||||
DECLARE_string(pack_info_json); | |||||
DECLARE_string(pack_cache); | |||||
DECLARE_string(pack_info_cryption); | |||||
DECLARE_string(pack_model_cryption); | |||||
namespace lar { | |||||
class PackModelOption : public OptionBase { | |||||
public: | |||||
static bool is_valid(); | |||||
static std::shared_ptr<OptionBase> create_option(); | |||||
void config_model( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||||
std::string option_name() const override { return m_option_name; } | |||||
private: | |||||
PackModelOption(); | |||||
template <typename ModelImpl> | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>); | |||||
std::string m_option_name; | |||||
std::string packed_model_dump; | |||||
std::string pack_info_json; | |||||
std::string pack_cache; | |||||
std::string pack_binary_cache; | |||||
std::string pack_info_cryption; | |||||
std::string pack_model_cryption; | |||||
bool is_fast_run_cache = true; | |||||
}; | |||||
} // namespace lar |
@@ -119,6 +119,12 @@ class TestNetwork(TestShuffleNet): | |||||
network.load(model_path) | network.load(model_path) | ||||
self.do_forward(network) | self.do_forward(network) | ||||
def test_pack_cache_to_model(self): | |||||
model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite") | |||||
network = LiteNetwork() | |||||
network.load(model_path) | |||||
self.do_forward(network) | |||||
def test_network_basic(self): | def test_network_basic(self): | ||||
network = LiteNetwork() | network = LiteNetwork() | ||||
network.load(self.model_path) | network.load(self.model_path) | ||||
@@ -0,0 +1,232 @@ | |||||
/** | |||||
* \file src/pack_model/pack_model.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "lite/pack_model.h" | |||||
#include "../misc.h" | |||||
#if LITE_BUILD_WITH_MGE | |||||
#include "megbrain/utils/infile_persistent_cache.h" | |||||
#endif | |||||
#include <flatbuffers/flatbuffers.h> | |||||
#include "nlohmann/json.hpp" | |||||
#include "pack_model_generated.h" | |||||
namespace lite { | |||||
class FbsHelper { | |||||
public: | |||||
FbsHelper() = default; | |||||
FbsHelper(ModelPacker* packer, std::string model_path); | |||||
flatbuffers::Offset<model_parse::ModelHeader> build_header(); | |||||
flatbuffers::Offset<model_parse::ModelInfo> build_info(); | |||||
flatbuffers::Offset<model_parse::ModelData> build_data(); | |||||
flatbuffers::FlatBufferBuilder& builder() { return m_builder; } | |||||
private: | |||||
ModelPacker* m_packer; | |||||
flatbuffers::FlatBufferBuilder m_builder; | |||||
std::vector<uint8_t> m_model_buffer; | |||||
const model_parse::ModelHeader* m_model_header = nullptr; | |||||
const model_parse::ModelInfo* m_model_info = nullptr; | |||||
const model_parse::ModelData* m_model_data = nullptr; | |||||
}; | |||||
} // namespace lite | |||||
using namespace lite; | |||||
using namespace model_parse; | |||||
std::vector<uint8_t> read_file(std::string path) { | |||||
FILE* fin = fopen(path.c_str(), "rb"); | |||||
LITE_ASSERT(fin, "failed to open %s: %s", path.c_str(), strerror(errno)); | |||||
fseek(fin, 0, SEEK_END); | |||||
size_t size = ftell(fin); | |||||
fseek(fin, 0, SEEK_SET); | |||||
std::vector<uint8_t> buf; | |||||
buf.resize(size); | |||||
auto nr = fread(buf.data(), size, 1, fin); | |||||
LITE_ASSERT(nr == 1); | |||||
fclose(fin); | |||||
return buf; | |||||
} | |||||
FbsHelper::FbsHelper(ModelPacker* packer, std::string model_path) : m_packer(packer) { | |||||
m_model_buffer = read_file(model_path); | |||||
const char* model_ptr = | |||||
static_cast<const char*>(static_cast<void*>(m_model_buffer.data())); | |||||
std::string tag(model_ptr, 12); | |||||
if (tag == "packed_model") { | |||||
uint8_t* buffer = m_model_buffer.data() + 12; | |||||
auto model = GetPackModel(buffer)->models()->Get(0); | |||||
m_model_header = model->header(); | |||||
m_model_info = model->info(); | |||||
m_model_data = model->data(); | |||||
} | |||||
} | |||||
flatbuffers::Offset<ModelHeader> FbsHelper::build_header() { | |||||
flatbuffers::Offset<flatbuffers::String> name, info_decryption_method, | |||||
info_parse_method, model_decryption_method, info_cache_parse_method; | |||||
bool is_fast_run_cache; | |||||
if (m_model_header) { | |||||
auto&& header = m_model_header; | |||||
name = m_builder.CreateSharedString(header->name()); | |||||
info_decryption_method = | |||||
m_builder.CreateSharedString(header->info_decryption_method()); | |||||
info_parse_method = m_builder.CreateSharedString(header->info_parse_method()); | |||||
model_decryption_method = | |||||
m_builder.CreateSharedString(header->model_decryption_method()); | |||||
info_cache_parse_method = | |||||
m_builder.CreateSharedString(header->info_cache_parse_method()); | |||||
is_fast_run_cache = header->is_fast_run_cache(); | |||||
} else { | |||||
auto&& header = m_packer->m_header; | |||||
name = m_builder.CreateSharedString(header.name); | |||||
info_decryption_method = | |||||
m_builder.CreateSharedString(header.info_decryption_method); | |||||
info_parse_method = m_builder.CreateSharedString(header.info_parse_method); | |||||
model_decryption_method = | |||||
m_builder.CreateSharedString(header.model_decryption_method); | |||||
info_cache_parse_method = | |||||
m_builder.CreateSharedString(header.info_cache_parse_method); | |||||
is_fast_run_cache = header.fb32.is_fast_run_cache; | |||||
} | |||||
return CreateModelHeader( | |||||
m_builder, name, info_decryption_method, info_parse_method, | |||||
model_decryption_method, info_cache_parse_method, is_fast_run_cache); | |||||
} | |||||
flatbuffers::Offset<ModelData> FbsHelper::build_data() { | |||||
if (m_model_data) { | |||||
auto data = m_model_data->data()->Data(); | |||||
auto size = m_model_data->data()->size(); | |||||
return CreateModelData(m_builder, m_builder.CreateVector(data, size)); | |||||
} else { | |||||
return CreateModelData(m_builder, m_builder.CreateVector(m_model_buffer)); | |||||
} | |||||
} | |||||
flatbuffers::Offset<ModelInfo> FbsHelper::build_info() { | |||||
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_data; | |||||
if (m_model_info && m_model_info->data() && m_packer->m_info_data_path.empty()) { | |||||
auto data = m_model_info->data()->Data(); | |||||
auto size = m_model_info->data()->size(); | |||||
fb_data = m_builder.CreateVector(data, size); | |||||
} else if (!m_packer->m_info_data_path.empty()) { | |||||
auto info_data = read_file(m_packer->m_info_data_path); | |||||
fb_data = m_builder.CreateVector(info_data); | |||||
} | |||||
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_algo_policy; | |||||
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_binary_cache; | |||||
if (m_packer->m_header.fb32.is_fast_run_cache) { | |||||
std::vector<uint8_t> info_algo_policy; | |||||
if (!m_packer->m_info_algo_policy_path.empty()) { | |||||
info_algo_policy = read_file(m_packer->m_info_algo_policy_path); | |||||
if (m_model_info && m_model_info->algo_policy()) { | |||||
auto cache = m_model_info->algo_policy()->Data(); | |||||
auto size = m_model_info->algo_policy()->size(); | |||||
uint32_t nr_category_1, nr_category_2, nr_category; | |||||
memcpy(&nr_category_1, cache, sizeof(uint32_t)); | |||||
memcpy(&nr_category_2, info_algo_policy.data(), sizeof(uint32_t)); | |||||
nr_category = nr_category_1 + nr_category_2; | |||||
std::vector<uint8_t> cache_append; | |||||
cache_append.resize(sizeof(nr_category)); | |||||
memcpy(cache_append.data(), &nr_category, sizeof(nr_category)); | |||||
cache_append.insert( | |||||
cache_append.end(), cache + sizeof(nr_category), cache + size); | |||||
cache_append.insert( | |||||
cache_append.end(), | |||||
info_algo_policy.begin() + sizeof(nr_category), | |||||
info_algo_policy.end()); | |||||
fb_algo_policy = m_builder.CreateVector(cache_append); | |||||
} else { | |||||
fb_algo_policy = m_builder.CreateVector(info_algo_policy); | |||||
} | |||||
} | |||||
#if LITE_BUILD_WITH_MGE | |||||
else { | |||||
info_algo_policy = static_cast<mgb::InFilePersistentCache&>( | |||||
mgb::PersistentCache::inst()) | |||||
.dump_cache(); | |||||
fb_algo_policy = m_builder.CreateVector(info_algo_policy); | |||||
} | |||||
#endif | |||||
} | |||||
ModelInfoBuilder builder(m_builder); | |||||
builder.add_data(fb_data); | |||||
builder.add_algo_policy(fb_algo_policy); | |||||
builder.add_binary_cache(fb_binary_cache); | |||||
return builder.Finish(); | |||||
} | |||||
ModelPacker::ModelPacker( | |||||
std::string model_path, std::string packed_model_path, | |||||
std::string info_data_path, std::string info_algo_policy_path, | |||||
std::string info_binary_cache_path) | |||||
: m_packed_model_path(packed_model_path), | |||||
m_info_data_path(info_data_path), | |||||
m_info_algo_policy_path(info_algo_policy_path), | |||||
m_info_binary_cache_path(info_binary_cache_path) { | |||||
m_fbs_helper = new FbsHelper(this, model_path); | |||||
} | |||||
void ModelPacker::set_header( | |||||
std::string model_decryption_method, std::string info_decryption_method, | |||||
bool is_fast_run_cache) { | |||||
m_header.model_decryption_method = model_decryption_method; | |||||
m_header.info_decryption_method = info_decryption_method; | |||||
memset(&m_header.fb32, 0, sizeof(m_header.fb32)); | |||||
m_header.fb32.is_fast_run_cache = is_fast_run_cache; | |||||
if (!m_info_data_path.empty()) { | |||||
auto buf = read_file(m_info_data_path); | |||||
std::string json_string( | |||||
static_cast<const char*>(static_cast<void*>(buf.data())), buf.size()); | |||||
auto info = nlohmann::json::parse(json_string); | |||||
m_header.name = info["name"]; | |||||
} | |||||
} | |||||
void ModelPacker::pack_model() { | |||||
auto fb_header = m_fbs_helper->build_header(); | |||||
auto fb_info = m_fbs_helper->build_info(); | |||||
auto fb_data = m_fbs_helper->build_data(); | |||||
ModelBuilder model_builder(m_fbs_helper->builder()); | |||||
model_builder.add_header(fb_header); | |||||
model_builder.add_info(fb_info); | |||||
model_builder.add_data(fb_data); | |||||
auto model = model_builder.Finish(); | |||||
std::vector<flatbuffers::Offset<Model>> models; | |||||
models.emplace_back(model); | |||||
auto fb_models = m_fbs_helper->builder().CreateVector(models); | |||||
PackModelBuilder pack_model_builder(m_fbs_helper->builder()); | |||||
pack_model_builder.add_models(fb_models); | |||||
m_fbs_helper->builder().Finish(pack_model_builder.Finish()); | |||||
FILE* fptr = fopen(m_packed_model_path.c_str(), "wb"); | |||||
std::string packed_model_tag = "packed_model"; | |||||
auto nr_tag = fwrite(packed_model_tag.c_str(), 1, packed_model_tag.size(), fptr); | |||||
LITE_ASSERT(nr_tag == packed_model_tag.size()); | |||||
auto fb_size = m_fbs_helper->builder().GetSize(); | |||||
auto nr_fb = fwrite(m_fbs_helper->builder().GetBufferPointer(), 1, fb_size, fptr); | |||||
LITE_ASSERT(nr_fb == fb_size); | |||||
fclose(fptr); | |||||
} |
@@ -0,0 +1,36 @@ | |||||
/** | |||||
* \file src/parse_info/cache_parse.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "lite/global.h" | |||||
#if LITE_BUILD_WITH_MGE | |||||
#include "megbrain/utils/infile_persistent_cache.h" | |||||
#endif | |||||
namespace lite { | |||||
//! The LITE_parse_cache parse info function | |||||
bool parse_info_cache( | |||||
const uint8_t* cache, size_t cache_length, bool is_fast_run_cache = true, | |||||
const uint8_t* binary_cache = nullptr, size_t binary_cache_length = 0) { | |||||
LITE_MARK_USED_VAR(binary_cache); | |||||
LITE_MARK_USED_VAR(binary_cache_length); | |||||
#if LITE_BUILD_WITH_MGE | |||||
if (is_fast_run_cache) { | |||||
mgb::PersistentCache::set_impl( | |||||
std::make_shared<mgb::InFilePersistentCache>(cache, cache_length)); | |||||
} | |||||
#endif | |||||
return true; | |||||
} | |||||
} // namespace lite | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -11,6 +11,7 @@ | |||||
#include "model_parser.h" | #include "model_parser.h" | ||||
#include "decryption/decrypt_base.h" | #include "decryption/decrypt_base.h" | ||||
#include "parse_info/cache_parse.h" | |||||
#include "parse_info/parse_info_base.h" | #include "parse_info/parse_info_base.h" | ||||
using namespace lite; | using namespace lite; | ||||
@@ -41,6 +42,10 @@ void ModelParser::parse_header() { | |||||
m_model_decryption_name = model->header()->model_decryption_method()->c_str(); | m_model_decryption_name = model->header()->model_decryption_method()->c_str(); | ||||
m_info_decryption_name = model->header()->info_decryption_method()->c_str(); | m_info_decryption_name = model->header()->info_decryption_method()->c_str(); | ||||
m_info_parse_func_name = model->header()->info_parse_method()->c_str(); | m_info_parse_func_name = model->header()->info_parse_method()->c_str(); | ||||
if (model->header()->info_cache_parse_method()) | |||||
m_info_cache_parse_func_name = | |||||
model->header()->info_cache_parse_method()->c_str(); | |||||
m_is_fast_run_cache = model->header()->is_fast_run_cache(); | |||||
m_info = model->info(); | m_info = model->info(); | ||||
m_model_data = model->data(); | m_model_data = model->data(); | ||||
@@ -54,31 +59,52 @@ bool ModelParser::parse_model_info( | |||||
if (m_is_bare_model || !m_info) { | if (m_is_bare_model || !m_info) { | ||||
return false; | return false; | ||||
} | } | ||||
size_t info_length = m_info->data()->size(); | |||||
const uint8_t* info_data = m_info->data()->Data(); | |||||
//! decryption the info | |||||
auto info_ptr = | |||||
decrypt_memory(info_data, info_length, m_info_decryption_name, info_length); | |||||
//! parse the info | |||||
LITE_LOCK_GUARD(parse_info_static_data().map_mutex); | |||||
auto it_parse = | |||||
parse_info_static_data().parse_info_methods.find(m_info_parse_func_name); | |||||
if (it_parse == parse_info_static_data().parse_info_methods.end()) { | |||||
LITE_THROW(ssprintf( | |||||
"can't find model info parse function %s.", | |||||
m_info_parse_func_name.c_str())); | |||||
//! parse ModelInfo::data | |||||
if (m_info->data()) { | |||||
size_t info_length = m_info->data()->size(); | |||||
const uint8_t* info_data = m_info->data()->Data(); | |||||
//! decryption the info | |||||
auto info_ptr = decrypt_memory( | |||||
info_data, info_length, m_info_decryption_name, info_length); | |||||
//! parse the info | |||||
LITE_LOCK_GUARD(parse_info_static_data().map_mutex); | |||||
auto it_parse = parse_info_static_data().parse_info_methods.find( | |||||
m_info_parse_func_name); | |||||
if (it_parse == parse_info_static_data().parse_info_methods.end()) { | |||||
LITE_THROW(ssprintf( | |||||
"can't find model info parse function %s.", | |||||
m_info_parse_func_name.c_str())); | |||||
} | |||||
auto model_info_parse_func = | |||||
parse_info_static_data().parse_info_methods[m_info_parse_func_name]; | |||||
//! convert for NetworkIOInner to NetworkIO | |||||
if (model_info_parse_func) { | |||||
model_info_parse_func( | |||||
info_ptr.get(), info_length, m_model_name, network_config, | |||||
network_io, isolated_config_map, extra_info); | |||||
} else { | |||||
LITE_THROW(ssprintf( | |||||
"model info parse function of %s is empty", | |||||
m_info_parse_func_name.c_str())); | |||||
} | |||||
} | } | ||||
auto model_info_parse_func = | |||||
parse_info_static_data().parse_info_methods[m_info_parse_func_name]; | |||||
//! convert for NetworkIOInner to NetworkIO | |||||
if (model_info_parse_func) { | |||||
model_info_parse_func( | |||||
info_ptr.get(), info_length, m_model_name, network_config, network_io, | |||||
isolated_config_map, extra_info); | |||||
} else { | |||||
LITE_THROW(ssprintf( | |||||
"model info parse function of %s is empty", | |||||
m_info_parse_func_name.c_str())); | |||||
//! parse ModelInfo::algo_policy | |||||
if (m_info->algo_policy()) { | |||||
size_t cache_length = m_info->algo_policy()->size(); | |||||
const uint8_t* cache = m_info->algo_policy()->Data(); | |||||
if (m_info_cache_parse_func_name == "LITE_parse_cache") { | |||||
if (m_is_fast_run_cache) { | |||||
parse_info_cache(cache, cache_length); | |||||
} else if (m_info->binary_cache()) { | |||||
size_t binary_cache_length = m_info->binary_cache()->size(); | |||||
const uint8_t* binary_cache = m_info->binary_cache()->Data(); | |||||
parse_info_cache( | |||||
cache, cache_length, m_is_fast_run_cache, binary_cache, | |||||
binary_cache_length); | |||||
} else { | |||||
LITE_THROW("opencl binary cache is not given"); | |||||
} | |||||
} | |||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
@@ -60,6 +60,8 @@ private: | |||||
std::string m_model_decryption_name; | std::string m_model_decryption_name; | ||||
//! the function name to parse the model info | //! the function name to parse the model info | ||||
std::string m_info_parse_func_name; | std::string m_info_parse_func_name; | ||||
std::string m_info_cache_parse_func_name; | |||||
bool m_is_fast_run_cache; | |||||
//! if a model is not added json info to the model is not crypted, the | //! if a model is not added json info to the model is not crypted, the | ||||
//! model is a bare model | //! model is a bare model | ||||
bool m_is_bare_model = true; | bool m_is_bare_model = true; | ||||
@@ -5,10 +5,14 @@ table ModelHeader { | |||||
info_decryption_method:string; | info_decryption_method:string; | ||||
info_parse_method:string; | info_parse_method:string; | ||||
model_decryption_method:string; | model_decryption_method:string; | ||||
info_cache_parse_method:string; | |||||
is_fast_run_cache:bool; | |||||
} | } | ||||
table ModelInfo { | table ModelInfo { | ||||
data:[ubyte]; | data:[ubyte]; | ||||
algo_policy:[ubyte]; | |||||
binary_cache:[ubyte]; | |||||
} | } | ||||
table ModelData { | table ModelData { | ||||
@@ -970,6 +970,25 @@ TEST(TestNetWork, LoadPackedModel) { | |||||
network->wait(); | network->wait(); | ||||
} | } | ||||
TEST(TestNetWork, LoadPackedCacheModel) { | |||||
auto tensor = get_input_data("./input_data.npy"); | |||||
std::string model_path = "./test_pack_cache_to_model.lite"; | |||||
std::string input_name = "data"; | |||||
NetworkIO IO; | |||||
Config config; | |||||
std::shared_ptr<Network> network = std::make_shared<Network>(config, IO); | |||||
network->load_model(model_path); | |||||
std::shared_ptr<Tensor> input_tensor = network->get_io_tensor(input_name); | |||||
auto src_ptr = tensor->get_memory_ptr(); | |||||
auto src_layout = tensor->get_layout(); | |||||
input_tensor->reset(src_ptr, src_layout); | |||||
network->forward(); | |||||
network->wait(); | |||||
} | |||||
TEST(TestNetWork, GlabalLayoutTransform) { | TEST(TestNetWork, GlabalLayoutTransform) { | ||||
auto tensor = get_input_data("./input_data.npy"); | auto tensor = get_input_data("./input_data.npy"); | ||||
std::string model_path = "./shufflenet.mge"; | std::string model_path = "./shufflenet.mge"; | ||||
@@ -216,6 +216,46 @@ void InFilePersistentCache::dump_cache(OutputFile* out_file) { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
std::vector<uint8_t> InFilePersistentCache::dump_cache() { | |||||
std::vector<uint8_t> ret; | |||||
uint32_t nr_category = m_cache.size(); | |||||
ret.resize(sizeof(nr_category)); | |||||
memcpy(ret.data(), &nr_category, sizeof(nr_category)); | |||||
auto write_to_buffer = [&ret](uint32_t val) { | |||||
std::vector<uint8_t> vec(sizeof(val)); | |||||
memcpy(vec.data(), &val, sizeof(val)); | |||||
ret.insert(ret.end(), vec.begin(), vec.end()); | |||||
}; | |||||
for (const auto& cached_category : m_cache) { | |||||
uint32_t category_size = cached_category.first.size(); | |||||
write_to_buffer(category_size); | |||||
std::vector<uint8_t> category( | |||||
cached_category.first.begin(), cached_category.first.end()); | |||||
ret.insert(ret.end(), category.begin(), category.end()); | |||||
uint32_t nr_bobs = cached_category.second.size(); | |||||
write_to_buffer(nr_bobs); | |||||
for (const auto& item : cached_category.second) { | |||||
uint32_t size_first = item.first.size; | |||||
write_to_buffer(size_first); | |||||
ret.insert( | |||||
ret.end(), item.first.data_refhold.get(), | |||||
item.first.data_refhold.get() + size_first); | |||||
uint32_t size_second = item.second.size; | |||||
write_to_buffer(size_second); | |||||
ret.insert( | |||||
ret.end(), item.second.data_refhold.get(), | |||||
item.second.data_refhold.get() + size_second); | |||||
} | |||||
} | |||||
return ret; | |||||
} | |||||
Maybe<InFilePersistentCache::Blob> InFilePersistentCache::get( | Maybe<InFilePersistentCache::Blob> InFilePersistentCache::get( | ||||
const std::string& category, const Blob& key) { | const std::string& category, const Blob& key) { | ||||
decltype(m_cache.begin()) iter0; | decltype(m_cache.begin()) iter0; | ||||
@@ -71,6 +71,7 @@ public: | |||||
*/ | */ | ||||
MGE_WIN_DECLSPEC_FUC void dump_cache(const char* path); | MGE_WIN_DECLSPEC_FUC void dump_cache(const char* path); | ||||
MGE_WIN_DECLSPEC_FUC void dump_cache(OutputFile* out_file); | MGE_WIN_DECLSPEC_FUC void dump_cache(OutputFile* out_file); | ||||
MGE_WIN_DECLSPEC_FUC std::vector<uint8_t> dump_cache(); | |||||
MGE_WIN_DECLSPEC_FUC Maybe<Blob> get( | MGE_WIN_DECLSPEC_FUC Maybe<Blob> get( | ||||
const std::string& category, const Blob& key) override; | const std::string& category, const Blob& key) override; | ||||