@@ -0,0 +1,69 @@ | |||
/** | |||
* \file inlude/lite/pack_model.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <string> | |||
namespace lite { | |||
struct FeatureBits32 { | |||
uint32_t is_fast_run_cache : 1; | |||
//! reserved for new fields | |||
uint32_t : 31; | |||
}; | |||
struct Header { | |||
std::string name; //! model name | |||
std::string | |||
model_decryption_method; //! model encryption method name, this is used to | |||
//! find the right decryption method. [ | |||
//! AES_default | RC4_default | | |||
//! SIMPLE_FAST_RC4_default ], default is NONE. | |||
std::string info_decryption_method; //! info data encryption method name, this is | |||
//! used to find the right decryption method. [ | |||
//! AES_default | RC4_default | | |||
//! SIMPLE_FAST_RC4_default ], default is NONE. | |||
std::string info_parse_method = "LITE_default"; //! info parse method name. | |||
std::string info_cache_parse_method = | |||
"LITE_parse_cache"; //! fastrun cache parse method name. | |||
FeatureBits32 fb32; | |||
}; | |||
class FbsHelper; | |||
class ModelPacker { | |||
public: | |||
ModelPacker( | |||
std::string model_path, std::string packed_model_path, | |||
std::string info_data_path = "", std::string info_algo_policy_path = "", | |||
std::string info_binary_cache_path = ""); | |||
void set_header( | |||
std::string model_decryption_method = "NONE", | |||
std::string info_decryption_method = "NONE", bool is_fast_run_cache = true); | |||
void pack_model(); | |||
private: | |||
std::string m_packed_model_path; | |||
std::string m_info_data_path; | |||
//! fastrun cache / algo policy | |||
std::string m_info_algo_policy_path; | |||
//! binary cache | |||
std::string m_info_binary_cache_path; | |||
Header m_header; | |||
friend class FbsHelper; | |||
FbsHelper* m_fbs_helper; | |||
}; | |||
} // namespace lite |
@@ -1,7 +1,7 @@ | |||
# BUILD the load and run for lite | |||
include_directories(PUBLIC | |||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/lite/load_and_run/src>) | |||
file(GLOB_RECURSE SOURCES ./*.cpp) | |||
file(GLOB_RECURSE SOURCES ./*.cpp ${PROJECT_SOURCE_DIR}/lite/src/pack_model/*.cpp) | |||
add_executable(load_and_run ${SOURCES}) | |||
target_link_libraries(load_and_run lite_static) | |||
@@ -43,6 +43,8 @@ public: | |||
virtual void wait() = 0; | |||
virtual ~ModelBase() = default; | |||
virtual const std::string& get_model_path() const = 0; | |||
}; | |||
} // namespace lar | |||
@@ -60,6 +60,8 @@ public: | |||
//! get algo strategy | |||
Strategy& get_lite_strategy() { return m_strategy; } | |||
const std::string& get_model_path() const override { return model_path; } | |||
private: | |||
bool share_model_mem; | |||
bool enable_layout_transform; | |||
@@ -107,6 +107,8 @@ public: | |||
std::move(out_file), m_format.val()); | |||
} | |||
const std::string& get_model_path() const override { return model_path; } | |||
private: | |||
bool share_model_mem; | |||
std::string model_path; | |||
@@ -0,0 +1,87 @@ | |||
/** | |||
* \file lite/load_and_run/src/options/model_options.cpp | |||
* | |||
* This file is part of MegEngine, a deep learning framework developed by | |||
* Megvii. | |||
* | |||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
*/ | |||
#include "model_options.h" | |||
#include "device_options.h" | |||
#include "lite/pack_model.h" | |||
#include "misc.h" | |||
#include "models/model_lite.h" | |||
#include "models/model_mdl.h" | |||
namespace lar { | |||
template <typename ModelImpl> | |||
void PackModelOption::config_model_internel( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelImpl> model) { | |||
if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||
lite::ModelPacker packer( | |||
model->get_model_path(), packed_model_dump, pack_info_json, pack_cache, | |||
pack_binary_cache); | |||
packer.set_header(pack_info_cryption, pack_model_cryption, is_fast_run_cache); | |||
packer.pack_model(); | |||
} | |||
} | |||
} // namespace lar | |||
using namespace lar; | |||
////////////////////// PackModel options //////////////////////// | |||
PackModelOption::PackModelOption() { | |||
m_option_name = "pack_model"; | |||
if (!FLAGS_packed_model_dump.empty()) | |||
packed_model_dump = FLAGS_packed_model_dump; | |||
if (!FLAGS_pack_info_json.empty()) | |||
pack_info_json = FLAGS_pack_info_json; | |||
if (!FLAGS_pack_cache.empty()) | |||
pack_cache = FLAGS_pack_cache; | |||
if (!FLAGS_pack_info_cryption.empty()) | |||
pack_info_cryption = FLAGS_pack_info_cryption; | |||
if (!FLAGS_pack_model_cryption.empty()) | |||
pack_model_cryption = FLAGS_pack_model_cryption; | |||
} | |||
bool PackModelOption::is_valid() { | |||
return !FLAGS_packed_model_dump.empty(); | |||
} | |||
std::shared_ptr<OptionBase> PackModelOption::create_option() { | |||
static std::shared_ptr<PackModelOption> option(new PackModelOption); | |||
if (PackModelOption::is_valid()) { | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
} | |||
} | |||
void PackModelOption::config_model( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
CONFIG_MODEL_FUN; | |||
} | |||
////////////////////// PackModel gflags //////////////////////// | |||
DEFINE_string(packed_model_dump, "", "The output file path of packed model."); | |||
DEFINE_string( | |||
pack_info_json, "", | |||
"An encrypted or not encrypted json format file to pack into the model."); | |||
DEFINE_string(pack_cache, "", "Pack the fastrun cache or algo policy into the model."); | |||
DEFINE_string( | |||
pack_info_cryption, "NONE", | |||
"The info data encryption method name, this is used to find the right " | |||
"decryption method. --pack-info-cryption [ AES_default | RC4_default | " | |||
"SIMPLE_FAST_RC4_default ], default is NONE. See " | |||
"https://megengine.megvii-inc.com/user-guide/deployment/lite/advance/" | |||
"pack-lite-model.html for more details."); | |||
DEFINE_string( | |||
pack_model_cryption, "NONE", | |||
"The model encryption method name, this is used to find the right decryption " | |||
"method. --pack-model-cryption [ AES_default | RC4_default | " | |||
"SIMPLE_FAST_RC4_default ], default is NONE. See " | |||
"https://megengine.megvii-inc.com/user-guide/deployment/lite/advance/" | |||
"pack-lite-model.html for more details."); | |||
REGIST_OPTION_CREATOR(pack_model, lar::PackModelOption::create_option); |
@@ -0,0 +1,45 @@ | |||
/** | |||
* \file lite/load_and_run/src/options/model_options.h | |||
* | |||
* This file is part of MegEngine, a deep learning framework developed by | |||
* Megvii. | |||
* | |||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||
*/ | |||
#pragma once | |||
#include <gflags/gflags.h> | |||
#include "models/model.h" | |||
#include "option_base.h" | |||
DECLARE_string(packed_model_dump); | |||
DECLARE_string(pack_info_json); | |||
DECLARE_string(pack_cache); | |||
DECLARE_string(pack_info_cryption); | |||
DECLARE_string(pack_model_cryption); | |||
namespace lar { | |||
class PackModelOption : public OptionBase { | |||
public: | |||
static bool is_valid(); | |||
static std::shared_ptr<OptionBase> create_option(); | |||
void config_model( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
std::string option_name() const override { return m_option_name; } | |||
private: | |||
PackModelOption(); | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>); | |||
std::string m_option_name; | |||
std::string packed_model_dump; | |||
std::string pack_info_json; | |||
std::string pack_cache; | |||
std::string pack_binary_cache; | |||
std::string pack_info_cryption; | |||
std::string pack_model_cryption; | |||
bool is_fast_run_cache = true; | |||
}; | |||
} // namespace lar |
@@ -119,6 +119,12 @@ class TestNetwork(TestShuffleNet): | |||
network.load(model_path) | |||
self.do_forward(network) | |||
def test_pack_cache_to_model(self): | |||
model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite") | |||
network = LiteNetwork() | |||
network.load(model_path) | |||
self.do_forward(network) | |||
def test_network_basic(self): | |||
network = LiteNetwork() | |||
network.load(self.model_path) | |||
@@ -0,0 +1,232 @@ | |||
/** | |||
* \file src/pack_model/pack_model.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "lite/pack_model.h" | |||
#include "../misc.h" | |||
#if LITE_BUILD_WITH_MGE | |||
#include "megbrain/utils/infile_persistent_cache.h" | |||
#endif | |||
#include <flatbuffers/flatbuffers.h> | |||
#include "nlohmann/json.hpp" | |||
#include "pack_model_generated.h" | |||
namespace lite { | |||
class FbsHelper { | |||
public: | |||
FbsHelper() = default; | |||
FbsHelper(ModelPacker* packer, std::string model_path); | |||
flatbuffers::Offset<model_parse::ModelHeader> build_header(); | |||
flatbuffers::Offset<model_parse::ModelInfo> build_info(); | |||
flatbuffers::Offset<model_parse::ModelData> build_data(); | |||
flatbuffers::FlatBufferBuilder& builder() { return m_builder; } | |||
private: | |||
ModelPacker* m_packer; | |||
flatbuffers::FlatBufferBuilder m_builder; | |||
std::vector<uint8_t> m_model_buffer; | |||
const model_parse::ModelHeader* m_model_header = nullptr; | |||
const model_parse::ModelInfo* m_model_info = nullptr; | |||
const model_parse::ModelData* m_model_data = nullptr; | |||
}; | |||
} // namespace lite | |||
using namespace lite; | |||
using namespace model_parse; | |||
std::vector<uint8_t> read_file(std::string path) { | |||
FILE* fin = fopen(path.c_str(), "rb"); | |||
LITE_ASSERT(fin, "failed to open %s: %s", path.c_str(), strerror(errno)); | |||
fseek(fin, 0, SEEK_END); | |||
size_t size = ftell(fin); | |||
fseek(fin, 0, SEEK_SET); | |||
std::vector<uint8_t> buf; | |||
buf.resize(size); | |||
auto nr = fread(buf.data(), size, 1, fin); | |||
LITE_ASSERT(nr == 1); | |||
fclose(fin); | |||
return buf; | |||
} | |||
FbsHelper::FbsHelper(ModelPacker* packer, std::string model_path) : m_packer(packer) { | |||
m_model_buffer = read_file(model_path); | |||
const char* model_ptr = | |||
static_cast<const char*>(static_cast<void*>(m_model_buffer.data())); | |||
std::string tag(model_ptr, 12); | |||
if (tag == "packed_model") { | |||
uint8_t* buffer = m_model_buffer.data() + 12; | |||
auto model = GetPackModel(buffer)->models()->Get(0); | |||
m_model_header = model->header(); | |||
m_model_info = model->info(); | |||
m_model_data = model->data(); | |||
} | |||
} | |||
flatbuffers::Offset<ModelHeader> FbsHelper::build_header() { | |||
flatbuffers::Offset<flatbuffers::String> name, info_decryption_method, | |||
info_parse_method, model_decryption_method, info_cache_parse_method; | |||
bool is_fast_run_cache; | |||
if (m_model_header) { | |||
auto&& header = m_model_header; | |||
name = m_builder.CreateSharedString(header->name()); | |||
info_decryption_method = | |||
m_builder.CreateSharedString(header->info_decryption_method()); | |||
info_parse_method = m_builder.CreateSharedString(header->info_parse_method()); | |||
model_decryption_method = | |||
m_builder.CreateSharedString(header->model_decryption_method()); | |||
info_cache_parse_method = | |||
m_builder.CreateSharedString(header->info_cache_parse_method()); | |||
is_fast_run_cache = header->is_fast_run_cache(); | |||
} else { | |||
auto&& header = m_packer->m_header; | |||
name = m_builder.CreateSharedString(header.name); | |||
info_decryption_method = | |||
m_builder.CreateSharedString(header.info_decryption_method); | |||
info_parse_method = m_builder.CreateSharedString(header.info_parse_method); | |||
model_decryption_method = | |||
m_builder.CreateSharedString(header.model_decryption_method); | |||
info_cache_parse_method = | |||
m_builder.CreateSharedString(header.info_cache_parse_method); | |||
is_fast_run_cache = header.fb32.is_fast_run_cache; | |||
} | |||
return CreateModelHeader( | |||
m_builder, name, info_decryption_method, info_parse_method, | |||
model_decryption_method, info_cache_parse_method, is_fast_run_cache); | |||
} | |||
flatbuffers::Offset<ModelData> FbsHelper::build_data() { | |||
if (m_model_data) { | |||
auto data = m_model_data->data()->Data(); | |||
auto size = m_model_data->data()->size(); | |||
return CreateModelData(m_builder, m_builder.CreateVector(data, size)); | |||
} else { | |||
return CreateModelData(m_builder, m_builder.CreateVector(m_model_buffer)); | |||
} | |||
} | |||
flatbuffers::Offset<ModelInfo> FbsHelper::build_info() { | |||
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_data; | |||
if (m_model_info && m_model_info->data() && m_packer->m_info_data_path.empty()) { | |||
auto data = m_model_info->data()->Data(); | |||
auto size = m_model_info->data()->size(); | |||
fb_data = m_builder.CreateVector(data, size); | |||
} else if (!m_packer->m_info_data_path.empty()) { | |||
auto info_data = read_file(m_packer->m_info_data_path); | |||
fb_data = m_builder.CreateVector(info_data); | |||
} | |||
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_algo_policy; | |||
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_binary_cache; | |||
if (m_packer->m_header.fb32.is_fast_run_cache) { | |||
std::vector<uint8_t> info_algo_policy; | |||
if (!m_packer->m_info_algo_policy_path.empty()) { | |||
info_algo_policy = read_file(m_packer->m_info_algo_policy_path); | |||
if (m_model_info && m_model_info->algo_policy()) { | |||
auto cache = m_model_info->algo_policy()->Data(); | |||
auto size = m_model_info->algo_policy()->size(); | |||
uint32_t nr_category_1, nr_category_2, nr_category; | |||
memcpy(&nr_category_1, cache, sizeof(uint32_t)); | |||
memcpy(&nr_category_2, info_algo_policy.data(), sizeof(uint32_t)); | |||
nr_category = nr_category_1 + nr_category_2; | |||
std::vector<uint8_t> cache_append; | |||
cache_append.resize(sizeof(nr_category)); | |||
memcpy(cache_append.data(), &nr_category, sizeof(nr_category)); | |||
cache_append.insert( | |||
cache_append.end(), cache + sizeof(nr_category), cache + size); | |||
cache_append.insert( | |||
cache_append.end(), | |||
info_algo_policy.begin() + sizeof(nr_category), | |||
info_algo_policy.end()); | |||
fb_algo_policy = m_builder.CreateVector(cache_append); | |||
} else { | |||
fb_algo_policy = m_builder.CreateVector(info_algo_policy); | |||
} | |||
} | |||
#if LITE_BUILD_WITH_MGE | |||
else { | |||
info_algo_policy = static_cast<mgb::InFilePersistentCache&>( | |||
mgb::PersistentCache::inst()) | |||
.dump_cache(); | |||
fb_algo_policy = m_builder.CreateVector(info_algo_policy); | |||
} | |||
#endif | |||
} | |||
ModelInfoBuilder builder(m_builder); | |||
builder.add_data(fb_data); | |||
builder.add_algo_policy(fb_algo_policy); | |||
builder.add_binary_cache(fb_binary_cache); | |||
return builder.Finish(); | |||
} | |||
ModelPacker::ModelPacker( | |||
std::string model_path, std::string packed_model_path, | |||
std::string info_data_path, std::string info_algo_policy_path, | |||
std::string info_binary_cache_path) | |||
: m_packed_model_path(packed_model_path), | |||
m_info_data_path(info_data_path), | |||
m_info_algo_policy_path(info_algo_policy_path), | |||
m_info_binary_cache_path(info_binary_cache_path) { | |||
m_fbs_helper = new FbsHelper(this, model_path); | |||
} | |||
void ModelPacker::set_header( | |||
std::string model_decryption_method, std::string info_decryption_method, | |||
bool is_fast_run_cache) { | |||
m_header.model_decryption_method = model_decryption_method; | |||
m_header.info_decryption_method = info_decryption_method; | |||
memset(&m_header.fb32, 0, sizeof(m_header.fb32)); | |||
m_header.fb32.is_fast_run_cache = is_fast_run_cache; | |||
if (!m_info_data_path.empty()) { | |||
auto buf = read_file(m_info_data_path); | |||
std::string json_string( | |||
static_cast<const char*>(static_cast<void*>(buf.data())), buf.size()); | |||
auto info = nlohmann::json::parse(json_string); | |||
m_header.name = info["name"]; | |||
} | |||
} | |||
void ModelPacker::pack_model() { | |||
auto fb_header = m_fbs_helper->build_header(); | |||
auto fb_info = m_fbs_helper->build_info(); | |||
auto fb_data = m_fbs_helper->build_data(); | |||
ModelBuilder model_builder(m_fbs_helper->builder()); | |||
model_builder.add_header(fb_header); | |||
model_builder.add_info(fb_info); | |||
model_builder.add_data(fb_data); | |||
auto model = model_builder.Finish(); | |||
std::vector<flatbuffers::Offset<Model>> models; | |||
models.emplace_back(model); | |||
auto fb_models = m_fbs_helper->builder().CreateVector(models); | |||
PackModelBuilder pack_model_builder(m_fbs_helper->builder()); | |||
pack_model_builder.add_models(fb_models); | |||
m_fbs_helper->builder().Finish(pack_model_builder.Finish()); | |||
FILE* fptr = fopen(m_packed_model_path.c_str(), "wb"); | |||
std::string packed_model_tag = "packed_model"; | |||
auto nr_tag = fwrite(packed_model_tag.c_str(), 1, packed_model_tag.size(), fptr); | |||
LITE_ASSERT(nr_tag == packed_model_tag.size()); | |||
auto fb_size = m_fbs_helper->builder().GetSize(); | |||
auto nr_fb = fwrite(m_fbs_helper->builder().GetBufferPointer(), 1, fb_size, fptr); | |||
LITE_ASSERT(nr_fb == fb_size); | |||
fclose(fptr); | |||
} |
@@ -0,0 +1,36 @@ | |||
/** | |||
* \file src/parse_info/cache_parse.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "lite/global.h" | |||
#if LITE_BUILD_WITH_MGE | |||
#include "megbrain/utils/infile_persistent_cache.h" | |||
#endif | |||
namespace lite { | |||
//! The LITE_parse_cache parse info function | |||
bool parse_info_cache( | |||
const uint8_t* cache, size_t cache_length, bool is_fast_run_cache = true, | |||
const uint8_t* binary_cache = nullptr, size_t binary_cache_length = 0) { | |||
LITE_MARK_USED_VAR(binary_cache); | |||
LITE_MARK_USED_VAR(binary_cache_length); | |||
#if LITE_BUILD_WITH_MGE | |||
if (is_fast_run_cache) { | |||
mgb::PersistentCache::set_impl( | |||
std::make_shared<mgb::InFilePersistentCache>(cache, cache_length)); | |||
} | |||
#endif | |||
return true; | |||
} | |||
} // namespace lite | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -11,6 +11,7 @@ | |||
#include "model_parser.h" | |||
#include "decryption/decrypt_base.h" | |||
#include "parse_info/cache_parse.h" | |||
#include "parse_info/parse_info_base.h" | |||
using namespace lite; | |||
@@ -41,6 +42,10 @@ void ModelParser::parse_header() { | |||
m_model_decryption_name = model->header()->model_decryption_method()->c_str(); | |||
m_info_decryption_name = model->header()->info_decryption_method()->c_str(); | |||
m_info_parse_func_name = model->header()->info_parse_method()->c_str(); | |||
if (model->header()->info_cache_parse_method()) | |||
m_info_cache_parse_func_name = | |||
model->header()->info_cache_parse_method()->c_str(); | |||
m_is_fast_run_cache = model->header()->is_fast_run_cache(); | |||
m_info = model->info(); | |||
m_model_data = model->data(); | |||
@@ -54,31 +59,52 @@ bool ModelParser::parse_model_info( | |||
if (m_is_bare_model || !m_info) { | |||
return false; | |||
} | |||
size_t info_length = m_info->data()->size(); | |||
const uint8_t* info_data = m_info->data()->Data(); | |||
//! decryption the info | |||
auto info_ptr = | |||
decrypt_memory(info_data, info_length, m_info_decryption_name, info_length); | |||
//! parse the info | |||
LITE_LOCK_GUARD(parse_info_static_data().map_mutex); | |||
auto it_parse = | |||
parse_info_static_data().parse_info_methods.find(m_info_parse_func_name); | |||
if (it_parse == parse_info_static_data().parse_info_methods.end()) { | |||
LITE_THROW(ssprintf( | |||
"can't find model info parse function %s.", | |||
m_info_parse_func_name.c_str())); | |||
//! parse ModelInfo::data | |||
if (m_info->data()) { | |||
size_t info_length = m_info->data()->size(); | |||
const uint8_t* info_data = m_info->data()->Data(); | |||
//! decryption the info | |||
auto info_ptr = decrypt_memory( | |||
info_data, info_length, m_info_decryption_name, info_length); | |||
//! parse the info | |||
LITE_LOCK_GUARD(parse_info_static_data().map_mutex); | |||
auto it_parse = parse_info_static_data().parse_info_methods.find( | |||
m_info_parse_func_name); | |||
if (it_parse == parse_info_static_data().parse_info_methods.end()) { | |||
LITE_THROW(ssprintf( | |||
"can't find model info parse function %s.", | |||
m_info_parse_func_name.c_str())); | |||
} | |||
auto model_info_parse_func = | |||
parse_info_static_data().parse_info_methods[m_info_parse_func_name]; | |||
//! convert for NetworkIOInner to NetworkIO | |||
if (model_info_parse_func) { | |||
model_info_parse_func( | |||
info_ptr.get(), info_length, m_model_name, network_config, | |||
network_io, isolated_config_map, extra_info); | |||
} else { | |||
LITE_THROW(ssprintf( | |||
"model info parse function of %s is empty", | |||
m_info_parse_func_name.c_str())); | |||
} | |||
} | |||
auto model_info_parse_func = | |||
parse_info_static_data().parse_info_methods[m_info_parse_func_name]; | |||
//! convert for NetworkIOInner to NetworkIO | |||
if (model_info_parse_func) { | |||
model_info_parse_func( | |||
info_ptr.get(), info_length, m_model_name, network_config, network_io, | |||
isolated_config_map, extra_info); | |||
} else { | |||
LITE_THROW(ssprintf( | |||
"model info parse function of %s is empty", | |||
m_info_parse_func_name.c_str())); | |||
//! parse ModelInfo::algo_policy | |||
if (m_info->algo_policy()) { | |||
size_t cache_length = m_info->algo_policy()->size(); | |||
const uint8_t* cache = m_info->algo_policy()->Data(); | |||
if (m_info_cache_parse_func_name == "LITE_parse_cache") { | |||
if (m_is_fast_run_cache) { | |||
parse_info_cache(cache, cache_length); | |||
} else if (m_info->binary_cache()) { | |||
size_t binary_cache_length = m_info->binary_cache()->size(); | |||
const uint8_t* binary_cache = m_info->binary_cache()->Data(); | |||
parse_info_cache( | |||
cache, cache_length, m_is_fast_run_cache, binary_cache, | |||
binary_cache_length); | |||
} else { | |||
LITE_THROW("opencl binary cache is not given"); | |||
} | |||
} | |||
} | |||
return true; | |||
} | |||
@@ -60,6 +60,8 @@ private: | |||
std::string m_model_decryption_name; | |||
//! the function name to parse the model info | |||
std::string m_info_parse_func_name; | |||
std::string m_info_cache_parse_func_name; | |||
bool m_is_fast_run_cache; | |||
//! if a model is not added json info to the model is not crypted, the | |||
//! model is a bare model | |||
bool m_is_bare_model = true; | |||
@@ -5,10 +5,14 @@ table ModelHeader { | |||
info_decryption_method:string; | |||
info_parse_method:string; | |||
model_decryption_method:string; | |||
info_cache_parse_method:string; | |||
is_fast_run_cache:bool; | |||
} | |||
table ModelInfo { | |||
data:[ubyte]; | |||
algo_policy:[ubyte]; | |||
binary_cache:[ubyte]; | |||
} | |||
table ModelData { | |||
@@ -970,6 +970,25 @@ TEST(TestNetWork, LoadPackedModel) { | |||
network->wait(); | |||
} | |||
TEST(TestNetWork, LoadPackedCacheModel) { | |||
auto tensor = get_input_data("./input_data.npy"); | |||
std::string model_path = "./test_pack_cache_to_model.lite"; | |||
std::string input_name = "data"; | |||
NetworkIO IO; | |||
Config config; | |||
std::shared_ptr<Network> network = std::make_shared<Network>(config, IO); | |||
network->load_model(model_path); | |||
std::shared_ptr<Tensor> input_tensor = network->get_io_tensor(input_name); | |||
auto src_ptr = tensor->get_memory_ptr(); | |||
auto src_layout = tensor->get_layout(); | |||
input_tensor->reset(src_ptr, src_layout); | |||
network->forward(); | |||
network->wait(); | |||
} | |||
TEST(TestNetWork, GlabalLayoutTransform) { | |||
auto tensor = get_input_data("./input_data.npy"); | |||
std::string model_path = "./shufflenet.mge"; | |||
@@ -216,6 +216,46 @@ void InFilePersistentCache::dump_cache(OutputFile* out_file) { | |||
} | |||
} | |||
} | |||
std::vector<uint8_t> InFilePersistentCache::dump_cache() { | |||
std::vector<uint8_t> ret; | |||
uint32_t nr_category = m_cache.size(); | |||
ret.resize(sizeof(nr_category)); | |||
memcpy(ret.data(), &nr_category, sizeof(nr_category)); | |||
auto write_to_buffer = [&ret](uint32_t val) { | |||
std::vector<uint8_t> vec(sizeof(val)); | |||
memcpy(vec.data(), &val, sizeof(val)); | |||
ret.insert(ret.end(), vec.begin(), vec.end()); | |||
}; | |||
for (const auto& cached_category : m_cache) { | |||
uint32_t category_size = cached_category.first.size(); | |||
write_to_buffer(category_size); | |||
std::vector<uint8_t> category( | |||
cached_category.first.begin(), cached_category.first.end()); | |||
ret.insert(ret.end(), category.begin(), category.end()); | |||
uint32_t nr_bobs = cached_category.second.size(); | |||
write_to_buffer(nr_bobs); | |||
for (const auto& item : cached_category.second) { | |||
uint32_t size_first = item.first.size; | |||
write_to_buffer(size_first); | |||
ret.insert( | |||
ret.end(), item.first.data_refhold.get(), | |||
item.first.data_refhold.get() + size_first); | |||
uint32_t size_second = item.second.size; | |||
write_to_buffer(size_second); | |||
ret.insert( | |||
ret.end(), item.second.data_refhold.get(), | |||
item.second.data_refhold.get() + size_second); | |||
} | |||
} | |||
return ret; | |||
} | |||
Maybe<InFilePersistentCache::Blob> InFilePersistentCache::get( | |||
const std::string& category, const Blob& key) { | |||
decltype(m_cache.begin()) iter0; | |||
@@ -71,6 +71,7 @@ public: | |||
*/ | |||
MGE_WIN_DECLSPEC_FUC void dump_cache(const char* path); | |||
MGE_WIN_DECLSPEC_FUC void dump_cache(OutputFile* out_file); | |||
MGE_WIN_DECLSPEC_FUC std::vector<uint8_t> dump_cache(); | |||
MGE_WIN_DECLSPEC_FUC Maybe<Blob> get( | |||
const std::string& category, const Blob& key) override; | |||