GitOrigin-RevId: 8f21fda9d3
release-1.10
@@ -1,15 +1,5 @@ | |||||
/** | |||||
* \file inlude/lite/pack_model.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <memory> | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
namespace lite { | namespace lite { | ||||
@@ -67,7 +57,7 @@ private: | |||||
Header m_header; | Header m_header; | ||||
friend class FbsHelper; | friend class FbsHelper; | ||||
FbsHelper* m_fbs_helper; | |||||
std::shared_ptr<FbsHelper> m_fbs_helper; | |||||
}; | }; | ||||
} // namespace lite | } // namespace lite |
@@ -5,7 +5,6 @@ cc_library( | |||||
hdrs = glob(["src/**/*.h"]), | hdrs = glob(["src/**/*.h"]), | ||||
includes = ["src"], | includes = ["src"], | ||||
features = if_opt([ | features = if_opt([ | ||||
"no_exceptions", | |||||
"no_rtti", | "no_rtti", | ||||
]), | ]), | ||||
@@ -1,15 +1,7 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/helpers/common.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include <memory> | #include <memory> | ||||
#include <unordered_map> | |||||
DECLARE_int32(thread); | DECLARE_int32(thread); | ||||
namespace lar { | namespace lar { | ||||
/*! | /*! | ||||
@@ -71,6 +63,122 @@ enum class OptLayoutType { | |||||
NHWCD4 = 1 << 6, | NHWCD4 = 1 << 6, | ||||
NCHW44_DOT = 1 << 7 | NCHW44_DOT = 1 << 7 | ||||
}; | }; | ||||
/** | |||||
* base class to story option value | |||||
*/ | |||||
enum class JsonValueType { | |||||
Bool = 0, | |||||
Number, | |||||
NumberInt32, | |||||
NumberUint64, | |||||
String, | |||||
}; | |||||
struct Value { | |||||
virtual JsonValueType get_type() const = 0; | |||||
virtual std::string type_string() const = 0; | |||||
virtual void reset_value() = 0; | |||||
virtual ~Value() = default; | |||||
}; | |||||
/** | |||||
* class for double option | |||||
*/ | |||||
struct Number final : public Value { | |||||
Number(double v) : m_val(v), m_default_val(v) {} | |||||
static std::shared_ptr<Number> make(double v) { | |||||
return std::make_shared<Number>(v); | |||||
} | |||||
void set_value(double v) { m_val = v; } | |||||
double get_value() { return m_val; } | |||||
double get_default() { return m_default_val; } | |||||
void reset_value() override { m_val = m_default_val; } | |||||
JsonValueType get_type() const override { return JsonValueType::Number; } | |||||
std::string type_string() const override { return "Number"; } | |||||
private: | |||||
double m_val; | |||||
double m_default_val; | |||||
}; | |||||
/** | |||||
* class for int32_t option | |||||
*/ | |||||
struct NumberInt32 final : public Value { | |||||
NumberInt32(int32_t v) : m_val(v), m_default_val(v) {} | |||||
static std::shared_ptr<NumberInt32> make(int32_t v) { | |||||
return std::make_shared<NumberInt32>(v); | |||||
} | |||||
void set_value(int32_t v) { m_val = v; } | |||||
int32_t get_value() { return m_val; } | |||||
int32_t get_default() { return m_default_val; } | |||||
void reset_value() override { m_val = m_default_val; } | |||||
JsonValueType get_type() const override { return JsonValueType::NumberInt32; } | |||||
std::string type_string() const override { return "NumberInt32"; } | |||||
private: | |||||
int32_t m_val; | |||||
int32_t m_default_val; | |||||
}; | |||||
/** | |||||
* class for uint64 option | |||||
*/ | |||||
struct NumberUint64 final : public Value { | |||||
NumberUint64(uint64_t v) : m_val(v), m_default_val(v) {} | |||||
static std::shared_ptr<NumberUint64> make(uint64_t v) { | |||||
return std::make_shared<NumberUint64>(v); | |||||
} | |||||
void set_value(uint64_t v) { m_val = v; } | |||||
uint64_t get_value() { return m_val; } | |||||
uint64_t get_default() { return m_default_val; } | |||||
void reset_value() override { m_val = m_default_val; } | |||||
JsonValueType get_type() const override { return JsonValueType::NumberUint64; } | |||||
std::string type_string() const override { return "NumberUint64"; } | |||||
private: | |||||
uint64_t m_val; | |||||
uint64_t m_default_val; | |||||
}; | |||||
/** | |||||
* class for boolean option | |||||
*/ | |||||
struct Bool final : public Value { | |||||
Bool(bool v) : m_val(v), m_default_val(v) {} | |||||
static std::shared_ptr<Bool> make(bool v) { return std::make_shared<Bool>(v); } | |||||
void set_value(bool v) { m_val = v; } | |||||
bool get_value() { return m_val; } | |||||
bool get_default() { return m_default_val; } | |||||
void reset_value() override { m_val = m_default_val; } | |||||
JsonValueType get_type() const override { return JsonValueType::Bool; } | |||||
std::string type_string() const override { return "Bool"; } | |||||
private: | |||||
bool m_val; | |||||
bool m_default_val; | |||||
}; | |||||
/** | |||||
* class for string option | |||||
*/ | |||||
struct String final : public Value { | |||||
String(std::string v) : m_val(v), m_default_val(v) {} | |||||
static std::shared_ptr<String> make(const std::string& v) { | |||||
return std::make_shared<String>(v); | |||||
} | |||||
void set_value(const std::string& v) { m_val = v; } | |||||
std::string& get_value() { return m_val; } | |||||
std::string get_default() { return m_default_val; } | |||||
void reset_value() override { m_val = m_default_val; } | |||||
JsonValueType get_type() const override { return JsonValueType::String; } | |||||
std::string type_string() const override { return "String"; } | |||||
private: | |||||
std::string m_val; | |||||
std::string m_default_val; | |||||
}; | |||||
using OptionValMap = std::unordered_map<std::string, std::shared_ptr<lar::Value>>; | |||||
} // namespace lar | } // namespace lar | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -5,9 +5,7 @@ using namespace mgb; | |||||
template <typename T> | template <typename T> | ||||
T* JsonLoader::Value::safe_cast() { | T* JsonLoader::Value::safe_cast() { | ||||
T* ptr = (T*)(this); | T* ptr = (T*)(this); | ||||
if (nullptr == ptr) { | |||||
fprintf(stderr, "cast ptr is null\n"); | |||||
} | |||||
mgb_assert(nullptr != ptr, "cast ptr is null\n"); | |||||
return ptr; | return ptr; | ||||
} | } | ||||
@@ -31,6 +29,12 @@ std::map<std::string, std::unique_ptr<JsonLoader::Value>>& JsonLoader::Value:: | |||||
return t->m_obj; | return t->m_obj; | ||||
} | } | ||||
std::vector<std::string>& JsonLoader::Value::keys() { | |||||
mgb_assert(Type::OBJECT == m_type); | |||||
auto t = safe_cast<JsonLoader::ObjectValue>(); | |||||
return t->m_keys; | |||||
} | |||||
size_t JsonLoader::Value::len() { | size_t JsonLoader::Value::len() { | ||||
if (Type::ARRAY == m_type) { | if (Type::ARRAY == m_type) { | ||||
auto t = safe_cast<JsonLoader::ArrayValue>(); | auto t = safe_cast<JsonLoader::ArrayValue>(); | ||||
@@ -54,6 +58,12 @@ double JsonLoader::Value::number() { | |||||
return t->value(); | return t->value(); | ||||
} | } | ||||
bool JsonLoader::Value::Bool() { | |||||
mgb_assert(Type::BOOL == m_type); | |||||
auto t = safe_cast<JsonLoader::BoolValue>(); | |||||
return t->value(); | |||||
} | |||||
std::string JsonLoader::Value::str() { | std::string JsonLoader::Value::str() { | ||||
if (Type::STRING == m_type) { | if (Type::STRING == m_type) { | ||||
auto t = safe_cast<StringValue>(); | auto t = safe_cast<StringValue>(); | ||||
@@ -69,7 +79,7 @@ void JsonLoader::expect(char c) { | |||||
void JsonLoader::skip_whitespace() { | void JsonLoader::skip_whitespace() { | ||||
const char* p = m_buf; | const char* p = m_buf; | ||||
while (*p == ' ' || *p == '\t' || *p == '\n' || *p == '\r') { | |||||
while (' ' == *p || '\t' == *p || '\n' == *p || '\r' == *p) { | |||||
++p; | ++p; | ||||
} | } | ||||
m_buf = p; | m_buf = p; | ||||
@@ -80,11 +90,12 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_object() { | |||||
skip_whitespace(); | skip_whitespace(); | ||||
std::unique_ptr<JsonLoader::Value> ret; | std::unique_ptr<JsonLoader::Value> ret; | ||||
JsonLoader::ObjectValue* pObject = new JsonLoader::ObjectValue(); | |||||
std::unique_ptr<JsonLoader::ObjectValue> pObject = | |||||
std::make_unique<JsonLoader::ObjectValue>(); | |||||
if ('}' == *m_buf) { | if ('}' == *m_buf) { | ||||
m_buf = m_buf + 1; | m_buf = m_buf + 1; | ||||
ret.reset((JsonLoader::Value*)(pObject)); | |||||
ret = std::move(pObject); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -113,6 +124,7 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_object() { | |||||
} | } | ||||
pObject->m_obj.insert(std::make_pair(key->str(), std::move(pVal))); | pObject->m_obj.insert(std::make_pair(key->str(), std::move(pVal))); | ||||
pObject->m_keys.push_back(key->str()); | |||||
skip_whitespace(); | skip_whitespace(); | ||||
if (',' == (*m_buf)) { | if (',' == (*m_buf)) { | ||||
@@ -126,22 +138,21 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_object() { | |||||
break; | break; | ||||
} | } | ||||
} | } | ||||
ret.reset((JsonLoader::Value*)(pObject)); | |||||
ret = std::move(pObject); | |||||
return ret; | return ret; | ||||
} | } | ||||
std::unique_ptr<JsonLoader::Value> JsonLoader::parse_array() { | std::unique_ptr<JsonLoader::Value> JsonLoader::parse_array() { | ||||
expect('['); | expect('['); | ||||
skip_whitespace(); | skip_whitespace(); | ||||
std::unique_ptr<JsonLoader::Value> ret; | std::unique_ptr<JsonLoader::Value> ret; | ||||
JsonLoader::ArrayValue* pArray = new JsonLoader::ArrayValue(); | |||||
std::unique_ptr<JsonLoader::ArrayValue> pArray = | |||||
std::make_unique<JsonLoader::ArrayValue>(); | |||||
if (']' == *m_buf) { | if (']' == *m_buf) { | ||||
m_buf = m_buf + 1; | m_buf = m_buf + 1; | ||||
ret.reset((JsonLoader::Value*)(pArray)); | |||||
ret = std::move(pArray); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -168,15 +179,14 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_array() { | |||||
} | } | ||||
} | } | ||||
ret.reset((JsonLoader::Value*)(pArray)); | |||||
ret = std::move(pArray); | |||||
return ret; | return ret; | ||||
} | } | ||||
std::unique_ptr<JsonLoader::Value> JsonLoader::parse_string() { | std::unique_ptr<JsonLoader::Value> JsonLoader::parse_string() { | ||||
expect('\"'); | expect('\"'); | ||||
std::unique_ptr<JsonLoader::Value> ret; | |||||
JsonLoader::StringValue* pStr = new JsonLoader::StringValue(); | |||||
std::unique_ptr<JsonLoader::StringValue> pStr = | |||||
std::make_unique<JsonLoader::StringValue>(); | |||||
const char* p = m_buf; | const char* p = m_buf; | ||||
while (true) { | while (true) { | ||||
@@ -189,7 +199,7 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_string() { | |||||
} | } | ||||
} | } | ||||
m_buf = p; | m_buf = p; | ||||
ret.reset((JsonLoader::Value*)(pStr)); | |||||
std::unique_ptr<JsonLoader::Value> ret = std::move(pStr); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -207,31 +217,31 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_number() { | |||||
return; | return; | ||||
}; | }; | ||||
if (*p == '-') | |||||
if ('-' == *p) | |||||
p++; | p++; | ||||
if (*p == '0') | |||||
if ('0' == *p) | |||||
p++; | p++; | ||||
else { | else { | ||||
loop_digit(std::ref(p)); | loop_digit(std::ref(p)); | ||||
} | } | ||||
if (*p == '.') { | |||||
if ('.' == *p) { | |||||
p++; | p++; | ||||
loop_digit(std::ref(p)); | loop_digit(std::ref(p)); | ||||
} | } | ||||
if (*p == 'e' || *p == 'E') { | |||||
if ('e' == *p || 'E' == *p) { | |||||
p++; | p++; | ||||
if (*p == '+' || *p == '-') | |||||
if ('+' == *p || '-' == *p) | |||||
p++; | p++; | ||||
loop_digit(std::ref(p)); | loop_digit(std::ref(p)); | ||||
} | } | ||||
JsonLoader::NumberValue* pNum = new JsonLoader::NumberValue(); | |||||
std::unique_ptr<JsonLoader::NumberValue> pNum = | |||||
std::make_unique<JsonLoader::NumberValue>(); | |||||
pNum->m_value = strtod(m_buf, nullptr); | pNum->m_value = strtod(m_buf, nullptr); | ||||
m_buf = p; | m_buf = p; | ||||
std::unique_ptr<JsonLoader::Value> ret; | |||||
ret.reset((JsonLoader::Value*)(pNum)); | |||||
std::unique_ptr<JsonLoader::Value> ret = std::move(pNum); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -243,6 +253,10 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_value() { | |||||
return parse_object(); | return parse_object(); | ||||
case '\"': | case '\"': | ||||
return parse_string(); | return parse_string(); | ||||
case 't': | |||||
return parse_bool(); | |||||
case 'f': | |||||
return parse_bool(); | |||||
case '\0': | case '\0': | ||||
m_state = State::BAD_TYPE; | m_state = State::BAD_TYPE; | ||||
break; | break; | ||||
@@ -252,6 +266,37 @@ std::unique_ptr<JsonLoader::Value> JsonLoader::parse_value() { | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
std::unique_ptr<JsonLoader::Value> JsonLoader::parse_bool() { | |||||
const char* p = m_buf; | |||||
std::string value; | |||||
if ('t' == *p) { | |||||
value = ""; | |||||
for (size_t idx = 0; idx < 4; ++idx) { | |||||
value += *p++; | |||||
} | |||||
} else if ('f' == *p) { | |||||
value = ""; | |||||
for (size_t idx = 0; idx < 5; ++idx) { | |||||
value += *p++; | |||||
} | |||||
} | |||||
bool val = false; | |||||
if ("true" == value) { | |||||
val = true; | |||||
} else if ("false" == value) { | |||||
val = false; | |||||
} else { | |||||
mgb_log_error("invalid value: %s for possible bool value", value.c_str()); | |||||
} | |||||
std::unique_ptr<JsonLoader::BoolValue> pBool = | |||||
std::make_unique<JsonLoader::BoolValue>(); | |||||
pBool->m_value = val; | |||||
m_buf = p; | |||||
std::unique_ptr<JsonLoader::Value> ret = std::move(pBool); | |||||
return ret; | |||||
} | |||||
std::unique_ptr<JsonLoader::Value> JsonLoader::load( | std::unique_ptr<JsonLoader::Value> JsonLoader::load( | ||||
const char* content, const size_t size) { | const char* content, const size_t size) { | ||||
m_buf = content; | m_buf = content; | ||||
@@ -18,7 +18,7 @@ public: | |||||
// base class for different value format | // base class for different value format | ||||
class Value { | class Value { | ||||
protected: | protected: | ||||
enum struct Type : uint8_t { UNKNOWN, NUMBER, STRING, OBJECT, ARRAY }; | |||||
enum struct Type : uint8_t { UNKNOWN, NUMBER, STRING, OBJECT, ARRAY, BOOL }; | |||||
Type m_type; | Type m_type; | ||||
public: | public: | ||||
@@ -39,12 +39,16 @@ public: | |||||
bool is_str() { return Type::STRING == m_type; } | bool is_str() { return Type::STRING == m_type; } | ||||
bool is_bool() { return Type::BOOL == m_type; } | |||||
std::unique_ptr<Value>& operator[](const std::string& key); | std::unique_ptr<Value>& operator[](const std::string& key); | ||||
std::unique_ptr<Value>& operator[](const size_t index); | std::unique_ptr<Value>& operator[](const size_t index); | ||||
std::map<std::string, std::unique_ptr<Value>>& objects(); | std::map<std::string, std::unique_ptr<Value>>& objects(); | ||||
std::vector<std::string>& keys(); | |||||
size_t len(); | size_t len(); | ||||
megdnn::SmallVector<std::unique_ptr<Value>>& array(); | megdnn::SmallVector<std::unique_ptr<Value>>& array(); | ||||
@@ -52,6 +56,8 @@ public: | |||||
double number(); | double number(); | ||||
std::string str(); | std::string str(); | ||||
bool Bool(); | |||||
}; | }; | ||||
void expect(char c); | void expect(char c); | ||||
@@ -68,6 +74,8 @@ public: | |||||
std::unique_ptr<Value> parse_value(); | std::unique_ptr<Value> parse_value(); | ||||
std::unique_ptr<Value> parse_bool(); | |||||
enum struct State : uint8_t { | enum struct State : uint8_t { | ||||
OK = 0, | OK = 0, | ||||
BAD_TYPE, | BAD_TYPE, | ||||
@@ -137,21 +145,26 @@ public: | |||||
class ObjectValue final : public Value { | class ObjectValue final : public Value { | ||||
std::map<std::string, std::unique_ptr<Value>> m_obj; | std::map<std::string, std::unique_ptr<Value>> m_obj; | ||||
std::vector<std::string> m_keys; | |||||
public: | public: | ||||
ObjectValue() : Value(Type::OBJECT) {} | ObjectValue() : Value(Type::OBJECT) {} | ||||
ObjectValue(ObjectValue& arr) : Value(arr) { | ObjectValue(ObjectValue& arr) : Value(arr) { | ||||
m_obj.clear(); | m_obj.clear(); | ||||
m_keys.clear(); | |||||
for (auto itra = arr.m_obj.begin(); itra != arr.m_obj.end(); ++itra) { | for (auto itra = arr.m_obj.begin(); itra != arr.m_obj.end(); ++itra) { | ||||
m_obj.emplace(std::make_pair(itra->first, std::move(itra->second))); | m_obj.emplace(std::make_pair(itra->first, std::move(itra->second))); | ||||
m_keys.push_back(itra->first); | |||||
} | } | ||||
} | } | ||||
ObjectValue(ObjectValue&& arr) : Value(arr) { | ObjectValue(ObjectValue&& arr) : Value(arr) { | ||||
m_obj.clear(); | m_obj.clear(); | ||||
m_keys.clear(); | |||||
for (auto itra = arr.m_obj.begin(); itra != arr.m_obj.end(); ++itra) { | for (auto itra = arr.m_obj.begin(); itra != arr.m_obj.end(); ++itra) { | ||||
m_obj.emplace(std::make_pair(itra->first, std::move(itra->second))); | m_obj.emplace(std::make_pair(itra->first, std::move(itra->second))); | ||||
m_keys.push_back(itra->first); | |||||
} | } | ||||
} | } | ||||
@@ -160,9 +173,19 @@ public: | |||||
const std::string&); | const std::string&); | ||||
friend std::map<std::string, std::unique_ptr<JsonLoader::Value>>& JsonLoader:: | friend std::map<std::string, std::unique_ptr<JsonLoader::Value>>& JsonLoader:: | ||||
Value::objects(); | Value::objects(); | ||||
friend std::vector<std::string>& JsonLoader::Value::keys(); | |||||
friend size_t JsonLoader::Value::len(); | friend size_t JsonLoader::Value::len(); | ||||
}; | }; | ||||
class BoolValue final : public Value { | |||||
bool m_value; | |||||
public: | |||||
BoolValue() : Value(Type::BOOL) {} | |||||
bool value() { return m_value; } | |||||
friend std::unique_ptr<Value> JsonLoader::parse_bool(); | |||||
}; | |||||
private: | private: | ||||
const char* m_buf; | const char* m_buf; | ||||
State m_state; | State m_state; | ||||
@@ -0,0 +1,362 @@ | |||||
#include "utils.h" | |||||
using namespace lar; | |||||
/////////////////// JsonOptionsCoder /////////////////// | |||||
#if MGB_ENABLE_JSON | |||||
//! encode option | |||||
void encode_single_options( | |||||
std::pair<std::string, std::shared_ptr<lar::Value>> item, | |||||
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>& | |||||
list, | |||||
bool encode_all) { | |||||
auto type = item.second->get_type(); | |||||
if (type == JsonValueType::Bool) { | |||||
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||||
if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) { | |||||
return; | |||||
} | |||||
list.push_back( | |||||
{mgb::json::String(item.first), | |||||
mgb::json::Bool::make(val_ptr->get_value())}); | |||||
} else if (type == JsonValueType::NumberInt32) { | |||||
auto val_ptr = std::static_pointer_cast<lar::NumberInt32>(item.second); | |||||
if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) { | |||||
return; | |||||
} | |||||
list.push_back( | |||||
{mgb::json::String(item.first), | |||||
mgb::json::NumberInt::make( | |||||
static_cast<int64_t>(val_ptr->get_value()))}); | |||||
} else if (type == JsonValueType::NumberUint64) { | |||||
auto val_ptr = std::static_pointer_cast<lar::NumberUint64>(item.second); | |||||
list.push_back( | |||||
{mgb::json::String(item.first), | |||||
mgb::json::NumberInt::make( | |||||
static_cast<int64_t>(val_ptr->get_value()))}); | |||||
} else if (type == JsonValueType::Number) { | |||||
auto val_ptr = std::static_pointer_cast<lar::Number>(item.second); | |||||
list.push_back( | |||||
{mgb::json::String(item.first), | |||||
mgb::json::Number::make(val_ptr->get_value())}); | |||||
} else if (type == JsonValueType::String) { | |||||
auto val_ptr = std::static_pointer_cast<lar::String>(item.second); | |||||
if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) { | |||||
return; | |||||
} | |||||
list.push_back( | |||||
{mgb::json::String(item.first), | |||||
mgb::json::String::make(val_ptr->get_value())}); | |||||
} else { | |||||
mgb_log_error( | |||||
"unsupport JsonValueType:%s for lar::Value", | |||||
item.second->type_string().c_str()); | |||||
} | |||||
} | |||||
std::string JsonOptionsCoder::encode(OptionValMap& option_val_map, bool encode_all) { | |||||
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||||
json_options; | |||||
for (auto& item : option_val_map) { | |||||
encode_single_options(item, json_options, encode_all); | |||||
} | |||||
auto json_obj = mgb::json::Object::make( | |||||
{{"options", mgb::json::Object::make(json_options)}}); | |||||
return json_obj->to_string(1); | |||||
} | |||||
//! encode device | |||||
std::vector<std::shared_ptr<mgb::json::Object>> JsonOptionsCoder::encode( | |||||
OptionValMap& option_val_map) { | |||||
std::vector<std::shared_ptr<mgb::json::Object>> info; | |||||
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||||
json_device; | |||||
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||||
json_options; | |||||
for (auto& item : option_val_map) { | |||||
if ((item.first == "cpu" || item.first == "cpu_default" || | |||||
item.first == "multithread" || item.first == "multithread_default")) { | |||||
auto type = item.second->get_type(); | |||||
if (type == JsonValueType::Bool) { | |||||
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||||
if (val_ptr->get_value() == val_ptr->get_default()) | |||||
continue; | |||||
} | |||||
if (type == JsonValueType::NumberInt32) { | |||||
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||||
if (val_ptr->get_value() == val_ptr->get_default()) | |||||
continue; | |||||
} | |||||
json_device.push_back( | |||||
{mgb::json::String("type"), mgb::json::String::make("CPU")}); | |||||
if (item.first == "cpu_default" || item.first == "multithread_default") { | |||||
json_device.push_back( | |||||
{mgb::json::String("enable_inplace_model"), | |||||
mgb::json::Bool::make(true)}); | |||||
} | |||||
if (item.first == "multithread" || item.first == "multithread_default") { | |||||
json_device.push_back( | |||||
{mgb::json::String("number_threads"), | |||||
mgb::json::NumberInt::make( | |||||
std::static_pointer_cast<lar::NumberInt32>(item.second) | |||||
->get_value())}); | |||||
if (item.first == "multithread") { | |||||
json_device.push_back( | |||||
{mgb::json::String("device_id"), | |||||
mgb::json::NumberInt::make(0)}); | |||||
} | |||||
} | |||||
} else if (item.first == "cuda") { | |||||
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||||
if (val_ptr->get_value() == val_ptr->get_default()) | |||||
continue; | |||||
json_device.push_back( | |||||
{mgb::json::String("type"), mgb::json::String::make("CUDA")}); | |||||
json_device.push_back( | |||||
{mgb::json::String("device_id"), mgb::json::NumberInt::make(0)}); | |||||
} else if (item.first == "opencl") { | |||||
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||||
if (val_ptr->get_value() == val_ptr->get_default()) | |||||
continue; | |||||
json_device.push_back( | |||||
{mgb::json::String("type"), mgb::json::String::make("OPENCL")}); | |||||
} else if ( | |||||
item.first == "record_comp_seq" || item.first == "record_comp_seq2") { | |||||
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||||
if (val_ptr->get_value() == val_ptr->get_default()) | |||||
continue; | |||||
int comp_node_seq_record_level = item.first == "record_comp_seq" ? 1 : 2; | |||||
json_options.push_back( | |||||
{mgb::json::String("comp_node_seq_record_level"), | |||||
mgb::json::NumberInt::make(comp_node_seq_record_level)}); | |||||
} else if (item.first == "fake_first") { | |||||
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||||
if (val_ptr->get_value() == val_ptr->get_default()) | |||||
continue; | |||||
json_options.push_back( | |||||
{mgb::json::String("fake_next_exec"), | |||||
mgb::json::Bool::make(val_ptr->get_value())}); | |||||
} else if (item.first == "no_sanity_check") { | |||||
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||||
if (val_ptr->get_value() == val_ptr->get_default()) | |||||
continue; | |||||
json_options.push_back( | |||||
{mgb::json::String("var_sanity_check_first_run"), | |||||
mgb::json::Bool::make(!val_ptr->get_value())}); | |||||
} else if (item.first == "weight_preprocess") { | |||||
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||||
if (val_ptr->get_value() == val_ptr->get_default()) | |||||
continue; | |||||
json_options.push_back( | |||||
{mgb::json::String("weight_preprocess"), | |||||
mgb::json::Bool::make(val_ptr->get_value())}); | |||||
} | |||||
} | |||||
info.push_back(mgb::json::Object::make( | |||||
{{"options", mgb::json::Object::make(json_options)}})); | |||||
info.push_back(mgb::json::Object::make( | |||||
{{"device", mgb::json::Object::make(json_device)}})); | |||||
return info; | |||||
} | |||||
//! decode options note string into option map | |||||
OptionValMap& JsonOptionsCoder::decode( | |||||
const std::string& code, OptionValMap& option_val_map) { | |||||
std::shared_ptr<mgb::JsonLoader::Value> root = | |||||
m_json_loader.load(code.c_str(), code.size()); | |||||
for (auto& item : root->objects()) { | |||||
auto& value = *item.second; | |||||
//! get all keys in json object | |||||
auto keys = value.keys(); | |||||
//! set the json format options into internal options | |||||
for (auto& val : keys) { | |||||
if (value[val]->is_bool()) { | |||||
auto val_ptr = std::static_pointer_cast<lar::Bool>(option_val_map[val]); | |||||
val_ptr->set_value(value[val]->Bool()); | |||||
} else if (value[val]->is_number()) { | |||||
auto type = option_val_map[val]->get_type(); | |||||
if (type == JsonValueType::Number) { | |||||
auto val_ptr = | |||||
std::static_pointer_cast<lar::Number>(option_val_map[val]); | |||||
val_ptr->set_value(value[val]->number()); | |||||
} else if (type == JsonValueType::NumberInt32) { | |||||
auto val_ptr = std::static_pointer_cast<lar::NumberInt32>( | |||||
option_val_map[val]); | |||||
val_ptr->set_value(static_cast<int32_t>(value[val]->number())); | |||||
} else if (type == JsonValueType::NumberUint64) { | |||||
auto val_ptr = std::static_pointer_cast<lar::NumberUint64>( | |||||
option_val_map[val]); | |||||
val_ptr->set_value(static_cast<uint64_t>(value[val]->number())); | |||||
} else { | |||||
mgb_log_error( | |||||
"invalid number type:%s to set", | |||||
option_val_map[val]->type_string().c_str()); | |||||
} | |||||
} else if (value[val]->is_str()) { | |||||
auto val_ptr = | |||||
std::static_pointer_cast<lar::String>(option_val_map[val]); | |||||
val_ptr->set_value(value[val]->str()); | |||||
} else { | |||||
mgb_log_error("invalid value type for JsonLoader"); | |||||
} | |||||
} | |||||
} | |||||
return option_val_map; | |||||
} | |||||
#endif | |||||
std::string GflagsOptionsCoder::encode(OptionValMap& option_val_map, bool encode_all) { | |||||
std::vector<std::string> gflags_options; | |||||
for (auto& item : option_val_map) { | |||||
auto type = item.second->get_type(); | |||||
std::string val = "--"; | |||||
if (type == JsonValueType::Bool) { | |||||
auto val_ptr = std::static_pointer_cast<lar::Bool>(item.second); | |||||
if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) { | |||||
continue; | |||||
} | |||||
val += item.first; | |||||
val += "="; | |||||
val += val_ptr->get_value() ? "true" : "false"; | |||||
gflags_options.push_back(val); | |||||
} else if (type == JsonValueType::NumberInt32) { | |||||
auto val_ptr = std::static_pointer_cast<lar::NumberInt32>(item.second); | |||||
if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) { | |||||
continue; | |||||
} | |||||
val += item.first; | |||||
val += "="; | |||||
val += std::to_string(val_ptr->get_value()); | |||||
gflags_options.push_back(val); | |||||
} else if (type == JsonValueType::NumberUint64) { | |||||
auto val_ptr = std::static_pointer_cast<lar::NumberUint64>(item.second); | |||||
val += item.first; | |||||
val += "="; | |||||
val += std::to_string(val_ptr->get_value()); | |||||
gflags_options.push_back(val); | |||||
} else if (type == JsonValueType::Number) { | |||||
auto val_ptr = std::static_pointer_cast<lar::Number>(item.second); | |||||
val += item.first; | |||||
val += "="; | |||||
val += std::to_string(val_ptr->get_value()); | |||||
gflags_options.push_back(val); | |||||
} else if (type == JsonValueType::String) { | |||||
auto val_ptr = std::static_pointer_cast<lar::String>(item.second); | |||||
if (!encode_all && val_ptr->get_value() == val_ptr->get_default()) { | |||||
continue; | |||||
} | |||||
val += item.first; | |||||
val += "=\""; | |||||
val += val_ptr->get_value(); | |||||
val += "\""; | |||||
gflags_options.push_back(val); | |||||
} else { | |||||
mgb_log_error( | |||||
"unsupport JsonValueType:%s for lar::Value", | |||||
item.second->type_string().c_str()); | |||||
} | |||||
} | |||||
std::string ret; | |||||
for (auto& item : gflags_options) { | |||||
ret += item; | |||||
ret += "\n"; | |||||
} | |||||
return ret; | |||||
} | |||||
//! decode options note string into option map | |||||
OptionValMap& GflagsOptionsCoder::decode( | |||||
const std::string& code, OptionValMap& option_val_map) { | |||||
std::unordered_map<std::string, std::string> gflags_map; | |||||
auto to_raw_string = [](const std::string& str) { | |||||
auto size = str.size(); | |||||
std::string ret; | |||||
if ('\"' == str[0] && '\"' == str[size - 1]) { | |||||
ret = str.substr(1, size - 2); | |||||
} else { | |||||
ret = str; | |||||
} | |||||
return ret; | |||||
}; | |||||
size_t start = 0; | |||||
size_t end = code.find("\n", start); | |||||
while (end != std::string::npos) { | |||||
auto str = code.substr(start, end - start); | |||||
if (str.substr(0, 2) == "--") { | |||||
size_t idx = str.find("=", 0); | |||||
gflags_map.insert( | |||||
{str.substr(2, idx - 2), to_raw_string(str.substr(idx + 1))}); | |||||
} else { | |||||
mgb_log_error("invaid gflags argument %s", str.c_str()); | |||||
} | |||||
start = end + 1; | |||||
end = code.find("\n", start); | |||||
} | |||||
for (auto& item : gflags_map) { | |||||
if (option_val_map.count(item.first) != 0) { | |||||
auto& option_val = option_val_map[item.first]; | |||||
auto type = option_val->get_type(); | |||||
if (type == JsonValueType::Bool) { | |||||
auto val_ptr = std::static_pointer_cast<lar::Bool>(option_val); | |||||
if (item.second == "true" || item.second == "false") { | |||||
auto val = item.second == "true"; | |||||
val_ptr->set_value(val); | |||||
} | |||||
} else if (type == JsonValueType::NumberInt32) { | |||||
auto val_ptr = std::static_pointer_cast<lar::NumberInt32>(option_val); | |||||
MGB_TRY { | |||||
int32_t val = std::stoi(item.second); | |||||
val_ptr->set_value(val); | |||||
} | |||||
MGB_CATCH(std::exception & exc, { | |||||
mgb_log_error( | |||||
"invaid value: %s for %s", item.second.c_str(), | |||||
item.first.c_str()); | |||||
}); | |||||
} else if (type == JsonValueType::NumberUint64) { | |||||
auto val_ptr = std::static_pointer_cast<lar::NumberUint64>(option_val); | |||||
MGB_TRY { | |||||
uint64_t val = std::stoull(item.second); | |||||
val_ptr->set_value(val); | |||||
} | |||||
MGB_CATCH(std::exception & exc, { | |||||
mgb_log_error( | |||||
"invaid value: %s for %s", item.second.c_str(), | |||||
item.first.c_str()); | |||||
}); | |||||
} else if (type == JsonValueType::Number) { | |||||
auto val_ptr = std::static_pointer_cast<lar::Number>(option_val); | |||||
MGB_TRY { | |||||
double val = std::stod(item.second); | |||||
val_ptr->set_value(val); | |||||
} | |||||
MGB_CATCH(std::exception & exc, { | |||||
mgb_log_error( | |||||
"invaid value: %s for %s", item.second.c_str(), | |||||
item.first.c_str()); | |||||
}); | |||||
} else if (type == JsonValueType::String) { | |||||
auto val_ptr = std::static_pointer_cast<lar::String>(option_val); | |||||
val_ptr->set_value(item.second); | |||||
} else { | |||||
mgb_log_error( | |||||
"unsupport JsonValueType:%s for lar::Value", | |||||
option_val->type_string().c_str()); | |||||
} | |||||
} else { | |||||
mgb_log_error("invalid gflags when set runtime options in fitting mode"); | |||||
} | |||||
} | |||||
return option_val_map; | |||||
} |
@@ -0,0 +1,69 @@ | |||||
#pragma once | |||||
#include <vector> | |||||
#include "common.h" | |||||
#include "json_loader.h" | |||||
#include "megbrain/utils/json.h" | |||||
namespace lar { | |||||
/** | |||||
* fitting profiler type | |||||
*/ | |||||
enum class ProiflerType { | |||||
TIME_PROFILER = 0, | |||||
UNSPEC_PROFILER = 1, | |||||
}; | |||||
/** | |||||
* option coder type | |||||
*/ | |||||
enum class CoderType { | |||||
GFLAGS = 0, | |||||
JSON = 1, | |||||
UNSPEC = 2, | |||||
}; | |||||
/** | |||||
* option coder to transform internal option val into differnet form | |||||
*/ | |||||
class OptionsCoder { | |||||
public: | |||||
OptionsCoder(){}; | |||||
//! encode options into given format | |||||
virtual std::string encode(OptionValMap&, bool) = 0; | |||||
//! decode options with given format into option map | |||||
virtual OptionValMap& decode(const std::string&, OptionValMap& val_map) = 0; | |||||
//! destructor | |||||
virtual ~OptionsCoder() = default; | |||||
}; | |||||
#if MGB_ENABLE_JSON | |||||
class JsonOptionsCoder final : public OptionsCoder { | |||||
public: | |||||
JsonOptionsCoder(){}; | |||||
//! encode given options into json format | |||||
std::string encode(OptionValMap&, bool encode_all) override; | |||||
std::vector<std::shared_ptr<mgb::json::Object>> encode(OptionValMap&); | |||||
//! decode given json format options into given options map | |||||
OptionValMap& decode(const std::string&, OptionValMap&) override; | |||||
private: | |||||
mgb::JsonLoader m_json_loader; | |||||
}; | |||||
#endif | |||||
class GflagsOptionsCoder final : public OptionsCoder { | |||||
public: | |||||
GflagsOptionsCoder(){}; | |||||
//! encode given options into gflags format | |||||
std::string encode(OptionValMap&, bool encode_all = false) override; | |||||
//! decode given gflags format options into given options maps | |||||
OptionValMap& decode(const std::string&, OptionValMap&) override; | |||||
}; | |||||
} // namespace lar |
@@ -1,20 +1,37 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/main.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include <string> | #include <string> | ||||
#include "strategys/strategy.h" | #include "strategys/strategy.h" | ||||
std::string simple_usage = R"( | |||||
load_and_run: load_and_run <model_path> [options Flags...] | |||||
Flags from lite/load_and_run/src/models/model.cpp: | |||||
-lite type: bool default: false use megengine lite interface to run model | |||||
Flags from lite/load_and_run/src/options/strategy_options.cpp: | |||||
-iter type: int32 default: 10 iteration number for run model | |||||
-thread type: int32 default: 1 thread number for run model when <thread> is supported | |||||
-warmup_iter type: int32 default: 1 iteration number for warm up model before run | |||||
Flags from com_github_gflags_gflags/src/gflags.cc: | |||||
-flagfile type: string default: "" load flags from file | |||||
-fromenv type: string default: "" set flags from the environment [use 'export FLAGS_flag1=value'] | |||||
... | |||||
Flags from com_github_gflags_gflags/src/gflags_reporting.cc: | |||||
-help type: bool default: false show help on all flags | |||||
-helpmatch type: string default: "" show help on modules whose name contains the specified substr | |||||
-version type: bool default: false show version and build info and exit | |||||
... | |||||
More details using "--help" to get!! | |||||
)"; | |||||
int main(int argc, char** argv) { | int main(int argc, char** argv) { | ||||
std::string usage = "load_and_run <model_path> [options...]"; | |||||
std::string usage = "load_and_run <model_path> [options Flags...]"; | |||||
if (argc < 2) { | if (argc < 2) { | ||||
printf("usage: %s\n", usage.c_str()); | |||||
printf("usage: %s\n", simple_usage.c_str()); | |||||
return -1; | return -1; | ||||
} | } | ||||
gflags::SetUsageMessage(usage); | gflags::SetUsageMessage(usage); | ||||
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/models/model.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "model.h" | #include "model.h" | ||||
#include <iostream> | #include <iostream> | ||||
#include <memory> | #include <memory> | ||||
@@ -56,5 +47,5 @@ std::shared_ptr<ModelBase> ModelBase::create_model(std::string model_path) { | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
} | } | ||||
DEFINE_bool(lite, false, "using lite model to run mdl model"); | |||||
DEFINE_bool(lite, false, "use megengine lite interface to run model"); | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -1,17 +1,8 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/models/model.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include <string> | #include <string> | ||||
#include "helpers/common.h" | #include "helpers/common.h" | ||||
#include "megbrain/utils/json.h" | |||||
DECLARE_bool(lite); | DECLARE_bool(lite); | ||||
namespace lar { | namespace lar { | ||||
@@ -45,6 +36,12 @@ public: | |||||
virtual ~ModelBase() = default; | virtual ~ModelBase() = default; | ||||
virtual const std::string& get_model_path() const = 0; | virtual const std::string& get_model_path() const = 0; | ||||
virtual std::vector<uint8_t> get_model_data() = 0; | |||||
#if MGB_ENABLE_JSON | |||||
//! get model io information | |||||
virtual std::shared_ptr<mgb::json::Object> get_io_info() = 0; | |||||
#endif | |||||
}; | }; | ||||
} // namespace lar | } // namespace lar | ||||
@@ -1,14 +1,7 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/models/model_lite.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "model_lite.h" | #include "model_lite.h" | ||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include <cstring> | #include <cstring> | ||||
#include <map> | |||||
#include "misc.h" | #include "misc.h" | ||||
DECLARE_bool(share_param_mem); | DECLARE_bool(share_param_mem); | ||||
@@ -51,3 +44,75 @@ void ModelLite::run_model() { | |||||
void ModelLite::wait() { | void ModelLite::wait() { | ||||
m_network->wait(); | m_network->wait(); | ||||
} | } | ||||
#if MGB_ENABLE_JSON | |||||
std::shared_ptr<mgb::json::Object> ModelLite::get_io_info() { | |||||
std::shared_ptr<mgb::json::Array> inputs = mgb::json::Array::make(); | |||||
std::shared_ptr<mgb::json::Array> outputs = mgb::json::Array::make(); | |||||
auto get_dtype = [&](lite::Layout& layout) { | |||||
std::map<LiteDataType, std::string> type_map = { | |||||
{LiteDataType::LITE_FLOAT, "float32"}, | |||||
{LiteDataType::LITE_HALF, "float16"}, | |||||
{LiteDataType::LITE_INT64, "int64"}, | |||||
{LiteDataType::LITE_INT, "int32"}, | |||||
{LiteDataType::LITE_UINT, "uint32"}, | |||||
{LiteDataType::LITE_INT16, "int16"}, | |||||
{LiteDataType::LITE_UINT16, "uint16"}, | |||||
{LiteDataType::LITE_INT8, "int8"}, | |||||
{LiteDataType::LITE_UINT8, "uint8"}}; | |||||
return type_map[layout.data_type]; | |||||
}; | |||||
auto make_shape = [](lite::Layout& layout) { | |||||
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||||
shape; | |||||
for (size_t i = 0; i < layout.ndim; ++i) { | |||||
std::string lable = "dim"; | |||||
lable += std::to_string(layout.ndim - i - 1); | |||||
shape.push_back( | |||||
{mgb::json::String(lable), | |||||
mgb::json::NumberInt::make(layout.shapes[layout.ndim - i - 1])}); | |||||
} | |||||
return shape; | |||||
}; | |||||
auto input_name = m_network->get_all_input_name(); | |||||
for (auto& i : input_name) { | |||||
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||||
json_inp; | |||||
auto layout = m_network->get_io_tensor(i)->get_layout(); | |||||
json_inp.push_back( | |||||
{mgb::json::String("shape"), | |||||
mgb::json::Object::make(make_shape(layout))}); | |||||
json_inp.push_back( | |||||
{mgb::json::String("dtype"), | |||||
mgb::json::String::make(get_dtype(layout))}); | |||||
json_inp.push_back({mgb::json::String("name"), mgb::json::String::make(i)}); | |||||
inputs->add(mgb::json::Object::make(json_inp)); | |||||
} | |||||
auto output_name = m_network->get_all_output_name(); | |||||
for (auto& i : output_name) { | |||||
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||||
json_out; | |||||
auto layout = m_network->get_io_tensor(i)->get_layout(); | |||||
json_out.push_back( | |||||
{mgb::json::String("shape"), | |||||
mgb::json::Object::make(make_shape(layout))}); | |||||
json_out.push_back( | |||||
{mgb::json::String("dtype"), | |||||
mgb::json::String::make(get_dtype(layout))}); | |||||
json_out.push_back({mgb::json::String("name"), mgb::json::String::make(i)}); | |||||
inputs->add(mgb::json::Object::make(json_out)); | |||||
} | |||||
return mgb::json::Object::make( | |||||
{{"IO", | |||||
mgb::json::Object::make({{"outputs", outputs}, {"inputs", inputs}})}}); | |||||
} | |||||
#endif | |||||
std::vector<uint8_t> ModelLite::get_model_data() { | |||||
std::vector<uint8_t> out_data; | |||||
LITE_THROW("unsupported interface: ModelLite::get_model_data() \n"); | |||||
return out_data; | |||||
} |
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/models/model_lite.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <string> | #include <string> | ||||
@@ -39,6 +30,10 @@ public: | |||||
//! wait the end of asynchronous function execution | //! wait the end of asynchronous function execution | ||||
void wait() override; | void wait() override; | ||||
#if MGB_ENABLE_JSON | |||||
std::shared_ptr<mgb::json::Object> get_io_info() override; | |||||
#endif | |||||
//! enable global layout transform | //! enable global layout transform | ||||
void set_layout_transform(bool state) { enable_layout_transform = state; } | void set_layout_transform(bool state) { enable_layout_transform = state; } | ||||
@@ -62,6 +57,8 @@ public: | |||||
const std::string& get_model_path() const override { return model_path; } | const std::string& get_model_path() const override { return model_path; } | ||||
std::vector<uint8_t> get_model_data() override; | |||||
private: | private: | ||||
bool share_model_mem; | bool share_model_mem; | ||||
bool enable_layout_transform; | bool enable_layout_transform; | ||||
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/models/model_mdl.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "model_mdl.h" | #include "model_mdl.h" | ||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include <iostream> | #include <iostream> | ||||
@@ -109,3 +100,76 @@ void ModelMdl::run_model() { | |||||
void ModelMdl::wait() { | void ModelMdl::wait() { | ||||
m_asyc_exec->wait(); | m_asyc_exec->wait(); | ||||
} | } | ||||
#if MGB_ENABLE_JSON | |||||
std::shared_ptr<mgb::json::Object> ModelMdl::get_io_info() { | |||||
std::shared_ptr<mgb::json::Array> inputs = mgb::json::Array::make(); | |||||
std::shared_ptr<mgb::json::Array> outputs = mgb::json::Array::make(); | |||||
auto get_dtype = [&](megdnn::DType data_type) { | |||||
std::map<megdnn::DTypeEnum, std::string> type_map = { | |||||
{mgb::dtype::Float32().enumv(), "float32"}, | |||||
{mgb::dtype::Int32().enumv(), "int32"}, | |||||
{mgb::dtype::Int16().enumv(), "int16"}, | |||||
{mgb::dtype::Uint16().enumv(), "uint16"}, | |||||
{mgb::dtype::Int8().enumv(), "int8"}, | |||||
{mgb::dtype::Uint8().enumv(), "uint8"}}; | |||||
return type_map[data_type.enumv()]; | |||||
}; | |||||
auto make_shape = [](mgb::TensorShape& shape_) { | |||||
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||||
shape; | |||||
for (size_t i = 0; i < shape_.ndim; ++i) { | |||||
std::string lable = "dim"; | |||||
lable += std::to_string(shape_.ndim - i - 1); | |||||
shape.push_back( | |||||
{mgb::json::String(lable), | |||||
mgb::json::NumberInt::make(shape_[shape_.ndim - i - 1])}); | |||||
} | |||||
return shape; | |||||
}; | |||||
for (auto&& i : m_load_result.tensor_map) { | |||||
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||||
json_inp; | |||||
auto shape_ = i.second->shape(); | |||||
json_inp.push_back( | |||||
{mgb::json::String("shape"), | |||||
mgb::json::Object::make(make_shape(shape_))}); | |||||
json_inp.push_back( | |||||
{mgb::json::String("dtype"), | |||||
mgb::json::String::make(get_dtype(i.second->dtype()))}); | |||||
json_inp.push_back( | |||||
{mgb::json::String("name"), mgb::json::String::make(i.first)}); | |||||
inputs->add(mgb::json::Object::make(json_inp)); | |||||
} | |||||
for (auto&& i : m_load_result.output_var_list) { | |||||
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||||
json_out; | |||||
auto shape_ = i.shape(); | |||||
json_out.push_back( | |||||
{mgb::json::String("shape"), | |||||
mgb::json::Object::make(make_shape(shape_))}); | |||||
json_out.push_back( | |||||
{mgb::json::String("dtype"), | |||||
mgb::json::String::make(get_dtype(i.dtype()))}); | |||||
json_out.push_back( | |||||
{mgb::json::String("name"), mgb::json::String::make(i.node()->name())}); | |||||
outputs->add(mgb::json::Object::make(json_out)); | |||||
} | |||||
return mgb::json::Object::make( | |||||
{{"IO", | |||||
mgb::json::Object::make({{"outputs", outputs}, {"inputs", inputs}})}}); | |||||
} | |||||
#endif | |||||
std::vector<uint8_t> ModelMdl::get_model_data() { | |||||
std::vector<uint8_t> out_data; | |||||
auto out_file = mgb::serialization::OutputFile::make_vector_proxy(&out_data); | |||||
using DumpConfig = mgb::serialization::GraphDumper::DumpConfig; | |||||
DumpConfig config{1, false, false}; | |||||
auto dumper = | |||||
mgb::serialization::GraphDumper::make(std::move(out_file), m_format.val()); | |||||
dumper->dump(m_load_result.output_var_list, config); | |||||
return out_data; | |||||
} |
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/models/model_mdl.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <string> | #include <string> | ||||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | #include "megbrain/opr/search_policy/algo_chooser_helper.h" | ||||
@@ -42,6 +33,10 @@ public: | |||||
void wait() override; | void wait() override; | ||||
#if MGB_ENABLE_JSON | |||||
std::shared_ptr<mgb::json::Object> get_io_info() override; | |||||
#endif | |||||
//! get load result for megDL model | //! get load result for megDL model | ||||
mgb::serialization::GraphLoader::LoadResult& get_mdl_load_result() { | mgb::serialization::GraphLoader::LoadResult& get_mdl_load_result() { | ||||
return m_load_result; | return m_load_result; | ||||
@@ -109,6 +104,8 @@ public: | |||||
const std::string& get_model_path() const override { return model_path; } | const std::string& get_model_path() const override { return model_path; } | ||||
std::vector<uint8_t> get_model_data() override; | |||||
private: | private: | ||||
bool share_model_mem; | bool share_model_mem; | ||||
std::string model_path; | std::string model_path; | ||||
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/device_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include <iostream> | #include <iostream> | ||||
#include <sstream> | #include <sstream> | ||||
#include "lite/global.h" | #include "lite/global.h" | ||||
@@ -76,7 +67,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( | |||||
loc.type = mgb::CompNode::DeviceType::CPU; | loc.type = mgb::CompNode::DeviceType::CPU; | ||||
}; | }; | ||||
} | } | ||||
#if MGB_CUDA | |||||
#if LITE_WITH_CUDA | |||||
if (enable_cuda) { | if (enable_cuda) { | ||||
mgb_log_warn("using cuda device\n"); | mgb_log_warn("using cuda device\n"); | ||||
model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { | ||||
@@ -134,7 +125,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( | |||||
XPUDeviceOption::XPUDeviceOption() { | XPUDeviceOption::XPUDeviceOption() { | ||||
m_option_name = "xpu_device"; | m_option_name = "xpu_device"; | ||||
enable_cpu = FLAGS_cpu; | enable_cpu = FLAGS_cpu; | ||||
#if MGB_CUDA | |||||
#if LITE_WITH_CUDA | |||||
enable_cuda = FLAGS_cuda; | enable_cuda = FLAGS_cuda; | ||||
#endif | #endif | ||||
enable_cpu_default = FLAGS_cpu_default; | enable_cpu_default = FLAGS_cpu_default; | ||||
@@ -165,18 +156,41 @@ XPUDeviceOption::XPUDeviceOption() { | |||||
"core ids number should be same with thread number set before"); | "core ids number should be same with thread number set before"); | ||||
enable_set_core_ids = true; | enable_set_core_ids = true; | ||||
} | } | ||||
} | |||||
m_option = { | |||||
{"cpu", lar::Bool::make(false)}, | |||||
#if LITE_WITH_CUDA | |||||
{"cuda", lar::Bool::make(false)}, | |||||
#endif | |||||
{"cpu_default", lar::Bool::make(false)}, | |||||
{"multithread", lar::NumberInt32::make(-1)}, | |||||
{"multithread_default", lar::NumberInt32::make(-1)}, | |||||
{"multi_thread_core_ids", lar::String::make("")}, | |||||
}; | |||||
std::static_pointer_cast<lar::Bool>(m_option["cpu"])->set_value(FLAGS_cpu); | |||||
#if LITE_WITH_CUDA | |||||
std::static_pointer_cast<lar::Bool>(m_option["cuda"])->set_value(FLAGS_cuda); | |||||
#endif | |||||
std::static_pointer_cast<lar::Bool>(m_option["cpu_default"]) | |||||
->set_value(FLAGS_cpu_default); | |||||
std::static_pointer_cast<lar::NumberInt32>(m_option["multithread"]) | |||||
->set_value(FLAGS_multithread); | |||||
std::static_pointer_cast<lar::NumberInt32>(m_option["multithread_default"]) | |||||
->set_value(FLAGS_multithread_default); | |||||
std::static_pointer_cast<lar::String>(m_option["multi_thread_core_ids"]) | |||||
->set_value(FLAGS_multi_thread_core_ids); | |||||
} | |||||
bool XPUDeviceOption::m_valid; | |||||
bool XPUDeviceOption::is_valid() { | bool XPUDeviceOption::is_valid() { | ||||
bool ret = FLAGS_cpu || FLAGS_cpu_default; | bool ret = FLAGS_cpu || FLAGS_cpu_default; | ||||
#if MGB_CUDA | |||||
#if LITE_WITH_CUDA | |||||
ret = ret || FLAGS_cuda; | ret = ret || FLAGS_cuda; | ||||
#endif | #endif | ||||
ret = ret || FLAGS_multithread >= 0; | ret = ret || FLAGS_multithread >= 0; | ||||
ret = ret || FLAGS_multithread_default >= 0; | ret = ret || FLAGS_multithread_default >= 0; | ||||
ret = ret || !FLAGS_multi_thread_core_ids.empty(); | ret = ret || !FLAGS_multi_thread_core_ids.empty(); | ||||
return ret; | |||||
return ret || m_valid; | |||||
} | } | ||||
std::shared_ptr<OptionBase> XPUDeviceOption::create_option() { | std::shared_ptr<OptionBase> XPUDeviceOption::create_option() { | ||||
@@ -190,11 +204,46 @@ std::shared_ptr<OptionBase> XPUDeviceOption::create_option() { | |||||
void XPUDeviceOption::config_model( | void XPUDeviceOption::config_model( | ||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | ||||
enable_cpu = std::static_pointer_cast<lar::Bool>(m_option["cpu"])->get_value(); | |||||
#if LITE_WITH_CUDA | |||||
enable_cuda = std::static_pointer_cast<lar::Bool>(m_option["cuda"])->get_value(); | |||||
#endif | |||||
enable_cpu_default = | |||||
std::static_pointer_cast<lar::Bool>(m_option["cpu_default"])->get_value(); | |||||
int32_t num_of_thread = | |||||
std::static_pointer_cast<lar::NumberInt32>(m_option["multithread"]) | |||||
->get_value(); | |||||
enable_multithread = num_of_thread >= 0; | |||||
num_of_thread = | |||||
std::static_pointer_cast<lar::NumberInt32>(m_option["multithread_default"]) | |||||
->get_value(); | |||||
enable_multithread_default = num_of_thread >= 0; | |||||
thread_num = num_of_thread >= 0 ? num_of_thread : 0; | |||||
std::string core_id_str = | |||||
std::static_pointer_cast<lar::String>(m_option["multi_thread_core_ids"]) | |||||
->get_value(); | |||||
if (!core_id_str.empty()) { | |||||
mgb_assert( | |||||
enable_multithread || enable_multithread_default, | |||||
"core ids should be set after --multithread or --multithread-default"); | |||||
std::stringstream id_stream(core_id_str); | |||||
std::string id; | |||||
size_t thread_cnt = 0; | |||||
while (getline(id_stream, id, ',')) { | |||||
thread_cnt++; | |||||
core_ids.push_back(atoi(id.c_str())); | |||||
} | |||||
mgb_assert( | |||||
thread_cnt == thread_num, | |||||
"core ids number should be same with thread number set before"); | |||||
enable_set_core_ids = true; | |||||
} | |||||
CONFIG_MODEL_FUN; | CONFIG_MODEL_FUN; | ||||
} | } | ||||
///////////////////////// xpu gflags //////////////////////////// | ///////////////////////// xpu gflags //////////////////////////// | ||||
DEFINE_bool(cpu, false, "set CPU device as running device"); | DEFINE_bool(cpu, false, "set CPU device as running device"); | ||||
#if MGB_CUDA || LITE_WITH_CUDA | |||||
#if LITE_WITH_CUDA | |||||
DEFINE_bool(cuda, false, "set CUDA device as running device "); | DEFINE_bool(cuda, false, "set CUDA device as running device "); | ||||
#endif | #endif | ||||
DEFINE_bool(cpu_default, false, "set running device as CPU device with inplace mode"); | DEFINE_bool(cpu_default, false, "set running device as CPU device with inplace mode"); | ||||
@@ -204,3 +253,4 @@ DEFINE_int32( | |||||
"set multithread device as running device with inplace mode"); | "set multithread device as running device with inplace mode"); | ||||
DEFINE_string(multi_thread_core_ids, "", "set multithread core id"); | DEFINE_string(multi_thread_core_ids, "", "set multithread core id"); | ||||
REGIST_OPTION_CREATOR(xpu_device, lar::XPUDeviceOption::create_option); | REGIST_OPTION_CREATOR(xpu_device, lar::XPUDeviceOption::create_option); | ||||
REGIST_OPTION_VALIDATER(xpu_device, lar::XPUDeviceOption::set_valid); |
@@ -1,18 +1,10 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/device_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include "models/model.h" | #include "models/model.h" | ||||
#include "option_base.h" | #include "option_base.h" | ||||
DECLARE_bool(cpu); | DECLARE_bool(cpu); | ||||
#if MGB_CUDA || LITE_WITH_CUDA | |||||
#if LITE_WITH_CUDA | |||||
DECLARE_bool(cuda); | DECLARE_bool(cuda); | ||||
#endif | #endif | ||||
DECLARE_bool(cpu_default); | DECLARE_bool(cpu_default); | ||||
@@ -29,12 +21,16 @@ public: | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | ||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
static void set_valid(bool val) { m_valid = val; } | |||||
OptionValMap* get_option() override { return &m_option; } | |||||
private: | private: | ||||
XPUDeviceOption(); | XPUDeviceOption(); | ||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
bool enable_cpu; | bool enable_cpu; | ||||
#if MGB_CUDA || LITE_WITH_CUDA | |||||
#if LITE_WITH_CUDA | |||||
bool enable_cuda; | bool enable_cuda; | ||||
#endif | #endif | ||||
bool enable_cpu_default; | bool enable_cpu_default; | ||||
@@ -44,5 +40,8 @@ private: | |||||
size_t thread_num; | size_t thread_num; | ||||
std::vector<int> core_ids; | std::vector<int> core_ids; | ||||
std::string m_option_name; | std::string m_option_name; | ||||
static bool m_valid; | |||||
OptionValMap m_option; | |||||
}; | }; | ||||
} // namespace lar | } // namespace lar |
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/extern_c_opr_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "extern_c_opr_options.h" | #include "extern_c_opr_options.h" | ||||
#include "megbrain/utils/debug.h" | #include "megbrain/utils/debug.h" | ||||
#include "misc.h" | #include "misc.h" | ||||
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/extern_c_opr_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include "megbrain/graph/extern_copr_api.h" | #include "megbrain/graph/extern_copr_api.h" | ||||
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/fastrun_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#if defined(_WIN32) | #if defined(_WIN32) | ||||
@@ -153,7 +144,7 @@ void FastRunOption::config_model_internel<ModelMdl>( | |||||
} // namespace lar | } // namespace lar | ||||
using namespace lar; | using namespace lar; | ||||
bool FastRunOption::m_valid; | |||||
FastRunOption::FastRunOption() { | FastRunOption::FastRunOption() { | ||||
m_option_name = "fastrun"; | m_option_name = "fastrun"; | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
@@ -164,6 +155,25 @@ FastRunOption::FastRunOption() { | |||||
enable_reproducible = FLAGS_reproducible; | enable_reproducible = FLAGS_reproducible; | ||||
m_fast_run_cache = FLAGS_fast_run_algo_policy; | m_fast_run_cache = FLAGS_fast_run_algo_policy; | ||||
share_batch_size = FLAGS_fast_run_shared_batch_size; | share_batch_size = FLAGS_fast_run_shared_batch_size; | ||||
m_option = { | |||||
#if MGB_ENABLE_FASTRUN | |||||
{"fast_run", lar::Bool::make(false)}, | |||||
{"full_run", lar::Bool::make(false)}, | |||||
#endif | |||||
{"binary_equal_between_batch", lar::Bool::make(false)}, | |||||
{"reproducible", lar::Bool::make(false)} | |||||
}; | |||||
#if MGB_ENABLE_FASTRUN | |||||
std::static_pointer_cast<lar::Bool>(m_option["fast_run"]) | |||||
->set_value(FLAGS_fast_run); | |||||
std::static_pointer_cast<lar::Bool>(m_option["full_run"]) | |||||
->set_value(FLAGS_full_run); | |||||
#endif | |||||
std::static_pointer_cast<lar::Bool>(m_option["binary_equal_between_batch"]) | |||||
->set_value(FLAGS_binary_equal_between_batch); | |||||
std::static_pointer_cast<lar::Bool>(m_option["reproducible"]) | |||||
->set_value(FLAGS_reproducible); | |||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
//! while fastrun cache file path is not empty and can't be accessed | //! while fastrun cache file path is not empty and can't be accessed | ||||
if (!m_fast_run_cache.empty() && access(m_fast_run_cache.c_str(), F_OK)) { | if (!m_fast_run_cache.empty() && access(m_fast_run_cache.c_str(), F_OK)) { | ||||
@@ -191,7 +201,7 @@ bool FastRunOption::is_valid() { | |||||
ret = ret || FLAGS_reproducible; | ret = ret || FLAGS_reproducible; | ||||
ret = ret || FLAGS_fast_run_algo_policy.size() > 0; | ret = ret || FLAGS_fast_run_algo_policy.size() > 0; | ||||
return ret; | |||||
return ret || m_valid; | |||||
} | } | ||||
std::shared_ptr<OptionBase> FastRunOption::create_option() { | std::shared_ptr<OptionBase> FastRunOption::create_option() { | ||||
@@ -205,6 +215,21 @@ std::shared_ptr<OptionBase> FastRunOption::create_option() { | |||||
void FastRunOption::config_model( | void FastRunOption::config_model( | ||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | ||||
#if MGB_ENABLE_FASTRUN | |||||
enable_fast_run = | |||||
std::static_pointer_cast<lar::Bool>(m_option["fast_run"])->get_value(); | |||||
enable_full_run = | |||||
std::static_pointer_cast<lar::Bool>(m_option["full_run"])->get_value(); | |||||
mgb_throw_if( | |||||
enable_fast_run && enable_full_run, mgb::AssertionError, | |||||
"invalid options of both fast-run and full-run"); | |||||
#endif | |||||
batch_binary_equal = | |||||
std::static_pointer_cast<lar::Bool>(m_option["binary_equal_between_batch"]) | |||||
->get_value(); | |||||
enable_reproducible = | |||||
std::static_pointer_cast<lar::Bool>(m_option["reproducible"])->get_value(); | |||||
CONFIG_MODEL_FUN; | CONFIG_MODEL_FUN; | ||||
} | } | ||||
@@ -228,4 +253,5 @@ DEFINE_bool( | |||||
DEFINE_uint32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun"); | DEFINE_uint32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun"); | ||||
DEFINE_string(fast_run_algo_policy, "", "fast-run cache path."); | DEFINE_string(fast_run_algo_policy, "", "fast-run cache path."); | ||||
REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option); | |||||
REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option); | |||||
REGIST_OPTION_VALIDATER(fastrun, lar::FastRunOption::set_valid); |
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/fastrun_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
@@ -38,6 +29,10 @@ public: | |||||
//! get options name for quickly search | //! get options name for quickly search | ||||
std::string option_name() const override { return m_option_name; } | std::string option_name() const override { return m_option_name; } | ||||
static void set_valid(bool val) { m_valid = val; } | |||||
OptionValMap* get_option() override { return &m_option; } | |||||
private: | private: | ||||
FastRunOption(); | FastRunOption(); | ||||
//! config template for different model | //! config template for different model | ||||
@@ -53,5 +48,8 @@ private: | |||||
size_t share_batch_size; //! fast run strategy share batch size setting | size_t share_batch_size; //! fast run strategy share batch size setting | ||||
std::string m_fast_run_cache; //! fast run cache file path | std::string m_fast_run_cache; //! fast run cache file path | ||||
std::string m_option_name; //! option name | std::string m_option_name; //! option name | ||||
static bool m_valid; | |||||
OptionValMap m_option; | |||||
}; | }; | ||||
} // namespace lar | } // namespace lar |
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/io_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include <map> | #include <map> | ||||
#include "helpers/data_parser.h" | #include "helpers/data_parser.h" | ||||
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/io_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include "helpers/outdumper.h" | #include "helpers/outdumper.h" | ||||
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/layout_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include "misc.h" | #include "misc.h" | ||||
@@ -24,7 +15,7 @@ void LayoutOption::config_model_internel<ModelLite>( | |||||
model->get_config().options.enable_##layout = true; \ | model->get_config().options.enable_##layout = true; \ | ||||
break; | break; | ||||
switch (option_flag) { | |||||
switch (m_option_flag) { | |||||
case OptLayoutType::NCHW4: | case OptLayoutType::NCHW4: | ||||
ENABLE_LAYOUT(nchw4) | ENABLE_LAYOUT(nchw4) | ||||
@@ -59,13 +50,12 @@ template <> | |||||
void lar::LayoutOption::config_model_internel<ModelMdl>( | void lar::LayoutOption::config_model_internel<ModelMdl>( | ||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | ||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
mgb_log_debug("mdl layout config start"); | |||||
#define ENABLE_LAYOUT(layout) \ | #define ENABLE_LAYOUT(layout) \ | ||||
mgb_log_warn("enable " #layout " optimization"); \ | mgb_log_warn("enable " #layout " optimization"); \ | ||||
model->get_mdl_config().comp_graph->options().graph_opt.enable_##layout(); \ | model->get_mdl_config().comp_graph->options().graph_opt.enable_##layout(); \ | ||||
break; | break; | ||||
switch (option_flag) { | |||||
switch (m_option_flag) { | |||||
case OptLayoutType::NCHW4: | case OptLayoutType::NCHW4: | ||||
ENABLE_LAYOUT(nchw4) | ENABLE_LAYOUT(nchw4) | ||||
@@ -93,7 +83,6 @@ void lar::LayoutOption::config_model_internel<ModelMdl>( | |||||
default: | default: | ||||
break; | break; | ||||
} | } | ||||
mgb_log_debug("mdl layout config end"); | |||||
#undef ENABLE_LAYOUT | #undef ENABLE_LAYOUT | ||||
} | } | ||||
@@ -101,48 +90,68 @@ void lar::LayoutOption::config_model_internel<ModelMdl>( | |||||
} // namespace lar | } // namespace lar | ||||
using namespace lar; | using namespace lar; | ||||
OptLayoutType LayoutOption::option_flag; | |||||
bool LayoutOption::m_valid; | |||||
LayoutOption::LayoutOption() { | LayoutOption::LayoutOption() { | ||||
m_option_name = "layout"; | m_option_name = "layout"; | ||||
m_option_flag = static_cast<OptLayoutType>(0); | |||||
m_option = { | |||||
{"enable_nchw4", lar::Bool::make(false)}, | |||||
{"enable_chwn4", lar::Bool::make(false)}, | |||||
{"enable_nchw44", lar::Bool::make(false)}, | |||||
{"enable_nchw88", lar::Bool::make(false)}, | |||||
{"enable_nchw32", lar::Bool::make(false)}, | |||||
{"enable_nchw64", lar::Bool::make(false)}, | |||||
{"enable_nhwcd4", lar::Bool::make(false)}, | |||||
{"enable_nchw44_dot", lar::Bool::make(false)}, | |||||
}; | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw4"]) | |||||
->set_value(FLAGS_enable_nchw4); | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_chwn4"]) | |||||
->set_value(FLAGS_enable_chwn4); | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44"]) | |||||
->set_value(FLAGS_enable_nchw44); | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw88"]) | |||||
->set_value(FLAGS_enable_nchw88); | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw32"]) | |||||
->set_value(FLAGS_enable_nchw32); | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw64"]) | |||||
->set_value(FLAGS_enable_nchw64); | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_nhwcd4"]) | |||||
->set_value(FLAGS_enable_nhwcd4); | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44_dot"]) | |||||
->set_value(FLAGS_enable_nchw44_dot); | |||||
} | } | ||||
bool LayoutOption::is_valid() { | bool LayoutOption::is_valid() { | ||||
size_t valid_flag = 0; | size_t valid_flag = 0; | ||||
if (FLAGS_enable_nchw4) { | if (FLAGS_enable_nchw4) { | ||||
valid_flag = valid_flag | (1 << 0); | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW4); | |||||
} | } | ||||
if (FLAGS_enable_chwn4) { | if (FLAGS_enable_chwn4) { | ||||
valid_flag = valid_flag | (1 << 1); | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::CHWN4); | |||||
} | } | ||||
if (FLAGS_enable_nchw44) { | if (FLAGS_enable_nchw44) { | ||||
valid_flag = valid_flag | (1 << 2); | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44); | |||||
} | } | ||||
if (FLAGS_enable_nchw88) { | if (FLAGS_enable_nchw88) { | ||||
valid_flag = valid_flag | (1 << 3); | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW88); | |||||
} | } | ||||
if (FLAGS_enable_nchw32) { | if (FLAGS_enable_nchw32) { | ||||
valid_flag = valid_flag | (1 << 4); | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW32); | |||||
} | } | ||||
if (FLAGS_enable_nchw64) { | if (FLAGS_enable_nchw64) { | ||||
valid_flag = valid_flag | (1 << 5); | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW64); | |||||
} | } | ||||
if (FLAGS_enable_nhwcd4) { | if (FLAGS_enable_nhwcd4) { | ||||
valid_flag = valid_flag | (1 << 6); | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NHWCD4); | |||||
} | } | ||||
if (FLAGS_enable_nchw44_dot) { | if (FLAGS_enable_nchw44_dot) { | ||||
valid_flag = valid_flag | (1 << 7); | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44_DOT); | |||||
} | } | ||||
//! only one flag is valid | |||||
bool ret = valid_flag && !(valid_flag & (valid_flag - 1)); | bool ret = valid_flag && !(valid_flag & (valid_flag - 1)); | ||||
if (ret) { | |||||
option_flag = static_cast<OptLayoutType>(valid_flag); | |||||
} else { | |||||
option_flag = static_cast<OptLayoutType>(0); | |||||
} | |||||
return ret; | |||||
return ret | m_valid; | |||||
}; | }; | ||||
std::shared_ptr<OptionBase> LayoutOption::create_option() { | std::shared_ptr<OptionBase> LayoutOption::create_option() { | ||||
@@ -156,6 +165,37 @@ std::shared_ptr<OptionBase> LayoutOption::create_option() { | |||||
void LayoutOption::config_model( | void LayoutOption::config_model( | ||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | ||||
size_t valid_flag = 0; | |||||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw4"])->get_value()) { | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW4); | |||||
} | |||||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_chwn4"])->get_value()) { | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::CHWN4); | |||||
} | |||||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44"])->get_value()) { | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44); | |||||
} | |||||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw88"])->get_value()) { | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW88); | |||||
} | |||||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw32"])->get_value()) { | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW32); | |||||
} | |||||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw64"])->get_value()) { | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW64); | |||||
} | |||||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nhwcd4"])->get_value()) { | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NHWCD4); | |||||
} | |||||
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44_dot"]) | |||||
->get_value()) { | |||||
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44_DOT); | |||||
} | |||||
mgb_throw_if( | |||||
valid_flag && (valid_flag & (valid_flag - 1)), mgb::AssertionError, | |||||
"invalid options of layout transform 0x%lx", valid_flag); | |||||
m_option_flag = static_cast<OptLayoutType>(valid_flag); | |||||
CONFIG_MODEL_FUN; | CONFIG_MODEL_FUN; | ||||
} | } | ||||
@@ -168,4 +208,5 @@ DEFINE_bool(enable_nchw64, false, "enable nchw64 layout optimization!!"); | |||||
DEFINE_bool(enable_nhwcd4, false, "enable nhwcd4 layout optimization!!"); | DEFINE_bool(enable_nhwcd4, false, "enable nhwcd4 layout optimization!!"); | ||||
DEFINE_bool(enable_nchw44_dot, false, "enable nchw444-dot layout optimization!!"); | DEFINE_bool(enable_nchw44_dot, false, "enable nchw444-dot layout optimization!!"); | ||||
REGIST_OPTION_CREATOR(layout, lar::LayoutOption::create_option); | |||||
REGIST_OPTION_CREATOR(layout, lar::LayoutOption::create_option); | |||||
REGIST_OPTION_VALIDATER(layout, lar::LayoutOption::set_valid); |
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/layout_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
@@ -42,6 +33,10 @@ public: | |||||
//! get option name | //! get option name | ||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
static void set_valid(bool val) { m_valid = val; } | |||||
OptionValMap* get_option() override { return &m_option; } | |||||
private: | private: | ||||
//! Constructor | //! Constructor | ||||
LayoutOption(); | LayoutOption(); | ||||
@@ -50,7 +45,9 @@ private: | |||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
static OptLayoutType option_flag; | |||||
OptLayoutType m_option_flag; | |||||
std::string m_option_name; | std::string m_option_name; | ||||
static bool m_valid; | |||||
OptionValMap m_option; | |||||
}; | }; | ||||
} // namespace lar | } // namespace lar |
@@ -1,11 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/layout_trans_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "layout_trans_options.h" | #include "layout_trans_options.h" | ||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
@@ -19,6 +11,7 @@ void GoptLayoutOption::config_model_internel<ModelLite>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { | ||||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | ||||
if (m_layout_transform) { | if (m_layout_transform) { | ||||
LITE_WARN("using global layout transform optimization\n"); | |||||
if (m_layout_transform_target == | if (m_layout_transform_target == | ||||
mgb::gopt::GraphTuningOptions::Target::CPU) { | mgb::gopt::GraphTuningOptions::Target::CPU) { | ||||
model->get_config().device_type = LiteDeviceType::LITE_CPU; | model->get_config().device_type = LiteDeviceType::LITE_CPU; | ||||
@@ -48,7 +41,47 @@ void GoptLayoutOption::config_model_internel<ModelMdl>( | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | ||||
if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { | if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { | ||||
if (m_layout_transform) { | if (m_layout_transform) { | ||||
mgb_log_warn("using global layout transform optimization\n"); | |||||
auto&& load_result = model->get_mdl_load_result(); | auto&& load_result = model->get_mdl_load_result(); | ||||
for (auto&& item : load_result.output_var_list) { | |||||
if (item.shape()[0] > 1) { | |||||
mgb_log_warn( | |||||
" model may be dumped with multi batch and will cost lots " | |||||
"of time to profile during global layout transform!!!\n"); | |||||
} | |||||
} | |||||
//! update output varlist when input shape maybe change(some pass excution | |||||
//! time depends on the shape of init input) | |||||
mgb::thin_hash_table::ThinHashMap<mgb::cg::SymbolVar, mgb::cg::SymbolVar> | |||||
varmap; | |||||
mgb::cg::DepOprIter dep([&](mgb::cg::OperatorNodeBase* opr) { | |||||
if (auto h2d = opr->try_cast_final<mgb::opr::Host2DeviceCopy>()) { | |||||
auto param = h2d->param(); | |||||
mgb::TensorShape new_shape = h2d->host_data()->shape(); | |||||
std::shared_ptr<mgb::HostTensorND> new_tensor = | |||||
std::make_shared<mgb::HostTensorND>( | |||||
h2d->host_data()->comp_node(), new_shape, | |||||
h2d->host_data()->dtype()); | |||||
new_tensor->only_reset_raw_storage(h2d->host_data()->storage()); | |||||
auto h2d_opr = mgb::opr::Host2DeviceCopy::make( | |||||
*h2d->owner_graph(), new_tensor, param, h2d->config()); | |||||
varmap[h2d->output(0)] = h2d_opr; | |||||
} | |||||
}); | |||||
for (auto&& i : load_result.output_var_list) | |||||
dep.add(i); | |||||
if (!varmap.empty()) { | |||||
auto output_vars = | |||||
mgb::cg::replace_vars(load_result.output_var_list, varmap); | |||||
for (size_t i = 0; i < load_result.output_var_list.size(); ++i) { | |||||
output_vars[i].rename( | |||||
load_result.output_var_list[i].node()->name()); | |||||
} | |||||
load_result.output_var_list = output_vars; | |||||
} | |||||
load_result.output_var_list = mgb::gopt::layout_transform( | load_result.output_var_list = mgb::gopt::layout_transform( | ||||
load_result.output_var_list, m_layout_transform_target); | load_result.output_var_list, m_layout_transform_target); | ||||
@@ -98,7 +131,7 @@ void GoptLayoutOption::config_model_internel<ModelMdl>( | |||||
} // namespace lar | } // namespace lar | ||||
using namespace lar; | using namespace lar; | ||||
bool GoptLayoutOption::m_valid; | |||||
GoptLayoutOption::GoptLayoutOption() { | GoptLayoutOption::GoptLayoutOption() { | ||||
m_option_name = "gopt_layout"; | m_option_name = "gopt_layout"; | ||||
if (FLAGS_layout_transform != "cpu" | if (FLAGS_layout_transform != "cpu" | ||||
@@ -122,6 +155,12 @@ GoptLayoutOption::GoptLayoutOption() { | |||||
#endif | #endif | ||||
} | } | ||||
m_layout_transform_dump_file = FLAGS_layout_transform_dump; | m_layout_transform_dump_file = FLAGS_layout_transform_dump; | ||||
m_option = { | |||||
{"layout_transform", lar::String::make("")}, | |||||
}; | |||||
std::static_pointer_cast<lar::String>(m_option["layout_transform"]) | |||||
->set_value(FLAGS_layout_transform); | |||||
} | } | ||||
bool GoptLayoutOption::is_valid() { | bool GoptLayoutOption::is_valid() { | ||||
@@ -143,7 +182,7 @@ bool GoptLayoutOption::is_valid() { | |||||
} | } | ||||
} | } | ||||
ret = ret || !FLAGS_layout_transform_dump.empty(); | ret = ret || !FLAGS_layout_transform_dump.empty(); | ||||
return ret; | |||||
return ret || m_valid; | |||||
} | } | ||||
std::shared_ptr<OptionBase> GoptLayoutOption::create_option() { | std::shared_ptr<OptionBase> GoptLayoutOption::create_option() { | ||||
@@ -157,6 +196,26 @@ std::shared_ptr<OptionBase> GoptLayoutOption::create_option() { | |||||
void GoptLayoutOption::config_model( | void GoptLayoutOption::config_model( | ||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | ||||
auto value = std::static_pointer_cast<lar::String>(m_option["layout_transform"]) | |||||
->get_value(); | |||||
if (value.empty()) { | |||||
return; | |||||
} | |||||
if (value == "cpu") { | |||||
m_layout_transform = true; | |||||
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU; | |||||
} | |||||
#if LITE_WITH_CUDA | |||||
else if (value == "cuda") { | |||||
m_layout_transform = true; | |||||
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA; | |||||
} | |||||
#endif | |||||
else { | |||||
mgb_throw( | |||||
mgb::AssertionError, "invalid options of global layout transform %s", | |||||
value.c_str()); | |||||
} | |||||
CONFIG_MODEL_FUN; | CONFIG_MODEL_FUN; | ||||
} | } | ||||
@@ -175,3 +234,4 @@ DEFINE_string( | |||||
"file path."); | "file path."); | ||||
REGIST_OPTION_CREATOR(gopt_layout, lar::GoptLayoutOption::create_option); | REGIST_OPTION_CREATOR(gopt_layout, lar::GoptLayoutOption::create_option); | ||||
REGIST_OPTION_VALIDATER(gopt_layout, lar::GoptLayoutOption::set_valid); |
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/layout_trans_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
@@ -32,6 +23,10 @@ public: | |||||
//! get options name for quickly search | //! get options name for quickly search | ||||
std::string option_name() const override { return m_option_name; } | std::string option_name() const override { return m_option_name; } | ||||
static void set_valid(bool val) { m_valid = val; } | |||||
OptionValMap* get_option() override { return &m_option; } | |||||
private: | private: | ||||
GoptLayoutOption(); | GoptLayoutOption(); | ||||
//! config template for different model | //! config template for different model | ||||
@@ -41,5 +36,7 @@ private: | |||||
std::string m_option_name; | std::string m_option_name; | ||||
std::string m_layout_transform_dump_file; | std::string m_layout_transform_dump_file; | ||||
mgb::gopt::GraphTuningOptions::Target m_layout_transform_target; | mgb::gopt::GraphTuningOptions::Target m_layout_transform_target; | ||||
static bool m_valid; | |||||
OptionValMap m_option; | |||||
}; | }; | ||||
} // namespace lar | } // namespace lar |
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/model_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "model_options.h" | #include "model_options.h" | ||||
#include "device_options.h" | #include "device_options.h" | ||||
#include "lite/pack_model.h" | #include "lite/pack_model.h" | ||||
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/model_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include "megbrain/graph/operator_node.h" | #include "megbrain/graph/operator_node.h" | ||||
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/optimize_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "megbrain/gopt/inference.h" | #include "megbrain/gopt/inference.h" | ||||
#if MGB_ENABLE_TENSOR_RT | #if MGB_ENABLE_TENSOR_RT | ||||
#include "megbrain/tensorrt/tensorrt_engine_cache.h" | #include "megbrain/tensorrt/tensorrt_engine_cache.h" | ||||
@@ -43,15 +34,18 @@ void FusePreprocessOption::config_model_internel<ModelMdl>( | |||||
} | } | ||||
} // namespace lar | } // namespace lar | ||||
using namespace lar; | using namespace lar; | ||||
bool FusePreprocessOption::m_valid; | |||||
FusePreprocessOption::FusePreprocessOption() { | FusePreprocessOption::FusePreprocessOption() { | ||||
m_option_name = "fuse_preprocess"; | m_option_name = "fuse_preprocess"; | ||||
enable_fuse_preprocess = FLAGS_enable_fuse_preprocess; | enable_fuse_preprocess = FLAGS_enable_fuse_preprocess; | ||||
m_option = {{"enable_fuse_preprocess", lar::Bool::make(false)}}; | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_fuse_preprocess"]) | |||||
->set_value(FLAGS_enable_fuse_preprocess); | |||||
} | } | ||||
bool FusePreprocessOption::is_valid() { | bool FusePreprocessOption::is_valid() { | ||||
bool ret = FLAGS_enable_fuse_preprocess; | bool ret = FLAGS_enable_fuse_preprocess; | ||||
return ret; | |||||
return ret || m_valid; | |||||
} | } | ||||
std::shared_ptr<OptionBase> FusePreprocessOption::create_option() { | std::shared_ptr<OptionBase> FusePreprocessOption::create_option() { | ||||
@@ -65,10 +59,14 @@ std::shared_ptr<OptionBase> FusePreprocessOption::create_option() { | |||||
void FusePreprocessOption::config_model( | void FusePreprocessOption::config_model( | ||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | ||||
enable_fuse_preprocess = | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_fuse_preprocess"]) | |||||
->get_value(); | |||||
CONFIG_MODEL_FUN; | CONFIG_MODEL_FUN; | ||||
} | } | ||||
///////////////////////// weight preprocess optimize options /////////////// | ///////////////////////// weight preprocess optimize options /////////////// | ||||
bool WeightPreprocessOption::m_valid; | |||||
namespace lar { | namespace lar { | ||||
template <> | template <> | ||||
void WeightPreprocessOption::config_model_internel<ModelLite>( | void WeightPreprocessOption::config_model_internel<ModelLite>( | ||||
@@ -97,11 +95,14 @@ void WeightPreprocessOption::config_model_internel<ModelMdl>( | |||||
WeightPreprocessOption::WeightPreprocessOption() { | WeightPreprocessOption::WeightPreprocessOption() { | ||||
m_option_name = "weight_preprocess"; | m_option_name = "weight_preprocess"; | ||||
weight_preprocess = FLAGS_weight_preprocess; | weight_preprocess = FLAGS_weight_preprocess; | ||||
m_option = {{"weight_preprocess", lar::Bool::make(false)}}; | |||||
std::static_pointer_cast<lar::Bool>(m_option["weight_preprocess"]) | |||||
->set_value(FLAGS_weight_preprocess); | |||||
} | } | ||||
bool WeightPreprocessOption::is_valid() { | bool WeightPreprocessOption::is_valid() { | ||||
bool ret = FLAGS_weight_preprocess; | bool ret = FLAGS_weight_preprocess; | ||||
return ret; | |||||
return ret || m_valid; | |||||
} | } | ||||
std::shared_ptr<OptionBase> WeightPreprocessOption::create_option() { | std::shared_ptr<OptionBase> WeightPreprocessOption::create_option() { | ||||
@@ -115,10 +116,14 @@ std::shared_ptr<OptionBase> WeightPreprocessOption::create_option() { | |||||
void WeightPreprocessOption::config_model( | void WeightPreprocessOption::config_model( | ||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | ||||
weight_preprocess = | |||||
std::static_pointer_cast<lar::Bool>(m_option["weight_preprocess"]) | |||||
->get_value(); | |||||
CONFIG_MODEL_FUN; | CONFIG_MODEL_FUN; | ||||
} | } | ||||
///// fuse conv bias and nonlinear activation opr optimize options //////// | ///// fuse conv bias and nonlinear activation opr optimize options //////// | ||||
bool FuseConvBiasNonlinearOption::m_valid; | |||||
namespace lar { | namespace lar { | ||||
template <> | template <> | ||||
void FuseConvBiasNonlinearOption::config_model_internel<ModelLite>( | void FuseConvBiasNonlinearOption::config_model_internel<ModelLite>( | ||||
@@ -145,13 +150,16 @@ void FuseConvBiasNonlinearOption::config_model_internel<ModelMdl>( | |||||
} // namespace lar | } // namespace lar | ||||
FuseConvBiasNonlinearOption::FuseConvBiasNonlinearOption() { | FuseConvBiasNonlinearOption::FuseConvBiasNonlinearOption() { | ||||
m_option_name = "fuse_conv_bias_nonlinear"; | |||||
m_option_name = "fuse_conv_bias_nonlinearity"; | |||||
enable_fuse_conv_bias_nonlinearity = FLAGS_enable_fuse_conv_bias_nonlinearity; | enable_fuse_conv_bias_nonlinearity = FLAGS_enable_fuse_conv_bias_nonlinearity; | ||||
m_option = {{"enable_fuse_conv_bias_nonlinearity", lar::Bool::make(false)}}; | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_fuse_conv_bias_nonlinearity"]) | |||||
->set_value(FLAGS_enable_fuse_conv_bias_nonlinearity); | |||||
} | } | ||||
bool FuseConvBiasNonlinearOption::is_valid() { | bool FuseConvBiasNonlinearOption::is_valid() { | ||||
bool ret = FLAGS_enable_fuse_conv_bias_nonlinearity; | bool ret = FLAGS_enable_fuse_conv_bias_nonlinearity; | ||||
return ret; | |||||
return ret || m_valid; | |||||
} | } | ||||
std::shared_ptr<OptionBase> FuseConvBiasNonlinearOption::create_option() { | std::shared_ptr<OptionBase> FuseConvBiasNonlinearOption::create_option() { | ||||
@@ -166,10 +174,15 @@ std::shared_ptr<OptionBase> FuseConvBiasNonlinearOption::create_option() { | |||||
void FuseConvBiasNonlinearOption::config_model( | void FuseConvBiasNonlinearOption::config_model( | ||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | ||||
enable_fuse_conv_bias_nonlinearity = | |||||
std::static_pointer_cast<lar::Bool>( | |||||
m_option["enable_fuse_conv_bias_nonlinearity"]) | |||||
->get_value(); | |||||
CONFIG_MODEL_FUN; | CONFIG_MODEL_FUN; | ||||
} | } | ||||
///////////////////////// fuse and preprocess optimize options /////////////// | ///////////////////////// fuse and preprocess optimize options /////////////// | ||||
bool FuseConvBiasElemwiseAddOption::m_valid; | |||||
namespace lar { | namespace lar { | ||||
template <> | template <> | ||||
void FuseConvBiasElemwiseAddOption::config_model_internel<ModelLite>( | void FuseConvBiasElemwiseAddOption::config_model_internel<ModelLite>( | ||||
@@ -198,13 +211,16 @@ void FuseConvBiasElemwiseAddOption::config_model_internel<ModelMdl>( | |||||
} // namespace lar | } // namespace lar | ||||
FuseConvBiasElemwiseAddOption::FuseConvBiasElemwiseAddOption() { | FuseConvBiasElemwiseAddOption::FuseConvBiasElemwiseAddOption() { | ||||
m_option_name = "fuse_conv_bias_z"; | |||||
m_option_name = "fuse_conv_bias_with_z"; | |||||
enable_fuse_conv_bias_with_z = FLAGS_enable_fuse_conv_bias_with_z; | enable_fuse_conv_bias_with_z = FLAGS_enable_fuse_conv_bias_with_z; | ||||
m_option = {{"enable_fuse_conv_bias_with_z", lar::Bool::make(false)}}; | |||||
std::static_pointer_cast<lar::Bool>(m_option["enable_fuse_conv_bias_with_z"]) | |||||
->set_value(FLAGS_enable_fuse_conv_bias_with_z); | |||||
} | } | ||||
bool FuseConvBiasElemwiseAddOption::is_valid() { | bool FuseConvBiasElemwiseAddOption::is_valid() { | ||||
bool ret = FLAGS_enable_fuse_conv_bias_with_z; | bool ret = FLAGS_enable_fuse_conv_bias_with_z; | ||||
return ret; | |||||
return ret || m_valid; | |||||
} | } | ||||
std::shared_ptr<OptionBase> FuseConvBiasElemwiseAddOption::create_option() { | std::shared_ptr<OptionBase> FuseConvBiasElemwiseAddOption::create_option() { | ||||
@@ -219,10 +235,14 @@ std::shared_ptr<OptionBase> FuseConvBiasElemwiseAddOption::create_option() { | |||||
void FuseConvBiasElemwiseAddOption::config_model( | void FuseConvBiasElemwiseAddOption::config_model( | ||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | ||||
enable_fuse_conv_bias_with_z = std::static_pointer_cast<lar::Bool>( | |||||
m_option["enable_fuse_conv_bias_with_z"]) | |||||
->get_value(); | |||||
CONFIG_MODEL_FUN; | CONFIG_MODEL_FUN; | ||||
} | } | ||||
///////////////////////// graph retrict options ///////////////////////// | ///////////////////////// graph retrict options ///////////////////////// | ||||
bool GraphRecordOption::m_valid; | |||||
namespace lar { | namespace lar { | ||||
template <> | template <> | ||||
void GraphRecordOption::config_model_internel<ModelLite>( | void GraphRecordOption::config_model_internel<ModelLite>( | ||||
@@ -299,6 +319,23 @@ GraphRecordOption::GraphRecordOption() { | |||||
if (FLAGS_record_comp_seq2) { | if (FLAGS_record_comp_seq2) { | ||||
m_record_comp_seq = 2; | m_record_comp_seq = 2; | ||||
} | } | ||||
m_option = { | |||||
{"record_comp_seq", lar::Bool::make(false)}, | |||||
{"record_comp_seq2", lar::Bool::make(false)}, | |||||
{"const_shape", lar::Bool::make(false)}, | |||||
{"fake_first", lar::Bool::make(false)}, | |||||
{"no_sanity_check", lar::Bool::make(false)}}; | |||||
std::static_pointer_cast<lar::Bool>(m_option["const_shape"]) | |||||
->set_value(FLAGS_const_shape); | |||||
std::static_pointer_cast<lar::Bool>(m_option["fake_first"]) | |||||
->set_value(FLAGS_fake_first); | |||||
std::static_pointer_cast<lar::Bool>(m_option["no_sanity_check"]) | |||||
->set_value(FLAGS_no_sanity_check); | |||||
std::static_pointer_cast<lar::Bool>(m_option["record_comp_seq"]) | |||||
->set_value(FLAGS_record_comp_seq); | |||||
std::static_pointer_cast<lar::Bool>(m_option["record_comp_seq2"]) | |||||
->set_value(FLAGS_record_comp_seq2); | |||||
} | } | ||||
bool GraphRecordOption::is_valid() { | bool GraphRecordOption::is_valid() { | ||||
@@ -307,7 +344,7 @@ bool GraphRecordOption::is_valid() { | |||||
ret = ret || FLAGS_no_sanity_check; | ret = ret || FLAGS_no_sanity_check; | ||||
ret = ret || FLAGS_record_comp_seq; | ret = ret || FLAGS_record_comp_seq; | ||||
ret = ret || FLAGS_record_comp_seq2; | ret = ret || FLAGS_record_comp_seq2; | ||||
return ret; | |||||
return ret || m_valid; | |||||
} | } | ||||
std::shared_ptr<OptionBase> GraphRecordOption::create_option() { | std::shared_ptr<OptionBase> GraphRecordOption::create_option() { | ||||
@@ -321,6 +358,22 @@ std::shared_ptr<OptionBase> GraphRecordOption::create_option() { | |||||
void GraphRecordOption::config_model( | void GraphRecordOption::config_model( | ||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | ||||
const_shape = | |||||
std::static_pointer_cast<lar::Bool>(m_option["const_shape"])->get_value(); | |||||
fake_first = | |||||
std::static_pointer_cast<lar::Bool>(m_option["fake_first"])->get_value(); | |||||
no_sanity_check = std::static_pointer_cast<lar::Bool>(m_option["no_sanity_check"]) | |||||
->get_value(); | |||||
m_record_comp_seq = std::static_pointer_cast<lar::Bool>(m_option["record_comp_seq"]) | |||||
->get_value() | |||||
? 1 | |||||
: 0; | |||||
m_record_comp_seq = | |||||
std::static_pointer_cast<lar::Bool>(m_option["record_comp_seq2"]) | |||||
->get_value() | |||||
? 2 | |||||
: 0; | |||||
CONFIG_MODEL_FUN; | CONFIG_MODEL_FUN; | ||||
} | } | ||||
///////////////////////// graph retrict options ///////////////////////// | ///////////////////////// graph retrict options ///////////////////////// | ||||
@@ -569,13 +622,26 @@ DEFINE_string( | |||||
"Set the TensorRT engine cache path for serialized prebuilt " | "Set the TensorRT engine cache path for serialized prebuilt " | ||||
"ICudaEngine"); | "ICudaEngine"); | ||||
#endif | #endif | ||||
REGIST_OPTION_CREATOR(fuse_preprocess, lar::FusePreprocessOption::create_option); | REGIST_OPTION_CREATOR(fuse_preprocess, lar::FusePreprocessOption::create_option); | ||||
REGIST_OPTION_VALIDATER(fuse_preprocess, lar::FusePreprocessOption::set_valid); | |||||
REGIST_OPTION_CREATOR(weight_preprocess, lar::WeightPreprocessOption::create_option); | REGIST_OPTION_CREATOR(weight_preprocess, lar::WeightPreprocessOption::create_option); | ||||
REGIST_OPTION_VALIDATER(weight_preprocess, lar::WeightPreprocessOption::set_valid); | |||||
REGIST_OPTION_CREATOR( | REGIST_OPTION_CREATOR( | ||||
fuse_conv_bias_nonlinear, lar::FuseConvBiasNonlinearOption::create_option); | |||||
fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::create_option); | |||||
REGIST_OPTION_VALIDATER( | |||||
fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::set_valid); | |||||
REGIST_OPTION_CREATOR( | REGIST_OPTION_CREATOR( | ||||
fuse_conv_bias_z, lar::FuseConvBiasElemwiseAddOption::create_option); | |||||
fuse_conv_bias_with_z, lar::FuseConvBiasElemwiseAddOption::create_option); | |||||
REGIST_OPTION_VALIDATER( | |||||
fuse_conv_bias_with_z, lar::FuseConvBiasElemwiseAddOption::set_valid); | |||||
REGIST_OPTION_CREATOR(graph_record, lar::GraphRecordOption::create_option); | REGIST_OPTION_CREATOR(graph_record, lar::GraphRecordOption::create_option); | ||||
REGIST_OPTION_VALIDATER(graph_record, lar::GraphRecordOption::set_valid); | |||||
REGIST_OPTION_CREATOR(memory_optimize, lar::MemoryOptimizeOption::create_option); | REGIST_OPTION_CREATOR(memory_optimize, lar::MemoryOptimizeOption::create_option); | ||||
REGIST_OPTION_CREATOR(JIT, lar::JITOption::create_option); | REGIST_OPTION_CREATOR(JIT, lar::JITOption::create_option); | ||||
#if MGB_ENABLE_TENSOR_RT | #if MGB_ENABLE_TENSOR_RT | ||||
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/optimize_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include "helpers/common.h" | #include "helpers/common.h" | ||||
@@ -44,6 +35,10 @@ public: | |||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
static void set_valid(bool val) { m_valid = val; } | |||||
OptionValMap* get_option() override { return &m_option; } | |||||
private: | private: | ||||
FusePreprocessOption(); | FusePreprocessOption(); | ||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
@@ -51,6 +46,8 @@ private: | |||||
std::string m_option_name; | std::string m_option_name; | ||||
bool enable_fuse_preprocess; | bool enable_fuse_preprocess; | ||||
static bool m_valid; | |||||
OptionValMap m_option; | |||||
}; | }; | ||||
///////////////////////// weight preprocess optimize options ////////////// | ///////////////////////// weight preprocess optimize options ////////////// | ||||
@@ -64,6 +61,9 @@ public: | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | ||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
static void set_valid(bool val) { m_valid = val; }; | |||||
OptionValMap* get_option() override { return &m_option; } | |||||
private: | private: | ||||
WeightPreprocessOption(); | WeightPreprocessOption(); | ||||
@@ -72,6 +72,8 @@ private: | |||||
std::string m_option_name; | std::string m_option_name; | ||||
bool weight_preprocess; | bool weight_preprocess; | ||||
static bool m_valid; | |||||
OptionValMap m_option; | |||||
}; | }; | ||||
/////////////// fuse_conv_bias_nonlinearity optimize options /////////////// | /////////////// fuse_conv_bias_nonlinearity optimize options /////////////// | ||||
@@ -85,6 +87,9 @@ public: | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | ||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
static void set_valid(bool val) { m_valid = val; } | |||||
OptionValMap* get_option() override { return &m_option; } | |||||
private: | private: | ||||
FuseConvBiasNonlinearOption(); | FuseConvBiasNonlinearOption(); | ||||
@@ -93,6 +98,8 @@ private: | |||||
std::string m_option_name; | std::string m_option_name; | ||||
bool enable_fuse_conv_bias_nonlinearity; | bool enable_fuse_conv_bias_nonlinearity; | ||||
static bool m_valid; | |||||
OptionValMap m_option; | |||||
}; | }; | ||||
///////////////////////// fuse_conv_bias_with_z optimize options ////////////// | ///////////////////////// fuse_conv_bias_with_z optimize options ////////////// | ||||
@@ -106,6 +113,9 @@ public: | |||||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | ||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
static void set_valid(bool val) { m_valid = val; } | |||||
OptionValMap* get_option() override { return &m_option; } | |||||
private: | private: | ||||
FuseConvBiasElemwiseAddOption(); | FuseConvBiasElemwiseAddOption(); | ||||
@@ -113,6 +123,8 @@ private: | |||||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | ||||
std::string m_option_name; | std::string m_option_name; | ||||
bool enable_fuse_conv_bias_with_z; | bool enable_fuse_conv_bias_with_z; | ||||
static bool m_valid; | |||||
OptionValMap m_option; | |||||
}; | }; | ||||
///////////////////////// graph record options /////////////////////////// | ///////////////////////// graph record options /////////////////////////// | ||||
@@ -127,6 +139,10 @@ public: | |||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
static void set_valid(bool val) { m_valid = val; } | |||||
OptionValMap* get_option() override { return &m_option; } | |||||
private: | private: | ||||
GraphRecordOption(); | GraphRecordOption(); | ||||
template <typename ModelImpl> | template <typename ModelImpl> | ||||
@@ -137,6 +153,8 @@ private: | |||||
bool const_shape; | bool const_shape; | ||||
bool fake_first; | bool fake_first; | ||||
bool no_sanity_check; | bool no_sanity_check; | ||||
static bool m_valid; | |||||
OptionValMap m_option; | |||||
}; | }; | ||||
///////////////////////// memory optimize options ///////////////////////// | ///////////////////////// memory optimize options ///////////////////////// | ||||
@@ -1,22 +1,13 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/option_base.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <functional> | #include <functional> | ||||
#include <iostream> | #include <iostream> | ||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <unordered_map> | |||||
#include <vector> | #include <vector> | ||||
#include "megbrain/common.h" | #include "megbrain/common.h" | ||||
#include "helpers/common.h" | #include "helpers/common.h" | ||||
#include "helpers/utils.h" | |||||
#include "models/model.h" | #include "models/model.h" | ||||
namespace lar { | namespace lar { | ||||
@@ -34,6 +25,9 @@ public: | |||||
//! get option name | //! get option name | ||||
virtual std::string option_name() const = 0; | virtual std::string option_name() const = 0; | ||||
//! get option map | |||||
virtual OptionValMap* get_option() { return nullptr; } | |||||
virtual ~OptionBase() = default; | virtual ~OptionBase() = default; | ||||
}; | }; | ||||
@@ -43,7 +37,10 @@ public: | |||||
class OptionFactory { | class OptionFactory { | ||||
public: | public: | ||||
using OptionCreator = std::function<std::shared_ptr<OptionBase>()>; | using OptionCreator = std::function<std::shared_ptr<OptionBase>()>; | ||||
using OptionMap = std::unordered_map<std::string, OptionCreator>; | |||||
using OptionValidater = std::function<void(bool)>; | |||||
using OptionCreatorMap = std::unordered_map<std::string, OptionCreator>; | |||||
using OptionValidaterMap = std::unordered_map<std::string, OptionValidater>; | |||||
//! get Singleton option factory | //! get Singleton option factory | ||||
static OptionFactory& get_Instance() { | static OptionFactory& get_Instance() { | ||||
@@ -52,29 +49,49 @@ public: | |||||
} | } | ||||
//! registe option creator into option map | //! registe option creator into option map | ||||
void registe_options(std::string name, OptionCreator creator) { | |||||
if (option_creator_map.count(name) == 0) { | |||||
option_creator_map[name] = creator; | |||||
void registe_options_creator(std::string name, OptionCreator creator) { | |||||
if (m_option_creator_map.count(name) == 0) { | |||||
m_option_creator_map[name] = creator; | |||||
} | |||||
} | |||||
//! registe option validater into option map | |||||
void registe_options_validater(std::string name, OptionValidater validater) { | |||||
if (m_option_validater_map.count(name) == 0) { | |||||
m_option_validater_map[name] = validater; | |||||
} | } | ||||
} | } | ||||
//! get creator map | //! get creator map | ||||
OptionMap* get_option_creator_map() { return &option_creator_map; } | |||||
OptionCreatorMap* get_option_creator_map() { return &m_option_creator_map; } | |||||
//! get validater map | |||||
OptionValidaterMap* get_option_validater_map() { return &m_option_validater_map; } | |||||
private: | private: | ||||
OptionFactory(){}; | OptionFactory(){}; | ||||
OptionMap option_creator_map; | |||||
OptionCreatorMap m_option_creator_map; | |||||
OptionValidaterMap m_option_validater_map; | |||||
}; | }; | ||||
} // namespace lar | } // namespace lar | ||||
#define REGIST_OPTION_CREATOR(name_, creator_) \ | |||||
struct OptionRegister_##name_ { \ | |||||
OptionRegister_##name_() { \ | |||||
lar::OptionFactory::get_Instance().registe_options(#name_, creator_); \ | |||||
} \ | |||||
}; \ | |||||
OptionRegister_##name_ name_; | |||||
#define REGIST_OPTION_CREATOR(_name, _creator) \ | |||||
struct CreatorRegister_##_name { \ | |||||
CreatorRegister_##_name() { \ | |||||
lar::OptionFactory::get_Instance().registe_options_creator( \ | |||||
#_name, _creator); \ | |||||
} \ | |||||
}; \ | |||||
CreatorRegister_##_name creator_##_name; | |||||
#define REGIST_OPTION_VALIDATER(_name, _validater) \ | |||||
struct ValitaterRegister_##_name { \ | |||||
ValitaterRegister_##_name() { \ | |||||
lar::OptionFactory::get_Instance().registe_options_validater( \ | |||||
#_name, _validater); \ | |||||
} \ | |||||
}; \ | |||||
ValitaterRegister_##_name validater_##_name; | |||||
#define CONFIG_MODEL_FUN \ | #define CONFIG_MODEL_FUN \ | ||||
if (model->type() == ModelType::LITE_MODEL) { \ | if (model->type() == ModelType::LITE_MODEL) { \ | ||||
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/plugin_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "plugin_options.h" | #include "plugin_options.h" | ||||
#include <map> | #include <map> | ||||
#include "misc.h" | #include "misc.h" | ||||
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/plugin_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#if __linux__ || __unix__ | #if __linux__ || __unix__ | ||||
@@ -1,24 +1,21 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/strategy_options.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "strategy_options.h" | #include "strategy_options.h" | ||||
#include "models/model_mdl.h" | #include "models/model_mdl.h" | ||||
using namespace lar; | using namespace lar; | ||||
DECLARE_bool(c_opr_lib_with_param); | DECLARE_bool(c_opr_lib_with_param); | ||||
DECLARE_bool(fitting); | |||||
StrategyOption::StrategyOption() { | StrategyOption::StrategyOption() { | ||||
m_option_name = "run_strategy"; | m_option_name = "run_strategy"; | ||||
warmup_iter = FLAGS_warmup_iter; | |||||
run_iter = FLAGS_iter; | |||||
threads = FLAGS_thread; | |||||
warmup_iter = FLAGS_fitting ? 3 : FLAGS_warmup_iter; | |||||
run_iter = FLAGS_fitting ? 10 : FLAGS_iter; | |||||
threads = FLAGS_fitting ? 1 : FLAGS_thread; | |||||
m_option = { | |||||
{"iter", lar::NumberInt32::make(run_iter)}, | |||||
{"warmup_iter", lar::NumberInt32::make(warmup_iter)}, | |||||
{"thread", lar::NumberInt32::make(threads)}, | |||||
}; | |||||
} | } | ||||
std::shared_ptr<OptionBase> StrategyOption::create_option() { | std::shared_ptr<OptionBase> StrategyOption::create_option() { | ||||
@@ -60,8 +57,7 @@ void TestcaseOption::config_model( | |||||
if (model->type() == ModelType::MEGDL_MODEL) { | if (model->type() == ModelType::MEGDL_MODEL) { | ||||
auto model_ptr = std::static_pointer_cast<ModelMdl>(model); | auto model_ptr = std::static_pointer_cast<ModelMdl>(model); | ||||
if (model_ptr->get_testcase_num() && !FLAGS_c_opr_lib_with_param) { | if (model_ptr->get_testcase_num() && !FLAGS_c_opr_lib_with_param) { | ||||
if (runtime_param.stage == RunStage::MODEL_RUNNING) { | |||||
auto load_result = model_ptr->get_mdl_load_result(); | |||||
if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { | |||||
auto input_tensor = model_ptr->get_test_input(); | auto input_tensor = model_ptr->get_test_input(); | ||||
auto loader = model_ptr->reset_loader(); | auto loader = model_ptr->reset_loader(); | ||||
auto testcase = loader->load(model_ptr->get_mdl_config(), false); | auto testcase = loader->load(model_ptr->get_mdl_config(), false); | ||||
@@ -1,12 +1,3 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/options/strategy_options.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include "models/model.h" | #include "models/model.h" | ||||
#include "option_base.h" | #include "option_base.h" | ||||
@@ -32,6 +23,8 @@ public: | |||||
//! get option name | //! get option name | ||||
std::string option_name() const override { return m_option_name; }; | std::string option_name() const override { return m_option_name; }; | ||||
OptionValMap* get_option() override { return &m_option; } | |||||
private: | private: | ||||
//! Constructor | //! Constructor | ||||
StrategyOption(); | StrategyOption(); | ||||
@@ -43,6 +36,7 @@ private: | |||||
size_t run_iter; //! iteration number for running model | size_t run_iter; //! iteration number for running model | ||||
size_t threads; //! thread number for running model (NOTE:it's different | size_t threads; //! thread number for running model (NOTE:it's different | ||||
//! from multithread device ) | //! from multithread device ) | ||||
OptionValMap m_option; | |||||
}; | }; | ||||
class TestcaseOption final : public OptionBase { | class TestcaseOption final : public OptionBase { | ||||
@@ -1,18 +1,10 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/strategys/strategy.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "strategy.h" | #include "strategy.h" | ||||
#include <iostream> | #include <iostream> | ||||
#include "strategy_fitting.h" | |||||
#include "strategy_normal.h" | |||||
using namespace lar; | using namespace lar; | ||||
DECLARE_bool(fitting); | |||||
std::shared_ptr<StrategyBase> StrategyBase::create_strategy(std::string model_path) { | std::shared_ptr<StrategyBase> StrategyBase::create_strategy(std::string model_path) { | ||||
if (FLAGS_fitting) { | if (FLAGS_fitting) { | ||||
return std::make_shared<FittingStrategy>(model_path); | return std::make_shared<FittingStrategy>(model_path); | ||||
@@ -1,23 +1,11 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/strategys/strategy.h | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include <gflags/gflags.h> | |||||
#include <string> | #include <string> | ||||
#include <unordered_map> | #include <unordered_map> | ||||
#include "helpers/common.h" | #include "helpers/common.h" | ||||
#include "models/model.h" | #include "models/model.h" | ||||
#include "options/option_base.h" | #include "options/option_base.h" | ||||
DECLARE_bool(fitting); | |||||
namespace lar { | namespace lar { | ||||
using OptionMap = std::unordered_map<std::string, std::shared_ptr<OptionBase>>; | |||||
/*! | /*! | ||||
* \brief: load and run strategy base class | * \brief: load and run strategy base class | ||||
*/ | */ | ||||
@@ -30,34 +18,10 @@ public: | |||||
virtual ~StrategyBase() = default; | virtual ~StrategyBase() = default; | ||||
RuntimeParam m_runtime_param; | RuntimeParam m_runtime_param; | ||||
std::unordered_map<std::string, std::shared_ptr<OptionBase>> m_options; | |||||
}; | |||||
/*! | |||||
* \brief: normal strategy for running | |||||
*/ | |||||
class NormalStrategy : public StrategyBase { | |||||
public: | |||||
NormalStrategy(std::string model_path); | |||||
//! run model with runtime parameter | |||||
void run() override; | |||||
private: | |||||
//! run model subline for multiple thread | |||||
void run_subline(); | |||||
std::string m_model_path; | |||||
std::shared_ptr<OptionMap> m_options; | |||||
}; | }; | ||||
/*! | |||||
* \brief: Fitting strategy for running | |||||
*/ | |||||
class FittingStrategy : public StrategyBase { | |||||
public: | |||||
FittingStrategy(std::string model_path); | |||||
void run() override; | |||||
}; | |||||
} // namespace lar | } // namespace lar | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -1,24 +1,590 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/strategys/strategy_fitting.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "strategy.h" | |||||
#include "strategy_fitting.h" | |||||
#if defined(_WIN32) | |||||
#include <io.h> | |||||
#define F_OK 0 | |||||
#define access(a, b) _access(a, b) | |||||
#elif __linux__ || __unix__ || __APPLE__ | |||||
#include <unistd.h> | |||||
#endif | |||||
#include <fstream> | |||||
#include <iostream> | |||||
#include <list> | |||||
#include <regex> | |||||
#include <thread> | |||||
#include "lite/pack_model.h" | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/comp_node_env.h" | |||||
#include "megbrain/exception.h" | |||||
#include "megbrain/utils/timer.h" | |||||
#include "megbrain/version.h" | |||||
#include "megdnn/version.h" | |||||
#include "misc.h" | |||||
DECLARE_bool(cpu); | |||||
using namespace lar; | using namespace lar; | ||||
FittingStrategy::FittingStrategy(std::string) { | |||||
mgb_assert("this version don't support Fitting Strategy"); | |||||
// /////////////////// OptionsFastManager /////////////////// | |||||
void OptionsFastManager::init(std::shared_ptr<OptionMap>& options) { | |||||
m_option_group_cnt = 0; | |||||
m_fixed_option_cnt = 0; | |||||
m_internal_options_name = { | |||||
{"enable_fuse_conv_bias_with_z"}, | |||||
{"enable_fuse_preprocess"}, | |||||
{"record_comp_seq"}}; | |||||
//! record the independent option value | |||||
for (auto& option : *options) { | |||||
auto option_vals = option.second->get_option(); | |||||
if (option_vals) { | |||||
for (auto& item : *option_vals) { | |||||
m_valid_option_vals.insert(item); | |||||
} | |||||
} | |||||
} | |||||
}; | }; | ||||
std::string OptionsFastManager::set_next_fixed_options() { | |||||
reset_option(); | |||||
auto& fixed_options_name = m_fixed_options_name[m_fixed_option_cnt]; | |||||
for (auto& item : fixed_options_name) { | |||||
if (m_valid_option_vals.find(item) != m_valid_option_vals.end()) { | |||||
auto& option_val = m_valid_option_vals[item]; | |||||
auto type = option_val->get_type(); | |||||
if (type == JsonValueType::Bool) { | |||||
auto option_val_ptr = std::static_pointer_cast<lar::Bool>(option_val); | |||||
option_val_ptr->set_value(true); | |||||
} else if (type == JsonValueType::String && item == "layout_transform") { | |||||
auto option_val_ptr = std::static_pointer_cast<lar::String>(option_val); | |||||
//! device type | |||||
option_val_ptr->set_value(fixed_options_name[0]); | |||||
} else { | |||||
mgb_log_error( | |||||
"invalid JsonValueType:%s to set next value for fitting mode", | |||||
option_val->type_string().c_str()); | |||||
} | |||||
} | |||||
} | |||||
++m_fixed_option_cnt; | |||||
std::string code = m_gflags_coder.encode(m_valid_option_vals); | |||||
return code; | |||||
} | |||||
std::string OptionsFastManager::set_next_options() { | |||||
reset_option(); | |||||
auto& constraint = m_internal_options_name[m_option_group_cnt]; | |||||
for (auto& item : constraint) { | |||||
if (m_valid_option_vals.find(item) != m_valid_option_vals.end()) { | |||||
auto& option_val = m_valid_option_vals[item]; | |||||
auto type = option_val->get_type(); | |||||
if (type == JsonValueType::Bool) { | |||||
auto option_val_ptr = std::static_pointer_cast<lar::Bool>(option_val); | |||||
option_val_ptr->set_value(true); | |||||
} else { | |||||
mgb_log_error( | |||||
"invalid JsonValueType: %s to set next value for fitting mode", | |||||
option_val->type_string().c_str()); | |||||
} | |||||
} | |||||
} | |||||
++m_option_group_cnt; | |||||
std::string code = m_gflags_coder.encode(m_valid_option_vals); | |||||
return code; | |||||
} | |||||
bool OptionsFastManager::is_end_options() { | |||||
return m_option_group_cnt == m_internal_options_name.size(); | |||||
} | |||||
bool OptionsFastManager::is_fixed_end() { | |||||
return m_fixed_option_cnt == m_fixed_options_name.size(); | |||||
} | |||||
void OptionsFastManager::set_options(const std::string& code) { | |||||
reset_option(); | |||||
#if MGB_ENABLE_JSON | |||||
const std::regex json_regex(".\\{"); | |||||
#endif | |||||
const std::regex gflags_regex("--.*=.*"); | |||||
if (std::regex_search(code, gflags_regex)) { | |||||
m_gflags_coder.decode(code, m_valid_option_vals); | |||||
} | |||||
#if MGB_ENABLE_JSON | |||||
else if (std::regex_search(code, json_regex)) { | |||||
m_json_coder.decode(code, m_valid_option_vals); | |||||
} | |||||
#endif | |||||
else { | |||||
mgb_log_error("invalid options code format \"%s\" to decode", code.c_str()); | |||||
} | |||||
} | |||||
void OptionsFastManager::registe_fixed_options( | |||||
const std::vector<std::string>& option_name) { | |||||
m_fixed_options_name.push_back(option_name); | |||||
} | |||||
std::string OptionsFastManager::get_curr_options_code(CoderType type, bool encode_all) { | |||||
if (type == CoderType::GFLAGS) { | |||||
return m_gflags_coder.encode(m_valid_option_vals, encode_all); | |||||
} | |||||
#if MGB_ENABLE_JSON | |||||
else if (type == CoderType::JSON) { | |||||
return m_json_coder.encode(m_valid_option_vals, encode_all); | |||||
} | |||||
#endif | |||||
else { | |||||
mgb_log_error("coder should be implemented in furture"); | |||||
return ""; | |||||
} | |||||
} | |||||
#if MGB_ENABLE_JSON | |||||
std::vector<std::shared_ptr<mgb::json::Object>> OptionsFastManager::get_json() { | |||||
std::vector<std::shared_ptr<mgb::json::Object>> ret = | |||||
m_json_coder.encode(m_valid_option_vals); | |||||
return ret; | |||||
} | |||||
#endif | |||||
void OptionsFastManager::reset_option() { | |||||
for (auto& option : m_valid_option_vals) { | |||||
option.second->reset_value(); | |||||
} | |||||
} | |||||
////////////////// OptionsTimeProfiler ////////////////// | |||||
void OptionsTimeProfiler::profile_with_given_options( | |||||
const std::string& model_path, std::shared_ptr<OptionMap>& given_options, | |||||
const std::string& option_code) { | |||||
RuntimeParam runtime_param; | |||||
auto model = ModelBase::create_model(model_path); | |||||
mgb::RealTimer timer; | |||||
auto stage_config_model = [&]() { | |||||
for (auto& option : *given_options) { | |||||
option.second->config_model(runtime_param, model); | |||||
} | |||||
}; | |||||
auto warm_up = [&]() { | |||||
for (size_t i = 0; i < runtime_param.warmup_iter; i++) { | |||||
auto start = timer.get_msecs(); | |||||
model->run_model(); | |||||
model->wait(); | |||||
mgb_log_warn("warm up %ld time %f ms", i, timer.get_msecs() - start); | |||||
} | |||||
}; | |||||
double inference_time = 0.0; | |||||
auto run_iter = [&]() { | |||||
for (size_t i = 0; i < runtime_param.run_iter; i++) { | |||||
auto start = timer.get_msecs(); | |||||
model->run_model(); | |||||
model->wait(); | |||||
auto end = timer.get_msecs(); | |||||
mgb_log_warn("run iter %ld time %f ms", i, end - start); | |||||
inference_time += end - start; | |||||
mgb_throw_if( | |||||
inference_time > TIME_OUT, mgb::TimeoutError, | |||||
"time out while using fitting"); | |||||
} | |||||
}; | |||||
//! model with testcase | |||||
size_t case_num = runtime_param.testcase_num; | |||||
bool exception_state = false; | |||||
MGB_TRY { | |||||
timer.reset(); | |||||
runtime_param.stage = RunStage::BEFORE_MODEL_LOAD; | |||||
stage_config_model(); | |||||
model->load_model(); | |||||
//! after load configure | |||||
auto config_model_before_runing = [&]() { | |||||
for (auto stage : | |||||
{RunStage::AFTER_MODEL_LOAD, RunStage::GLOBAL_OPTIMIZATION, | |||||
RunStage::BEFORE_OUTSPEC_SET, RunStage::AFTER_OUTSPEC_SET, | |||||
RunStage::MODEL_RUNNING}) { | |||||
runtime_param.stage = stage; | |||||
stage_config_model(); | |||||
} | |||||
}; | |||||
timer.reset(); | |||||
for (size_t idx = 0; idx < case_num; idx++) { | |||||
auto start = timer.get_msecs(); | |||||
config_model_before_runing(); | |||||
auto end = timer.get_msecs(); | |||||
mgb_log_warn("config model time %f ms", end - start); | |||||
warm_up(); | |||||
run_iter(); | |||||
} | |||||
runtime_param.stage = RunStage::AFTER_MODEL_RUNNING; | |||||
stage_config_model(); | |||||
} | |||||
MGB_CATCH(std::exception & exc, { | |||||
mgb_log_error("catch exception: %s", exc.what()); | |||||
exception_state = true; | |||||
}); | |||||
auto average = inference_time / runtime_param.run_iter; | |||||
if (exception_state) { | |||||
average = TIME_OUT; | |||||
} | |||||
//! record profile result | |||||
printf("profile option:\n%s\naverage time = %.2f\n", option_code.c_str(), average); | |||||
m_options_profile_result.insert({option_code, average}); | |||||
//! record the best result | |||||
if (average < m_best_setting.second) { | |||||
m_best_setting.first = option_code; | |||||
m_best_setting.second = average; | |||||
} | |||||
} | |||||
/////////////////////////// UserInfoParser ///////////////////////////// | |||||
void UserInfoParser::get_user_info() { | |||||
//! register user information tips | |||||
std::vector<std::pair<std::string, std::string>> info_tips; | |||||
m_user_info["fitting_preference"] = "Inferspeed"; | |||||
info_tips.push_back( | |||||
{"use_const_shape", "whether the input shape is constant?(yes/no)?"}); | |||||
for (auto& tip : info_tips) { | |||||
std::cout << tip.second; | |||||
std::string answer = ""; | |||||
std::cin >> answer; | |||||
m_user_info[tip.first] = answer; | |||||
} | |||||
} | |||||
void UserInfoParser::parse_info(std::shared_ptr<OptionsFastManager>& manager) { | |||||
std::vector<std::string> fixed_options; | |||||
if (m_user_info["use_const_shape"] == "yes") { | |||||
fixed_options.push_back("const_shape"); | |||||
} else if (m_user_info["use_const_shape"] != "no") { | |||||
mgb_log_error("invalid user information for \"use_const_shape\""); | |||||
} | |||||
fixed_options.push_back("enable_fuse_conv_bias_nonlinearity"); | |||||
std::vector<std::string> tmp_options; | |||||
auto insert_common_cpu_options = [&]() { | |||||
tmp_options = {"cpu"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
tmp_options = {"cpu", "weight_preprocess", "fast_run"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
tmp_options = {"cpu", "layout_transform"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
tmp_options = {"cpu", "layout_transform", "weight_preprocess"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
}; | |||||
#if (MEGDNN_AARCH64 || MEGDNN_ARMV7) | |||||
//! arm cpu device | |||||
insert_common_cpu_options(); | |||||
tmp_options = {"cpu", "enable_nchw44"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
tmp_options = {"cpu", "enable_nchw44", "weight_preprocess", "fast_run"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
tmp_options = {"cpu", "enable_nchw44_dot"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
tmp_options = {"cpu", "enable_nchw44_dot", "weight_preprocess", "fast_run"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
#else | |||||
#if LITE_WITH_CUDA | |||||
//! build with cuda and not force to use cpu device | |||||
if (!FLAGS_cpu) { | |||||
tmp_options = {"cuda"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
tmp_options = {"cuda", "enable_nchw4"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
tmp_options = {"cuda", "enable_chwn4"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
tmp_options = {"cuda", "enable_nchw64"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
tmp_options = {"cuda", "enable_nchw32"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
tmp_options = {"cuda", "layout_transform"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
tmp_options = {"cuda", "layout_transform", "weight_preprocess"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
} | |||||
#endif | |||||
#if LITE_WITH_CUDA | |||||
//! build with cuda force to use cpu | |||||
if (FLAGS_cpu) { | |||||
#endif | |||||
//!x86 cpu options | |||||
insert_common_cpu_options(); | |||||
tmp_options = {"cpu", "enable_nchw88"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
tmp_options = {"cpu", "enable_nchw88", "weight_preprocess", "fast_run"}; | |||||
tmp_options.insert( | |||||
tmp_options.end(), fixed_options.begin(), fixed_options.end()); | |||||
manager->registe_fixed_options(tmp_options); | |||||
#if LITE_WITH_CUDA | |||||
} | |||||
#endif | |||||
#endif | |||||
m_proifler_type = ProiflerType::TIME_PROFILER; | |||||
} | |||||
// /////////////////// FittingStrategy ////////////////////////////////// | |||||
FittingStrategy::FittingStrategy(std::string model_path) { | |||||
m_manager = std::make_shared<OptionsFastManager>(); | |||||
m_dumped_model = FLAGS_dump_fitting_model; | |||||
mgb::set_log_level(mgb::LogLevel::WARN); | |||||
m_options = std::make_shared<OptionMap>(); | |||||
m_model_path = model_path; | |||||
auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map(); | |||||
auto option_validater_map = | |||||
OptionFactory::get_Instance().get_option_validater_map(); | |||||
//! validate option used in fitting | |||||
auto validate_option = [&](std::string name) -> void { | |||||
if (option_validater_map->find(name) != option_validater_map->end()) { | |||||
auto& validater = (*option_validater_map).at(name); | |||||
if (validater) { | |||||
validater(true); | |||||
} | |||||
} | |||||
}; | |||||
//! construct option which is valid | |||||
auto construct_option = [&](std::string name) -> void { | |||||
auto& creator = (*option_creator_map)[name]; | |||||
auto option = creator(); | |||||
if (option) { | |||||
m_options->insert({name, option}); | |||||
} | |||||
}; | |||||
//! get all options which is valid | |||||
for (auto& creator : *option_creator_map) { | |||||
auto name = creator.first; | |||||
if (m_options->count(name) == 0) { | |||||
validate_option(name); | |||||
construct_option(name); | |||||
} | |||||
} | |||||
m_manager->init(m_options); | |||||
} | |||||
void FittingStrategy::dump_best_options_with_model() { | |||||
std::vector<uint8_t> info_algo_policy_data; | |||||
std::vector<uint8_t> info_binary_cache_data; | |||||
auto model = ModelBase::create_model(m_model_path); | |||||
RuntimeParam runtime_param; | |||||
auto stage_config_model = [&]() { | |||||
for (auto& option : *m_options) { | |||||
option.second->config_model(runtime_param, model); | |||||
} | |||||
}; | |||||
runtime_param.stage = RunStage::BEFORE_MODEL_LOAD; | |||||
stage_config_model(); | |||||
model->load_model(); | |||||
//! get json info vector | |||||
std::string json_info_str; | |||||
#if MGB_ENABLE_JSON | |||||
std::shared_ptr<mgb::json::Object> code_json = model->get_io_info(); | |||||
m_packed_info.push_back({mgb::json::String("IO"), (*code_json)["IO"]}); | |||||
auto info_json = m_manager->get_json(); | |||||
m_packed_info.push_back({mgb::json::String("options"), (*info_json[0])["options"]}); | |||||
m_packed_info.push_back({mgb::json::String("device"), (*info_json[1])["device"]}); | |||||
m_packed_info.push_back( | |||||
{mgb::json::String("backend"), mgb::json::String::make("MGE")}); | |||||
int lite_major, lite_minor, lite_patch; | |||||
lite::get_version(lite_major, lite_minor, lite_patch); | |||||
std::string version = std::to_string(lite_major); | |||||
version += "."; | |||||
version += std::to_string(lite_minor) + "."; | |||||
version += std::to_string(lite_patch); | |||||
m_packed_info.push_back( | |||||
{mgb::json::String("version"), mgb::json::String::make(version)}); | |||||
m_packed_info.push_back({mgb::json::String("valid"), mgb::json::Bool::make(true)}); | |||||
m_packed_info.push_back( | |||||
{mgb::json::String("name"), mgb::json::String::make("packed_model")}); | |||||
auto obj = mgb::json::Object::make(m_packed_info); | |||||
json_info_str = obj->to_string(); | |||||
#endif | |||||
std::vector<uint8_t> json_info(json_info_str.begin(), json_info_str.end()); | |||||
//! get model binary data after optimized | |||||
for (auto stage : | |||||
{RunStage::AFTER_MODEL_LOAD, RunStage::GLOBAL_OPTIMIZATION, | |||||
RunStage::BEFORE_OUTSPEC_SET, RunStage::AFTER_OUTSPEC_SET, | |||||
RunStage::MODEL_RUNNING}) { | |||||
runtime_param.stage = stage; | |||||
stage_config_model(); | |||||
} | |||||
model->run_model(); | |||||
model->wait(); | |||||
std::vector<uint8_t> model_data = model->get_model_data(); | |||||
mgb_log_warn("model_data size=%zu", model_data.size()); | |||||
mgb_log_warn("json_info size=%zu", json_info.size()); | |||||
mgb_log_warn("info_algo_policy_data size=%zu", info_algo_policy_data.size()); | |||||
mgb_log_warn("info_binary_cache_data size=%zu", info_binary_cache_data.size()); | |||||
lite::ModelPacker packer( | |||||
model_data, m_dumped_model, json_info, info_algo_policy_data, | |||||
info_binary_cache_data); | |||||
packer.set_header(); | |||||
packer.pack_model(); | |||||
} | |||||
///////////////////////// AutoCleanFile/////////////////////////// | |||||
FittingStrategy::AutoCleanFile::AutoCleanFile( | |||||
const std::string& model_path, std::shared_ptr<OptionMap>& options) | |||||
: m_model_path(model_path), m_options(options) { | |||||
m_filename = "fitting_tmp_model"; | |||||
if (!access(m_filename.c_str(), F_OK)) { | |||||
remove(m_filename.c_str()); | |||||
} | |||||
} | |||||
FittingStrategy::AutoCleanFile::~AutoCleanFile() { | |||||
if (!access(m_filename.c_str(), F_OK)) { | |||||
remove(m_filename.c_str()); | |||||
} | |||||
} | |||||
void FittingStrategy::AutoCleanFile::dump_model() { | |||||
auto model = ModelBase::create_model(m_model_path); | |||||
RuntimeParam runtime_param; | |||||
auto stage_config_model = [&]() { | |||||
for (auto& option : *m_options) { | |||||
option.second->config_model(runtime_param, model); | |||||
} | |||||
}; | |||||
runtime_param.stage = RunStage::BEFORE_MODEL_LOAD; | |||||
stage_config_model(); | |||||
model->load_model(); | |||||
//! get model binary data after optimized | |||||
for (auto stage : | |||||
{RunStage::AFTER_MODEL_LOAD, RunStage::GLOBAL_OPTIMIZATION, | |||||
RunStage::BEFORE_OUTSPEC_SET, RunStage::AFTER_OUTSPEC_SET, | |||||
RunStage::MODEL_RUNNING}) { | |||||
runtime_param.stage = stage; | |||||
stage_config_model(); | |||||
} | |||||
model->run_model(); | |||||
model->wait(); | |||||
std::vector<uint8_t> model_data = model->get_model_data(); | |||||
mgb_log_warn("dumped model_data size=%zu\n", model_data.size()); | |||||
auto fp = fopen(m_filename.c_str(), "wb"); | |||||
fwrite(model_data.data(), 1, model_data.size(), fp); | |||||
fclose(fp); | |||||
} | |||||
void FittingStrategy::run() { | void FittingStrategy::run() { | ||||
mgb_assert("this version don't support Fitting Strategy"); | |||||
}; | |||||
auto mgb_version = mgb::get_version(); | |||||
auto dnn_version = megdnn::get_version(); | |||||
printf("megbrain/lite/load_and_run:\nusing MegBrain " | |||||
"%d.%d.%d(%d) and MegDNN %d.%d.%d\n", | |||||
mgb_version.major, mgb_version.minor, mgb_version.patch, mgb_version.is_dev, | |||||
dnn_version.major, dnn_version.minor, dnn_version.patch); | |||||
// ! create profiler with given user info | |||||
m_info_parser.get_user_info(); | |||||
m_info_parser.parse_info(m_manager); | |||||
auto profiler = m_info_parser.create_profiler(); | |||||
mgb_throw_if( | |||||
profiler == nullptr, mgb::AssertionError, | |||||
"get empty profiler for fittting\n"); | |||||
//! profile model with fixed options | |||||
while (!m_manager->is_fixed_end()) { | |||||
std::string option_str = m_manager->set_next_fixed_options(); | |||||
profiler->profile_with_given_options(m_model_path, m_options, option_str); | |||||
#if (MEGDNN_AARCH64 || MEGDNN_ARMV7) | |||||
//! sleep to keep machine with stable cpu frequence | |||||
usleep(500000); | |||||
#endif | |||||
} | |||||
std::string m_tmp_model = m_model_path; | |||||
const std::regex layout_regex("layout_transform"); | |||||
auto best_fixed_options = profiler->get_best_setting(); | |||||
m_manager->set_options(best_fixed_options); | |||||
//! dump model for global layout transform | |||||
auto m_tmp_file = AutoCleanFile(m_model_path, m_options); | |||||
if (std::regex_search(best_fixed_options, layout_regex)) { | |||||
m_tmp_file.dump_model(); | |||||
m_model_path = m_tmp_file.filename(); | |||||
} | |||||
//! profile model with given profiler | |||||
while (!m_manager->is_end_options()) { | |||||
std::string curr_option_str = m_manager->set_next_options(); | |||||
//! set option with current option and fixed options | |||||
if (m_model_path == m_tmp_model) { | |||||
auto total_option_str = curr_option_str + best_fixed_options; | |||||
m_manager->set_options(total_option_str); | |||||
} | |||||
curr_option_str += best_fixed_options; | |||||
profiler->profile_with_given_options(m_model_path, m_options, curr_option_str); | |||||
#if (MEGDNN_AARCH64 || MEGDNN_ARMV7) | |||||
usleep(500000); | |||||
#endif | |||||
} | |||||
//! set with best options and inference | |||||
m_model_path = m_tmp_model; | |||||
auto best_options = profiler->get_best_setting(); | |||||
m_manager->set_options(best_options); | |||||
profiler->profile_with_given_options(m_model_path, m_options, best_options); | |||||
//! save best options into given dir | |||||
std::cout << "the best options:\n" << best_options << std::endl; | |||||
if (!m_dumped_model.empty()) { | |||||
dump_best_options_with_model(); | |||||
} | |||||
} | |||||
DEFINE_bool( | DEFINE_bool( | ||||
fitting, false, | |||||
"whether to use the fitting model, which will auto profile and get " | |||||
"the best option set!"); | |||||
fitting, false, "use the fitting mode profile and get the best option set."); | |||||
DEFINE_string(dump_fitting_model, "", "dump the best option and algo cache into model"); |
@@ -0,0 +1,152 @@ | |||||
#pragma once | |||||
#include <gflags/gflags.h> | |||||
#include "helpers/utils.h" | |||||
#include "strategy.h" | |||||
DECLARE_bool(fitting); | |||||
DECLARE_string(dump_fitting_model); | |||||
#define TIME_OUT 10000 | |||||
namespace lar { | |||||
class OptionsFastManager { | |||||
public: | |||||
using ConstraintMap = std::unordered_map<std::string, bool>; | |||||
OptionsFastManager(){}; | |||||
//! init the options value map with given options | |||||
void init(std::shared_ptr<OptionMap>&); | |||||
//! set next options group cyclely | |||||
std::string set_next_options(); | |||||
std::string set_next_fixed_options(); | |||||
//! check the end of options group | |||||
bool is_end_options(); | |||||
bool is_fixed_end(); | |||||
std::string get_curr_options_code(CoderType, bool encode_all = false); | |||||
//! set current options with given options | |||||
void set_options(const std::string&); | |||||
void registe_fixed_options(const std::vector<std::string>&); | |||||
#if MGB_ENABLE_JSON | |||||
std::vector<std::shared_ptr<mgb::json::Object>> get_json(); | |||||
#endif | |||||
private: | |||||
void reset_option(); | |||||
size_t m_option_group_cnt; | |||||
size_t m_fixed_option_cnt; | |||||
OptionValMap m_valid_option_vals; | |||||
std::vector<std::vector<std::string>> m_internal_options_name; | |||||
std::vector<std::vector<std::string>> m_fixed_options_name; | |||||
#if MGB_ENABLE_JSON | |||||
JsonOptionsCoder m_json_coder; | |||||
#endif | |||||
GflagsOptionsCoder m_gflags_coder; | |||||
}; | |||||
//! Options proifler to get the best settings with different evaluate standard | |||||
class OptionsProfiler { | |||||
public: | |||||
OptionsProfiler(){}; | |||||
//! run with m_options | |||||
virtual void profile_with_given_options( | |||||
const std::string&, std::shared_ptr<OptionMap>&, const std::string&) = 0; | |||||
//! get the best setting and inference time | |||||
virtual std::string get_best_setting() { return ""; } | |||||
virtual ~OptionsProfiler() = default; | |||||
}; | |||||
/** | |||||
* profiler to get the fast setting | |||||
*/ | |||||
class OptionsTimeProfiler final : public OptionsProfiler { | |||||
public: | |||||
OptionsTimeProfiler(){}; | |||||
void profile_with_given_options( | |||||
const std::string&, std::shared_ptr<OptionMap>&, | |||||
const std::string&) override; | |||||
std::string get_best_setting() override { return m_best_setting.first; } | |||||
private: | |||||
std::unordered_map<std::string, double> m_options_profile_result; | |||||
std::pair<std::string, double> m_best_setting = {"", TIME_OUT}; | |||||
}; | |||||
/** | |||||
* parse information from user given | |||||
*/ | |||||
class UserInfoParser { | |||||
public: | |||||
UserInfoParser(){}; | |||||
void get_user_info(); | |||||
void parse_info(std::shared_ptr<OptionsFastManager>&); | |||||
std::shared_ptr<OptionsProfiler> create_profiler() { | |||||
switch (m_proifler_type) { | |||||
case ProiflerType::TIME_PROFILER: | |||||
return std::make_shared<OptionsTimeProfiler>(); | |||||
case ProiflerType::UNSPEC_PROFILER: | |||||
return nullptr; | |||||
default: | |||||
return nullptr; | |||||
} | |||||
} | |||||
private: | |||||
ProiflerType m_proifler_type; | |||||
std::unordered_map<std::string, std::string> m_user_info; | |||||
}; | |||||
/*! | |||||
* \brief: Fitting strategy for running | |||||
*/ | |||||
class FittingStrategy : public StrategyBase { | |||||
public: | |||||
class AutoCleanFile { | |||||
public: | |||||
AutoCleanFile( | |||||
const std::string& model_path, std::shared_ptr<OptionMap>& options); | |||||
void dump_model(); | |||||
std::string filename() { return m_filename; } | |||||
~AutoCleanFile(); | |||||
private: | |||||
std::string m_model_path; | |||||
std::shared_ptr<OptionMap> m_options; | |||||
std::string m_filename; | |||||
}; | |||||
FittingStrategy(std::string model_path); | |||||
void run() override; | |||||
void dump_best_options_with_model(); | |||||
void dump_model(); | |||||
private: | |||||
std::string m_model_path; | |||||
std::string m_dumped_model; | |||||
std::shared_ptr<OptionsFastManager> m_manager; | |||||
UserInfoParser m_info_parser; | |||||
#if MGB_ENABLE_JSON | |||||
std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>> | |||||
m_packed_info; | |||||
#endif | |||||
}; | |||||
} // namespace lar |
@@ -1,11 +1,4 @@ | |||||
/** | |||||
* \file lite/load_and_run/src/strategys/strategy_normal.cpp | |||||
* | |||||
* This file is part of MegEngine, a deep learning framework developed by | |||||
* Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#include "strategy_normal.h" | |||||
#include <iostream> | #include <iostream> | ||||
#include <thread> | #include <thread> | ||||
#include "megbrain/common.h" | #include "megbrain/common.h" | ||||
@@ -13,13 +6,13 @@ | |||||
#include "megbrain/version.h" | #include "megbrain/version.h" | ||||
#include "megdnn/version.h" | #include "megdnn/version.h" | ||||
#include "misc.h" | #include "misc.h" | ||||
#include "strategy.h" | |||||
using namespace lar; | using namespace lar; | ||||
NormalStrategy::NormalStrategy(std::string model_path) { | NormalStrategy::NormalStrategy(std::string model_path) { | ||||
mgb::set_log_level(mgb::LogLevel::WARN); | mgb::set_log_level(mgb::LogLevel::WARN); | ||||
lite::set_log_level(LiteLogLevel::WARN); | lite::set_log_level(LiteLogLevel::WARN); | ||||
m_options = std::make_shared<OptionMap>(); | |||||
m_model_path = model_path; | m_model_path = model_path; | ||||
auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map(); | auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map(); | ||||
mgb_log_debug("option map size: %lu", option_creator_map->size()); | mgb_log_debug("option map size: %lu", option_creator_map->size()); | ||||
@@ -27,13 +20,13 @@ NormalStrategy::NormalStrategy(std::string model_path) { | |||||
auto& creator = (*option_creator_map)[name]; | auto& creator = (*option_creator_map)[name]; | ||||
auto option = creator(); | auto option = creator(); | ||||
if (option) { | if (option) { | ||||
m_options.insert({name, option}); | |||||
m_options->insert({name, option}); | |||||
} | } | ||||
}; | }; | ||||
for (auto& creator : *option_creator_map) { | for (auto& creator : *option_creator_map) { | ||||
auto name = creator.first; | auto name = creator.first; | ||||
if (m_options.count(name) == 0) { | |||||
if (m_options->count(name) == 0) { | |||||
construct_option(name); | construct_option(name); | ||||
} | } | ||||
} | } | ||||
@@ -44,7 +37,7 @@ void NormalStrategy::run_subline() { | |||||
mgb_assert(model != nullptr, "create model failed!!"); | mgb_assert(model != nullptr, "create model failed!!"); | ||||
auto stage_config_model = [&]() { | auto stage_config_model = [&]() { | ||||
for (auto& option : m_options) { | |||||
for (auto& option : *m_options) { | |||||
option.second->config_model(m_runtime_param, model); | option.second->config_model(m_runtime_param, model); | ||||
} | } | ||||
}; | }; | ||||
@@ -57,18 +50,14 @@ void NormalStrategy::run_subline() { | |||||
printf("load model: %.3fms\n", timer.get_msecs_reset()); | printf("load model: %.3fms\n", timer.get_msecs_reset()); | ||||
//! after load configure | //! after load configure | ||||
m_runtime_param.stage = RunStage::AFTER_MODEL_LOAD; | |||||
stage_config_model(); | |||||
m_runtime_param.stage = RunStage::GLOBAL_OPTIMIZATION; | |||||
stage_config_model(); | |||||
m_runtime_param.stage = RunStage::BEFORE_OUTSPEC_SET; | |||||
stage_config_model(); | |||||
// for get static memmory information options | |||||
m_runtime_param.stage = RunStage::AFTER_OUTSPEC_SET; | |||||
stage_config_model(); | |||||
auto config_after_load = [&]() { | |||||
for (auto stage : | |||||
{RunStage::AFTER_MODEL_LOAD, RunStage::GLOBAL_OPTIMIZATION, | |||||
RunStage::BEFORE_OUTSPEC_SET, RunStage::AFTER_OUTSPEC_SET}) { | |||||
m_runtime_param.stage = stage; | |||||
stage_config_model(); | |||||
} | |||||
}; | |||||
auto warm_up = [&]() { | auto warm_up = [&]() { | ||||
auto warmup_num = m_runtime_param.warmup_iter; | auto warmup_num = m_runtime_param.warmup_iter; | ||||
@@ -117,6 +106,8 @@ void NormalStrategy::run_subline() { | |||||
double tot_time = 0; | double tot_time = 0; | ||||
for (size_t idx = 0; idx < iter_num; idx++) { | for (size_t idx = 0; idx < iter_num; idx++) { | ||||
//! config model | |||||
config_after_load(); | |||||
//! config when running model | //! config when running model | ||||
mgb_log_warn("run testcase: %zu ", idx); | mgb_log_warn("run testcase: %zu ", idx); | ||||
m_runtime_param.stage = RunStage::MODEL_RUNNING; | m_runtime_param.stage = RunStage::MODEL_RUNNING; | ||||
@@ -0,0 +1,22 @@ | |||||
#pragma once | |||||
#include "strategy.h" | |||||
namespace lar { | |||||
/*! | |||||
* \brief: normal strategy for running | |||||
*/ | |||||
class NormalStrategy : public StrategyBase { | |||||
public: | |||||
NormalStrategy(std::string model_path); | |||||
//! run model with runtime parameter | |||||
void run() override; | |||||
private: | |||||
//! run model subline for multiple thread | |||||
void run_subline(); | |||||
std::string m_model_path; | |||||
}; | |||||
} // namespace lar |
@@ -1,14 +1,3 @@ | |||||
/** | |||||
* \file src/pack_model/pack_model.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "lite/pack_model.h" | #include "lite/pack_model.h" | ||||
#include "../misc.h" | #include "../misc.h" | ||||
#if LITE_BUILD_WITH_MGE | #if LITE_BUILD_WITH_MGE | ||||
@@ -192,7 +181,7 @@ ModelPacker::ModelPacker( | |||||
std::string info_data_path, std::string info_algo_policy_path, | std::string info_data_path, std::string info_algo_policy_path, | ||||
std::string info_binary_cache_path) | std::string info_binary_cache_path) | ||||
: m_packed_model_path(packed_model_path) { | : m_packed_model_path(packed_model_path) { | ||||
m_fbs_helper = new FbsHelper(this, model_path); | |||||
m_fbs_helper = std::make_shared<FbsHelper>(this, model_path); | |||||
std::vector<uint8_t> empty_vec; | std::vector<uint8_t> empty_vec; | ||||
m_info_data = info_data_path.empty() ? empty_vec : read_file(info_data_path); | m_info_data = info_data_path.empty() ? empty_vec : read_file(info_data_path); | ||||
m_algo_policy_data = info_algo_policy_path.empty() | m_algo_policy_data = info_algo_policy_path.empty() | ||||
@@ -207,7 +196,7 @@ ModelPacker::ModelPacker( | |||||
std::vector<uint8_t> model_data, std::string packed_model_path, | std::vector<uint8_t> model_data, std::string packed_model_path, | ||||
std::vector<uint8_t> info_data, std::vector<uint8_t> info_algo_policy_data, | std::vector<uint8_t> info_data, std::vector<uint8_t> info_algo_policy_data, | ||||
std::vector<uint8_t> info_binary_cache_data) { | std::vector<uint8_t> info_binary_cache_data) { | ||||
m_fbs_helper = new FbsHelper(this, model_data); | |||||
m_fbs_helper = std::make_shared<FbsHelper>(this, model_data); | |||||
m_packed_model_path = packed_model_path; | m_packed_model_path = packed_model_path; | ||||
m_info_data = info_data; | m_info_data = info_data; | ||||
m_algo_policy_data = info_algo_policy_data; | m_algo_policy_data = info_algo_policy_data; | ||||
@@ -1,14 +1,3 @@ | |||||
/** | |||||
* \file src/parse_info/cache_parse.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | #pragma once | ||||
#include "lite/global.h" | #include "lite/global.h" | ||||
#if LITE_BUILD_WITH_MGE | #if LITE_BUILD_WITH_MGE | ||||