From b9cbc10120206ba3f114afeb4812ff951d4a766d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 12 Apr 2022 17:21:23 +0800 Subject: [PATCH] feat(lite): add pack model GitOrigin-RevId: 1a150f2af346acd3139b848c3d9e61959df254f3 --- lite/include/lite/pack_model.h | 69 ++++++ lite/load_and_run/CMakeLists.txt | 2 +- lite/load_and_run/src/models/model.h | 2 + lite/load_and_run/src/models/model_lite.h | 2 + lite/load_and_run/src/models/model_mdl.h | 2 + lite/load_and_run/src/options/model_options.cpp | 87 ++++++++ lite/load_and_run/src/options/model_options.h | 45 ++++ lite/pylite/test/test_network.py | 6 + lite/src/pack_model/pack_model.cpp | 232 +++++++++++++++++++++ lite/src/parse_info/cache_parse.h | 36 ++++ lite/src/parse_model/model_parser.cpp | 74 ++++--- lite/src/parse_model/model_parser.h | 2 + lite/src/parse_model/pack_model.fbs | 4 + lite/test/test_network.cpp | 19 ++ src/core/impl/utils/infile_persistent_cache.cpp | 40 ++++ .../megbrain/utils/infile_persistent_cache.h | 1 + 16 files changed, 598 insertions(+), 25 deletions(-) create mode 100644 lite/include/lite/pack_model.h create mode 100644 lite/load_and_run/src/options/model_options.cpp create mode 100644 lite/load_and_run/src/options/model_options.h create mode 100644 lite/src/pack_model/pack_model.cpp create mode 100644 lite/src/parse_info/cache_parse.h diff --git a/lite/include/lite/pack_model.h b/lite/include/lite/pack_model.h new file mode 100644 index 00000000..14e7227b --- /dev/null +++ b/lite/include/lite/pack_model.h @@ -0,0 +1,69 @@ +/** + * \file inlude/lite/pack_model.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once +#include + +namespace lite { + +struct FeatureBits32 { + uint32_t is_fast_run_cache : 1; + //! reserved for new fields + uint32_t : 31; +}; + +struct Header { + std::string name; //! model name + std::string + model_decryption_method; //! model encryption method name, this is used to + //! find the right decryption method. [ + //! AES_default | RC4_default | + //! SIMPLE_FAST_RC4_default ], default is NONE. + std::string info_decryption_method; //! info data encryption method name, this is + //! used to find the right decryption method. [ + //! AES_default | RC4_default | + //! SIMPLE_FAST_RC4_default ], default is NONE. + std::string info_parse_method = "LITE_default"; //! info parse method name. + std::string info_cache_parse_method = + "LITE_parse_cache"; //! fastrun cache parse method name. + FeatureBits32 fb32; +}; + +class FbsHelper; + +class ModelPacker { +public: + ModelPacker( + std::string model_path, std::string packed_model_path, + std::string info_data_path = "", std::string info_algo_policy_path = "", + std::string info_binary_cache_path = ""); + + void set_header( + std::string model_decryption_method = "NONE", + std::string info_decryption_method = "NONE", bool is_fast_run_cache = true); + + void pack_model(); + +private: + std::string m_packed_model_path; + std::string m_info_data_path; + //! fastrun cache / algo policy + std::string m_info_algo_policy_path; + //! binary cache + std::string m_info_binary_cache_path; + + Header m_header; + + friend class FbsHelper; + FbsHelper* m_fbs_helper; +}; + +} // namespace lite \ No newline at end of file diff --git a/lite/load_and_run/CMakeLists.txt b/lite/load_and_run/CMakeLists.txt index c655c209..87d8c9c2 100644 --- a/lite/load_and_run/CMakeLists.txt +++ b/lite/load_and_run/CMakeLists.txt @@ -1,7 +1,7 @@ # BUILD the load and run for lite include_directories(PUBLIC $) -file(GLOB_RECURSE SOURCES ./*.cpp) +file(GLOB_RECURSE SOURCES ./*.cpp ${PROJECT_SOURCE_DIR}/lite/src/pack_model/*.cpp) add_executable(load_and_run ${SOURCES}) target_link_libraries(load_and_run lite_static) diff --git a/lite/load_and_run/src/models/model.h b/lite/load_and_run/src/models/model.h index 240574f2..586c241f 100644 --- a/lite/load_and_run/src/models/model.h +++ b/lite/load_and_run/src/models/model.h @@ -43,6 +43,8 @@ public: virtual void wait() = 0; virtual ~ModelBase() = default; + + virtual const std::string& get_model_path() const = 0; }; } // namespace lar diff --git a/lite/load_and_run/src/models/model_lite.h b/lite/load_and_run/src/models/model_lite.h index 2e687e20..ed296a49 100644 --- a/lite/load_and_run/src/models/model_lite.h +++ b/lite/load_and_run/src/models/model_lite.h @@ -60,6 +60,8 @@ public: //! get algo strategy Strategy& get_lite_strategy() { return m_strategy; } + const std::string& get_model_path() const override { return model_path; } + private: bool share_model_mem; bool enable_layout_transform; diff --git a/lite/load_and_run/src/models/model_mdl.h b/lite/load_and_run/src/models/model_mdl.h index 07211e46..aab8deba 100644 --- a/lite/load_and_run/src/models/model_mdl.h +++ b/lite/load_and_run/src/models/model_mdl.h @@ -107,6 +107,8 @@ public: std::move(out_file), m_format.val()); } + const std::string& get_model_path() const override { return model_path; } + private: bool share_model_mem; std::string model_path; diff --git a/lite/load_and_run/src/options/model_options.cpp b/lite/load_and_run/src/options/model_options.cpp new file mode 100644 index 00000000..a04fe33e --- /dev/null +++ b/lite/load_and_run/src/options/model_options.cpp @@ -0,0 +1,87 @@ +/** + * \file lite/load_and_run/src/options/model_options.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#include "model_options.h" +#include "device_options.h" +#include "lite/pack_model.h" +#include "misc.h" +#include "models/model_lite.h" +#include "models/model_mdl.h" + +namespace lar { +template +void PackModelOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { + lite::ModelPacker packer( + model->get_model_path(), packed_model_dump, pack_info_json, pack_cache, + pack_binary_cache); + packer.set_header(pack_info_cryption, pack_model_cryption, is_fast_run_cache); + packer.pack_model(); + } +} +} // namespace lar +using namespace lar; +////////////////////// PackModel options //////////////////////// + +PackModelOption::PackModelOption() { + m_option_name = "pack_model"; + if (!FLAGS_packed_model_dump.empty()) + packed_model_dump = FLAGS_packed_model_dump; + if (!FLAGS_pack_info_json.empty()) + pack_info_json = FLAGS_pack_info_json; + if (!FLAGS_pack_cache.empty()) + pack_cache = FLAGS_pack_cache; + if (!FLAGS_pack_info_cryption.empty()) + pack_info_cryption = FLAGS_pack_info_cryption; + if (!FLAGS_pack_model_cryption.empty()) + pack_model_cryption = FLAGS_pack_model_cryption; +} + +bool PackModelOption::is_valid() { + return !FLAGS_packed_model_dump.empty(); +} + +std::shared_ptr PackModelOption::create_option() { + static std::shared_ptr option(new PackModelOption); + if (PackModelOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void PackModelOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} + +////////////////////// PackModel gflags //////////////////////// + +DEFINE_string(packed_model_dump, "", "The output file path of packed model."); +DEFINE_string( + pack_info_json, "", + "An encrypted or not encrypted json format file to pack into the model."); +DEFINE_string(pack_cache, "", "Pack the fastrun cache or algo policy into the model."); +DEFINE_string( + pack_info_cryption, "NONE", + "The info data encryption method name, this is used to find the right " + "decryption method. --pack-info-cryption [ AES_default | RC4_default | " + "SIMPLE_FAST_RC4_default ], default is NONE. See " + "https://megengine.megvii-inc.com/user-guide/deployment/lite/advance/" + "pack-lite-model.html for more details."); +DEFINE_string( + pack_model_cryption, "NONE", + "The model encryption method name, this is used to find the right decryption " + "method. --pack-model-cryption [ AES_default | RC4_default | " + "SIMPLE_FAST_RC4_default ], default is NONE. See " + "https://megengine.megvii-inc.com/user-guide/deployment/lite/advance/" + "pack-lite-model.html for more details."); + +REGIST_OPTION_CREATOR(pack_model, lar::PackModelOption::create_option); diff --git a/lite/load_and_run/src/options/model_options.h b/lite/load_and_run/src/options/model_options.h new file mode 100644 index 00000000..3f8218e6 --- /dev/null +++ b/lite/load_and_run/src/options/model_options.h @@ -0,0 +1,45 @@ +/** + * \file lite/load_and_run/src/options/model_options.h + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#pragma once +#include +#include "models/model.h" +#include "option_base.h" + +DECLARE_string(packed_model_dump); +DECLARE_string(pack_info_json); +DECLARE_string(pack_cache); +DECLARE_string(pack_info_cryption); +DECLARE_string(pack_model_cryption); + +namespace lar { +class PackModelOption : public OptionBase { +public: + static bool is_valid(); + static std::shared_ptr create_option(); + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + std::string option_name() const override { return m_option_name; } + +private: + PackModelOption(); + + template + void config_model_internel(RuntimeParam&, std::shared_ptr); + + std::string m_option_name; + std::string packed_model_dump; + std::string pack_info_json; + std::string pack_cache; + std::string pack_binary_cache; + std::string pack_info_cryption; + std::string pack_model_cryption; + bool is_fast_run_cache = true; +}; +} // namespace lar diff --git a/lite/pylite/test/test_network.py b/lite/pylite/test/test_network.py index ee502c20..ad1b73b7 100644 --- a/lite/pylite/test/test_network.py +++ b/lite/pylite/test/test_network.py @@ -119,6 +119,12 @@ class TestNetwork(TestShuffleNet): network.load(model_path) self.do_forward(network) + def test_pack_cache_to_model(self): + model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite") + network = LiteNetwork() + network.load(model_path) + self.do_forward(network) + def test_network_basic(self): network = LiteNetwork() network.load(self.model_path) diff --git a/lite/src/pack_model/pack_model.cpp b/lite/src/pack_model/pack_model.cpp new file mode 100644 index 00000000..d4594ea1 --- /dev/null +++ b/lite/src/pack_model/pack_model.cpp @@ -0,0 +1,232 @@ +/** + * \file src/pack_model/pack_model.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "lite/pack_model.h" +#include "../misc.h" +#if LITE_BUILD_WITH_MGE +#include "megbrain/utils/infile_persistent_cache.h" +#endif + +#include +#include "nlohmann/json.hpp" +#include "pack_model_generated.h" + +namespace lite { + +class FbsHelper { +public: + FbsHelper() = default; + FbsHelper(ModelPacker* packer, std::string model_path); + flatbuffers::Offset build_header(); + flatbuffers::Offset build_info(); + flatbuffers::Offset build_data(); + flatbuffers::FlatBufferBuilder& builder() { return m_builder; } + +private: + ModelPacker* m_packer; + flatbuffers::FlatBufferBuilder m_builder; + std::vector m_model_buffer; + + const model_parse::ModelHeader* m_model_header = nullptr; + const model_parse::ModelInfo* m_model_info = nullptr; + const model_parse::ModelData* m_model_data = nullptr; +}; + +} // namespace lite + +using namespace lite; +using namespace model_parse; + +std::vector read_file(std::string path) { + FILE* fin = fopen(path.c_str(), "rb"); + LITE_ASSERT(fin, "failed to open %s: %s", path.c_str(), strerror(errno)); + fseek(fin, 0, SEEK_END); + size_t size = ftell(fin); + fseek(fin, 0, SEEK_SET); + std::vector buf; + buf.resize(size); + auto nr = fread(buf.data(), size, 1, fin); + LITE_ASSERT(nr == 1); + fclose(fin); + return buf; +} + +FbsHelper::FbsHelper(ModelPacker* packer, std::string model_path) : m_packer(packer) { + m_model_buffer = read_file(model_path); + + const char* model_ptr = + static_cast(static_cast(m_model_buffer.data())); + std::string tag(model_ptr, 12); + if (tag == "packed_model") { + uint8_t* buffer = m_model_buffer.data() + 12; + auto model = GetPackModel(buffer)->models()->Get(0); + m_model_header = model->header(); + m_model_info = model->info(); + m_model_data = model->data(); + } +} + +flatbuffers::Offset FbsHelper::build_header() { + flatbuffers::Offset name, info_decryption_method, + info_parse_method, model_decryption_method, info_cache_parse_method; + bool is_fast_run_cache; + if (m_model_header) { + auto&& header = m_model_header; + name = m_builder.CreateSharedString(header->name()); + info_decryption_method = + m_builder.CreateSharedString(header->info_decryption_method()); + info_parse_method = m_builder.CreateSharedString(header->info_parse_method()); + model_decryption_method = + m_builder.CreateSharedString(header->model_decryption_method()); + info_cache_parse_method = + m_builder.CreateSharedString(header->info_cache_parse_method()); + is_fast_run_cache = header->is_fast_run_cache(); + } else { + auto&& header = m_packer->m_header; + name = m_builder.CreateSharedString(header.name); + info_decryption_method = + m_builder.CreateSharedString(header.info_decryption_method); + info_parse_method = m_builder.CreateSharedString(header.info_parse_method); + model_decryption_method = + m_builder.CreateSharedString(header.model_decryption_method); + info_cache_parse_method = + m_builder.CreateSharedString(header.info_cache_parse_method); + is_fast_run_cache = header.fb32.is_fast_run_cache; + } + return CreateModelHeader( + m_builder, name, info_decryption_method, info_parse_method, + model_decryption_method, info_cache_parse_method, is_fast_run_cache); +} + +flatbuffers::Offset FbsHelper::build_data() { + if (m_model_data) { + auto data = m_model_data->data()->Data(); + auto size = m_model_data->data()->size(); + return CreateModelData(m_builder, m_builder.CreateVector(data, size)); + } else { + return CreateModelData(m_builder, m_builder.CreateVector(m_model_buffer)); + } +} + +flatbuffers::Offset FbsHelper::build_info() { + flatbuffers::Offset> fb_data; + if (m_model_info && m_model_info->data() && m_packer->m_info_data_path.empty()) { + auto data = m_model_info->data()->Data(); + auto size = m_model_info->data()->size(); + fb_data = m_builder.CreateVector(data, size); + } else if (!m_packer->m_info_data_path.empty()) { + auto info_data = read_file(m_packer->m_info_data_path); + fb_data = m_builder.CreateVector(info_data); + } + + flatbuffers::Offset> fb_algo_policy; + flatbuffers::Offset> fb_binary_cache; + if (m_packer->m_header.fb32.is_fast_run_cache) { + std::vector info_algo_policy; + if (!m_packer->m_info_algo_policy_path.empty()) { + info_algo_policy = read_file(m_packer->m_info_algo_policy_path); + if (m_model_info && m_model_info->algo_policy()) { + auto cache = m_model_info->algo_policy()->Data(); + auto size = m_model_info->algo_policy()->size(); + + uint32_t nr_category_1, nr_category_2, nr_category; + memcpy(&nr_category_1, cache, sizeof(uint32_t)); + memcpy(&nr_category_2, info_algo_policy.data(), sizeof(uint32_t)); + nr_category = nr_category_1 + nr_category_2; + + std::vector cache_append; + cache_append.resize(sizeof(nr_category)); + memcpy(cache_append.data(), &nr_category, sizeof(nr_category)); + cache_append.insert( + cache_append.end(), cache + sizeof(nr_category), cache + size); + cache_append.insert( + cache_append.end(), + info_algo_policy.begin() + sizeof(nr_category), + info_algo_policy.end()); + + fb_algo_policy = m_builder.CreateVector(cache_append); + } else { + fb_algo_policy = m_builder.CreateVector(info_algo_policy); + } + } +#if LITE_BUILD_WITH_MGE + else { + info_algo_policy = static_cast( + mgb::PersistentCache::inst()) + .dump_cache(); + fb_algo_policy = m_builder.CreateVector(info_algo_policy); + } +#endif + } + + ModelInfoBuilder builder(m_builder); + builder.add_data(fb_data); + builder.add_algo_policy(fb_algo_policy); + builder.add_binary_cache(fb_binary_cache); + return builder.Finish(); +} + +ModelPacker::ModelPacker( + std::string model_path, std::string packed_model_path, + std::string info_data_path, std::string info_algo_policy_path, + std::string info_binary_cache_path) + : m_packed_model_path(packed_model_path), + m_info_data_path(info_data_path), + m_info_algo_policy_path(info_algo_policy_path), + m_info_binary_cache_path(info_binary_cache_path) { + m_fbs_helper = new FbsHelper(this, model_path); +} + +void ModelPacker::set_header( + std::string model_decryption_method, std::string info_decryption_method, + bool is_fast_run_cache) { + m_header.model_decryption_method = model_decryption_method; + m_header.info_decryption_method = info_decryption_method; + memset(&m_header.fb32, 0, sizeof(m_header.fb32)); + m_header.fb32.is_fast_run_cache = is_fast_run_cache; + if (!m_info_data_path.empty()) { + auto buf = read_file(m_info_data_path); + std::string json_string( + static_cast(static_cast(buf.data())), buf.size()); + auto info = nlohmann::json::parse(json_string); + m_header.name = info["name"]; + } +} + +void ModelPacker::pack_model() { + auto fb_header = m_fbs_helper->build_header(); + auto fb_info = m_fbs_helper->build_info(); + auto fb_data = m_fbs_helper->build_data(); + + ModelBuilder model_builder(m_fbs_helper->builder()); + model_builder.add_header(fb_header); + model_builder.add_info(fb_info); + model_builder.add_data(fb_data); + + auto model = model_builder.Finish(); + std::vector> models; + models.emplace_back(model); + auto fb_models = m_fbs_helper->builder().CreateVector(models); + + PackModelBuilder pack_model_builder(m_fbs_helper->builder()); + pack_model_builder.add_models(fb_models); + m_fbs_helper->builder().Finish(pack_model_builder.Finish()); + + FILE* fptr = fopen(m_packed_model_path.c_str(), "wb"); + std::string packed_model_tag = "packed_model"; + auto nr_tag = fwrite(packed_model_tag.c_str(), 1, packed_model_tag.size(), fptr); + LITE_ASSERT(nr_tag == packed_model_tag.size()); + + auto fb_size = m_fbs_helper->builder().GetSize(); + auto nr_fb = fwrite(m_fbs_helper->builder().GetBufferPointer(), 1, fb_size, fptr); + LITE_ASSERT(nr_fb == fb_size); + fclose(fptr); +} \ No newline at end of file diff --git a/lite/src/parse_info/cache_parse.h b/lite/src/parse_info/cache_parse.h new file mode 100644 index 00000000..46f2804e --- /dev/null +++ b/lite/src/parse_info/cache_parse.h @@ -0,0 +1,36 @@ +/** + * \file src/parse_info/cache_parse.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once +#include "lite/global.h" +#if LITE_BUILD_WITH_MGE +#include "megbrain/utils/infile_persistent_cache.h" +#endif + +namespace lite { +//! The LITE_parse_cache parse info function +bool parse_info_cache( + const uint8_t* cache, size_t cache_length, bool is_fast_run_cache = true, + const uint8_t* binary_cache = nullptr, size_t binary_cache_length = 0) { + LITE_MARK_USED_VAR(binary_cache); + LITE_MARK_USED_VAR(binary_cache_length); +#if LITE_BUILD_WITH_MGE + if (is_fast_run_cache) { + mgb::PersistentCache::set_impl( + std::make_shared(cache, cache_length)); + } +#endif + return true; +} + +} // namespace lite + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/lite/src/parse_model/model_parser.cpp b/lite/src/parse_model/model_parser.cpp index 6bedf078..15fd91ab 100644 --- a/lite/src/parse_model/model_parser.cpp +++ b/lite/src/parse_model/model_parser.cpp @@ -11,6 +11,7 @@ #include "model_parser.h" #include "decryption/decrypt_base.h" +#include "parse_info/cache_parse.h" #include "parse_info/parse_info_base.h" using namespace lite; @@ -41,6 +42,10 @@ void ModelParser::parse_header() { m_model_decryption_name = model->header()->model_decryption_method()->c_str(); m_info_decryption_name = model->header()->info_decryption_method()->c_str(); m_info_parse_func_name = model->header()->info_parse_method()->c_str(); + if (model->header()->info_cache_parse_method()) + m_info_cache_parse_func_name = + model->header()->info_cache_parse_method()->c_str(); + m_is_fast_run_cache = model->header()->is_fast_run_cache(); m_info = model->info(); m_model_data = model->data(); @@ -54,31 +59,52 @@ bool ModelParser::parse_model_info( if (m_is_bare_model || !m_info) { return false; } - size_t info_length = m_info->data()->size(); - const uint8_t* info_data = m_info->data()->Data(); - //! decryption the info - auto info_ptr = - decrypt_memory(info_data, info_length, m_info_decryption_name, info_length); - //! parse the info - LITE_LOCK_GUARD(parse_info_static_data().map_mutex); - auto it_parse = - parse_info_static_data().parse_info_methods.find(m_info_parse_func_name); - if (it_parse == parse_info_static_data().parse_info_methods.end()) { - LITE_THROW(ssprintf( - "can't find model info parse function %s.", - m_info_parse_func_name.c_str())); + //! parse ModelInfo::data + if (m_info->data()) { + size_t info_length = m_info->data()->size(); + const uint8_t* info_data = m_info->data()->Data(); + //! decryption the info + auto info_ptr = decrypt_memory( + info_data, info_length, m_info_decryption_name, info_length); + //! parse the info + LITE_LOCK_GUARD(parse_info_static_data().map_mutex); + auto it_parse = parse_info_static_data().parse_info_methods.find( + m_info_parse_func_name); + if (it_parse == parse_info_static_data().parse_info_methods.end()) { + LITE_THROW(ssprintf( + "can't find model info parse function %s.", + m_info_parse_func_name.c_str())); + } + auto model_info_parse_func = + parse_info_static_data().parse_info_methods[m_info_parse_func_name]; + //! convert for NetworkIOInner to NetworkIO + if (model_info_parse_func) { + model_info_parse_func( + info_ptr.get(), info_length, m_model_name, network_config, + network_io, isolated_config_map, extra_info); + } else { + LITE_THROW(ssprintf( + "model info parse function of %s is empty", + m_info_parse_func_name.c_str())); + } } - auto model_info_parse_func = - parse_info_static_data().parse_info_methods[m_info_parse_func_name]; - //! convert for NetworkIOInner to NetworkIO - if (model_info_parse_func) { - model_info_parse_func( - info_ptr.get(), info_length, m_model_name, network_config, network_io, - isolated_config_map, extra_info); - } else { - LITE_THROW(ssprintf( - "model info parse function of %s is empty", - m_info_parse_func_name.c_str())); + //! parse ModelInfo::algo_policy + if (m_info->algo_policy()) { + size_t cache_length = m_info->algo_policy()->size(); + const uint8_t* cache = m_info->algo_policy()->Data(); + if (m_info_cache_parse_func_name == "LITE_parse_cache") { + if (m_is_fast_run_cache) { + parse_info_cache(cache, cache_length); + } else if (m_info->binary_cache()) { + size_t binary_cache_length = m_info->binary_cache()->size(); + const uint8_t* binary_cache = m_info->binary_cache()->Data(); + parse_info_cache( + cache, cache_length, m_is_fast_run_cache, binary_cache, + binary_cache_length); + } else { + LITE_THROW("opencl binary cache is not given"); + } + } } return true; } diff --git a/lite/src/parse_model/model_parser.h b/lite/src/parse_model/model_parser.h index f556c600..b7d703af 100644 --- a/lite/src/parse_model/model_parser.h +++ b/lite/src/parse_model/model_parser.h @@ -60,6 +60,8 @@ private: std::string m_model_decryption_name; //! the function name to parse the model info std::string m_info_parse_func_name; + std::string m_info_cache_parse_func_name; + bool m_is_fast_run_cache; //! if a model is not added json info to the model is not crypted, the //! model is a bare model bool m_is_bare_model = true; diff --git a/lite/src/parse_model/pack_model.fbs b/lite/src/parse_model/pack_model.fbs index d0bc442e..baa97502 100644 --- a/lite/src/parse_model/pack_model.fbs +++ b/lite/src/parse_model/pack_model.fbs @@ -5,10 +5,14 @@ table ModelHeader { info_decryption_method:string; info_parse_method:string; model_decryption_method:string; + info_cache_parse_method:string; + is_fast_run_cache:bool; } table ModelInfo { data:[ubyte]; + algo_policy:[ubyte]; + binary_cache:[ubyte]; } table ModelData { diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index f971033f..707fc19d 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -970,6 +970,25 @@ TEST(TestNetWork, LoadPackedModel) { network->wait(); } +TEST(TestNetWork, LoadPackedCacheModel) { + auto tensor = get_input_data("./input_data.npy"); + std::string model_path = "./test_pack_cache_to_model.lite"; + std::string input_name = "data"; + + NetworkIO IO; + Config config; + std::shared_ptr network = std::make_shared(config, IO); + network->load_model(model_path); + std::shared_ptr input_tensor = network->get_io_tensor(input_name); + + auto src_ptr = tensor->get_memory_ptr(); + auto src_layout = tensor->get_layout(); + input_tensor->reset(src_ptr, src_layout); + + network->forward(); + network->wait(); +} + TEST(TestNetWork, GlabalLayoutTransform) { auto tensor = get_input_data("./input_data.npy"); std::string model_path = "./shufflenet.mge"; diff --git a/src/core/impl/utils/infile_persistent_cache.cpp b/src/core/impl/utils/infile_persistent_cache.cpp index 43abea96..be96dfaf 100644 --- a/src/core/impl/utils/infile_persistent_cache.cpp +++ b/src/core/impl/utils/infile_persistent_cache.cpp @@ -216,6 +216,46 @@ void InFilePersistentCache::dump_cache(OutputFile* out_file) { } } } + +std::vector InFilePersistentCache::dump_cache() { + std::vector ret; + + uint32_t nr_category = m_cache.size(); + ret.resize(sizeof(nr_category)); + memcpy(ret.data(), &nr_category, sizeof(nr_category)); + + auto write_to_buffer = [&ret](uint32_t val) { + std::vector vec(sizeof(val)); + memcpy(vec.data(), &val, sizeof(val)); + ret.insert(ret.end(), vec.begin(), vec.end()); + }; + + for (const auto& cached_category : m_cache) { + uint32_t category_size = cached_category.first.size(); + write_to_buffer(category_size); + std::vector category( + cached_category.first.begin(), cached_category.first.end()); + ret.insert(ret.end(), category.begin(), category.end()); + + uint32_t nr_bobs = cached_category.second.size(); + write_to_buffer(nr_bobs); + for (const auto& item : cached_category.second) { + uint32_t size_first = item.first.size; + write_to_buffer(size_first); + ret.insert( + ret.end(), item.first.data_refhold.get(), + item.first.data_refhold.get() + size_first); + + uint32_t size_second = item.second.size; + write_to_buffer(size_second); + ret.insert( + ret.end(), item.second.data_refhold.get(), + item.second.data_refhold.get() + size_second); + } + } + return ret; +} + Maybe InFilePersistentCache::get( const std::string& category, const Blob& key) { decltype(m_cache.begin()) iter0; diff --git a/src/core/include/megbrain/utils/infile_persistent_cache.h b/src/core/include/megbrain/utils/infile_persistent_cache.h index 008f2549..088f766a 100644 --- a/src/core/include/megbrain/utils/infile_persistent_cache.h +++ b/src/core/include/megbrain/utils/infile_persistent_cache.h @@ -71,6 +71,7 @@ public: */ MGE_WIN_DECLSPEC_FUC void dump_cache(const char* path); MGE_WIN_DECLSPEC_FUC void dump_cache(OutputFile* out_file); + MGE_WIN_DECLSPEC_FUC std::vector dump_cache(); MGE_WIN_DECLSPEC_FUC Maybe get( const std::string& category, const Blob& key) override;