Browse Source

feat(lite): add and fix some feature for load and run fitting mode

GitOrigin-RevId: bbddc9bb79
release-1.10
Megvii Engine Team 3 years ago
parent
commit
02bfb8f8b9
5 changed files with 93 additions and 37 deletions
  1. +10
    -6
      lite/include/lite/pack_model.h
  2. +45
    -13
      lite/src/pack_model/pack_model.cpp
  3. +31
    -16
      src/core/include/megbrain/utils/json.h
  4. +6
    -1
      src/gopt/impl/inference.cpp
  5. +1
    -1
      src/opr/impl/io.cpp

+ 10
- 6
lite/include/lite/pack_model.h View File

@@ -11,7 +11,7 @@


#pragma once #pragma once
#include <string> #include <string>
#include <vector>
namespace lite { namespace lite {


struct FeatureBits32 { struct FeatureBits32 {
@@ -45,6 +45,11 @@ public:
std::string model_path, std::string packed_model_path, std::string model_path, std::string packed_model_path,
std::string info_data_path = "", std::string info_algo_policy_path = "", std::string info_data_path = "", std::string info_algo_policy_path = "",
std::string info_binary_cache_path = ""); std::string info_binary_cache_path = "");
ModelPacker(
std::vector<uint8_t> model_data, std::string packed_model_path,
std::vector<uint8_t> info_data = {},
std::vector<uint8_t> info_algo_policy_data = {},
std::vector<uint8_t> info_binary_cache_data = {});


void set_header( void set_header(
std::string model_decryption_method = "NONE", std::string model_decryption_method = "NONE",
@@ -53,13 +58,12 @@ public:
void pack_model(); void pack_model();


private: private:
std::string m_packed_model_path;
std::string m_info_data_path;
std::vector<uint8_t> m_info_data;
//! fastrun cache / algo policy //! fastrun cache / algo policy
std::string m_info_algo_policy_path;
std::vector<uint8_t> m_algo_policy_data;
//! binary cache //! binary cache
std::string m_info_binary_cache_path;
std::vector<uint8_t> m_binary_cache_data;
std::string m_packed_model_path;
Header m_header; Header m_header;


friend class FbsHelper; friend class FbsHelper;


+ 45
- 13
lite/src/pack_model/pack_model.cpp View File

@@ -25,6 +25,7 @@ class FbsHelper {
public: public:
FbsHelper() = default; FbsHelper() = default;
FbsHelper(ModelPacker* packer, std::string model_path); FbsHelper(ModelPacker* packer, std::string model_path);
FbsHelper(ModelPacker* packer, std::vector<uint8_t>& model_data);
flatbuffers::Offset<model_parse::ModelHeader> build_header(); flatbuffers::Offset<model_parse::ModelHeader> build_header();
flatbuffers::Offset<model_parse::ModelInfo> build_info(); flatbuffers::Offset<model_parse::ModelInfo> build_info();
flatbuffers::Offset<model_parse::ModelData> build_data(); flatbuffers::Offset<model_parse::ModelData> build_data();
@@ -58,6 +59,19 @@ std::vector<uint8_t> read_file(std::string path) {
fclose(fin); fclose(fin);
return buf; return buf;
} }
FbsHelper::FbsHelper(ModelPacker* packer, std::vector<uint8_t>& model_data)
: m_packer(packer), m_model_buffer(model_data) {
const char* model_ptr =
static_cast<const char*>(static_cast<void*>(m_model_buffer.data()));
std::string tag(model_ptr, 12);
if (tag == "packed_model") {
uint8_t* buffer = m_model_buffer.data() + 12;
auto model = GetPackModel(buffer)->models()->Get(0);
m_model_header = model->header();
m_model_info = model->info();
m_model_data = model->data();
}
}


FbsHelper::FbsHelper(ModelPacker* packer, std::string model_path) : m_packer(packer) { FbsHelper::FbsHelper(ModelPacker* packer, std::string model_path) : m_packer(packer) {
m_model_buffer = read_file(model_path); m_model_buffer = read_file(model_path);
@@ -118,21 +132,20 @@ flatbuffers::Offset<ModelData> FbsHelper::build_data() {


flatbuffers::Offset<ModelInfo> FbsHelper::build_info() { flatbuffers::Offset<ModelInfo> FbsHelper::build_info() {
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_data; flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_data;
if (m_model_info && m_model_info->data() && m_packer->m_info_data_path.empty()) {
if (m_model_info && m_model_info->data() && m_packer->m_info_data.empty()) {
auto data = m_model_info->data()->Data(); auto data = m_model_info->data()->Data();
auto size = m_model_info->data()->size(); auto size = m_model_info->data()->size();
fb_data = m_builder.CreateVector(data, size); fb_data = m_builder.CreateVector(data, size);
} else if (!m_packer->m_info_data_path.empty()) {
auto info_data = read_file(m_packer->m_info_data_path);
fb_data = m_builder.CreateVector(info_data);
} else if (!m_packer->m_info_data.empty()) {
fb_data = m_builder.CreateVector(m_packer->m_info_data);
} }


flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_algo_policy; flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_algo_policy;
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_binary_cache; flatbuffers::Offset<flatbuffers::Vector<uint8_t>> fb_binary_cache;
if (m_packer->m_header.fb32.is_fast_run_cache) { if (m_packer->m_header.fb32.is_fast_run_cache) {
std::vector<uint8_t> info_algo_policy; std::vector<uint8_t> info_algo_policy;
if (!m_packer->m_info_algo_policy_path.empty()) {
info_algo_policy = read_file(m_packer->m_info_algo_policy_path);
if (!m_packer->m_algo_policy_data.empty()) {
info_algo_policy = m_packer->m_algo_policy_data;
if (m_model_info && m_model_info->algo_policy()) { if (m_model_info && m_model_info->algo_policy()) {
auto cache = m_model_info->algo_policy()->Data(); auto cache = m_model_info->algo_policy()->Data();
auto size = m_model_info->algo_policy()->size(); auto size = m_model_info->algo_policy()->size();
@@ -178,11 +191,27 @@ ModelPacker::ModelPacker(
std::string model_path, std::string packed_model_path, std::string model_path, std::string packed_model_path,
std::string info_data_path, std::string info_algo_policy_path, std::string info_data_path, std::string info_algo_policy_path,
std::string info_binary_cache_path) std::string info_binary_cache_path)
: m_packed_model_path(packed_model_path),
m_info_data_path(info_data_path),
m_info_algo_policy_path(info_algo_policy_path),
m_info_binary_cache_path(info_binary_cache_path) {
: m_packed_model_path(packed_model_path) {
m_fbs_helper = new FbsHelper(this, model_path); m_fbs_helper = new FbsHelper(this, model_path);
std::vector<uint8_t> empty_vec;
m_info_data = info_data_path.empty() ? empty_vec : read_file(info_data_path);
m_algo_policy_data = info_algo_policy_path.empty()
? empty_vec
: read_file(info_algo_policy_path);
m_binary_cache_data = info_binary_cache_path.empty()
? empty_vec
: read_file(info_binary_cache_path);
}

ModelPacker::ModelPacker(
std::vector<uint8_t> model_data, std::string packed_model_path,
std::vector<uint8_t> info_data, std::vector<uint8_t> info_algo_policy_data,
std::vector<uint8_t> info_binary_cache_data) {
m_fbs_helper = new FbsHelper(this, model_data);
m_packed_model_path = packed_model_path;
m_info_data = info_data;
m_algo_policy_data = info_algo_policy_data;
m_binary_cache_data = info_binary_cache_data;
} }


void ModelPacker::set_header( void ModelPacker::set_header(
@@ -192,10 +221,10 @@ void ModelPacker::set_header(
m_header.info_decryption_method = info_decryption_method; m_header.info_decryption_method = info_decryption_method;
memset(&m_header.fb32, 0, sizeof(m_header.fb32)); memset(&m_header.fb32, 0, sizeof(m_header.fb32));
m_header.fb32.is_fast_run_cache = is_fast_run_cache; m_header.fb32.is_fast_run_cache = is_fast_run_cache;
if (!m_info_data_path.empty()) {
auto buf = read_file(m_info_data_path);
if (!m_info_data.empty()) {
std::string json_string( std::string json_string(
static_cast<const char*>(static_cast<void*>(buf.data())), buf.size());
static_cast<const char*>(static_cast<void*>(m_info_data.data())),
m_info_data.size());
auto info = nlohmann::json::parse(json_string); auto info = nlohmann::json::parse(json_string);
m_header.name = info["name"]; m_header.name = info["name"];
} }
@@ -221,6 +250,9 @@ void ModelPacker::pack_model() {
m_fbs_helper->builder().Finish(pack_model_builder.Finish()); m_fbs_helper->builder().Finish(pack_model_builder.Finish());


FILE* fptr = fopen(m_packed_model_path.c_str(), "wb"); FILE* fptr = fopen(m_packed_model_path.c_str(), "wb");
LITE_ASSERT(
fptr, "failed to open %s: %s", m_packed_model_path.c_str(),
strerror(errno));
std::string packed_model_tag = "packed_model"; std::string packed_model_tag = "packed_model";
auto nr_tag = fwrite(packed_model_tag.c_str(), 1, packed_model_tag.size(), fptr); auto nr_tag = fwrite(packed_model_tag.c_str(), 1, packed_model_tag.size(), fptr);
LITE_ASSERT(nr_tag == packed_model_tag.size()); LITE_ASSERT(nr_tag == packed_model_tag.size());


+ 31
- 16
src/core/include/megbrain/utils/json.h View File

@@ -15,7 +15,8 @@ namespace json {


class Value : public std::enable_shared_from_this<Value>, public DynTypeObj { class Value : public std::enable_shared_from_this<Value>, public DynTypeObj {
public: public:
virtual void writeto(std::string& fout, int indent = 0) const = 0;
MGE_WIN_DECLSPEC_FUC virtual void writeto(
std::string& fout, int indent = 0) const = 0;


MGE_WIN_DECLSPEC_FUC void writeto_fpath( MGE_WIN_DECLSPEC_FUC void writeto_fpath(
const std::string& fout_path, int indent = 0) const { const std::string& fout_path, int indent = 0) const {
@@ -38,11 +39,11 @@ class Number final : public Value {
public: public:
Number(double v) : m_val(v) {} Number(double v) : m_val(v) {}


static std::shared_ptr<Number> make(double v) {
MGE_WIN_DECLSPEC_FUC static std::shared_ptr<Number> make(double v) {
return std::make_shared<Number>(v); return std::make_shared<Number>(v);
} }


void writeto(std::string& fout, int indent = 0) const override;
MGE_WIN_DECLSPEC_FUC void writeto(std::string& fout, int indent = 0) const override;


auto&& get_impl() { return m_val; } auto&& get_impl() { return m_val; }


@@ -57,7 +58,7 @@ class NumberInt final : public Value {
public: public:
NumberInt(int64_t v) : m_val(v) {} NumberInt(int64_t v) : m_val(v) {}


static std::shared_ptr<NumberInt> make(int64_t v) {
MGE_WIN_DECLSPEC_FUC static std::shared_ptr<NumberInt> make(int64_t v) {
return std::make_shared<NumberInt>(v); return std::make_shared<NumberInt>(v);
} }


@@ -76,7 +77,7 @@ class Bool final : public Value {
public: public:
Bool(bool v) : m_val(v) {} Bool(bool v) : m_val(v) {}


static std::shared_ptr<Bool> make(bool v);
MGE_WIN_DECLSPEC_FUC static std::shared_ptr<Bool> make(bool v);


MGE_WIN_DECLSPEC_FUC void writeto(std::string& fout, int indent = 0) const override; MGE_WIN_DECLSPEC_FUC void writeto(std::string& fout, int indent = 0) const override;


@@ -95,11 +96,13 @@ public:


String(char const* v) : m_val(v) {} String(char const* v) : m_val(v) {}


static std::shared_ptr<String> make(const std::string& v) {
MGE_WIN_DECLSPEC_FUC static std::shared_ptr<String> make(const std::string& v) {
return std::make_shared<String>(v); return std::make_shared<String>(v);
} }


bool operator==(const String& rhs) const { return m_val == rhs.m_val; }
MGE_WIN_DECLSPEC_FUC bool operator==(const String& rhs) const {
return m_val == rhs.m_val;
}


MGE_WIN_DECLSPEC_FUC void writeto(std::string& fout, int indent = 0) const override; MGE_WIN_DECLSPEC_FUC void writeto(std::string& fout, int indent = 0) const override;


@@ -114,9 +117,11 @@ class Object final : public Value {
std::unordered_map<String, std::shared_ptr<Value>, StdHashAdaptor<String>> m_val; std::unordered_map<String, std::shared_ptr<Value>, StdHashAdaptor<String>> m_val;


public: public:
static std::shared_ptr<Object> make() { return std::make_shared<Object>(); }
MGE_WIN_DECLSPEC_FUC static std::shared_ptr<Object> make() {
return std::make_shared<Object>();
}


static std::shared_ptr<Object> make(
MGE_WIN_DECLSPEC_FUC static std::shared_ptr<Object> make(
const std::vector<std::pair<String, std::shared_ptr<Value>>>& val) { const std::vector<std::pair<String, std::shared_ptr<Value>>>& val) {
for (auto&& i : val) for (auto&& i : val)
mgb_assert(i.second); mgb_assert(i.second);
@@ -125,11 +130,17 @@ public:
return rst; return rst;
} }


std::shared_ptr<Value>& operator[](const String& s) { return m_val[s]; }
MGE_WIN_DECLSPEC_FUC std::shared_ptr<Value>& operator[](const String& s) {
return m_val[s];
}


std::shared_ptr<Value>& operator[](const std::string& s) { return m_val[s]; }
MGE_WIN_DECLSPEC_FUC std::shared_ptr<Value>& operator[](const std::string& s) {
return m_val[s];
}


std::shared_ptr<Value>& operator[](const char* s) { return m_val[std::string(s)]; }
MGE_WIN_DECLSPEC_FUC std::shared_ptr<Value>& operator[](const char* s) {
return m_val[std::string(s)];
}


MGE_WIN_DECLSPEC_FUC void writeto(std::string& fout, int indent = 0) const override; MGE_WIN_DECLSPEC_FUC void writeto(std::string& fout, int indent = 0) const override;


@@ -144,14 +155,18 @@ class Array final : public Value {
std::vector<std::shared_ptr<Value>> m_val; std::vector<std::shared_ptr<Value>> m_val;


public: public:
static std::shared_ptr<Array> make() { return std::make_shared<Array>(); }
MGE_WIN_DECLSPEC_FUC static std::shared_ptr<Array> make() {
return std::make_shared<Array>();
}


void add(std::shared_ptr<Value> val) {
MGE_WIN_DECLSPEC_FUC void add(std::shared_ptr<Value> val) {
mgb_assert(val); mgb_assert(val);
m_val.emplace_back(std::move(val)); m_val.emplace_back(std::move(val));
} }


std::shared_ptr<Value>& operator[](size_t idx) { return m_val.at(idx); }
MGE_WIN_DECLSPEC_FUC std::shared_ptr<Value>& operator[](size_t idx) {
return m_val.at(idx);
}


MGE_WIN_DECLSPEC_FUC void writeto(std::string& fout, int indent = 0) const override; MGE_WIN_DECLSPEC_FUC void writeto(std::string& fout, int indent = 0) const override;


@@ -164,7 +179,7 @@ class Null final : public Value {
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT; MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT;


public: public:
static std::shared_ptr<Value> make() {
MGE_WIN_DECLSPEC_FUC static std::shared_ptr<Value> make() {
static std::shared_ptr<Null> v(new Null); static std::shared_ptr<Null> v(new Null);
return v; return v;
} }


+ 6
- 1
src/gopt/impl/inference.cpp View File

@@ -524,9 +524,14 @@ void ParamFusePass::apply(OptState& state) const {


{ {
auto orig_level = cg->options().log_level; auto orig_level = cg->options().log_level;
auto orig_record_level = cg->options().comp_node_seq_record_level;
cg->options().comp_node_seq_record_level = 0;
cg->options().log_level = 0; cg->options().log_level = 0;
MGB_TRY { cg->compile({{var, cb}})->execute(); } MGB_TRY { cg->compile({{var, cb}})->execute(); }
MGB_FINALLY(cg->options().log_level = orig_level);
MGB_FINALLY({
cg->options().comp_node_seq_record_level = orig_record_level;
cg->options().log_level = orig_level;
});
} }


SymbolVar new_var; SymbolVar new_var;


+ 1
- 1
src/opr/impl/io.cpp View File

@@ -398,7 +398,7 @@ void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND& val) {
return true; return true;
}; };


if (one_elem(val.shape())) {
if (!val.empty() && one_elem(val.shape())) {
float v; float v;
static_cast_dtype(&v, val.dtype(), val.raw_ptr()); static_cast_dtype(&v, val.dtype(), val.raw_ptr());
m_summary = ssprintf("const<%.3g>", v); m_summary = ssprintf("const<%.3g>", v);


Loading…
Cancel
Save