Browse Source

feat(lite): add fitting mode for load and run

GitOrigin-RevId: 8f21fda9d3
release-1.10
Megvii Engine Team 3 years ago
parent
commit
115bcbce2b
43 changed files with 2057 additions and 481 deletions
  1. +2
    -12
      lite/include/lite/pack_model.h
  2. +0
    -1
      lite/load_and_run/BUILD
  3. +117
    -9
      lite/load_and_run/src/helpers/common.h
  4. +69
    -24
      lite/load_and_run/src/helpers/json_loader.cpp
  5. +24
    -1
      lite/load_and_run/src/helpers/json_loader.h
  6. +362
    -0
      lite/load_and_run/src/helpers/utils.cpp
  7. +69
    -0
      lite/load_and_run/src/helpers/utils.h
  8. +28
    -11
      lite/load_and_run/src/main.cpp
  9. +1
    -10
      lite/load_and_run/src/models/model.cpp
  10. +7
    -10
      lite/load_and_run/src/models/model.h
  11. +73
    -8
      lite/load_and_run/src/models/model_lite.cpp
  12. +6
    -9
      lite/load_and_run/src/models/model_lite.h
  13. +73
    -9
      lite/load_and_run/src/models/model_mdl.cpp
  14. +6
    -9
      lite/load_and_run/src/models/model_mdl.h
  15. +65
    -15
      lite/load_and_run/src/options/device_options.cpp
  16. +9
    -10
      lite/load_and_run/src/options/device_options.h
  17. +0
    -9
      lite/load_and_run/src/options/extern_c_opr_options.cpp
  18. +0
    -9
      lite/load_and_run/src/options/extern_c_opr_options.h
  19. +38
    -12
      lite/load_and_run/src/options/fastrun_options.cpp
  20. +7
    -9
      lite/load_and_run/src/options/fastrun_options.h
  21. +0
    -9
      lite/load_and_run/src/options/io_options.cpp
  22. +0
    -9
      lite/load_and_run/src/options/io_options.h
  23. +73
    -32
      lite/load_and_run/src/options/layout_options.cpp
  24. +7
    -10
      lite/load_and_run/src/options/layout_options.h
  25. +70
    -10
      lite/load_and_run/src/options/layout_trans_options.cpp
  26. +6
    -9
      lite/load_and_run/src/options/layout_trans_options.h
  27. +0
    -9
      lite/load_and_run/src/options/model_options.cpp
  28. +0
    -9
      lite/load_and_run/src/options/model_options.h
  29. +85
    -19
      lite/load_and_run/src/options/optimize_options.cpp
  30. +27
    -9
      lite/load_and_run/src/options/optimize_options.h
  31. +40
    -23
      lite/load_and_run/src/options/option_base.h
  32. +0
    -9
      lite/load_and_run/src/options/plugin_options.cpp
  33. +0
    -9
      lite/load_and_run/src/options/plugin_options.h
  34. +11
    -15
      lite/load_and_run/src/options/strategy_options.cpp
  35. +3
    -9
      lite/load_and_run/src/options/strategy_options.h
  36. +3
    -11
      lite/load_and_run/src/strategys/strategy.cpp
  37. +2
    -38
      lite/load_and_run/src/strategys/strategy.h
  38. +583
    -17
      lite/load_and_run/src/strategys/strategy_fitting.cpp
  39. +152
    -0
      lite/load_and_run/src/strategys/strategy_fitting.h
  40. +15
    -24
      lite/load_and_run/src/strategys/strategy_normal.cpp
  41. +22
    -0
      lite/load_and_run/src/strategys/strategy_normal.h
  42. +2
    -13
      lite/src/pack_model/pack_model.cpp
  43. +0
    -11
      lite/src/parse_info/cache_parse.h

+ 2
- 12
lite/include/lite/pack_model.h View File

@@ -1,15 +1,5 @@
/**
* \file inlude/lite/pack_model.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once
#include <memory>
#include <string>
#include <vector>
namespace lite {
@@ -67,7 +57,7 @@ private:
Header m_header;

friend class FbsHelper;
FbsHelper* m_fbs_helper;
std::shared_ptr<FbsHelper> m_fbs_helper;
};

} // namespace lite

+ 0
- 1
lite/load_and_run/BUILD View File

@@ -5,7 +5,6 @@ cc_library(
hdrs = glob(["src/**/*.h"]),
includes = ["src"],
features = if_opt([
"no_exceptions",
"no_rtti",
]),



+ 117
- 9
lite/load_and_run/src/helpers/common.h View File

@@ -1,15 +1,7 @@
/**
* \file lite/load_and_run/src/helpers/common.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#pragma once
#include <gflags/gflags.h>
#include <memory>
#include <unordered_map>
DECLARE_int32(thread);
namespace lar {
/*!
@@ -71,6 +63,122 @@ enum class OptLayoutType {
NHWCD4 = 1 << 6,
NCHW44_DOT = 1 << 7
};
/**
* base class to story option value
*/
enum class JsonValueType {
Bool = 0,
Number,
NumberInt32,
NumberUint64,
String,

};
struct Value {
virtual JsonValueType get_type() const = 0;
virtual std::string type_string() const = 0;
virtual void reset_value() = 0;
virtual ~Value() = default;
};

/**
* class for double option
*/
struct Number final : public Value {
Number(double v) : m_val(v), m_default_val(v) {}
static std::shared_ptr<Number> make(double v) {
return std::make_shared<Number>(v);
}
void set_value(double v) { m_val = v; }
double get_value() { return m_val; }
double get_default() { return m_default_val; }
void reset_value() override { m_val = m_default_val; }
JsonValueType get_type() const override { return JsonValueType::Number; }
std::string type_string() const override { return "Number"; }

private:
double m_val;
double m_default_val;
};

/**
* class for int32_t option
*/
struct NumberInt32 final : public Value {
NumberInt32(int32_t v) : m_val(v), m_default_val(v) {}
static std::shared_ptr<NumberInt32> make(int32_t v) {
return std::make_shared<NumberInt32>(v);
}
void set_value(int32_t v) { m_val = v; }
int32_t get_value() { return m_val; }
int32_t get_default() { return m_default_val; }
void reset_value() override { m_val = m_default_val; }
JsonValueType get_type() const override { return JsonValueType::NumberInt32; }
std::string type_string() const override { return "NumberInt32"; }

private:
int32_t m_val;
int32_t m_default_val;
};
/**
* class for uint64 option
*/
struct NumberUint64 final : public Value {
NumberUint64(uint64_t v) : m_val(v), m_default_val(v) {}
static std::shared_ptr<NumberUint64> make(uint64_t v) {
return std::make_shared<NumberUint64>(v);
}
void set_value(uint64_t v) { m_val = v; }
uint64_t get_value() { return m_val; }
uint64_t get_default() { return m_default_val; }
void reset_value() override { m_val = m_default_val; }
JsonValueType get_type() const override { return JsonValueType::NumberUint64; }
std::string type_string() const override { return "NumberUint64"; }

private:
uint64_t m_val;
uint64_t m_default_val;
};

/**
* class for boolean option
*/
struct Bool final : public Value {
Bool(bool v) : m_val(v), m_default_val(v) {}
static std::shared_ptr<Bool> make(bool v) { return std::make_shared<Bool>(v); }
void set_value(bool v) { m_val = v; }
bool get_value() { return m_val; }
bool get_default() { return m_default_val; }
void reset_value() override { m_val = m_default_val; }
JsonValueType get_type() const override { return JsonValueType::Bool; }
std::string type_string() const override { return "Bool"; }

private:
bool m_val;
bool m_default_val;
};

/**
* class for string option
*/
struct String final : public Value {
String(std::string v) : m_val(v), m_default_val(v) {}
static std::shared_ptr<String> make(const std::string& v) {
return std::make_shared<String>(v);
}
void set_value(const std::string& v) { m_val = v; }
std::string& get_value() { return m_val; }
std::string get_default() { return m_default_val; }
void reset_value() override { m_val = m_default_val; }
JsonValueType get_type() const override { return JsonValueType::String; }
std::string type_string() const override { return "String"; }

private:
std::string m_val;
std::string m_default_val;
};

using OptionValMap = std::unordered_map<std::string, std::shared_ptr<lar::Value>>;

} // namespace lar
// vim: syntax=cpp.doxygen

+ 69
- 24
lite/load_and_run/src/helpers/json_loader.cpp View File

@@ -5,9 +5,7 @@ using namespace mgb;
template <typename T>
T* JsonLoader::Value::safe_cast() {
T* ptr = (T*)(this);
if (nullptr == ptr) {
fprintf(stderr, "cast ptr is null\n");
}
mgb_assert(nullptr != ptr, "cast ptr is null\n");
return ptr;
}

@@ -31,6 +29,12 @@ std::map<std::string, std::unique_ptr<JsonLoader::Value>>& JsonLoader::Value::
return t->m_obj;
}

std::vector<std::string>& JsonLoader::Value::keys() {
mgb_assert(Type::OBJECT == m_type);
auto t = safe_cast<JsonLoader::ObjectValue>();
return t->m_keys;
}

size_t JsonLoader::Value::len() {
if (Type::ARRAY == m_type) {
auto t = safe_cast<JsonLoader::ArrayValue>();
@@ -54,6 +58,12 @@ double JsonLoader::Value::number() {
return t->value();
}

bool JsonLoader::Value::Bool() {
mgb_assert(Type::BOOL == m_type);
auto t = safe_cast<JsonLoader::BoolValue>();
return t->value();
}

std::string JsonLoader::Value::str() {
if (Type::STRING == m_type) {
auto t = safe_cast<StringValue>();
@@ -69,7 +79,7 @@ void JsonLoader::expect(char c) {

void JsonLoader::skip_whitespace() {
const char* p = m_buf;
while (*p == ' ' || *p == '\t' || *p == '\n' || *p == '\r') {
while (' ' == *p || '\t' == *p || '\n' == *p || '\r' == *p) {
++p;
}
m_buf = p;
@@ -80,11 +90,12 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_object() {
skip_whitespace();

std::unique_ptr<JsonLoader::Value> ret;
JsonLoader::ObjectValue* pObject = new JsonLoader::ObjectValue();
std::unique_ptr<JsonLoader::ObjectValue> pObject =
std::make_unique<JsonLoader::ObjectValue>();

if ('}' == *m_buf) {
m_buf = m_buf + 1;
ret.reset((JsonLoader::Value*)(pObject));
ret = std::move(pObject);
return ret;
}

@@ -113,6 +124,7 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_object() {
}

pObject->m_obj.insert(std::make_pair(key->str(), std::move(pVal)));
pObject->m_keys.push_back(key->str());

skip_whitespace();
if (',' == (*m_buf)) {
@@ -126,22 +138,21 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_object() {
break;
}
}

ret.reset((JsonLoader::Value*)(pObject));
ret = std::move(pObject);
return ret;
}

std::unique_ptr<JsonLoader::Value> JsonLoader::parse_array() {
expect('[');
skip_whitespace();

std::unique_ptr<JsonLoader::Value> ret;
JsonLoader::ArrayValue* pArray = new JsonLoader::ArrayValue();
std::unique_ptr<JsonLoader::ArrayValue> pArray =
std::make_unique<JsonLoader::ArrayValue>();

if (']' == *m_buf) {
m_buf = m_buf + 1;

ret.reset((JsonLoader::Value*)(pArray));
ret = std::move(pArray);
return ret;
}

@@ -168,15 +179,14 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_array() {
}
}

ret.reset((JsonLoader::Value*)(pArray));
ret = std::move(pArray);
return ret;
}

std::unique_ptr<JsonLoader::Value> JsonLoader::parse_string() {
expect('\"');

std::unique_ptr<JsonLoader::Value> ret;
JsonLoader::StringValue* pStr = new JsonLoader::StringValue();
std::unique_ptr<JsonLoader::StringValue> pStr =
std::make_unique<JsonLoader::StringValue>();

const char* p = m_buf;
while (true) {
@@ -189,7 +199,7 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_string() {
}
}
m_buf = p;
ret.reset((JsonLoader::Value*)(pStr));
std::unique_ptr<JsonLoader::Value> ret = std::move(pStr);
return ret;
}

@@ -207,31 +217,31 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_number() {
return;
};

if (*p == '-')
if ('-' == *p)
p++;
if (*p == '0')
if ('0' == *p)
p++;
else {
loop_digit(std::ref(p));
}
if (*p == '.') {
if ('.' == *p) {
p++;
loop_digit(std::ref(p));
}

if (*p == 'e' || *p == 'E') {
if ('e' == *p || 'E' == *p) {
p++;
if (*p == '+' || *p == '-')
if ('+' == *p || '-' == *p)
p++;
loop_digit(std::ref(p));
}
JsonLoader::NumberValue* pNum = new JsonLoader::NumberValue();
std::unique_ptr<JsonLoader::NumberValue> pNum =
std::make_unique<JsonLoader::NumberValue>();
pNum->m_value = strtod(m_buf, nullptr);

m_buf = p;

std::unique_ptr<JsonLoader::Value> ret;
ret.reset((JsonLoader::Value*)(pNum));
std::unique_ptr<JsonLoader::Value> ret = std::move(pNum);
return ret;
}

@@ -243,6 +253,10 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_value() {
return parse_object();
case '\"':
return parse_string();
case 't':
return parse_bool();
case 'f':
return parse_bool();
case '\0':
m_state = State::BAD_TYPE;
break;
@@ -252,6 +266,37 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_value() {
return nullptr;
}

std::unique_ptr<JsonLoader::Value> JsonLoader::parse_bool() {
const char* p = m_buf;
std::string value;
if ('t' == *p) {
value = "";
for (size_t idx = 0; idx < 4; ++idx) {
value += *p++;
}
} else if ('f' == *p) {
value = "";
for (size_t idx = 0; idx < 5; ++idx) {
value += *p++;
}
}
bool val = false;
if ("true" == value) {
val = true;
} else if ("false" == value) {
val = false;
} else {
mgb_log_error("invalid value: %s for possible bool value", value.c_str());
}

std::unique_ptr<JsonLoader::BoolValue> pBool =
std::make_unique<JsonLoader::BoolValue>();
pBool->m_value = val;
m_buf = p;
std::unique_ptr<JsonLoader::Value> ret = std::move(pBool);
return ret;
}

std::unique_ptr<JsonLoader::Value> JsonLoader::load(
const char* content, const size_t size) {
m_buf = content;


+ 24
- 1
lite/load_and_run/src/helpers/json_loader.h View File

@@ -18,7 +18,7 @@ public:
// base class for different value format
class Value {
protected:
enum struct Type : uint8_t { UNKNOWN, NUMBER, STRING, OBJECT, ARRAY };
enum struct Type : uint8_t { UNKNOWN, NUMBER, STRING, OBJECT, ARRAY, BOOL };
Type m_type;

public:
@@ -39,12 +39,16 @@ public:

bool is_str() { return Type::STRING == m_type; }

bool is_bool() { return Type::BOOL == m_type; }

std::unique_ptr<Value>& operator[](const std::string& key);

std::unique_ptr<Value>& operator[](const size_t index);

std::map<std::string, std::unique_ptr<Value>>& objects();

std::vector<std::string>& keys();

size_t len();

megdnn::SmallVector<std::unique_ptr<Value>>& array();
@@ -52,6 +56,8 @@ public:
double number();

std::string str();

bool Bool();
};

void expect(char c);
@@ -68,6 +74,8 @@ public:

std::unique_ptr<Value> parse_value();

std::unique_ptr<Value> parse_bool();

enum struct State : uint8_t {
OK = 0,
BAD_TYPE,
@@ -137,21 +145,26 @@ public:

class ObjectValue final : public Value {
std::map<std::string, std::unique_ptr<Value>> m_obj;
std::vector<std::string> m_keys;

public:
ObjectValue() : Value(Type::OBJECT) {}

ObjectValue(ObjectValue& arr) : Value(arr) {
m_obj.clear();
m_keys.clear();
for (auto itra = arr.m_obj.begin(); itra != arr.m_obj.end(); ++itra) {
m_obj.emplace(std::make_pair(itra->first, std::move(itra->second)));
m_keys.push_back(itra->first);
}
}

ObjectValue(ObjectValue&& arr) : Value(arr) {
m_obj.clear();
m_keys.clear();
for (auto itra = arr.m_obj.begin(); itra != arr.m_obj.end(); ++itra) {
m_obj.emplace(std::make_pair(itra->first, std::move(itra->second)));
m_keys.push_back(itra->first);
}
}

@@ -160,9 +173,19 @@ public:
const std::string&);
friend std::map<std::string, std::unique_ptr<JsonLoader::Value>>& JsonLoader::
Value::objects();
friend std::vector<std::string>& JsonLoader::Value::keys();
friend size_t JsonLoader::Value::len();
};

class BoolValue final : public Value {
bool m_value;

public:
BoolValue() : Value(Type::BOOL) {}
bool value() { return m_value; }
friend std::unique_ptr<Value> JsonLoader::parse_bool();
};

private:
const char* m_buf;
State m_state;


+ 362
- 0
lite/load_and_run/src/helpers/utils.cpp View File

@@ -0,0 +1,362 @@
#include "utils.h"
using namespace lar;

/////////////////// JsonOptionsCoder ///////////////////
#if MGB_ENABLE_JSON
//! encode option

void encode_single_options(
std::pair<std::string, std::shared_ptr<lar::Value>> item,
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>&
list,
bool encode_all) {
auto type = item.second->get_type();
if (type == JsonValueType::Bool) {
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second);
if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) {
return;
}
list.push_back(
{mgb::json::String(item.first),
mgb::json::Bool::make(val_ptr->get_value())});
} else if (type == JsonValueType::NumberInt32) {
auto val_ptr = std::static_pointer_cast<lar::NumberInt32>(item.second);
if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) {
return;
}
list.push_back(
{mgb::json::String(item.first),
mgb::json::NumberInt::make(
static_cast<int64_t>(val_ptr->get_value()))});
} else if (type == JsonValueType::NumberUint64) {
auto val_ptr = std::static_pointer_cast<lar::NumberUint64>(item.second);
list.push_back(
{mgb::json::String(item.first),
mgb::json::NumberInt::make(
static_cast<int64_t>(val_ptr->get_value()))});
} else if (type == JsonValueType::Number) {
auto val_ptr = std::static_pointer_cast<lar::Number>(item.second);
list.push_back(
{mgb::json::String(item.first),
mgb::json::Number::make(val_ptr->get_value())});
} else if (type == JsonValueType::String) {
auto val_ptr = std::static_pointer_cast<lar::String>(item.second);
if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) {
return;
}
list.push_back(
{mgb::json::String(item.first),
mgb::json::String::make(val_ptr->get_value())});
} else {
mgb_log_error(
"unsupport JsonValueType:%s for lar::Value",
item.second->type_string().c_str());
}
}
std::string JsonOptionsCoder::encode(OptionValMap& option_val_map, bool encode_all) {
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
json_options;
for (auto& item : option_val_map) {
encode_single_options(item, json_options, encode_all);
}

auto json_obj = mgb::json::Object::make(
{{"options", mgb::json::Object::make(json_options)}});

return json_obj->to_string(1);
}

//! encode device
std::vector<std::shared_ptr<mgb::json::Object>> JsonOptionsCoder::encode(
OptionValMap& option_val_map) {
std::vector<std::shared_ptr<mgb::json::Object>> info;
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
json_device;
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
json_options;
for (auto& item : option_val_map) {
if ((item.first == "cpu" || item.first == "cpu_default" ||
item.first == "multithread" || item.first == "multithread_default")) {
auto type = item.second->get_type();
if (type == JsonValueType::Bool) {
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second);
if (val_ptr->get_value() == val_ptr->get_default())
continue;
}
if (type == JsonValueType::NumberInt32) {
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second);
if (val_ptr->get_value() == val_ptr->get_default())
continue;
}
json_device.push_back(
{mgb::json::String("type"), mgb::json::String::make("CPU")});

if (item.first == "cpu_default" || item.first == "multithread_default") {
json_device.push_back(
{mgb::json::String("enable_inplace_model"),
mgb::json::Bool::make(true)});
}

if (item.first == "multithread" || item.first == "multithread_default") {
json_device.push_back(
{mgb::json::String("number_threads"),
mgb::json::NumberInt::make(
std::static_pointer_cast<lar::NumberInt32>(item.second)
->get_value())});
if (item.first == "multithread") {
json_device.push_back(
{mgb::json::String("device_id"),
mgb::json::NumberInt::make(0)});
}
}

} else if (item.first == "cuda") {
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second);
if (val_ptr->get_value() == val_ptr->get_default())
continue;
json_device.push_back(
{mgb::json::String("type"), mgb::json::String::make("CUDA")});
json_device.push_back(
{mgb::json::String("device_id"), mgb::json::NumberInt::make(0)});
} else if (item.first == "opencl") {
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second);
if (val_ptr->get_value() == val_ptr->get_default())
continue;
json_device.push_back(
{mgb::json::String("type"), mgb::json::String::make("OPENCL")});
} else if (
item.first == "record_comp_seq" || item.first == "record_comp_seq2") {
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second);
if (val_ptr->get_value() == val_ptr->get_default())
continue;
int comp_node_seq_record_level = item.first == "record_comp_seq" ? 1 : 2;
json_options.push_back(
{mgb::json::String("comp_node_seq_record_level"),
mgb::json::NumberInt::make(comp_node_seq_record_level)});
} else if (item.first == "fake_first") {
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second);
if (val_ptr->get_value() == val_ptr->get_default())
continue;
json_options.push_back(
{mgb::json::String("fake_next_exec"),
mgb::json::Bool::make(val_ptr->get_value())});
} else if (item.first == "no_sanity_check") {
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second);
if (val_ptr->get_value() == val_ptr->get_default())
continue;
json_options.push_back(
{mgb::json::String("var_sanity_check_first_run"),
mgb::json::Bool::make(!val_ptr->get_value())});
} else if (item.first == "weight_preprocess") {
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second);
if (val_ptr->get_value() == val_ptr->get_default())
continue;
json_options.push_back(
{mgb::json::String("weight_preprocess"),
mgb::json::Bool::make(val_ptr->get_value())});
}
}
info.push_back(mgb::json::Object::make(
{{"options", mgb::json::Object::make(json_options)}}));
info.push_back(mgb::json::Object::make(
{{"device", mgb::json::Object::make(json_device)}}));
return info;
}
//! decode options note string into option map
OptionValMap& JsonOptionsCoder::decode(
const std::string& code, OptionValMap& option_val_map) {
std::shared_ptr<mgb::JsonLoader::Value> root =
m_json_loader.load(code.c_str(), code.size());
for (auto& item : root->objects()) {
auto& value = *item.second;
//! get all keys in json object
auto keys = value.keys();

//! set the json format options into internal options
for (auto& val : keys) {
if (value[val]->is_bool()) {
auto val_ptr = std::static_pointer_cast<lar::Bool>(option_val_map[val]);
val_ptr->set_value(value[val]->Bool());
} else if (value[val]->is_number()) {
auto type = option_val_map[val]->get_type();
if (type == JsonValueType::Number) {
auto val_ptr =
std::static_pointer_cast<lar::Number>(option_val_map[val]);
val_ptr->set_value(value[val]->number());
} else if (type == JsonValueType::NumberInt32) {
auto val_ptr = std::static_pointer_cast<lar::NumberInt32>(
option_val_map[val]);
val_ptr->set_value(static_cast<int32_t>(value[val]->number()));
} else if (type == JsonValueType::NumberUint64) {
auto val_ptr = std::static_pointer_cast<lar::NumberUint64>(
option_val_map[val]);
val_ptr->set_value(static_cast<uint64_t>(value[val]->number()));
} else {
mgb_log_error(
"invalid number type:%s to set",
option_val_map[val]->type_string().c_str());
}
} else if (value[val]->is_str()) {
auto val_ptr =
std::static_pointer_cast<lar::String>(option_val_map[val]);
val_ptr->set_value(value[val]->str());
} else {
mgb_log_error("invalid value type for JsonLoader");
}
}
}
return option_val_map;
}

#endif

std::string GflagsOptionsCoder::encode(OptionValMap& option_val_map, bool encode_all) {
std::vector<std::string> gflags_options;
for (auto& item : option_val_map) {
auto type = item.second->get_type();
std::string val = "--";
if (type == JsonValueType::Bool) {
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second);
if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) {
continue;
}
val += item.first;
val += "=";
val += val_ptr->get_value() ? "true" : "false";
gflags_options.push_back(val);
} else if (type == JsonValueType::NumberInt32) {
auto val_ptr = std::static_pointer_cast<lar::NumberInt32>(item.second);
if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) {
continue;
}
val += item.first;
val += "=";
val += std::to_string(val_ptr->get_value());
gflags_options.push_back(val);
} else if (type == JsonValueType::NumberUint64) {
auto val_ptr = std::static_pointer_cast<lar::NumberUint64>(item.second);
val += item.first;
val += "=";
val += std::to_string(val_ptr->get_value());
gflags_options.push_back(val);
} else if (type == JsonValueType::Number) {
auto val_ptr = std::static_pointer_cast<lar::Number>(item.second);
val += item.first;
val += "=";
val += std::to_string(val_ptr->get_value());
gflags_options.push_back(val);
} else if (type == JsonValueType::String) {
auto val_ptr = std::static_pointer_cast<lar::String>(item.second);
if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) {
continue;
}
val += item.first;
val += "=\"";
val += val_ptr->get_value();
val += "\"";
gflags_options.push_back(val);
} else {
mgb_log_error(
"unsupport JsonValueType:%s for lar::Value",
item.second->type_string().c_str());
}
}
std::string ret;
for (auto& item : gflags_options) {
ret += item;
ret += "\n";
}

return ret;
}

//! decode options note string into option map
OptionValMap& GflagsOptionsCoder::decode(
const std::string& code, OptionValMap& option_val_map) {
std::unordered_map<std::string, std::string> gflags_map;
auto to_raw_string = [](const std::string& str) {
auto size = str.size();
std::string ret;
if ('\"' == str[0] && '\"' == str[size - 1]) {
ret = str.substr(1, size - 2);
} else {
ret = str;
}
return ret;
};

size_t start = 0;
size_t end = code.find("\n", start);

while (end != std::string::npos) {
auto str = code.substr(start, end - start);
if (str.substr(0, 2) == "--") {
size_t idx = str.find("=", 0);
gflags_map.insert(
{str.substr(2, idx - 2), to_raw_string(str.substr(idx + 1))});
} else {
mgb_log_error("invaid gflags argument %s", str.c_str());
}
start = end + 1;
end = code.find("\n", start);
}
for (auto& item : gflags_map) {
if (option_val_map.count(item.first) != 0) {
auto& option_val = option_val_map[item.first];
auto type = option_val->get_type();
if (type == JsonValueType::Bool) {
auto val_ptr = std::static_pointer_cast<lar::Bool>(option_val);
if (item.second == "true" || item.second == "false") {
auto val = item.second == "true";
val_ptr->set_value(val);
}
} else if (type == JsonValueType::NumberInt32) {
auto val_ptr = std::static_pointer_cast<lar::NumberInt32>(option_val);
MGB_TRY {
int32_t val = std::stoi(item.second);
val_ptr->set_value(val);
}
MGB_CATCH(std::exception & exc, {
mgb_log_error(
"invaid value: %s for %s", item.second.c_str(),
item.first.c_str());
});
} else if (type == JsonValueType::NumberUint64) {
auto val_ptr = std::static_pointer_cast<lar::NumberUint64>(option_val);
MGB_TRY {
uint64_t val = std::stoull(item.second);
val_ptr->set_value(val);
}
MGB_CATCH(std::exception & exc, {
mgb_log_error(
"invaid value: %s for %s", item.second.c_str(),
item.first.c_str());
});

} else if (type == JsonValueType::Number) {
auto val_ptr = std::static_pointer_cast<lar::Number>(option_val);
MGB_TRY {
double val = std::stod(item.second);
val_ptr->set_value(val);
}
MGB_CATCH(std::exception & exc, {
mgb_log_error(
"invaid value: %s for %s", item.second.c_str(),
item.first.c_str());
});
} else if (type == JsonValueType::String) {
auto val_ptr = std::static_pointer_cast<lar::String>(option_val);
val_ptr->set_value(item.second);

} else {
mgb_log_error(
"unsupport JsonValueType:%s for lar::Value",
option_val->type_string().c_str());
}
} else {
mgb_log_error("invalid gflags when set runtime options in fitting mode");
}
}

return option_val_map;
}

+ 69
- 0
lite/load_and_run/src/helpers/utils.h View File

@@ -0,0 +1,69 @@
#pragma once

#include <vector>
#include "common.h"
#include "json_loader.h"
#include "megbrain/utils/json.h"
namespace lar {
/**
* fitting profiler type
*/
enum class ProiflerType {
TIME_PROFILER = 0,
UNSPEC_PROFILER = 1,
};
/**
* option coder type
*/
enum class CoderType {
GFLAGS = 0,
JSON = 1,
UNSPEC = 2,
};

/**
* option coder to transform internal option val into differnet form
*/
class OptionsCoder {
public:
OptionsCoder(){};
//! encode options into given format
virtual std::string encode(OptionValMap&, bool) = 0;

//! decode options with given format into option map
virtual OptionValMap& decode(const std::string&, OptionValMap& val_map) = 0;

//! destructor
virtual ~OptionsCoder() = default;
};

#if MGB_ENABLE_JSON
class JsonOptionsCoder final : public OptionsCoder {
public:
JsonOptionsCoder(){};

//! encode given options into json format
std::string encode(OptionValMap&, bool encode_all) override;

std::vector<std::shared_ptr<mgb::json::Object>> encode(OptionValMap&);

//! decode given json format options into given options map
OptionValMap& decode(const std::string&, OptionValMap&) override;

private:
mgb::JsonLoader m_json_loader;
};
#endif

class GflagsOptionsCoder final : public OptionsCoder {
public:
GflagsOptionsCoder(){};

//! encode given options into gflags format
std::string encode(OptionValMap&, bool encode_all = false) override;

//! decode given gflags format options into given options maps
OptionValMap& decode(const std::string&, OptionValMap&) override;
};

} // namespace lar

+ 28
- 11
lite/load_and_run/src/main.cpp View File

@@ -1,20 +1,37 @@
/**
* \file lite/load_and_run/src/main.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#include <gflags/gflags.h>
#include <string>
#include "strategys/strategy.h"

std::string simple_usage = R"(
load_and_run: load_and_run <model_path> [options Flags...]

Flags from lite/load_and_run/src/models/model.cpp:
-lite type: bool default: false use megengine lite interface to run model

Flags from lite/load_and_run/src/options/strategy_options.cpp:
-iter type: int32 default: 10 iteration number for run model
-thread type: int32 default: 1 thread number for run model when <thread> is supported
-warmup_iter type: int32 default: 1 iteration number for warm up model before run
Flags from com_github_gflags_gflags/src/gflags.cc:
-flagfile type: string default: "" load flags from file
-fromenv type: string default: "" set flags from the environment [use 'export FLAGS_flag1=value']
...

Flags from com_github_gflags_gflags/src/gflags_reporting.cc:
-help type: bool default: false show help on all flags
-helpmatch type: string default: "" show help on modules whose name contains the specified substr
-version type: bool default: false show version and build info and exit
...

More details using "--help" to get!!

)";

int main(int argc, char** argv) {
std::string usage = "load_and_run <model_path> [options...]";
std::string usage = "load_and_run <model_path> [options Flags...]";
if (argc < 2) {
printf("usage: %s\n", usage.c_str());
printf("usage: %s\n", simple_usage.c_str());
return -1;
}
gflags::SetUsageMessage(usage);


+ 1
- 10
lite/load_and_run/src/models/model.cpp View File

@@ -1,12 +1,3 @@

/**
* \file lite/load_and_run/src/models/model.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/
#include "model.h"
#include <iostream>
#include <memory>
@@ -56,5 +47,5 @@ std::shared_ptr<ModelBase> ModelBase::create_model(std::string model_path) {
return nullptr;
}
}
DEFINE_bool(lite, false, "using lite model to run mdl model");
DEFINE_bool(lite, false, "use megengine lite interface to run model");
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 7
- 10
lite/load_and_run/src/models/model.h View File

@@ -1,17 +1,8 @@
/**
* \file lite/load_and_run/src/models/model.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#pragma once
#include <gflags/gflags.h>
#include <string>
#include "helpers/common.h"
#include "megbrain/utils/json.h"
DECLARE_bool(lite);

namespace lar {
@@ -45,6 +36,12 @@ public:
virtual ~ModelBase() = default;

virtual const std::string& get_model_path() const = 0;

virtual std::vector<uint8_t> get_model_data() = 0;
#if MGB_ENABLE_JSON
//! get model io information
virtual std::shared_ptr<mgb::json::Object> get_io_info() = 0;
#endif
};
} // namespace lar



+ 73
- 8
lite/load_and_run/src/models/model_lite.cpp View File

@@ -1,14 +1,7 @@
/**
* \file lite/load_and_run/src/models/model_lite.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/
#include "model_lite.h"
#include <gflags/gflags.h>
#include <cstring>
#include <map>
#include "misc.h"

DECLARE_bool(share_param_mem);
@@ -51,3 +44,75 @@ void ModelLite::run_model() {
void ModelLite::wait() {
m_network->wait();
}
#if MGB_ENABLE_JSON
std::shared_ptr<mgb::json::Object> ModelLite::get_io_info() {
std::shared_ptr<mgb::json::Array> inputs = mgb::json::Array::make();
std::shared_ptr<mgb::json::Array> outputs = mgb::json::Array::make();

auto get_dtype = [&](lite::Layout& layout) {
std::map<LiteDataType, std::string> type_map = {
{LiteDataType::LITE_FLOAT, "float32"},
{LiteDataType::LITE_HALF, "float16"},
{LiteDataType::LITE_INT64, "int64"},
{LiteDataType::LITE_INT, "int32"},
{LiteDataType::LITE_UINT, "uint32"},
{LiteDataType::LITE_INT16, "int16"},
{LiteDataType::LITE_UINT16, "uint16"},
{LiteDataType::LITE_INT8, "int8"},
{LiteDataType::LITE_UINT8, "uint8"}};
return type_map[layout.data_type];
};
auto make_shape = [](lite::Layout& layout) {
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
shape;
for (size_t i = 0; i < layout.ndim; ++i) {
std::string lable = "dim";
lable += std::to_string(layout.ndim - i - 1);
shape.push_back(
{mgb::json::String(lable),
mgb::json::NumberInt::make(layout.shapes[layout.ndim - i - 1])});
}
return shape;
};
auto input_name = m_network->get_all_input_name();
for (auto& i : input_name) {
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
json_inp;
auto layout = m_network->get_io_tensor(i)->get_layout();
json_inp.push_back(
{mgb::json::String("shape"),
mgb::json::Object::make(make_shape(layout))});
json_inp.push_back(
{mgb::json::String("dtype"),
mgb::json::String::make(get_dtype(layout))});
json_inp.push_back({mgb::json::String("name"), mgb::json::String::make(i)});
inputs->add(mgb::json::Object::make(json_inp));
}

auto output_name = m_network->get_all_output_name();
for (auto& i : output_name) {
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
json_out;
auto layout = m_network->get_io_tensor(i)->get_layout();
json_out.push_back(
{mgb::json::String("shape"),
mgb::json::Object::make(make_shape(layout))});
json_out.push_back(
{mgb::json::String("dtype"),
mgb::json::String::make(get_dtype(layout))});
json_out.push_back({mgb::json::String("name"), mgb::json::String::make(i)});
inputs->add(mgb::json::Object::make(json_out));
}

return mgb::json::Object::make(
{{"IO",
mgb::json::Object::make({{"outputs", outputs}, {"inputs", inputs}})}});
}
#endif

std::vector<uint8_t> ModelLite::get_model_data() {
std::vector<uint8_t> out_data;
LITE_THROW("unsupported interface: ModelLite::get_model_data() \n");

return out_data;
}

+ 6
- 9
lite/load_and_run/src/models/model_lite.h View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/models/model_lite.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#pragma once

#include <string>
@@ -39,6 +30,10 @@ public:
//! wait the end of asynchronous function execution
void wait() override;

#if MGB_ENABLE_JSON
std::shared_ptr<mgb::json::Object> get_io_info() override;
#endif

//! enable global layout transform
void set_layout_transform(bool state) { enable_layout_transform = state; }

@@ -62,6 +57,8 @@ public:

const std::string& get_model_path() const override { return model_path; }

std::vector<uint8_t> get_model_data() override;

private:
bool share_model_mem;
bool enable_layout_transform;


+ 73
- 9
lite/load_and_run/src/models/model_mdl.cpp View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/models/model_mdl.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#include "model_mdl.h"
#include <gflags/gflags.h>
#include <iostream>
@@ -109,3 +100,76 @@ void ModelMdl::run_model() {
void ModelMdl::wait() {
m_asyc_exec->wait();
}

#if MGB_ENABLE_JSON
std::shared_ptr<mgb::json::Object> ModelMdl::get_io_info() {
std::shared_ptr<mgb::json::Array> inputs = mgb::json::Array::make();
std::shared_ptr<mgb::json::Array> outputs = mgb::json::Array::make();
auto get_dtype = [&](megdnn::DType data_type) {
std::map<megdnn::DTypeEnum, std::string> type_map = {
{mgb::dtype::Float32().enumv(), "float32"},
{mgb::dtype::Int32().enumv(), "int32"},
{mgb::dtype::Int16().enumv(), "int16"},
{mgb::dtype::Uint16().enumv(), "uint16"},
{mgb::dtype::Int8().enumv(), "int8"},
{mgb::dtype::Uint8().enumv(), "uint8"}};
return type_map[data_type.enumv()];
};
auto make_shape = [](mgb::TensorShape& shape_) {
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
shape;
for (size_t i = 0; i < shape_.ndim; ++i) {
std::string lable = "dim";
lable += std::to_string(shape_.ndim - i - 1);
shape.push_back(
{mgb::json::String(lable),
mgb::json::NumberInt::make(shape_[shape_.ndim - i - 1])});
}
return shape;
};
for (auto&& i : m_load_result.tensor_map) {
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
json_inp;
auto shape_ = i.second->shape();
json_inp.push_back(
{mgb::json::String("shape"),
mgb::json::Object::make(make_shape(shape_))});
json_inp.push_back(
{mgb::json::String("dtype"),
mgb::json::String::make(get_dtype(i.second->dtype()))});
json_inp.push_back(
{mgb::json::String("name"), mgb::json::String::make(i.first)});
inputs->add(mgb::json::Object::make(json_inp));
}

for (auto&& i : m_load_result.output_var_list) {
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
json_out;
auto shape_ = i.shape();
json_out.push_back(
{mgb::json::String("shape"),
mgb::json::Object::make(make_shape(shape_))});
json_out.push_back(
{mgb::json::String("dtype"),
mgb::json::String::make(get_dtype(i.dtype()))});

json_out.push_back(
{mgb::json::String("name"), mgb::json::String::make(i.node()->name())});
outputs->add(mgb::json::Object::make(json_out));
}
return mgb::json::Object::make(
{{"IO",
mgb::json::Object::make({{"outputs", outputs}, {"inputs", inputs}})}});
}
#endif

std::vector<uint8_t> ModelMdl::get_model_data() {
std::vector<uint8_t> out_data;
auto out_file = mgb::serialization::OutputFile::make_vector_proxy(&out_data);
using DumpConfig = mgb::serialization::GraphDumper::DumpConfig;
DumpConfig config{1, false, false};
auto dumper =
mgb::serialization::GraphDumper::make(std::move(out_file), m_format.val());
dumper->dump(m_load_result.output_var_list, config);
return out_data;
}

+ 6
- 9
lite/load_and_run/src/models/model_mdl.h View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/models/model_mdl.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#pragma once
#include <string>
#include "megbrain/opr/search_policy/algo_chooser_helper.h"
@@ -42,6 +33,10 @@ public:

void wait() override;

#if MGB_ENABLE_JSON
std::shared_ptr<mgb::json::Object> get_io_info() override;
#endif

//! get load result for megDL model
mgb::serialization::GraphLoader::LoadResult& get_mdl_load_result() {
return m_load_result;
@@ -109,6 +104,8 @@ public:

const std::string& get_model_path() const override { return model_path; }

std::vector<uint8_t> get_model_data() override;

private:
bool share_model_mem;
std::string model_path;


+ 65
- 15
lite/load_and_run/src/options/device_options.cpp View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/device_options.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#include <iostream>
#include <sstream>
#include "lite/global.h"
@@ -76,7 +67,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>(
loc.type = mgb::CompNode::DeviceType::CPU;
};
}
#if MGB_CUDA
#if LITE_WITH_CUDA
if (enable_cuda) {
mgb_log_warn("using cuda device\n");
model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) {
@@ -134,7 +125,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>(
XPUDeviceOption::XPUDeviceOption() {
m_option_name = "xpu_device";
enable_cpu = FLAGS_cpu;
#if MGB_CUDA
#if LITE_WITH_CUDA
enable_cuda = FLAGS_cuda;
#endif
enable_cpu_default = FLAGS_cpu_default;
@@ -165,18 +156,41 @@ XPUDeviceOption::XPUDeviceOption() {
"core ids number should be same with thread number set before");
enable_set_core_ids = true;
}
}

m_option = {
{"cpu", lar::Bool::make(false)},
#if LITE_WITH_CUDA
{"cuda", lar::Bool::make(false)},
#endif
{"cpu_default", lar::Bool::make(false)},
{"multithread", lar::NumberInt32::make(-1)},
{"multithread_default", lar::NumberInt32::make(-1)},
{"multi_thread_core_ids", lar::String::make("")},
};
std::static_pointer_cast<lar::Bool>(m_option["cpu"])->set_value(FLAGS_cpu);
#if LITE_WITH_CUDA
std::static_pointer_cast<lar::Bool>(m_option["cuda"])->set_value(FLAGS_cuda);
#endif
std::static_pointer_cast<lar::Bool>(m_option["cpu_default"])
->set_value(FLAGS_cpu_default);
std::static_pointer_cast<lar::NumberInt32>(m_option["multithread"])
->set_value(FLAGS_multithread);
std::static_pointer_cast<lar::NumberInt32>(m_option["multithread_default"])
->set_value(FLAGS_multithread_default);
std::static_pointer_cast<lar::String>(m_option["multi_thread_core_ids"])
->set_value(FLAGS_multi_thread_core_ids);
}
bool XPUDeviceOption::m_valid;
bool XPUDeviceOption::is_valid() {
bool ret = FLAGS_cpu || FLAGS_cpu_default;
#if MGB_CUDA
#if LITE_WITH_CUDA
ret = ret || FLAGS_cuda;
#endif
ret = ret || FLAGS_multithread >= 0;
ret = ret || FLAGS_multithread_default >= 0;
ret = ret || !FLAGS_multi_thread_core_ids.empty();

return ret;
return ret || m_valid;
}

std::shared_ptr<OptionBase> XPUDeviceOption::create_option() {
@@ -190,11 +204,46 @@ std::shared_ptr<OptionBase> XPUDeviceOption::create_option() {

void XPUDeviceOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
enable_cpu = std::static_pointer_cast<lar::Bool>(m_option["cpu"])->get_value();
#if LITE_WITH_CUDA
enable_cuda = std::static_pointer_cast<lar::Bool>(m_option["cuda"])->get_value();
#endif
enable_cpu_default =
std::static_pointer_cast<lar::Bool>(m_option["cpu_default"])->get_value();
int32_t num_of_thread =
std::static_pointer_cast<lar::NumberInt32>(m_option["multithread"])
->get_value();
enable_multithread = num_of_thread >= 0;
num_of_thread =
std::static_pointer_cast<lar::NumberInt32>(m_option["multithread_default"])
->get_value();
enable_multithread_default = num_of_thread >= 0;
thread_num = num_of_thread >= 0 ? num_of_thread : 0;
std::string core_id_str =
std::static_pointer_cast<lar::String>(m_option["multi_thread_core_ids"])
->get_value();
if (!core_id_str.empty()) {
mgb_assert(
enable_multithread || enable_multithread_default,
"core ids should be set after --multithread or --multithread-default");
std::stringstream id_stream(core_id_str);
std::string id;
size_t thread_cnt = 0;
while (getline(id_stream, id, ',')) {
thread_cnt++;
core_ids.push_back(atoi(id.c_str()));
}
mgb_assert(
thread_cnt == thread_num,
"core ids number should be same with thread number set before");
enable_set_core_ids = true;
}

CONFIG_MODEL_FUN;
}
///////////////////////// xpu gflags ////////////////////////////
DEFINE_bool(cpu, false, "set CPU device as running device");
#if MGB_CUDA || LITE_WITH_CUDA
#if LITE_WITH_CUDA
DEFINE_bool(cuda, false, "set CUDA device as running device ");
#endif
DEFINE_bool(cpu_default, false, "set running device as CPU device with inplace mode");
@@ -204,3 +253,4 @@ DEFINE_int32(
"set multithread device as running device with inplace mode");
DEFINE_string(multi_thread_core_ids, "", "set multithread core id");
REGIST_OPTION_CREATOR(xpu_device, lar::XPUDeviceOption::create_option);
REGIST_OPTION_VALIDATER(xpu_device, lar::XPUDeviceOption::set_valid);

+ 9
- 10
lite/load_and_run/src/options/device_options.h View File

@@ -1,18 +1,10 @@
/**
* \file lite/load_and_run/src/options/device_options.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/
#pragma once
#include <gflags/gflags.h>
#include "models/model.h"
#include "option_base.h"

DECLARE_bool(cpu);
#if MGB_CUDA || LITE_WITH_CUDA
#if LITE_WITH_CUDA
DECLARE_bool(cuda);
#endif
DECLARE_bool(cpu_default);
@@ -29,12 +21,16 @@ public:
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
std::string option_name() const override { return m_option_name; };

static void set_valid(bool val) { m_valid = val; }

OptionValMap* get_option() override { return &m_option; }

private:
XPUDeviceOption();
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
bool enable_cpu;
#if MGB_CUDA || LITE_WITH_CUDA
#if LITE_WITH_CUDA
bool enable_cuda;
#endif
bool enable_cpu_default;
@@ -44,5 +40,8 @@ private:
size_t thread_num;
std::vector<int> core_ids;
std::string m_option_name;

static bool m_valid;
OptionValMap m_option;
};
} // namespace lar

+ 0
- 9
lite/load_and_run/src/options/extern_c_opr_options.cpp View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/extern_c_opr_options.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#include "extern_c_opr_options.h"
#include "megbrain/utils/debug.h"
#include "misc.h"


+ 0
- 9
lite/load_and_run/src/options/extern_c_opr_options.h View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/extern_c_opr_options.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#pragma once
#include <gflags/gflags.h>
#include "megbrain/graph/extern_copr_api.h"


+ 38
- 12
lite/load_and_run/src/options/fastrun_options.cpp View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/fastrun_options.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#include <gflags/gflags.h>

#if defined(_WIN32)
@@ -153,7 +144,7 @@ void FastRunOption::config_model_internel<ModelMdl>(
} // namespace lar

using namespace lar;
bool FastRunOption::m_valid;
FastRunOption::FastRunOption() {
m_option_name = "fastrun";
#if MGB_ENABLE_FASTRUN
@@ -164,6 +155,25 @@ FastRunOption::FastRunOption() {
enable_reproducible = FLAGS_reproducible;
m_fast_run_cache = FLAGS_fast_run_algo_policy;
share_batch_size = FLAGS_fast_run_shared_batch_size;
m_option = {
#if MGB_ENABLE_FASTRUN
{"fast_run", lar::Bool::make(false)},
{"full_run", lar::Bool::make(false)},
#endif
{"binary_equal_between_batch", lar::Bool::make(false)},
{"reproducible", lar::Bool::make(false)}
};
#if MGB_ENABLE_FASTRUN
std::static_pointer_cast<lar::Bool>(m_option["fast_run"])
->set_value(FLAGS_fast_run);
std::static_pointer_cast<lar::Bool>(m_option["full_run"])
->set_value(FLAGS_full_run);
#endif
std::static_pointer_cast<lar::Bool>(m_option["binary_equal_between_batch"])
->set_value(FLAGS_binary_equal_between_batch);
std::static_pointer_cast<lar::Bool>(m_option["reproducible"])
->set_value(FLAGS_reproducible);

#if MGB_ENABLE_FASTRUN
//! while fastrun cache file path is not empty and can't be accessed
if (!m_fast_run_cache.empty() && access(m_fast_run_cache.c_str(), F_OK)) {
@@ -191,7 +201,7 @@ bool FastRunOption::is_valid() {
ret = ret || FLAGS_reproducible;
ret = ret || FLAGS_fast_run_algo_policy.size() > 0;

return ret;
return ret || m_valid;
}

std::shared_ptr<OptionBase> FastRunOption::create_option() {
@@ -205,6 +215,21 @@ std::shared_ptr<OptionBase> FastRunOption::create_option() {

void FastRunOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
#if MGB_ENABLE_FASTRUN
enable_fast_run =
std::static_pointer_cast<lar::Bool>(m_option["fast_run"])->get_value();
enable_full_run =
std::static_pointer_cast<lar::Bool>(m_option["full_run"])->get_value();
mgb_throw_if(
enable_fast_run && enable_full_run, mgb::AssertionError,
"invalid options of both fast-run and full-run");
#endif
batch_binary_equal =
std::static_pointer_cast<lar::Bool>(m_option["binary_equal_between_batch"])
->get_value();
enable_reproducible =
std::static_pointer_cast<lar::Bool>(m_option["reproducible"])->get_value();

CONFIG_MODEL_FUN;
}

@@ -228,4 +253,5 @@ DEFINE_bool(
DEFINE_uint32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun");
DEFINE_string(fast_run_algo_policy, "", "fast-run cache path.");

REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option);
REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option);
REGIST_OPTION_VALIDATER(fastrun, lar::FastRunOption::set_valid);

+ 7
- 9
lite/load_and_run/src/options/fastrun_options.h View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/fastrun_options.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#pragma once

#include <gflags/gflags.h>
@@ -38,6 +29,10 @@ public:
//! get options name for quickly search
std::string option_name() const override { return m_option_name; }

static void set_valid(bool val) { m_valid = val; }

OptionValMap* get_option() override { return &m_option; }

private:
FastRunOption();
//! config template for different model
@@ -53,5 +48,8 @@ private:
size_t share_batch_size; //! fast run strategy share batch size setting
std::string m_fast_run_cache; //! fast run cache file path
std::string m_option_name; //! option name

static bool m_valid;
OptionValMap m_option;
};
} // namespace lar

+ 0
- 9
lite/load_and_run/src/options/io_options.cpp View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/io_options.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#include <map>

#include "helpers/data_parser.h"


+ 0
- 9
lite/load_and_run/src/options/io_options.h View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/io_options.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#pragma once
#include <gflags/gflags.h>
#include "helpers/outdumper.h"


+ 73
- 32
lite/load_and_run/src/options/layout_options.cpp View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/layout_options.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#include <gflags/gflags.h>

#include "misc.h"
@@ -24,7 +15,7 @@ void LayoutOption::config_model_internel<ModelLite>(
model->get_config().options.enable_##layout = true; \
break;

switch (option_flag) {
switch (m_option_flag) {
case OptLayoutType::NCHW4:
ENABLE_LAYOUT(nchw4)

@@ -59,13 +50,12 @@ template <>
void lar::LayoutOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
mgb_log_debug("mdl layout config start");
#define ENABLE_LAYOUT(layout) \
mgb_log_warn("enable " #layout " optimization"); \
model->get_mdl_config().comp_graph->options().graph_opt.enable_##layout(); \
break;

switch (option_flag) {
switch (m_option_flag) {
case OptLayoutType::NCHW4:
ENABLE_LAYOUT(nchw4)

@@ -93,7 +83,6 @@ void lar::LayoutOption::config_model_internel<ModelMdl>(
default:
break;
}
mgb_log_debug("mdl layout config end");

#undef ENABLE_LAYOUT
}
@@ -101,48 +90,68 @@ void lar::LayoutOption::config_model_internel<ModelMdl>(
} // namespace lar

using namespace lar;

OptLayoutType LayoutOption::option_flag;

bool LayoutOption::m_valid;
LayoutOption::LayoutOption() {
m_option_name = "layout";
m_option_flag = static_cast<OptLayoutType>(0);
m_option = {
{"enable_nchw4", lar::Bool::make(false)},
{"enable_chwn4", lar::Bool::make(false)},
{"enable_nchw44", lar::Bool::make(false)},
{"enable_nchw88", lar::Bool::make(false)},
{"enable_nchw32", lar::Bool::make(false)},
{"enable_nchw64", lar::Bool::make(false)},
{"enable_nhwcd4", lar::Bool::make(false)},
{"enable_nchw44_dot", lar::Bool::make(false)},
};
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw4"])
->set_value(FLAGS_enable_nchw4);
std::static_pointer_cast<lar::Bool>(m_option["enable_chwn4"])
->set_value(FLAGS_enable_chwn4);
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44"])
->set_value(FLAGS_enable_nchw44);
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw88"])
->set_value(FLAGS_enable_nchw88);
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw32"])
->set_value(FLAGS_enable_nchw32);
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw64"])
->set_value(FLAGS_enable_nchw64);
std::static_pointer_cast<lar::Bool>(m_option["enable_nhwcd4"])
->set_value(FLAGS_enable_nhwcd4);
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44_dot"])
->set_value(FLAGS_enable_nchw44_dot);
}

bool LayoutOption::is_valid() {
size_t valid_flag = 0;
if (FLAGS_enable_nchw4) {
valid_flag = valid_flag | (1 << 0);
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW4);
}
if (FLAGS_enable_chwn4) {
valid_flag = valid_flag | (1 << 1);
valid_flag |= static_cast<size_t>(OptLayoutType::CHWN4);
}
if (FLAGS_enable_nchw44) {
valid_flag = valid_flag | (1 << 2);
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44);
}
if (FLAGS_enable_nchw88) {
valid_flag = valid_flag | (1 << 3);
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW88);
}
if (FLAGS_enable_nchw32) {
valid_flag = valid_flag | (1 << 4);
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW32);
}
if (FLAGS_enable_nchw64) {
valid_flag = valid_flag | (1 << 5);
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW64);
}
if (FLAGS_enable_nhwcd4) {
valid_flag = valid_flag | (1 << 6);
valid_flag |= static_cast<size_t>(OptLayoutType::NHWCD4);
}
if (FLAGS_enable_nchw44_dot) {
valid_flag = valid_flag | (1 << 7);
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44_DOT);
}
//! only one flag is valid
bool ret = valid_flag && !(valid_flag & (valid_flag - 1));
if (ret) {
option_flag = static_cast<OptLayoutType>(valid_flag);
} else {
option_flag = static_cast<OptLayoutType>(0);
}

return ret;
return ret | m_valid;
};

std::shared_ptr<OptionBase> LayoutOption::create_option() {
@@ -156,6 +165,37 @@ std::shared_ptr<OptionBase> LayoutOption::create_option() {

void LayoutOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
size_t valid_flag = 0;
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw4"])->get_value()) {
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW4);
}
if (std::static_pointer_cast<lar::Bool>(m_option["enable_chwn4"])->get_value()) {
valid_flag |= static_cast<size_t>(OptLayoutType::CHWN4);
}
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44"])->get_value()) {
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44);
}
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw88"])->get_value()) {
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW88);
}
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw32"])->get_value()) {
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW32);
}
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw64"])->get_value()) {
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW64);
}
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nhwcd4"])->get_value()) {
valid_flag |= static_cast<size_t>(OptLayoutType::NHWCD4);
}
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44_dot"])
->get_value()) {
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44_DOT);
}

mgb_throw_if(
valid_flag && (valid_flag & (valid_flag - 1)), mgb::AssertionError,
"invalid options of layout transform 0x%lx", valid_flag);
m_option_flag = static_cast<OptLayoutType>(valid_flag);
CONFIG_MODEL_FUN;
}

@@ -168,4 +208,5 @@ DEFINE_bool(enable_nchw64, false, "enable nchw64 layout optimization!!");
DEFINE_bool(enable_nhwcd4, false, "enable nhwcd4 layout optimization!!");
DEFINE_bool(enable_nchw44_dot, false, "enable nchw444-dot layout optimization!!");

REGIST_OPTION_CREATOR(layout, lar::LayoutOption::create_option);
REGIST_OPTION_CREATOR(layout, lar::LayoutOption::create_option);
REGIST_OPTION_VALIDATER(layout, lar::LayoutOption::set_valid);

+ 7
- 10
lite/load_and_run/src/options/layout_options.h View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/layout_options.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#pragma once

#include <gflags/gflags.h>
@@ -42,6 +33,10 @@ public:
//! get option name
std::string option_name() const override { return m_option_name; };

static void set_valid(bool val) { m_valid = val; }

OptionValMap* get_option() override { return &m_option; }

private:
//! Constructor
LayoutOption();
@@ -50,7 +45,9 @@ private:
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};

static OptLayoutType option_flag;
OptLayoutType m_option_flag;
std::string m_option_name;
static bool m_valid;
OptionValMap m_option;
};
} // namespace lar

+ 70
- 10
lite/load_and_run/src/options/layout_trans_options.cpp View File

@@ -1,11 +1,3 @@
/**
* \file lite/load_and_run/src/options/layout_trans_options.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/
#include "layout_trans_options.h"
#include <gflags/gflags.h>
#include "megbrain/serialization/serializer.h"
@@ -19,6 +11,7 @@ void GoptLayoutOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
if (m_layout_transform) {
LITE_WARN("using global layout transform optimization\n");
if (m_layout_transform_target ==
mgb::gopt::GraphTuningOptions::Target::CPU) {
model->get_config().device_type = LiteDeviceType::LITE_CPU;
@@ -48,7 +41,47 @@ void GoptLayoutOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
if (m_layout_transform) {
mgb_log_warn("using global layout transform optimization\n");
auto&& load_result = model->get_mdl_load_result();
for (auto&& item : load_result.output_var_list) {
if (item.shape()[0] > 1) {
mgb_log_warn(
" model may be dumped with multi batch and will cost lots "
"of time to profile during global layout transform!!!\n");
}
}
//! update output varlist when input shape maybe change(some pass excution
//! time depends on the shape of init input)
mgb::thin_hash_table::ThinHashMap<mgb::cg::SymbolVar, mgb::cg::SymbolVar>
varmap;
mgb::cg::DepOprIter dep([&](mgb::cg::OperatorNodeBase* opr) {
if (auto h2d = opr->try_cast_final<mgb::opr::Host2DeviceCopy>()) {
auto param = h2d->param();
mgb::TensorShape new_shape = h2d->host_data()->shape();
std::shared_ptr<mgb::HostTensorND> new_tensor =
std::make_shared<mgb::HostTensorND>(
h2d->host_data()->comp_node(), new_shape,
h2d->host_data()->dtype());
new_tensor->only_reset_raw_storage(h2d->host_data()->storage());
auto h2d_opr = mgb::opr::Host2DeviceCopy::make(
*h2d->owner_graph(), new_tensor, param, h2d->config());
varmap[h2d->output(0)] = h2d_opr;
}
});

for (auto&& i : load_result.output_var_list)
dep.add(i);

if (!varmap.empty()) {
auto output_vars =
mgb::cg::replace_vars(load_result.output_var_list, varmap);
for (size_t i = 0; i < load_result.output_var_list.size(); ++i) {
output_vars[i].rename(
load_result.output_var_list[i].node()->name());
}
load_result.output_var_list = output_vars;
}

load_result.output_var_list = mgb::gopt::layout_transform(
load_result.output_var_list, m_layout_transform_target);

@@ -98,7 +131,7 @@ void GoptLayoutOption::config_model_internel<ModelMdl>(
} // namespace lar

using namespace lar;
bool GoptLayoutOption::m_valid;
GoptLayoutOption::GoptLayoutOption() {
m_option_name = "gopt_layout";
if (FLAGS_layout_transform != "cpu"
@@ -122,6 +155,12 @@ GoptLayoutOption::GoptLayoutOption() {
#endif
}
m_layout_transform_dump_file = FLAGS_layout_transform_dump;

m_option = {
{"layout_transform", lar::String::make("")},
};
std::static_pointer_cast<lar::String>(m_option["layout_transform"])
->set_value(FLAGS_layout_transform);
}

bool GoptLayoutOption::is_valid() {
@@ -143,7 +182,7 @@ bool GoptLayoutOption::is_valid() {
}
}
ret = ret || !FLAGS_layout_transform_dump.empty();
return ret;
return ret || m_valid;
}

std::shared_ptr<OptionBase> GoptLayoutOption::create_option() {
@@ -157,6 +196,26 @@ std::shared_ptr<OptionBase> GoptLayoutOption::create_option() {

void GoptLayoutOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
auto value = std::static_pointer_cast<lar::String>(m_option["layout_transform"])
->get_value();
if (value.empty()) {
return;
}
if (value == "cpu") {
m_layout_transform = true;
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
}
#if LITE_WITH_CUDA
else if (value == "cuda") {
m_layout_transform = true;
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
}
#endif
else {
mgb_throw(
mgb::AssertionError, "invalid options of global layout transform %s",
value.c_str());
}
CONFIG_MODEL_FUN;
}

@@ -175,3 +234,4 @@ DEFINE_string(
"file path.");

REGIST_OPTION_CREATOR(gopt_layout, lar::GoptLayoutOption::create_option);
REGIST_OPTION_VALIDATER(gopt_layout, lar::GoptLayoutOption::set_valid);

+ 6
- 9
lite/load_and_run/src/options/layout_trans_options.h View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/layout_trans_options.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#pragma once

#include <gflags/gflags.h>
@@ -32,6 +23,10 @@ public:
//! get options name for quickly search
std::string option_name() const override { return m_option_name; }

static void set_valid(bool val) { m_valid = val; }

OptionValMap* get_option() override { return &m_option; }

private:
GoptLayoutOption();
//! config template for different model
@@ -41,5 +36,7 @@ private:
std::string m_option_name;
std::string m_layout_transform_dump_file;
mgb::gopt::GraphTuningOptions::Target m_layout_transform_target;
static bool m_valid;
OptionValMap m_option;
};
} // namespace lar

+ 0
- 9
lite/load_and_run/src/options/model_options.cpp View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/model_options.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#include "model_options.h"
#include "device_options.h"
#include "lite/pack_model.h"


+ 0
- 9
lite/load_and_run/src/options/model_options.h View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/model_options.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#pragma once
#include <gflags/gflags.h>
#include "megbrain/graph/operator_node.h"


+ 85
- 19
lite/load_and_run/src/options/optimize_options.cpp View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/optimize_options.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#include "megbrain/gopt/inference.h"
#if MGB_ENABLE_TENSOR_RT
#include "megbrain/tensorrt/tensorrt_engine_cache.h"
@@ -43,15 +34,18 @@ void FusePreprocessOption::config_model_internel<ModelMdl>(
}
} // namespace lar
using namespace lar;
bool FusePreprocessOption::m_valid;
FusePreprocessOption::FusePreprocessOption() {
m_option_name = "fuse_preprocess";
enable_fuse_preprocess = FLAGS_enable_fuse_preprocess;
m_option = {{"enable_fuse_preprocess", lar::Bool::make(false)}};
std::static_pointer_cast<lar::Bool>(m_option["enable_fuse_preprocess"])
->set_value(FLAGS_enable_fuse_preprocess);
}

bool FusePreprocessOption::is_valid() {
bool ret = FLAGS_enable_fuse_preprocess;
return ret;
return ret || m_valid;
}

std::shared_ptr<OptionBase> FusePreprocessOption::create_option() {
@@ -65,10 +59,14 @@ std::shared_ptr<OptionBase> FusePreprocessOption::create_option() {

void FusePreprocessOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
enable_fuse_preprocess =
std::static_pointer_cast<lar::Bool>(m_option["enable_fuse_preprocess"])
->get_value();
CONFIG_MODEL_FUN;
}

///////////////////////// weight preprocess optimize options ///////////////
bool WeightPreprocessOption::m_valid;
namespace lar {
template <>
void WeightPreprocessOption::config_model_internel<ModelLite>(
@@ -97,11 +95,14 @@ void WeightPreprocessOption::config_model_internel<ModelMdl>(
WeightPreprocessOption::WeightPreprocessOption() {
m_option_name = "weight_preprocess";
weight_preprocess = FLAGS_weight_preprocess;
m_option = {{"weight_preprocess", lar::Bool::make(false)}};
std::static_pointer_cast<lar::Bool>(m_option["weight_preprocess"])
->set_value(FLAGS_weight_preprocess);
}

bool WeightPreprocessOption::is_valid() {
bool ret = FLAGS_weight_preprocess;
return ret;
return ret || m_valid;
}

std::shared_ptr<OptionBase> WeightPreprocessOption::create_option() {
@@ -115,10 +116,14 @@ std::shared_ptr<OptionBase> WeightPreprocessOption::create_option() {

void WeightPreprocessOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
weight_preprocess =
std::static_pointer_cast<lar::Bool>(m_option["weight_preprocess"])
->get_value();
CONFIG_MODEL_FUN;
}

///// fuse conv bias and nonlinear activation opr optimize options ////////
bool FuseConvBiasNonlinearOption::m_valid;
namespace lar {
template <>
void FuseConvBiasNonlinearOption::config_model_internel<ModelLite>(
@@ -145,13 +150,16 @@ void FuseConvBiasNonlinearOption::config_model_internel<ModelMdl>(
} // namespace lar

FuseConvBiasNonlinearOption::FuseConvBiasNonlinearOption() {
m_option_name = "fuse_conv_bias_nonlinear";
m_option_name = "fuse_conv_bias_nonlinearity";
enable_fuse_conv_bias_nonlinearity = FLAGS_enable_fuse_conv_bias_nonlinearity;
m_option = {{"enable_fuse_conv_bias_nonlinearity", lar::Bool::make(false)}};
std::static_pointer_cast<lar::Bool>(m_option["enable_fuse_conv_bias_nonlinearity"])
->set_value(FLAGS_enable_fuse_conv_bias_nonlinearity);
}

bool FuseConvBiasNonlinearOption::is_valid() {
bool ret = FLAGS_enable_fuse_conv_bias_nonlinearity;
return ret;
return ret || m_valid;
}

std::shared_ptr<OptionBase> FuseConvBiasNonlinearOption::create_option() {
@@ -166,10 +174,15 @@ std::shared_ptr<OptionBase> FuseConvBiasNonlinearOption::create_option() {

void FuseConvBiasNonlinearOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
enable_fuse_conv_bias_nonlinearity =
std::static_pointer_cast<lar::Bool>(
m_option["enable_fuse_conv_bias_nonlinearity"])
->get_value();
CONFIG_MODEL_FUN;
}

///////////////////////// fuse and preprocess optimize options ///////////////
bool FuseConvBiasElemwiseAddOption::m_valid;
namespace lar {
template <>
void FuseConvBiasElemwiseAddOption::config_model_internel<ModelLite>(
@@ -198,13 +211,16 @@ void FuseConvBiasElemwiseAddOption::config_model_internel<ModelMdl>(
} // namespace lar

FuseConvBiasElemwiseAddOption::FuseConvBiasElemwiseAddOption() {
m_option_name = "fuse_conv_bias_z";
m_option_name = "fuse_conv_bias_with_z";
enable_fuse_conv_bias_with_z = FLAGS_enable_fuse_conv_bias_with_z;
m_option = {{"enable_fuse_conv_bias_with_z", lar::Bool::make(false)}};
std::static_pointer_cast<lar::Bool>(m_option["enable_fuse_conv_bias_with_z"])
->set_value(FLAGS_enable_fuse_conv_bias_with_z);
}

bool FuseConvBiasElemwiseAddOption::is_valid() {
bool ret = FLAGS_enable_fuse_conv_bias_with_z;
return ret;
return ret || m_valid;
}

std::shared_ptr<OptionBase> FuseConvBiasElemwiseAddOption::create_option() {
@@ -219,10 +235,14 @@ std::shared_ptr<OptionBase> FuseConvBiasElemwiseAddOption::create_option() {

void FuseConvBiasElemwiseAddOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
enable_fuse_conv_bias_with_z = std::static_pointer_cast<lar::Bool>(
m_option["enable_fuse_conv_bias_with_z"])
->get_value();
CONFIG_MODEL_FUN;
}

///////////////////////// graph retrict options /////////////////////////
bool GraphRecordOption::m_valid;
namespace lar {
template <>
void GraphRecordOption::config_model_internel<ModelLite>(
@@ -299,6 +319,23 @@ GraphRecordOption::GraphRecordOption() {
if (FLAGS_record_comp_seq2) {
m_record_comp_seq = 2;
}

m_option = {
{"record_comp_seq", lar::Bool::make(false)},
{"record_comp_seq2", lar::Bool::make(false)},
{"const_shape", lar::Bool::make(false)},
{"fake_first", lar::Bool::make(false)},
{"no_sanity_check", lar::Bool::make(false)}};
std::static_pointer_cast<lar::Bool>(m_option["const_shape"])
->set_value(FLAGS_const_shape);
std::static_pointer_cast<lar::Bool>(m_option["fake_first"])
->set_value(FLAGS_fake_first);
std::static_pointer_cast<lar::Bool>(m_option["no_sanity_check"])
->set_value(FLAGS_no_sanity_check);
std::static_pointer_cast<lar::Bool>(m_option["record_comp_seq"])
->set_value(FLAGS_record_comp_seq);
std::static_pointer_cast<lar::Bool>(m_option["record_comp_seq2"])
->set_value(FLAGS_record_comp_seq2);
}

bool GraphRecordOption::is_valid() {
@@ -307,7 +344,7 @@ bool GraphRecordOption::is_valid() {
ret = ret || FLAGS_no_sanity_check;
ret = ret || FLAGS_record_comp_seq;
ret = ret || FLAGS_record_comp_seq2;
return ret;
return ret || m_valid;
}

std::shared_ptr<OptionBase> GraphRecordOption::create_option() {
@@ -321,6 +358,22 @@ std::shared_ptr<OptionBase> GraphRecordOption::create_option() {

void GraphRecordOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
const_shape =
std::static_pointer_cast<lar::Bool>(m_option["const_shape"])->get_value();
fake_first =
std::static_pointer_cast<lar::Bool>(m_option["fake_first"])->get_value();
no_sanity_check = std::static_pointer_cast<lar::Bool>(m_option["no_sanity_check"])
->get_value();
m_record_comp_seq = std::static_pointer_cast<lar::Bool>(m_option["record_comp_seq"])
->get_value()
? 1
: 0;
m_record_comp_seq =
std::static_pointer_cast<lar::Bool>(m_option["record_comp_seq2"])
->get_value()
? 2
: 0;

CONFIG_MODEL_FUN;
}
///////////////////////// graph retrict options /////////////////////////
@@ -569,13 +622,26 @@ DEFINE_string(
"Set the TensorRT engine cache path for serialized prebuilt "
"ICudaEngine");
#endif

REGIST_OPTION_CREATOR(fuse_preprocess, lar::FusePreprocessOption::create_option);
REGIST_OPTION_VALIDATER(fuse_preprocess, lar::FusePreprocessOption::set_valid);

REGIST_OPTION_CREATOR(weight_preprocess, lar::WeightPreprocessOption::create_option);
REGIST_OPTION_VALIDATER(weight_preprocess, lar::WeightPreprocessOption::set_valid);

REGIST_OPTION_CREATOR(
fuse_conv_bias_nonlinear, lar::FuseConvBiasNonlinearOption::create_option);
fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::create_option);
REGIST_OPTION_VALIDATER(
fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::set_valid);

REGIST_OPTION_CREATOR(
fuse_conv_bias_z, lar::FuseConvBiasElemwiseAddOption::create_option);
fuse_conv_bias_with_z, lar::FuseConvBiasElemwiseAddOption::create_option);
REGIST_OPTION_VALIDATER(
fuse_conv_bias_with_z, lar::FuseConvBiasElemwiseAddOption::set_valid);

REGIST_OPTION_CREATOR(graph_record, lar::GraphRecordOption::create_option);
REGIST_OPTION_VALIDATER(graph_record, lar::GraphRecordOption::set_valid);

REGIST_OPTION_CREATOR(memory_optimize, lar::MemoryOptimizeOption::create_option);
REGIST_OPTION_CREATOR(JIT, lar::JITOption::create_option);
#if MGB_ENABLE_TENSOR_RT


+ 27
- 9
lite/load_and_run/src/options/optimize_options.h View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/optimize_options.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#pragma once
#include <gflags/gflags.h>
#include "helpers/common.h"
@@ -44,6 +35,10 @@ public:

std::string option_name() const override { return m_option_name; };

static void set_valid(bool val) { m_valid = val; }

OptionValMap* get_option() override { return &m_option; }

private:
FusePreprocessOption();
template <typename ModelImpl>
@@ -51,6 +46,8 @@ private:

std::string m_option_name;
bool enable_fuse_preprocess;
static bool m_valid;
OptionValMap m_option;
};

///////////////////////// weight preprocess optimize options //////////////
@@ -64,6 +61,9 @@ public:
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;

std::string option_name() const override { return m_option_name; };
static void set_valid(bool val) { m_valid = val; };

OptionValMap* get_option() override { return &m_option; }

private:
WeightPreprocessOption();
@@ -72,6 +72,8 @@ private:

std::string m_option_name;
bool weight_preprocess;
static bool m_valid;
OptionValMap m_option;
};

/////////////// fuse_conv_bias_nonlinearity optimize options ///////////////
@@ -85,6 +87,9 @@ public:
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;

std::string option_name() const override { return m_option_name; };
static void set_valid(bool val) { m_valid = val; }

OptionValMap* get_option() override { return &m_option; }

private:
FuseConvBiasNonlinearOption();
@@ -93,6 +98,8 @@ private:

std::string m_option_name;
bool enable_fuse_conv_bias_nonlinearity;
static bool m_valid;
OptionValMap m_option;
};

///////////////////////// fuse_conv_bias_with_z optimize options //////////////
@@ -106,6 +113,9 @@ public:
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;

std::string option_name() const override { return m_option_name; };
static void set_valid(bool val) { m_valid = val; }

OptionValMap* get_option() override { return &m_option; }

private:
FuseConvBiasElemwiseAddOption();
@@ -113,6 +123,8 @@ private:
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
std::string m_option_name;
bool enable_fuse_conv_bias_with_z;
static bool m_valid;
OptionValMap m_option;
};

///////////////////////// graph record options ///////////////////////////
@@ -127,6 +139,10 @@ public:

std::string option_name() const override { return m_option_name; };

static void set_valid(bool val) { m_valid = val; }

OptionValMap* get_option() override { return &m_option; }

private:
GraphRecordOption();
template <typename ModelImpl>
@@ -137,6 +153,8 @@ private:
bool const_shape;
bool fake_first;
bool no_sanity_check;
static bool m_valid;
OptionValMap m_option;
};

///////////////////////// memory optimize options /////////////////////////


+ 40
- 23
lite/load_and_run/src/options/option_base.h View File

@@ -1,22 +1,13 @@
/**
* \file lite/load_and_run/src/options/option_base.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#pragma once
#include <functional>
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "megbrain/common.h"

#include "helpers/common.h"
#include "helpers/utils.h"
#include "models/model.h"

namespace lar {
@@ -34,6 +25,9 @@ public:
//! get option name
virtual std::string option_name() const = 0;

//! get option map
virtual OptionValMap* get_option() { return nullptr; }

virtual ~OptionBase() = default;
};

@@ -43,7 +37,10 @@ public:
class OptionFactory {
public:
using OptionCreator = std::function<std::shared_ptr<OptionBase>()>;
using OptionMap = std::unordered_map<std::string, OptionCreator>;
using OptionValidater = std::function<void(bool)>;

using OptionCreatorMap = std::unordered_map<std::string, OptionCreator>;
using OptionValidaterMap = std::unordered_map<std::string, OptionValidater>;

//! get Singleton option factory
static OptionFactory& get_Instance() {
@@ -52,29 +49,49 @@ public:
}

//! registe option creator into option map
void registe_options(std::string name, OptionCreator creator) {
if (option_creator_map.count(name) == 0) {
option_creator_map[name] = creator;
void registe_options_creator(std::string name, OptionCreator creator) {
if (m_option_creator_map.count(name) == 0) {
m_option_creator_map[name] = creator;
}
}
//! registe option validater into option map
void registe_options_validater(std::string name, OptionValidater validater) {
if (m_option_validater_map.count(name) == 0) {
m_option_validater_map[name] = validater;
}
}

//! get creator map
OptionMap* get_option_creator_map() { return &option_creator_map; }
OptionCreatorMap* get_option_creator_map() { return &m_option_creator_map; }

//! get validater map
OptionValidaterMap* get_option_validater_map() { return &m_option_validater_map; }

private:
OptionFactory(){};
OptionMap option_creator_map;
OptionCreatorMap m_option_creator_map;
OptionValidaterMap m_option_validater_map;
};

} // namespace lar

#define REGIST_OPTION_CREATOR(name_, creator_) \
struct OptionRegister_##name_ { \
OptionRegister_##name_() { \
lar::OptionFactory::get_Instance().registe_options(#name_, creator_); \
} \
}; \
OptionRegister_##name_ name_;
#define REGIST_OPTION_CREATOR(_name, _creator) \
struct CreatorRegister_##_name { \
CreatorRegister_##_name() { \
lar::OptionFactory::get_Instance().registe_options_creator( \
#_name, _creator); \
} \
}; \
CreatorRegister_##_name creator_##_name;

#define REGIST_OPTION_VALIDATER(_name, _validater) \
struct ValitaterRegister_##_name { \
ValitaterRegister_##_name() { \
lar::OptionFactory::get_Instance().registe_options_validater( \
#_name, _validater); \
} \
}; \
ValitaterRegister_##_name validater_##_name;

#define CONFIG_MODEL_FUN \
if (model->type() == ModelType::LITE_MODEL) { \


+ 0
- 9
lite/load_and_run/src/options/plugin_options.cpp View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/plugin_options.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#include "plugin_options.h"
#include <map>
#include "misc.h"


+ 0
- 9
lite/load_and_run/src/options/plugin_options.h View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/plugin_options.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#pragma once
#include <gflags/gflags.h>
#if __linux__ || __unix__


+ 11
- 15
lite/load_and_run/src/options/strategy_options.cpp View File

@@ -1,24 +1,21 @@
/**
* \file lite/load_and_run/src/options/strategy_options.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#include "strategy_options.h"
#include "models/model_mdl.h"

using namespace lar;

DECLARE_bool(c_opr_lib_with_param);
DECLARE_bool(fitting);
StrategyOption::StrategyOption() {
m_option_name = "run_strategy";
warmup_iter = FLAGS_warmup_iter;
run_iter = FLAGS_iter;
threads = FLAGS_thread;
warmup_iter = FLAGS_fitting ? 3 : FLAGS_warmup_iter;
run_iter = FLAGS_fitting ? 10 : FLAGS_iter;
threads = FLAGS_fitting ? 1 : FLAGS_thread;
m_option = {
{"iter", lar::NumberInt32::make(run_iter)},
{"warmup_iter", lar::NumberInt32::make(warmup_iter)},
{"thread", lar::NumberInt32::make(threads)},

};
}

std::shared_ptr<OptionBase> StrategyOption::create_option() {
@@ -60,8 +57,7 @@ void TestcaseOption::config_model(
if (model->type() == ModelType::MEGDL_MODEL) {
auto model_ptr = std::static_pointer_cast<ModelMdl>(model);
if (model_ptr->get_testcase_num() && !FLAGS_c_opr_lib_with_param) {
if (runtime_param.stage == RunStage::MODEL_RUNNING) {
auto load_result = model_ptr->get_mdl_load_result();
if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto input_tensor = model_ptr->get_test_input();
auto loader = model_ptr->reset_loader();
auto testcase = loader->load(model_ptr->get_mdl_config(), false);


+ 3
- 9
lite/load_and_run/src/options/strategy_options.h View File

@@ -1,12 +1,3 @@
/**
* \file lite/load_and_run/src/options/strategy_options.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#include <gflags/gflags.h>
#include "models/model.h"
#include "option_base.h"
@@ -32,6 +23,8 @@ public:
//! get option name
std::string option_name() const override { return m_option_name; };

OptionValMap* get_option() override { return &m_option; }

private:
//! Constructor
StrategyOption();
@@ -43,6 +36,7 @@ private:
size_t run_iter; //! iteration number for running model
size_t threads; //! thread number for running model (NOTE:it's different
//! from multithread device )
OptionValMap m_option;
};

class TestcaseOption final : public OptionBase {


+ 3
- 11
lite/load_and_run/src/strategys/strategy.cpp View File

@@ -1,18 +1,10 @@

/**
* \file lite/load_and_run/src/strategys/strategy.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#include "strategy.h"
#include <iostream>
#include "strategy_fitting.h"
#include "strategy_normal.h"

using namespace lar;
DECLARE_bool(fitting);
std::shared_ptr<StrategyBase> StrategyBase::create_strategy(std::string model_path) {
if (FLAGS_fitting) {
return std::make_shared<FittingStrategy>(model_path);


+ 2
- 38
lite/load_and_run/src/strategys/strategy.h View File

@@ -1,23 +1,11 @@
/**
* \file lite/load_and_run/src/strategys/strategy.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#pragma once
#include <gflags/gflags.h>
#include <string>
#include <unordered_map>
#include "helpers/common.h"
#include "models/model.h"
#include "options/option_base.h"

DECLARE_bool(fitting);

namespace lar {
using OptionMap = std::unordered_map<std::string, std::shared_ptr<OptionBase>>;
/*!
* \brief: load and run strategy base class
*/
@@ -30,34 +18,10 @@ public:
virtual ~StrategyBase() = default;

RuntimeParam m_runtime_param;
std::unordered_map<std::string, std::shared_ptr<OptionBase>> m_options;
};

/*!
* \brief: normal strategy for running
*/
class NormalStrategy : public StrategyBase {
public:
NormalStrategy(std::string model_path);

//! run model with runtime parameter
void run() override;

private:
//! run model subline for multiple thread
void run_subline();

std::string m_model_path;
std::shared_ptr<OptionMap> m_options;
};

/*!
* \brief: Fitting strategy for running
*/
class FittingStrategy : public StrategyBase {
public:
FittingStrategy(std::string model_path);
void run() override;
};
} // namespace lar

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 583
- 17
lite/load_and_run/src/strategys/strategy_fitting.cpp View File

@@ -1,24 +1,590 @@
/**
* \file lite/load_and_run/src/strategys/strategy_fitting.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/

#include "strategy.h"
#include "strategy_fitting.h"
#if defined(_WIN32)
#include <io.h>
#define F_OK 0
#define access(a, b) _access(a, b)
#elif __linux__ || __unix__ || __APPLE__
#include <unistd.h>
#endif
#include <fstream>
#include <iostream>
#include <list>
#include <regex>
#include <thread>
#include "lite/pack_model.h"
#include "megbrain/common.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/exception.h"
#include "megbrain/utils/timer.h"
#include "megbrain/version.h"
#include "megdnn/version.h"
#include "misc.h"
DECLARE_bool(cpu);
using namespace lar;

FittingStrategy::FittingStrategy(std::string) {
mgb_assert("this version don't support Fitting Strategy");
// /////////////////// OptionsFastManager ///////////////////
void OptionsFastManager::init(std::shared_ptr<OptionMap>& options) {
m_option_group_cnt = 0;
m_fixed_option_cnt = 0;
m_internal_options_name = {
{"enable_fuse_conv_bias_with_z"},
{"enable_fuse_preprocess"},
{"record_comp_seq"}};
//! record the independent option value
for (auto& option : *options) {
auto option_vals = option.second->get_option();
if (option_vals) {
for (auto& item : *option_vals) {
m_valid_option_vals.insert(item);
}
}
}
};

std::string OptionsFastManager::set_next_fixed_options() {
reset_option();
auto& fixed_options_name = m_fixed_options_name[m_fixed_option_cnt];
for (auto& item : fixed_options_name) {
if (m_valid_option_vals.find(item) != m_valid_option_vals.end()) {
auto& option_val = m_valid_option_vals[item];
auto type = option_val->get_type();
if (type == JsonValueType::Bool) {
auto option_val_ptr = std::static_pointer_cast<lar::Bool>(option_val);
option_val_ptr->set_value(true);
} else if (type == JsonValueType::String && item == "layout_transform") {
auto option_val_ptr = std::static_pointer_cast<lar::String>(option_val);
//! device type
option_val_ptr->set_value(fixed_options_name[0]);
} else {
mgb_log_error(
"invalid JsonValueType:%s to set next value for fitting mode",
option_val->type_string().c_str());
}
}
}
++m_fixed_option_cnt;
std::string code = m_gflags_coder.encode(m_valid_option_vals);
return code;
}

std::string OptionsFastManager::set_next_options() {
reset_option();
auto& constraint = m_internal_options_name[m_option_group_cnt];
for (auto& item : constraint) {
if (m_valid_option_vals.find(item) != m_valid_option_vals.end()) {
auto& option_val = m_valid_option_vals[item];
auto type = option_val->get_type();
if (type == JsonValueType::Bool) {
auto option_val_ptr = std::static_pointer_cast<lar::Bool>(option_val);
option_val_ptr->set_value(true);
} else {
mgb_log_error(
"invalid JsonValueType: %s to set next value for fitting mode",
option_val->type_string().c_str());
}
}
}
++m_option_group_cnt;
std::string code = m_gflags_coder.encode(m_valid_option_vals);
return code;
}

bool OptionsFastManager::is_end_options() {
return m_option_group_cnt == m_internal_options_name.size();
}

bool OptionsFastManager::is_fixed_end() {
return m_fixed_option_cnt == m_fixed_options_name.size();
}

void OptionsFastManager::set_options(const std::string& code) {
reset_option();
#if MGB_ENABLE_JSON
const std::regex json_regex(".\\{");
#endif
const std::regex gflags_regex("--.*=.*");
if (std::regex_search(code, gflags_regex)) {
m_gflags_coder.decode(code, m_valid_option_vals);
}
#if MGB_ENABLE_JSON
else if (std::regex_search(code, json_regex)) {
m_json_coder.decode(code, m_valid_option_vals);
}
#endif
else {
mgb_log_error("invalid options code format \"%s\" to decode", code.c_str());
}
}

void OptionsFastManager::registe_fixed_options(
const std::vector<std::string>& option_name) {
m_fixed_options_name.push_back(option_name);
}

std::string OptionsFastManager::get_curr_options_code(CoderType type, bool encode_all) {
if (type == CoderType::GFLAGS) {
return m_gflags_coder.encode(m_valid_option_vals, encode_all);
}
#if MGB_ENABLE_JSON
else if (type == CoderType::JSON) {
return m_json_coder.encode(m_valid_option_vals, encode_all);
}
#endif
else {
mgb_log_error("coder should be implemented in furture");
return "";
}
}
#if MGB_ENABLE_JSON
std::vector<std::shared_ptr<mgb::json::Object>> OptionsFastManager::get_json() {
std::vector<std::shared_ptr<mgb::json::Object>> ret =
m_json_coder.encode(m_valid_option_vals);
return ret;
}
#endif

void OptionsFastManager::reset_option() {
for (auto& option : m_valid_option_vals) {
option.second->reset_value();
}
}

////////////////// OptionsTimeProfiler //////////////////

void OptionsTimeProfiler::profile_with_given_options(
const std::string& model_path, std::shared_ptr<OptionMap>& given_options,
const std::string& option_code) {
RuntimeParam runtime_param;
auto model = ModelBase::create_model(model_path);
mgb::RealTimer timer;
auto stage_config_model = [&]() {
for (auto& option : *given_options) {
option.second->config_model(runtime_param, model);
}
};

auto warm_up = [&]() {
for (size_t i = 0; i < runtime_param.warmup_iter; i++) {
auto start = timer.get_msecs();
model->run_model();
model->wait();
mgb_log_warn("warm up %ld time %f ms", i, timer.get_msecs() - start);
}
};
double inference_time = 0.0;
auto run_iter = [&]() {
for (size_t i = 0; i < runtime_param.run_iter; i++) {
auto start = timer.get_msecs();
model->run_model();
model->wait();
auto end = timer.get_msecs();
mgb_log_warn("run iter %ld time %f ms", i, end - start);
inference_time += end - start;
mgb_throw_if(
inference_time > TIME_OUT, mgb::TimeoutError,
"time out while using fitting");
}
};

//! model with testcase
size_t case_num = runtime_param.testcase_num;

bool exception_state = false;
MGB_TRY {
timer.reset();
runtime_param.stage = RunStage::BEFORE_MODEL_LOAD;
stage_config_model();

model->load_model();
//! after load configure
auto config_model_before_runing = [&]() {
for (auto stage :
{RunStage::AFTER_MODEL_LOAD, RunStage::GLOBAL_OPTIMIZATION,
RunStage::BEFORE_OUTSPEC_SET, RunStage::AFTER_OUTSPEC_SET,
RunStage::MODEL_RUNNING}) {
runtime_param.stage = stage;
stage_config_model();
}
};
timer.reset();
for (size_t idx = 0; idx < case_num; idx++) {
auto start = timer.get_msecs();
config_model_before_runing();
auto end = timer.get_msecs();
mgb_log_warn("config model time %f ms", end - start);
warm_up();
run_iter();
}
runtime_param.stage = RunStage::AFTER_MODEL_RUNNING;
stage_config_model();
}
MGB_CATCH(std::exception & exc, {
mgb_log_error("catch exception: %s", exc.what());
exception_state = true;
});

auto average = inference_time / runtime_param.run_iter;
if (exception_state) {
average = TIME_OUT;
}

//! record profile result
printf("profile option:\n%s\naverage time = %.2f\n", option_code.c_str(), average);
m_options_profile_result.insert({option_code, average});

//! record the best result

if (average < m_best_setting.second) {
m_best_setting.first = option_code;
m_best_setting.second = average;
}
}
/////////////////////////// UserInfoParser /////////////////////////////
void UserInfoParser::get_user_info() {
//! register user information tips
std::vector<std::pair<std::string, std::string>> info_tips;
m_user_info["fitting_preference"] = "Inferspeed";
info_tips.push_back(
{"use_const_shape", "whether the input shape is constant?(yes/no)?"});

for (auto& tip : info_tips) {
std::cout << tip.second;
std::string answer = "";
std::cin >> answer;
m_user_info[tip.first] = answer;
}
}
void UserInfoParser::parse_info(std::shared_ptr<OptionsFastManager>& manager) {
std::vector<std::string> fixed_options;
if (m_user_info["use_const_shape"] == "yes") {
fixed_options.push_back("const_shape");
} else if (m_user_info["use_const_shape"] != "no") {
mgb_log_error("invalid user information for \"use_const_shape\"");
}
fixed_options.push_back("enable_fuse_conv_bias_nonlinearity");
std::vector<std::string> tmp_options;

auto insert_common_cpu_options = [&]() {
tmp_options = {"cpu"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);

tmp_options = {"cpu", "weight_preprocess", "fast_run"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);

tmp_options = {"cpu", "layout_transform"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);

tmp_options = {"cpu", "layout_transform", "weight_preprocess"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);
};

#if (MEGDNN_AARCH64 || MEGDNN_ARMV7)
//! arm cpu device
insert_common_cpu_options();
tmp_options = {"cpu", "enable_nchw44"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);

tmp_options = {"cpu", "enable_nchw44", "weight_preprocess", "fast_run"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);

tmp_options = {"cpu", "enable_nchw44_dot"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);

tmp_options = {"cpu", "enable_nchw44_dot", "weight_preprocess", "fast_run"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);

#else
#if LITE_WITH_CUDA
//! build with cuda and not force to use cpu device
if (!FLAGS_cpu) {
tmp_options = {"cuda"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);

tmp_options = {"cuda", "enable_nchw4"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);

tmp_options = {"cuda", "enable_chwn4"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);

tmp_options = {"cuda", "enable_nchw64"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);

tmp_options = {"cuda", "enable_nchw32"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);

tmp_options = {"cuda", "layout_transform"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);

tmp_options = {"cuda", "layout_transform", "weight_preprocess"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);
}

#endif
#if LITE_WITH_CUDA
//! build with cuda force to use cpu
if (FLAGS_cpu) {
#endif
//!x86 cpu options
insert_common_cpu_options();
tmp_options = {"cpu", "enable_nchw88"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);

tmp_options = {"cpu", "enable_nchw88", "weight_preprocess", "fast_run"};
tmp_options.insert(
tmp_options.end(), fixed_options.begin(), fixed_options.end());
manager->registe_fixed_options(tmp_options);
#if LITE_WITH_CUDA
}
#endif
#endif

m_proifler_type = ProiflerType::TIME_PROFILER;
}

// /////////////////// FittingStrategy //////////////////////////////////
FittingStrategy::FittingStrategy(std::string model_path) {
m_manager = std::make_shared<OptionsFastManager>();
m_dumped_model = FLAGS_dump_fitting_model;
mgb::set_log_level(mgb::LogLevel::WARN);
m_options = std::make_shared<OptionMap>();
m_model_path = model_path;
auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map();
auto option_validater_map =
OptionFactory::get_Instance().get_option_validater_map();
//! validate option used in fitting
auto validate_option = [&](std::string name) -> void {
if (option_validater_map->find(name) != option_validater_map->end()) {
auto& validater = (*option_validater_map).at(name);
if (validater) {
validater(true);
}
}
};

//! construct option which is valid
auto construct_option = [&](std::string name) -> void {
auto& creator = (*option_creator_map)[name];
auto option = creator();
if (option) {
m_options->insert({name, option});
}
};

//! get all options which is valid
for (auto& creator : *option_creator_map) {
auto name = creator.first;
if (m_options->count(name) == 0) {
validate_option(name);
construct_option(name);
}
}

m_manager->init(m_options);
}

void FittingStrategy::dump_best_options_with_model() {
std::vector<uint8_t> info_algo_policy_data;
std::vector<uint8_t> info_binary_cache_data;
auto model = ModelBase::create_model(m_model_path);
RuntimeParam runtime_param;
auto stage_config_model = [&]() {
for (auto& option : *m_options) {
option.second->config_model(runtime_param, model);
}
};
runtime_param.stage = RunStage::BEFORE_MODEL_LOAD;
stage_config_model();

model->load_model();

//! get json info vector
std::string json_info_str;
#if MGB_ENABLE_JSON
std::shared_ptr<mgb::json::Object> code_json = model->get_io_info();
m_packed_info.push_back({mgb::json::String("IO"), (*code_json)["IO"]});
auto info_json = m_manager->get_json();
m_packed_info.push_back({mgb::json::String("options"), (*info_json[0])["options"]});
m_packed_info.push_back({mgb::json::String("device"), (*info_json[1])["device"]});
m_packed_info.push_back(
{mgb::json::String("backend"), mgb::json::String::make("MGE")});
int lite_major, lite_minor, lite_patch;
lite::get_version(lite_major, lite_minor, lite_patch);
std::string version = std::to_string(lite_major);
version += ".";
version += std::to_string(lite_minor) + ".";
version += std::to_string(lite_patch);
m_packed_info.push_back(
{mgb::json::String("version"), mgb::json::String::make(version)});
m_packed_info.push_back({mgb::json::String("valid"), mgb::json::Bool::make(true)});
m_packed_info.push_back(
{mgb::json::String("name"), mgb::json::String::make("packed_model")});
auto obj = mgb::json::Object::make(m_packed_info);
json_info_str = obj->to_string();
#endif
std::vector<uint8_t> json_info(json_info_str.begin(), json_info_str.end());

//! get model binary data after optimized
for (auto stage :
{RunStage::AFTER_MODEL_LOAD, RunStage::GLOBAL_OPTIMIZATION,
RunStage::BEFORE_OUTSPEC_SET, RunStage::AFTER_OUTSPEC_SET,
RunStage::MODEL_RUNNING}) {
runtime_param.stage = stage;
stage_config_model();
}
model->run_model();
model->wait();

std::vector<uint8_t> model_data = model->get_model_data();
mgb_log_warn("model_data size=%zu", model_data.size());
mgb_log_warn("json_info size=%zu", json_info.size());
mgb_log_warn("info_algo_policy_data size=%zu", info_algo_policy_data.size());
mgb_log_warn("info_binary_cache_data size=%zu", info_binary_cache_data.size());
lite::ModelPacker packer(
model_data, m_dumped_model, json_info, info_algo_policy_data,
info_binary_cache_data);
packer.set_header();
packer.pack_model();
}
///////////////////////// AutoCleanFile///////////////////////////
FittingStrategy::AutoCleanFile::AutoCleanFile(
const std::string& model_path, std::shared_ptr<OptionMap>& options)
: m_model_path(model_path), m_options(options) {
m_filename = "fitting_tmp_model";
if (!access(m_filename.c_str(), F_OK)) {
remove(m_filename.c_str());
}
}

FittingStrategy::AutoCleanFile::~AutoCleanFile() {
if (!access(m_filename.c_str(), F_OK)) {
remove(m_filename.c_str());
}
}

void FittingStrategy::AutoCleanFile::dump_model() {
auto model = ModelBase::create_model(m_model_path);
RuntimeParam runtime_param;
auto stage_config_model = [&]() {
for (auto& option : *m_options) {
option.second->config_model(runtime_param, model);
}
};
runtime_param.stage = RunStage::BEFORE_MODEL_LOAD;
stage_config_model();

model->load_model();
//! get model binary data after optimized
for (auto stage :
{RunStage::AFTER_MODEL_LOAD, RunStage::GLOBAL_OPTIMIZATION,
RunStage::BEFORE_OUTSPEC_SET, RunStage::AFTER_OUTSPEC_SET,
RunStage::MODEL_RUNNING}) {
runtime_param.stage = stage;
stage_config_model();
}
model->run_model();
model->wait();

std::vector<uint8_t> model_data = model->get_model_data();
mgb_log_warn("dumped model_data size=%zu\n", model_data.size());
auto fp = fopen(m_filename.c_str(), "wb");
fwrite(model_data.data(), 1, model_data.size(), fp);
fclose(fp);
}

void FittingStrategy::run() {
mgb_assert("this version don't support Fitting Strategy");
};
auto mgb_version = mgb::get_version();
auto dnn_version = megdnn::get_version();
printf("megbrain/lite/load_and_run:\nusing MegBrain "
"%d.%d.%d(%d) and MegDNN %d.%d.%d\n",
mgb_version.major, mgb_version.minor, mgb_version.patch, mgb_version.is_dev,
dnn_version.major, dnn_version.minor, dnn_version.patch);
// ! create profiler with given user info
m_info_parser.get_user_info();
m_info_parser.parse_info(m_manager);
auto profiler = m_info_parser.create_profiler();
mgb_throw_if(
profiler == nullptr, mgb::AssertionError,
"get empty profiler for fittting\n");
//! profile model with fixed options
while (!m_manager->is_fixed_end()) {
std::string option_str = m_manager->set_next_fixed_options();
profiler->profile_with_given_options(m_model_path, m_options, option_str);
#if (MEGDNN_AARCH64 || MEGDNN_ARMV7)
//! sleep to keep machine with stable cpu frequence
usleep(500000);
#endif
}
std::string m_tmp_model = m_model_path;
const std::regex layout_regex("layout_transform");
auto best_fixed_options = profiler->get_best_setting();
m_manager->set_options(best_fixed_options);
//! dump model for global layout transform
auto m_tmp_file = AutoCleanFile(m_model_path, m_options);
if (std::regex_search(best_fixed_options, layout_regex)) {
m_tmp_file.dump_model();
m_model_path = m_tmp_file.filename();
}
//! profile model with given profiler
while (!m_manager->is_end_options()) {
std::string curr_option_str = m_manager->set_next_options();
//! set option with current option and fixed options
if (m_model_path == m_tmp_model) {
auto total_option_str = curr_option_str + best_fixed_options;
m_manager->set_options(total_option_str);
}

curr_option_str += best_fixed_options;
profiler->profile_with_given_options(m_model_path, m_options, curr_option_str);
#if (MEGDNN_AARCH64 || MEGDNN_ARMV7)
usleep(500000);
#endif
}
//! set with best options and inference
m_model_path = m_tmp_model;
auto best_options = profiler->get_best_setting();

m_manager->set_options(best_options);
profiler->profile_with_given_options(m_model_path, m_options, best_options);

//! save best options into given dir
std::cout << "the best options:\n" << best_options << std::endl;

if (!m_dumped_model.empty()) {
dump_best_options_with_model();
}
}
DEFINE_bool(
fitting, false,
"whether to use the fitting model, which will auto profile and get "
"the best option set!");
fitting, false, "use the fitting mode profile and get the best option set.");
DEFINE_string(dump_fitting_model, "", "dump the best option and algo cache into model");

+ 152
- 0
lite/load_and_run/src/strategys/strategy_fitting.h View File

@@ -0,0 +1,152 @@
#pragma once
#include <gflags/gflags.h>
#include "helpers/utils.h"
#include "strategy.h"
DECLARE_bool(fitting);
DECLARE_string(dump_fitting_model);
#define TIME_OUT 10000
namespace lar {

class OptionsFastManager {
public:
using ConstraintMap = std::unordered_map<std::string, bool>;

OptionsFastManager(){};

//! init the options value map with given options
void init(std::shared_ptr<OptionMap>&);

//! set next options group cyclely
std::string set_next_options();
std::string set_next_fixed_options();

//! check the end of options group
bool is_end_options();
bool is_fixed_end();

std::string get_curr_options_code(CoderType, bool encode_all = false);

//! set current options with given options
void set_options(const std::string&);

void registe_fixed_options(const std::vector<std::string>&);

#if MGB_ENABLE_JSON
std::vector<std::shared_ptr<mgb::json::Object>> get_json();
#endif

private:
void reset_option();
size_t m_option_group_cnt;
size_t m_fixed_option_cnt;
OptionValMap m_valid_option_vals;
std::vector<std::vector<std::string>> m_internal_options_name;
std::vector<std::vector<std::string>> m_fixed_options_name;

#if MGB_ENABLE_JSON
JsonOptionsCoder m_json_coder;
#endif

GflagsOptionsCoder m_gflags_coder;
};

//! Options proifler to get the best settings with different evaluate standard
class OptionsProfiler {
public:
OptionsProfiler(){};

//! run with m_options
virtual void profile_with_given_options(
const std::string&, std::shared_ptr<OptionMap>&, const std::string&) = 0;

//! get the best setting and inference time
virtual std::string get_best_setting() { return ""; }

virtual ~OptionsProfiler() = default;
};

/**
* profiler to get the fast setting
*/
class OptionsTimeProfiler final : public OptionsProfiler {
public:
OptionsTimeProfiler(){};

void profile_with_given_options(
const std::string&, std::shared_ptr<OptionMap>&,
const std::string&) override;

std::string get_best_setting() override { return m_best_setting.first; }

private:
std::unordered_map<std::string, double> m_options_profile_result;
std::pair<std::string, double> m_best_setting = {"", TIME_OUT};
};

/**
* parse information from user given
*/
class UserInfoParser {
public:
UserInfoParser(){};

void get_user_info();

void parse_info(std::shared_ptr<OptionsFastManager>&);

std::shared_ptr<OptionsProfiler> create_profiler() {
switch (m_proifler_type) {
case ProiflerType::TIME_PROFILER:
return std::make_shared<OptionsTimeProfiler>();
case ProiflerType::UNSPEC_PROFILER:
return nullptr;
default:
return nullptr;
}
}

private:
ProiflerType m_proifler_type;
std::unordered_map<std::string, std::string> m_user_info;
};
/*!
* \brief: Fitting strategy for running
*/
class FittingStrategy : public StrategyBase {
public:
class AutoCleanFile {
public:
AutoCleanFile(
const std::string& model_path, std::shared_ptr<OptionMap>& options);
void dump_model();
std::string filename() { return m_filename; }
~AutoCleanFile();

private:
std::string m_model_path;
std::shared_ptr<OptionMap> m_options;
std::string m_filename;
};

FittingStrategy(std::string model_path);

void run() override;

void dump_best_options_with_model();

void dump_model();

private:
std::string m_model_path;

std::string m_dumped_model;
std::shared_ptr<OptionsFastManager> m_manager;

UserInfoParser m_info_parser;

#if MGB_ENABLE_JSON
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
m_packed_info;
#endif
};
} // namespace lar

+ 15
- 24
lite/load_and_run/src/strategys/strategy_normal.cpp View File

@@ -1,11 +1,4 @@
/**
* \file lite/load_and_run/src/strategys/strategy_normal.cpp
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/
#include "strategy_normal.h"
#include <iostream>
#include <thread>
#include "megbrain/common.h"
@@ -13,13 +6,13 @@
#include "megbrain/version.h"
#include "megdnn/version.h"
#include "misc.h"
#include "strategy.h"

using namespace lar;

NormalStrategy::NormalStrategy(std::string model_path) {
mgb::set_log_level(mgb::LogLevel::WARN);
lite::set_log_level(LiteLogLevel::WARN);
m_options = std::make_shared<OptionMap>();
m_model_path = model_path;
auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map();
mgb_log_debug("option map size: %lu", option_creator_map->size());
@@ -27,13 +20,13 @@ NormalStrategy::NormalStrategy(std::string model_path) {
auto& creator = (*option_creator_map)[name];
auto option = creator();
if (option) {
m_options.insert({name, option});
m_options->insert({name, option});
}
};

for (auto& creator : *option_creator_map) {
auto name = creator.first;
if (m_options.count(name) == 0) {
if (m_options->count(name) == 0) {
construct_option(name);
}
}
@@ -44,7 +37,7 @@ void NormalStrategy::run_subline() {
mgb_assert(model != nullptr, "create model failed!!");

auto stage_config_model = [&]() {
for (auto& option : m_options) {
for (auto& option : *m_options) {
option.second->config_model(m_runtime_param, model);
}
};
@@ -57,18 +50,14 @@ void NormalStrategy::run_subline() {
printf("load model: %.3fms\n", timer.get_msecs_reset());

//! after load configure
m_runtime_param.stage = RunStage::AFTER_MODEL_LOAD;
stage_config_model();

m_runtime_param.stage = RunStage::GLOBAL_OPTIMIZATION;
stage_config_model();

m_runtime_param.stage = RunStage::BEFORE_OUTSPEC_SET;
stage_config_model();

// for get static memmory information options
m_runtime_param.stage = RunStage::AFTER_OUTSPEC_SET;
stage_config_model();
auto config_after_load = [&]() {
for (auto stage :
{RunStage::AFTER_MODEL_LOAD, RunStage::GLOBAL_OPTIMIZATION,
RunStage::BEFORE_OUTSPEC_SET, RunStage::AFTER_OUTSPEC_SET}) {
m_runtime_param.stage = stage;
stage_config_model();
}
};

auto warm_up = [&]() {
auto warmup_num = m_runtime_param.warmup_iter;
@@ -117,6 +106,8 @@ void NormalStrategy::run_subline() {

double tot_time = 0;
for (size_t idx = 0; idx < iter_num; idx++) {
//! config model
config_after_load();
//! config when running model
mgb_log_warn("run testcase: %zu ", idx);
m_runtime_param.stage = RunStage::MODEL_RUNNING;


+ 22
- 0
lite/load_and_run/src/strategys/strategy_normal.h View File

@@ -0,0 +1,22 @@
#pragma once

#include "strategy.h"

namespace lar {
/*!
* \brief: normal strategy for running
*/
class NormalStrategy : public StrategyBase {
public:
NormalStrategy(std::string model_path);

//! run model with runtime parameter
void run() override;

private:
//! run model subline for multiple thread
void run_subline();

std::string m_model_path;
};
} // namespace lar

+ 2
- 13
lite/src/pack_model/pack_model.cpp View File

@@ -1,14 +1,3 @@
/**
* \file src/pack_model/pack_model.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "lite/pack_model.h"
#include "../misc.h"
#if LITE_BUILD_WITH_MGE
@@ -192,7 +181,7 @@ ModelPacker::ModelPacker(
std::string info_data_path, std::string info_algo_policy_path,
std::string info_binary_cache_path)
: m_packed_model_path(packed_model_path) {
m_fbs_helper = new FbsHelper(this, model_path);
m_fbs_helper = std::make_shared<FbsHelper>(this, model_path);
std::vector<uint8_t> empty_vec;
m_info_data = info_data_path.empty() ? empty_vec : read_file(info_data_path);
m_algo_policy_data = info_algo_policy_path.empty()
@@ -207,7 +196,7 @@ ModelPacker::ModelPacker(
std::vector<uint8_t> model_data, std::string packed_model_path,
std::vector<uint8_t> info_data, std::vector<uint8_t> info_algo_policy_data,
std::vector<uint8_t> info_binary_cache_data) {
m_fbs_helper = new FbsHelper(this, model_data);
m_fbs_helper = std::make_shared<FbsHelper>(this, model_data);
m_packed_model_path = packed_model_path;
m_info_data = info_data;
m_algo_policy_data = info_algo_policy_data;


+ 0
- 11
lite/src/parse_info/cache_parse.h View File

@@ -1,14 +1,3 @@
/**
* \file src/parse_info/cache_parse.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once
#include "lite/global.h"
#if LITE_BUILD_WITH_MGE


Loading…
Cancel
Save