GitOrigin-RevId: 7a427bdab4
HuaHua404-patch-1
@@ -120,22 +120,13 @@ InFilePersistentCache::BlobStorage& InFilePersistentCache::BlobStorage::init_fro | |||||
return *this; | return *this; | ||||
} | } | ||||
template <typename OutputFile> | |||||
void InFilePersistentCache::BlobStorage::write_to_file(OutputFile& out_file) const { | void InFilePersistentCache::BlobStorage::write_to_file(OutputFile& out_file) const { | ||||
uint32_t u_size = size; | uint32_t u_size = size; | ||||
out_file.write(u_size); | out_file.write(u_size); | ||||
out_file.write(data_refhold.get(), u_size); | out_file.write(data_refhold.get(), u_size); | ||||
} | } | ||||
InFilePersistentCache::BlobStorage& InFilePersistentCache::BlobStorage::init_data_ref( | |||||
const Blob& b) { | |||||
data_refhold = std::make_unique<uint8_t[]>(b.size + 1); | |||||
memcpy(data_refhold.get(), b.ptr, b.size); | |||||
data_refhold.get()[b.size] = 0; // for C-string safety | |||||
ptr = data_refhold.get(); | |||||
size = b.size; | |||||
return *this; | |||||
} | |||||
//////////////////////// InFilePersistentCache ////////////////////// | //////////////////////// InFilePersistentCache ////////////////////// | ||||
template <typename Input> | template <typename Input> | ||||
@@ -21,10 +21,62 @@ std::shared_ptr<PersistentCache> PersistentCache::sm_impl = | |||||
std::shared_ptr<PersistentCache> PersistentCache::set_impl( | std::shared_ptr<PersistentCache> PersistentCache::set_impl( | ||||
std::shared_ptr<PersistentCache> impl) { | std::shared_ptr<PersistentCache> impl) { | ||||
mgb_assert(impl); | mgb_assert(impl); | ||||
merge_old_cache(impl); | |||||
sm_impl.swap(impl); | sm_impl.swap(impl); | ||||
return impl; | return impl; | ||||
} | } | ||||
void PersistentCache::merge_old_cache(std::shared_ptr<PersistentCache> impl) { | |||||
MGB_LOCK_GUARD(PersistentCache::inst().m_mtx); | |||||
if (sm_impl) { | |||||
auto& old_cache = sm_impl->m_cache; | |||||
if (old_cache.size() > 0) { | |||||
mgb_log_debug("find old persistent cache, now append to it!!"); | |||||
auto& new_cache = impl->m_cache; | |||||
CacheMap tmp_cache; | |||||
//! CacheMap do not imp deepcopy and = operator, so we insert manually | |||||
auto insert = [](CacheMap& dst, CacheMap& in) { | |||||
for (auto& x : in) { | |||||
auto category = x.first; | |||||
for (auto& y : x.second) { | |||||
auto& key = y.first; | |||||
auto& value = y.second; | |||||
BlobStorage key_storage; | |||||
key_storage.init_data_ref(key).init_hash(); | |||||
dst[category][std::move(key_storage)].init_data_ref(value); | |||||
} | |||||
} | |||||
}; | |||||
insert(tmp_cache, old_cache); | |||||
insert(tmp_cache, new_cache); | |||||
impl->m_cache = std::move(tmp_cache); | |||||
} else { | |||||
mgb_log_debug("do not find any old persistent cache"); | |||||
} | |||||
} | |||||
} | |||||
PersistentCache::BlobStorage& PersistentCache::BlobStorage::init_data_ref( | |||||
const Blob& b) { | |||||
data_refhold = std::make_unique<uint8_t[]>(b.size + 1); | |||||
memcpy(data_refhold.get(), b.ptr, b.size); | |||||
data_refhold.get()[b.size] = 0; // for C-string safety | |||||
ptr = data_refhold.get(); | |||||
size = b.size; | |||||
return *this; | |||||
} | |||||
PersistentCache::BlobStorage& PersistentCache::BlobStorage::init_hash() { | |||||
hash = XXHash{}.update(ptr, size).digest(); | |||||
return *this; | |||||
} | |||||
bool PersistentCache::BlobStorage::operator==(const BlobStorage& rhs) const { | |||||
return size == rhs.size && !memcmp(ptr, rhs.ptr, size); | |||||
} | |||||
std::string PersistentCache::make_category_from_comp_node(CompNode comp_node) { | std::string PersistentCache::make_category_from_comp_node(CompNode comp_node) { | ||||
auto&& env = CompNodeEnv::from_comp_node(comp_node); | auto&& env = CompNodeEnv::from_comp_node(comp_node); | ||||
switch (env.property().type) { | switch (env.property().type) { | ||||
@@ -65,26 +117,6 @@ std::string PersistentCache::make_category_from_comp_node(CompNode comp_node) { | |||||
// ================= InMemoryPersistentCache ================== | // ================= InMemoryPersistentCache ================== | ||||
using Blob = PersistentCache::Blob; | using Blob = PersistentCache::Blob; | ||||
InMemoryPersistentCache::BlobStorage& InMemoryPersistentCache::BlobStorage:: | |||||
init_data_ref(const Blob& b) { | |||||
data_refhold = std::make_unique<uint8_t[]>(b.size + 1); | |||||
memcpy(data_refhold.get(), b.ptr, b.size); | |||||
data_refhold.get()[b.size] = 0; // for C-string safety | |||||
ptr = data_refhold.get(); | |||||
size = b.size; | |||||
return *this; | |||||
} | |||||
InMemoryPersistentCache::BlobStorage& InMemoryPersistentCache::BlobStorage:: | |||||
init_hash() { | |||||
hash = XXHash{}.update(ptr, size).digest(); | |||||
return *this; | |||||
} | |||||
bool InMemoryPersistentCache::BlobStorage::operator==(const BlobStorage& rhs) const { | |||||
return size == rhs.size && !memcmp(ptr, rhs.ptr, size); | |||||
} | |||||
Maybe<Blob> InMemoryPersistentCache::get(const std::string& category, const Blob& key) { | Maybe<Blob> InMemoryPersistentCache::get(const std::string& category, const Blob& key) { | ||||
decltype(m_cache.begin()) iter0; | decltype(m_cache.begin()) iter0; | ||||
{ | { | ||||
@@ -17,33 +17,6 @@ class InFilePersistentCache final : public PersistentCache { | |||||
class InputFile; | class InputFile; | ||||
class InputMemory; | class InputMemory; | ||||
class OutputFile; | class OutputFile; | ||||
struct BlobStorage : public Blob { | |||||
std::unique_ptr<uint8_t[]> data_refhold; | |||||
size_t hash = 0; | |||||
template <typename Input> | |||||
BlobStorage& init_from_input(Input& inp); | |||||
void write_to_file(OutputFile& out_file) const; | |||||
BlobStorage& init_data_ref(const Blob& b); | |||||
BlobStorage& init_hash() { | |||||
hash = XXHash{}.update(ptr, size).digest(); | |||||
return *this; | |||||
} | |||||
bool operator==(const BlobStorage& rhs) const { | |||||
return size == rhs.size && !memcmp(ptr, rhs.ptr, size); | |||||
} | |||||
struct Hash { | |||||
size_t operator()(const BlobStorage& b) const { return b.hash; } | |||||
}; | |||||
}; | |||||
std::unordered_map< | |||||
std::string, | |||||
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>> | |||||
m_cache; | |||||
MGB_MUTEX m_mtx; | |||||
std::shared_ptr<OutputFile> m_always_open_file; | std::shared_ptr<OutputFile> m_always_open_file; | ||||
template <typename Input> | template <typename Input> | ||||
@@ -68,13 +41,6 @@ public: | |||||
MGE_WIN_DECLSPEC_FUC void put( | MGE_WIN_DECLSPEC_FUC void put( | ||||
const std::string& category, const Blob& key, const Blob& value) override; | const std::string& category, const Blob& key, const Blob& value) override; | ||||
bool support_dump_cache() override { return true; } | bool support_dump_cache() override { return true; } | ||||
std::unordered_map< | |||||
std::string, | |||||
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>> | |||||
get_cache() { | |||||
return std::move(m_cache); | |||||
} | |||||
}; | }; | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -23,6 +23,41 @@ public: | |||||
size_t size; | size_t size; | ||||
}; | }; | ||||
struct BlobStorage : public Blob { | |||||
std::unique_ptr<uint8_t[]> data_refhold; | |||||
size_t hash = 0; | |||||
BlobStorage& init_data_ref(const Blob& b); | |||||
BlobStorage& init_hash(); | |||||
bool operator==(const BlobStorage& rhs) const; | |||||
struct Hash { | |||||
size_t operator()(const BlobStorage& b) const { return b.hash; } | |||||
}; | |||||
template <typename Input> | |||||
BlobStorage& init_from_input(Input& inp); | |||||
template <typename OutputFile> | |||||
void write_to_file(OutputFile& out_file) const; | |||||
}; | |||||
typedef std::unordered_map< | |||||
std::string, | |||||
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>> | |||||
CacheMap; | |||||
CacheMap m_cache; | |||||
MGB_MUTEX m_mtx; | |||||
//! will make m_cache empty | |||||
CacheMap get_cache() { return std::move(m_cache); } | |||||
//! clear cache | |||||
MGE_WIN_DECLSPEC_FUC void clear_cache() { m_cache.clear(); } | |||||
virtual Maybe<Blob> get(const std::string& category, const Blob& key) = 0; | virtual Maybe<Blob> get(const std::string& category, const Blob& key) = 0; | ||||
virtual void put( | virtual void put( | ||||
@@ -34,6 +69,9 @@ public: | |||||
MGE_WIN_DECLSPEC_FUC static std::shared_ptr<PersistentCache> set_impl( | MGE_WIN_DECLSPEC_FUC static std::shared_ptr<PersistentCache> set_impl( | ||||
std::shared_ptr<PersistentCache> impl); | std::shared_ptr<PersistentCache> impl); | ||||
//! merge sm_impl m_cache, use to append insert cache | |||||
MGE_WIN_DECLSPEC_FUC static void merge_old_cache(std::shared_ptr<PersistentCache>); | |||||
//! get the instance; the default implementation just caches in | //! get the instance; the default implementation just caches in | ||||
//! memory | //! memory | ||||
static PersistentCache& inst() { return *sm_impl; } | static PersistentCache& inst() { return *sm_impl; } | ||||
@@ -48,32 +86,11 @@ public: | |||||
* The implementation is thread safe. | * The implementation is thread safe. | ||||
*/ | */ | ||||
class InMemoryPersistentCache final : public PersistentCache { | class InMemoryPersistentCache final : public PersistentCache { | ||||
struct BlobStorage : public PersistentCache::Blob { | |||||
std::unique_ptr<uint8_t[]> data_refhold; | |||||
size_t hash = 0; | |||||
BlobStorage& init_data_ref(const Blob& b); | |||||
BlobStorage& init_hash(); | |||||
bool operator==(const BlobStorage& rhs) const; | |||||
struct Hash { | |||||
size_t operator()(const BlobStorage& b) const { return b.hash; } | |||||
}; | |||||
}; | |||||
MGE_WIN_DECLSPEC_FUC Maybe<Blob> get( | MGE_WIN_DECLSPEC_FUC Maybe<Blob> get( | ||||
const std::string& category, const Blob& key) override; | const std::string& category, const Blob& key) override; | ||||
MGE_WIN_DECLSPEC_FUC void put( | MGE_WIN_DECLSPEC_FUC void put( | ||||
const std::string& category, const Blob& key, const Blob& value) override; | const std::string& category, const Blob& key, const Blob& value) override; | ||||
std::unordered_map< | |||||
std::string, | |||||
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>> | |||||
m_cache; | |||||
MGB_MUTEX m_mtx; | |||||
public: | public: | ||||
MGE_WIN_DECLSPEC_FUC InMemoryPersistentCache() = default; | MGE_WIN_DECLSPEC_FUC InMemoryPersistentCache() = default; | ||||
}; | }; | ||||
@@ -9,6 +9,7 @@ | |||||
#include "megbrain/test/autocheck.h" | #include "megbrain/test/autocheck.h" | ||||
#include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
#include "megbrain/test/megdnn_helper.h" | #include "megbrain/test/megdnn_helper.h" | ||||
#include "megbrain/utils/infile_persistent_cache.h" | |||||
#include "megdnn/algorithm_cache.h" | #include "megdnn/algorithm_cache.h" | ||||
#include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
#include "megdnn/oprs/base.h" | #include "megdnn/oprs/base.h" | ||||
@@ -354,6 +355,10 @@ TEST(TestOprDNN, ConvBiasExePolicy) { | |||||
HostTensorND host_y; | HostTensorND host_y; | ||||
auto func = graph->compile({make_callback_copy(conv_bias, host_y)}); | auto func = graph->compile({make_callback_copy(conv_bias, host_y)}); | ||||
func->execute(); | func->execute(); | ||||
//! force clear all PersistentCache by get_cache | |||||
PersistentCache::inst().clear_cache(); | |||||
size_t old_size = PersistentCache::inst().get_cache().size(); | |||||
ASSERT_EQ(old_size, 0); | |||||
//! set a new cache | //! set a new cache | ||||
PersistentCache::set_impl(std::make_shared<InMemoryPersistentCache>()); | PersistentCache::set_impl(std::make_shared<InMemoryPersistentCache>()); | ||||
}; | }; | ||||
@@ -372,6 +377,64 @@ TEST(TestOprDNN, ConvBiasExePolicy) { | |||||
PersistentCache::set_impl(orig_impl); | PersistentCache::set_impl(orig_impl); | ||||
} | } | ||||
TEST(TestOprDNN, PersistentCacheAppend) { | |||||
PersistentCache::inst().clear_cache(); | |||||
auto orig_impl = | |||||
PersistentCache::set_impl(std::make_shared<InMemoryPersistentCache>()); | |||||
auto orig_impl_size = orig_impl->get_cache().size(); | |||||
auto category_a = "test_category_a"; | |||||
std::vector<int8_t> blob_key{1, 2, 3, 4, 5, 6, 7, 8}; | |||||
std::vector<int8_t> blob_value{-1, -2, -3, -4, -5, -6, -7, -8}; | |||||
PersistentCache::Blob key = {.ptr = blob_key.data(), .size = blob_key.size()}; | |||||
PersistentCache::Blob value = {.ptr = blob_value.data(), .size = blob_value.size()}; | |||||
//! trigger call InMemoryPersistentCache put | |||||
PersistentCache::inst().put(category_a, key, value); | |||||
auto now_size = PersistentCache::inst().get_cache().size(); | |||||
//! assert new key not in InMemoryPersistentCache imp | |||||
ASSERT_EQ(orig_impl_size + 1, now_size); | |||||
//! trigger append call InFilePersistentCache init | |||||
PersistentCache::set_impl(std::make_shared<InFilePersistentCache>()); | |||||
auto size_after_restore = PersistentCache::inst().get_cache().size(); | |||||
//! assert key not in InFilePersistentCache imp | |||||
//! as memory instance do cache do not sync cache to file | |||||
ASSERT_EQ(size_after_restore, orig_impl_size); | |||||
auto t_file_imp = std::make_shared<InFilePersistentCache>(); | |||||
auto category_b = "test_category_b"; | |||||
//! trigger call InFilePersistentCache put | |||||
t_file_imp->put(category_b, key, value); | |||||
//! set new file imp | |||||
PersistentCache::set_impl(t_file_imp); | |||||
//! trigger InFilePersistentCache append init | |||||
auto old_cache = | |||||
PersistentCache::set_impl(std::make_shared<InFilePersistentCache>()); | |||||
//! assert set_impl return old cache exactly | |||||
ASSERT_EQ(old_cache->m_cache.size(), now_size); | |||||
//! test key get | |||||
auto get_value = PersistentCache::inst().get(category_b, key); | |||||
ASSERT_TRUE( | |||||
!memcmp(get_value.val().ptr, blob_value.data(), | |||||
blob_value.size() * sizeof(int8_t))); | |||||
size_after_restore = PersistentCache::inst().get_cache().size(); | |||||
//! assert key still in orig_impl imp | |||||
ASSERT_EQ(size_after_restore, now_size); | |||||
//! restore old impl, may memory or file, trigger may memory append init | |||||
PersistentCache::set_impl(orig_impl); | |||||
size_after_restore = PersistentCache::inst().get_cache().size(); | |||||
//! assert key not in orig_impl imp, caused by get_cache will clear m_cache | |||||
ASSERT_EQ(size_after_restore + 1, now_size); | |||||
} | |||||
TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) { | TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) { | ||||
using Param = opr::ConvBias::Param; | using Param = opr::ConvBias::Param; | ||||
Param param; | Param param; | ||||