GitOrigin-RevId: 6046a2db0c
release-1.4
@@ -25,79 +25,9 @@ | |||
using namespace mgb; | |||
namespace { | |||
class InMemoryPersistentCache final: public PersistentCache { | |||
struct BlobStorage: public Blob { | |||
std::unique_ptr<uint8_t[]> data_refhold; | |||
size_t hash = 0; | |||
BlobStorage& init_data_ref(const Blob &b) { | |||
data_refhold = std::make_unique<uint8_t[]>(b.size + 1); | |||
memcpy(data_refhold.get(), b.ptr, b.size); | |||
data_refhold.get()[b.size] = 0; // for C-string safety | |||
ptr = data_refhold.get(); | |||
size = b.size; | |||
return *this; | |||
} | |||
BlobStorage& init_hash() { | |||
hash = XXHash{}.update(ptr, size).digest(); | |||
return *this; | |||
} | |||
bool operator == (const BlobStorage &rhs) const { | |||
return size == rhs.size && !memcmp(ptr, rhs.ptr, size); | |||
} | |||
struct Hash { | |||
size_t operator() (const BlobStorage &b) const { | |||
return b.hash; | |||
} | |||
}; | |||
}; | |||
std::unordered_map<std::string, | |||
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>> | |||
m_cache; | |||
std::mutex m_mtx; | |||
Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||
decltype(m_cache.begin()) iter0; | |||
{ | |||
MGB_LOCK_GUARD(m_mtx); | |||
iter0 = m_cache.find(category); | |||
if (iter0 == m_cache.end()) | |||
return None; | |||
} | |||
BlobStorage key_storage; | |||
key_storage.Blob::operator=(key); | |||
key_storage.init_hash(); | |||
MGB_LOCK_GUARD(m_mtx); | |||
auto iter1 = iter0->second.find(key_storage); | |||
if (iter1 == iter0->second.end()) | |||
return None; | |||
return iter1->second; | |||
} | |||
void put(const std::string& category, const Blob& key, | |||
const Blob& value) override { | |||
BlobStorage key_storage; | |||
key_storage.init_data_ref(key).init_hash(); | |||
MGB_LOCK_GUARD(m_mtx); | |||
auto size0 = m_cache.size(); | |||
m_cache[category][std::move(key_storage)].init_data_ref(value); | |||
if (m_cache.size() > size0) { | |||
mgb_log_debug("new cache category: %s", category.c_str()); | |||
} | |||
} | |||
}; | |||
} | |||
// ================= PersistentCache ====================== | |||
std::shared_ptr<PersistentCache> PersistentCache::sm_impl = | |||
std::make_shared<InMemoryPersistentCache>(); | |||
std::make_shared<InMemoryPersistentCache>(); | |||
std::shared_ptr<PersistentCache> PersistentCache::set_impl( | |||
std::shared_ptr<PersistentCache> impl) { | |||
@@ -141,6 +71,65 @@ std::string PersistentCache::make_category_from_comp_node(CompNode comp_node) { | |||
} | |||
} | |||
// ================= InMemoryPersistentCache ================== | |||
using Blob = PersistentCache::Blob; | |||
InMemoryPersistentCache::BlobStorage& | |||
InMemoryPersistentCache::BlobStorage::init_data_ref(const Blob& b) { | |||
data_refhold = std::make_unique<uint8_t[]>(b.size + 1); | |||
memcpy(data_refhold.get(), b.ptr, b.size); | |||
data_refhold.get()[b.size] = 0; // for C-string safety | |||
ptr = data_refhold.get(); | |||
size = b.size; | |||
return *this; | |||
} | |||
InMemoryPersistentCache::BlobStorage& | |||
InMemoryPersistentCache::BlobStorage::init_hash() { | |||
hash = XXHash{}.update(ptr, size).digest(); | |||
return *this; | |||
} | |||
bool InMemoryPersistentCache::BlobStorage::operator==( | |||
const BlobStorage& rhs) const { | |||
return size == rhs.size && !memcmp(ptr, rhs.ptr, size); | |||
} | |||
Maybe<Blob> InMemoryPersistentCache::get(const std::string& category, | |||
const Blob& key) { | |||
decltype(m_cache.begin()) iter0; | |||
{ | |||
MGB_LOCK_GUARD(m_mtx); | |||
iter0 = m_cache.find(category); | |||
if (iter0 == m_cache.end()) | |||
return None; | |||
} | |||
BlobStorage key_storage; | |||
key_storage.Blob::operator=(key); | |||
key_storage.init_hash(); | |||
MGB_LOCK_GUARD(m_mtx); | |||
auto iter1 = iter0->second.find(key_storage); | |||
if (iter1 == iter0->second.end()) | |||
return None; | |||
return iter1->second; | |||
} | |||
void InMemoryPersistentCache::put(const std::string& category, const Blob& key, | |||
const Blob& value) { | |||
BlobStorage key_storage; | |||
key_storage.init_data_ref(key).init_hash(); | |||
MGB_LOCK_GUARD(m_mtx); | |||
auto size0 = m_cache.size(); | |||
m_cache[category][std::move(key_storage)].init_data_ref(value); | |||
if (m_cache.size() > size0) { | |||
mgb_log_debug("new cache category: %s", category.c_str()); | |||
} | |||
} | |||
// ================= AlgoChooserProfileCache ================== | |||
AlgoChooserProfileCache::AlgoChooserProfileCache( | |||
CompNode cn, const char *opr_type) { | |||
m_category = "profile:"; | |||
@@ -56,6 +56,37 @@ namespace mgb { | |||
}; | |||
/*! | |||
* \brief persistent cache that keep in memory | |||
* The implementation is thread safe. | |||
*/ | |||
class InMemoryPersistentCache final : public PersistentCache { | |||
struct BlobStorage : public PersistentCache::Blob { | |||
std::unique_ptr<uint8_t[]> data_refhold; | |||
size_t hash = 0; | |||
BlobStorage& init_data_ref(const Blob& b); | |||
BlobStorage& init_hash(); | |||
bool operator==(const BlobStorage& rhs) const; | |||
struct Hash { | |||
size_t operator()(const BlobStorage& b) const { return b.hash; } | |||
}; | |||
}; | |||
Maybe<Blob> get(const std::string& category, const Blob& key) override; | |||
void put(const std::string& category, const Blob& key, | |||
const Blob& value) override; | |||
std::unordered_map< | |||
std::string, | |||
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>> | |||
m_cache; | |||
std::mutex m_mtx; | |||
}; | |||
/*! | |||
* \brief proxy PersistentCache to be better suited for managing profiling | |||
* results of operator impl algorithms | |||
* | |||
@@ -68,7 +68,6 @@ std::string format_fixlayouts( | |||
ret.append(", "); | |||
} | |||
ret.append(layouts[i].to_string() + " "); | |||
ret.append(layouts[i].dtype.name()); | |||
} | |||
ret.append(") -> ("); | |||
for (size_t i = 0; i < arity_out; ++i) { | |||
@@ -76,7 +75,6 @@ std::string format_fixlayouts( | |||
ret.append(", "); | |||
} | |||
ret.append(layouts[i + arity_in].to_string() + " "); | |||
ret.append(layouts[i + arity_in].dtype.name()); | |||
} | |||
return ret; | |||
} | |||
@@ -420,6 +418,7 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, | |||
AlgoChooser<_Opr>::profile(sub_ctx, selected_strategy); | |||
}); | |||
} | |||
typename AlgoChooser<Opr>::ImplExecutionPolicy policy; | |||
ctx.construct_execution_policy(selected_strategy, policy); | |||
return policy; | |||
@@ -660,8 +659,28 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||
bool retrive_from_cache) const { | |||
if (!policy.algo.valid()) { | |||
if (retrive_from_cache) { | |||
policy.algo = | |||
get_profile_result_from_cache(selected_strategy).desc; | |||
policy.algo = get_profile_result_from_cache(selected_strategy).desc; | |||
if (!policy.algo.valid()) { | |||
auto target_attr = | |||
extract_algo_attribute_from_execution_strategy( | |||
selected_strategy); | |||
std::string layouts_str = | |||
format_fixlayouts<Opr>(m_layouts, arity_in, arity_out); | |||
std::string msg = ssprintf( | |||
"(mbg_opr : %s, layouts %s, with attribute(%s) and " | |||
"without attribute(%s)", | |||
m_base_mgb_opr->dyn_typeinfo()->name, | |||
layouts_str.c_str(), | |||
Algorithm::attribute_str(target_attr.first).c_str(), | |||
Algorithm::attribute_str(target_attr.second).c_str()); | |||
mgb_log_warn( | |||
"No algo get from cache for %s. This may caused by " | |||
"mismatch with model and cache file. ex. profiling " | |||
"with version1, but inferencing on version2 or " | |||
"profiling modelA but inferencing modelB", | |||
msg.c_str()); | |||
return; | |||
} | |||
} else { | |||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | |||
owner_graph(), m_cn, m_execution_policy.workspace_limit); | |||
@@ -673,10 +692,12 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||
attr.second), | |||
m_layouts) | |||
.desc; | |||
mgb_assert(policy.algo.valid(), | |||
"No algo found from heuristic with strategy %u and " | |||
"workspace limit %zu", | |||
static_cast<uint32_t>(selected_strategy), | |||
workspace_limit); | |||
} | |||
mgb_assert(policy.algo.valid(), | |||
"No algo found from cache or heuristic, maybe some error " | |||
"occured"); | |||
} | |||
Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); | |||
@@ -697,9 +718,13 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||
sub_ctx.construct_execution_policy(selected_strategy, | |||
policy.sub_policy.back(), | |||
retrive_from_cache); | |||
if (!policy.sub_policy.back().algo.valid()) { | |||
// means sub_ctx.construct_execution_policy fails. clean up | |||
// policy.algo and return | |||
policy = {}; | |||
return; | |||
} | |||
}); | |||
return; | |||
} | |||
template <typename Opr> | |||
@@ -140,9 +140,10 @@ public: | |||
* \brief construct execution policy from cache or heuristic. | |||
* | |||
* \param selected_strategy select algo which matched this strategy | |||
* \param policy execution policy | |||
* \param [out] policy execution policy | |||
* \param retrive_from_cache retrive algo from cache if set True, get | |||
* from heuristic otherwise. | |||
* \note When contruction fail, the policy will be cleaned. | |||
*/ | |||
void construct_execution_policy(ExecutionStrategy selected_strategy, | |||
ImplExecutionPolicy& policy, | |||
@@ -152,14 +153,13 @@ public: | |||
Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const; | |||
}; | |||
template<typename U> | |||
template <typename U> | |||
friend class AlgoChooser; | |||
private: | |||
//! entrance for getting algorithm according to execution strategy | |||
static ImplExecutionPolicy get_policy(ExeContext& ctx); | |||
//! profile and save to cache | |||
static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy); | |||
@@ -30,7 +30,6 @@ | |||
#include <random> | |||
using namespace mgb; | |||
namespace { | |||
using Param = opr::Convolution::Param; | |||
@@ -354,21 +353,26 @@ TEST(TestOprDNN, ConvBiasExePolicy) { | |||
auto cn = CompNode::load("cpux"); | |||
auto orig_impl = PersistentCache::set_impl( | |||
std::make_shared<InMemoryPersistentCache>()); | |||
#if MGB_ENABLE_FASTRUN | |||
for (auto strategy : | |||
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, | |||
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) { | |||
S::PROFILE | S::HEURISTIC}) { | |||
#else | |||
for (auto strategy : | |||
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { | |||
#endif | |||
auto graph = ComputingGraph::make(); | |||
HostTensorGenerator<> gen; | |||
auto mkvar = [&](const char* name, const TensorShape& shp, | |||
const DType& dtype) { | |||
return opr::TypeCvt::make( | |||
opr::Host2DeviceCopy::make(*graph, gen(shp), cn).rename(name), | |||
opr::Host2DeviceCopy::make(*graph, gen(shp), cn) | |||
.rename(name), | |||
dtype); | |||
}; | |||
@@ -388,7 +392,11 @@ TEST(TestOprDNN, ConvBiasExePolicy) { | |||
HostTensorND host_y; | |||
auto func = graph->compile({make_callback_copy(conv_bias, host_y)}); | |||
func->execute(); | |||
//! set a new cache | |||
PersistentCache::set_impl(std::make_shared<InMemoryPersistentCache>()); | |||
} | |||
PersistentCache::set_impl(orig_impl); | |||
} | |||
TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) { | |||
@@ -401,19 +409,21 @@ TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) { | |||
for (auto strategy : | |||
SmallVector<S>{S::PROFILE, S::PROFILE | S::REPRODUCIBLE}) { | |||
auto graph = ComputingGraph::make(); | |||
HostTensorGenerator<> gen; | |||
auto mkvar = [&](const char* name, const TensorShape& shp, | |||
const DType& dtype) { | |||
return opr::TypeCvt::make( | |||
opr::Host2DeviceCopy::make(*graph, gen(shp), cn).rename(name), | |||
opr::Host2DeviceCopy::make(*graph, gen(shp), cn) | |||
.rename(name), | |||
dtype); | |||
}; | |||
auto x = mkvar("x", {20, 50, 50, 16}, dtype::Quantized8Asymm(2.5f, static_cast<uint8_t>(0))); | |||
auto w = mkvar("w", {24, 3, 3, 16}, dtype::Quantized8Asymm(2.5f, static_cast<uint8_t>(0))); | |||
auto x = mkvar("x", {20, 50, 50, 16}, | |||
dtype::Quantized8Asymm(2.5f, static_cast<uint8_t>(0))); | |||
auto w = mkvar("w", {24, 3, 3, 16}, | |||
dtype::Quantized8Asymm(2.5f, static_cast<uint8_t>(0))); | |||
auto bias = mkvar("bias", {1, 1, 1, 24}, dtype::QuantizedS32(6.25f)); | |||
param.nonlineMode = Param::NonlineMode::RELU; | |||