@@ -160,7 +160,7 @@ private: | |||
template <typename TItem> | |||
void register_converter() { | |||
m_table[typeid(TItem)] = [](const any_t& input) { | |||
return variant_t(*input.as<TItem>()); | |||
return variant_t(input.cast<TItem>()); | |||
}; | |||
} | |||
@@ -11,7 +11,6 @@ | |||
#pragma once | |||
#include <any> | |||
#include <bitset> | |||
#include <chrono> | |||
#include <deque> | |||
@@ -28,6 +27,7 @@ | |||
#include "megbrain/imperative/op_def.h" | |||
#include "megbrain/imperative/physical_tensor.h" | |||
#include "megbrain/imperative/utils/any.h" | |||
namespace mgb { | |||
namespace imperative { | |||
@@ -51,48 +51,6 @@ public: | |||
static std::shared_ptr<CompNode::Event> record_device(CompNode device); | |||
}; | |||
class AnyPtr { | |||
public: | |||
struct Deleter { | |||
void* object; | |||
void (*method)(void*, void*); | |||
void operator()(void* ptr) { method(object, ptr); } | |||
}; | |||
private: | |||
using holder_t = std::unique_ptr<void, Deleter>; | |||
const std::type_info* m_type = nullptr; | |||
holder_t m_holder = nullptr; | |||
public: | |||
AnyPtr() = default; | |||
template < | |||
typename T, | |||
typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, AnyPtr>>> | |||
explicit AnyPtr(T* value, Deleter deleter) { | |||
m_type = &typeid(T); | |||
m_holder = {value, deleter}; | |||
} | |||
template <typename T> | |||
T* as() { | |||
mgb_assert(is_exactly<T>(), "type mismatch"); | |||
return reinterpret_cast<T*>(m_holder.get()); | |||
} | |||
template <typename T> | |||
const T* as() const { | |||
mgb_assert(is_exactly<T>(), "type mismatch"); | |||
return reinterpret_cast<const T*>(m_holder.get()); | |||
} | |||
template <typename T> | |||
bool is_exactly() const { | |||
return std::type_index{typeid(T)} == std::type_index{*m_type}; | |||
} | |||
const std::type_info& type() const { return *m_type; } | |||
bool operator==(std::nullptr_t nptr) const { return m_holder == nullptr; } | |||
operator bool() const { return m_holder != nullptr; } | |||
}; | |||
class Profiler { | |||
public: | |||
struct Record { | |||
@@ -128,7 +86,6 @@ private: | |||
std::thread::id m_thread_id; | |||
std::vector<Record> m_records; | |||
std::atomic<Status> m_status = Running; | |||
std::unordered_map<std::type_index, AnyPtr> m_mem_pools; | |||
static std::vector<entry_t> sm_records; | |||
static options_t sm_profile_options; | |||
@@ -161,42 +118,21 @@ public: | |||
return *tm_profiler; | |||
} | |||
template <typename T> | |||
static MemPool<T>& get_mem_pool() { | |||
thread_local MemPool<T>* t_pool = nullptr; | |||
if (t_pool == nullptr) { | |||
auto& pool = get_instance().m_mem_pools[typeid(MemPool<T>)]; | |||
if (pool == nullptr) { | |||
pool = | |||
AnyPtr(new MemPool<T>(), | |||
{nullptr, [](void*, void* ptr) { | |||
delete reinterpret_cast<MemPool<T>*>(ptr); | |||
}}); | |||
} | |||
t_pool = pool.as<MemPool<T>>(); | |||
} | |||
return *t_pool; | |||
} | |||
static uint64_t next_id() { return sm_last_id++; } | |||
template <typename T, typename... TArgs> | |||
static uint64_t record(TArgs&&... args) { | |||
auto& profiler = get_instance(); | |||
auto& mem_pool = get_mem_pool<T>(); | |||
// auto& mem_pool = get_mem_pool<T>(); | |||
if constexpr (sm_debug) { | |||
Status expected = Running; | |||
mgb_assert(profiler.m_status.compare_exchange_strong(expected, Recording)); | |||
} | |||
uint64_t id = next_id(); | |||
profiler::Time time = sm_timer.record_host(); | |||
auto deleter = [](void* obj, void* ptr) { | |||
reinterpret_cast<MemPool<T>*>(obj)->free(reinterpret_cast<T*>(ptr)); | |||
}; | |||
profiler.m_records.emplace_back( | |||
id, profiler.m_thread_id, time, | |||
AnyPtr{mem_pool.alloc(T{std::forward<TArgs>(args)...}), | |||
{&mem_pool, deleter}}); | |||
AnyPtr::make<T>(T{std::forward<TArgs&&>(args)...})); | |||
if constexpr (sm_debug) { | |||
Status expected = Recording; | |||
mgb_assert(profiler.m_status.compare_exchange_strong(expected, Running)); | |||
@@ -241,7 +177,7 @@ public: | |||
bundle.options = get_options(); | |||
bundle.start_at = sm_start_at; | |||
bundle.thread_dict = get_thread_dict(); | |||
return std::move(bundle); | |||
return bundle; | |||
} | |||
static option_t get_option(std::string key, option_t default_val) { | |||
@@ -0,0 +1,71 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/utils/allocator.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <typeindex> | |||
#include "megbrain/utils/mempool.h" | |||
#include "megbrain/utils/metahelper.h" | |||
namespace mgb::imperative { | |||
template <typename T> | |||
class Allocator { | |||
public: | |||
using pointer = T*; | |||
using const_pointer = const T*; | |||
using void_pointer = void*; | |||
using const_void_pointer = const void*; | |||
using value_type = T; | |||
using size_type = std::size_t; | |||
using diffenence_type = std::ptrdiff_t; | |||
using pool_type = MemPoolStorage; | |||
private: | |||
pool_type* m_pool = nullptr; | |||
public: | |||
Allocator(pool_type* pool) : m_pool(pool) {} | |||
T* allocate(size_type n) { | |||
mgb_assert(n == 1); | |||
return m_pool->alloc(sizeof(T)); | |||
} | |||
void deallocate(pointer* p, size_type n) { | |||
mgb_assert(n == 1); | |||
m_pool->free(p); | |||
} | |||
bool operator==(const Allocator& rhs) const { return m_pool == rhs.m_pool; } | |||
bool operator!=(const Allocator& rhs) const { return m_pool != rhs.m_pool; } | |||
}; | |||
template <typename T> | |||
class ThreadLocalAllocatorAdapter { | |||
public: | |||
using value_type = T; | |||
using size_type = std::size_t; | |||
using pointer = T*; | |||
public: | |||
T* allocate(size_type n) { mgb_assert(false); } | |||
void deallocate(pointer* p, size_type n) { mgb_assert(false); } | |||
bool operator==(const ThreadLocalAllocatorAdapter& rhs) const { return true; } | |||
bool operator!=(const ThreadLocalAllocatorAdapter& rhs) const { return false; } | |||
}; | |||
} // namespace mgb::imperative |
@@ -0,0 +1,70 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/utils/any.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <typeindex> | |||
#include "megbrain/imperative/utils/local_ptr.h" | |||
namespace mgb::imperative { | |||
class AnyMixinBase { | |||
private: | |||
const std::type_info* m_type = nullptr; | |||
public: | |||
AnyMixinBase() = default; | |||
const std::type_info& type() const { return *m_type; } | |||
friend class AnyPtr; | |||
}; | |||
template <typename T> | |||
class AnyMixin : public AnyMixinBase, public T { | |||
public: | |||
AnyMixin(T&& val) : T(std::move(val)) {} | |||
}; | |||
class AnyPtr { | |||
public: | |||
using storage_t = LocalPtr<AnyMixinBase>; | |||
private: | |||
storage_t m_storage; | |||
public: | |||
const std::type_info& type() const { return m_storage->type(); } | |||
template <typename T> | |||
const T& cast() const { | |||
mgb_assert(is_exactly<T>(), "type mismatch"); | |||
return *static_cast<const AnyMixin<T>*>(m_storage.get()); | |||
} | |||
template <typename T> | |||
bool is_exactly() const { | |||
return std::type_index{typeid(T)} == std::type_index{type()}; | |||
} | |||
bool operator==(std::nullptr_t nptr) const { return m_storage == nullptr; } | |||
bool operator!=(std::nullptr_t nptr) const { return m_storage != nullptr; } | |||
operator bool() const { return m_storage != nullptr; } | |||
template <typename T, typename... TArgs> | |||
static AnyPtr make(TArgs&&... args) { | |||
AnyPtr ret; | |||
ret.m_storage = LocalPtr<AnyMixinBase>::make<AnyMixin<T>>( | |||
std::forward<TArgs&&>(args)...); | |||
ret.m_storage->m_type = &typeid(T); | |||
return ret; | |||
} | |||
}; | |||
} // namespace mgb::imperative |
@@ -0,0 +1,96 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/utils/visit.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <chrono> | |||
#include <future> | |||
#include <vector> | |||
#include "megbrain/utils/metahelper.h" | |||
#include "megbrain/utils/small_vector.h" | |||
namespace mgb::imperative { | |||
class BoxBase : public NonCopyableObj { | |||
public: | |||
virtual void reset() = 0; | |||
virtual void set_exception(std::exception_ptr exc) = 0; | |||
virtual bool try_set_exception(std::exception_ptr exc) = 0; | |||
}; | |||
/** | |||
* \brief An reusable promise | |||
* | |||
* \tparam T type of value | |||
*/ | |||
template <typename T> | |||
class Box final : public BoxBase { | |||
private: | |||
std::promise<T> m_promise; | |||
std::shared_future<T> m_future; | |||
std::mutex m_mutex; | |||
bool m_value_set; | |||
bool m_exception_set; | |||
public: | |||
Box() { reset(); } | |||
const T& get_value() { return m_future.get(); } | |||
T take_value() { | |||
T value = m_future.get(); | |||
reset(); | |||
return value; | |||
} | |||
void set_value(T value) { | |||
MGB_LOCK_GUARD(m_mutex); | |||
m_promise.set_value(std::move(value)); | |||
m_value_set = true; | |||
} | |||
bool try_set_value(T value) { | |||
MGB_LOCK_GUARD(m_mutex); | |||
if (m_exception_set) { | |||
return false; | |||
} | |||
m_promise.set_value(std::move(value)); | |||
m_value_set = true; | |||
return true; | |||
} | |||
void set_exception(std::exception_ptr exc) override { | |||
MGB_LOCK_GUARD(m_mutex); | |||
m_promise.set_exception(exc); | |||
m_exception_set = true; | |||
} | |||
bool try_set_exception(std::exception_ptr exc) override { | |||
MGB_LOCK_GUARD(m_mutex); | |||
if (m_value_set) { | |||
return false; | |||
} | |||
m_promise.set_exception(exc); | |||
m_exception_set = true; | |||
return true; | |||
} | |||
void reset() override { | |||
MGB_LOCK_GUARD(m_mutex); | |||
m_promise = {}; | |||
m_future = m_promise.get_future(); | |||
m_value_set = false; | |||
m_exception_set = false; | |||
} | |||
/** | |||
* \brief make an empty box | |||
* | |||
* \return std::shared_ptr<Box> | |||
*/ | |||
static std::shared_ptr<Box> make() { return std::make_shared<Box>(); } | |||
}; | |||
} // namespace mgb::imperative |
@@ -0,0 +1,40 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/utils/span.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <iomanip> | |||
#include <memory> | |||
#include <sstream> | |||
namespace mgb { | |||
namespace imperative { | |||
template <typename T> | |||
class CleanupGuard { | |||
private: | |||
T m_callback; | |||
public: | |||
explicit CleanupGuard(T cb) : m_callback{std::move(cb)} {} | |||
~CleanupGuard() { m_callback(); } | |||
}; | |||
inline std::string quoted(std::string str) { | |||
std::stringstream ss; | |||
ss << std::quoted(str); | |||
return ss.str(); | |||
} | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -0,0 +1,245 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/utils/intrusive_list.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/utils/metahelper.h" | |||
namespace mgb::imperative::utils::intrusive_list { | |||
// copy policy | |||
struct after_t {}; | |||
struct before_t {}; | |||
struct disable_t {}; | |||
template <typename T> | |||
struct Tail; | |||
// invariant: next->prev == this | |||
template <typename T> | |||
struct Head { | |||
Tail<T>* next; | |||
Head(Tail<T>* node = nullptr) : next(node) {} | |||
Head(const Head<T>&) = delete; | |||
Head<T>& operator=(const Head<T>&) = delete; | |||
Head(Head<T>&& rhs) : next(rhs.next) { | |||
rhs.next = nullptr; | |||
if (next) { | |||
next->prev = this; | |||
} | |||
} | |||
Head<T>& operator=(Head<T>&& rhs) { | |||
mgb_assert(!next); | |||
next = rhs.next; | |||
rhs.next = nullptr; | |||
if (next) { | |||
next->prev = this; | |||
} | |||
return *this; | |||
} | |||
~Head() { | |||
if (next) { | |||
next->prev = nullptr; | |||
} | |||
} | |||
}; | |||
// invariant: prev->next == this | |||
template <typename T> | |||
struct Tail { | |||
Head<T>* prev; | |||
Tail(Head<T>* node = nullptr) : prev(node) {} | |||
Tail(const Tail<T>&) = delete; | |||
Tail<T>& operator=(const Tail<T>&) = delete; | |||
Tail(Tail<T>&& rhs) : prev(rhs.prev) { | |||
rhs.prev = nullptr; | |||
if (prev) { | |||
prev->next = this; | |||
} | |||
} | |||
Tail<T>& operator=(Tail<T>&& rhs) { | |||
mgb_assert(!prev); | |||
prev = rhs.prev; | |||
rhs.prev = nullptr; | |||
if (prev) { | |||
prev->next = this; | |||
} | |||
return *this; | |||
} | |||
~Tail() { | |||
if (prev) { | |||
prev->next = nullptr; | |||
} | |||
} | |||
}; | |||
template <typename T, typename policy> | |||
struct Node; | |||
template <typename T> | |||
class Iterator { | |||
T* ptr; | |||
void inc() { ptr = static_cast<T*>(ptr->Head<T>::next); } | |||
void dec() { ptr = static_cast<T*>(ptr->Head<T>::prev); } | |||
public: | |||
Iterator(Head<T>& head) : ptr(static_cast<T*>(head.next)) {} | |||
Iterator(Tail<T>& tail) : ptr(static_cast<T*>(tail.prev)) {} | |||
template <typename policy> | |||
Iterator(Node<T, policy>& node) : ptr(static_cast<T*>(&node)) {} | |||
T& operator*() { return *static_cast<T*>(ptr); } | |||
T* operator->() { return static_cast<T*>(ptr); } | |||
operator bool() { return ptr; } | |||
bool operator==(const Iterator<T>& rhs) { return ptr == rhs.ptr; } | |||
Iterator& operator++() { | |||
inc(); | |||
return *this; | |||
} | |||
Iterator& operator--() { | |||
dec(); | |||
return *this; | |||
} | |||
Iterator operator++(int) { | |||
auto ret = *this; | |||
inc(); | |||
return ret; | |||
} | |||
Iterator operator--(int) { | |||
auto ret = *this; | |||
dec(); | |||
return ret; | |||
} | |||
}; | |||
// Node in a doubly linked list. Unlike std::list, nodes are not owned by a container. | |||
// Instead, nodes may join or leave a list freely. | |||
// NOTE: Derived classes have to explicitly declare copy / assignment as default, | |||
// otherwise the compiler generated version would use the const T& signature, | |||
// which is deleted. | |||
template <typename T = void, typename policy = disable_t> | |||
struct Node : Tail<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>>, | |||
Head<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>> { | |||
private: | |||
using this_t = Node<T, policy>; | |||
using U = std::conditional_t<std::is_same_v<T, void>, this_t, T>; | |||
public: | |||
using head_t = Head<U>; | |||
using tail_t = Tail<U>; | |||
using head_t::next; | |||
using tail_t::prev; | |||
Node() = default; | |||
Node(const this_t&) = delete; | |||
this_t& operator=(const this_t&) = delete; | |||
//! constructed node is inserted after the input node | |||
Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) { | |||
node.next = this; | |||
if (next) { | |||
next->prev = this; | |||
} | |||
} | |||
//! constructed node is inserted before the input node | |||
Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) { | |||
node.prev = this; | |||
if (prev) { | |||
prev->next = this; | |||
} | |||
} | |||
Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) { | |||
rhs.prev = nullptr; | |||
rhs.next = nullptr; | |||
if (prev) { | |||
prev->next = this; | |||
} | |||
if (next) { | |||
next->prev = this; | |||
} | |||
} | |||
Node& operator=(this_t&& rhs) { | |||
unlink(); | |||
prev = rhs.prev; | |||
next = rhs.next; | |||
rhs.prev = nullptr; | |||
rhs.next = nullptr; | |||
if (prev) { | |||
prev->next = this; | |||
} | |||
if (next) { | |||
next->prev = this; | |||
} | |||
return *this; | |||
} | |||
template < | |||
typename p = policy, | |||
typename = std::enable_if_t< | |||
std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||
Node(this_t& rhs) : Node(policy{}, rhs) {} | |||
template < | |||
typename p = policy, | |||
typename = std::enable_if_t< | |||
std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||
this_t& operator=(this_t& rhs) { | |||
insert(policy{}, rhs); | |||
return *this; | |||
} | |||
void unlink() { | |||
if (prev) { | |||
prev->next = next; | |||
} | |||
if (next) { | |||
next->prev = prev; | |||
} | |||
prev = nullptr; | |||
next = nullptr; | |||
} | |||
//! this node is unlinked from its list and inserted after the input node | |||
void insert(after_t, head_t& node) { | |||
unlink(); | |||
prev = &node; | |||
next = node.next; | |||
node.next = this; | |||
if (next) { | |||
next->prev = this; | |||
} | |||
} | |||
//! this node is unlinked from its list and inserted before the input node | |||
void insert(before_t, tail_t& node) { | |||
unlink(); | |||
next = &node; | |||
prev = node.prev; | |||
node.prev = this; | |||
if (prev) { | |||
prev->next = this; | |||
} | |||
} | |||
void insert_before(tail_t& node) { insert(before_t{}, node); } | |||
void insert_after(head_t& node) { insert(after_t{}, node); } | |||
~Node() { unlink(); } | |||
}; | |||
} // namespace mgb::imperative::utils::intrusive_list |
@@ -0,0 +1,285 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/utils/local_ptr.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <optional> | |||
#include "megbrain/imperative/utils/mempool.h" | |||
#include "megbrain/utils/metahelper.h" | |||
namespace mgb::imperative { | |||
template <typename T> | |||
class LocalPtrStorage : public NonCopyableObj { | |||
private: | |||
size_t m_ref_count = 0; | |||
size_t m_weak_count = 0; | |||
T* m_pointer = nullptr; | |||
void (*reset)(LocalPtrStorage*) = nullptr; | |||
void (*free)(LocalPtrStorage*) = nullptr; | |||
void inc_ref() { m_ref_count++; } | |||
void dec_ref() { | |||
m_ref_count--; | |||
if (m_ref_count == 0) { | |||
reset(this); | |||
m_pointer = nullptr; | |||
reset = nullptr; | |||
if (m_weak_count == 0) { | |||
free(this); | |||
// dead | |||
} | |||
} | |||
} | |||
void inc_weak_ref() { m_weak_count++; } | |||
void dec_weak_ref() { | |||
m_weak_count--; | |||
if ((m_weak_count + m_ref_count) == 0) { | |||
free(this); | |||
// dead | |||
} | |||
} | |||
template <typename U> | |||
friend class LocalPtr; | |||
template <typename U> | |||
friend class LocalWeakPtr; | |||
public: | |||
}; | |||
template <typename T, typename TDerived> | |||
class LocalPtrStorgeImpl : public LocalPtrStorage<T> { | |||
private: | |||
std::optional<TDerived> m_value; | |||
void* m_pool = nullptr; | |||
template <typename U> | |||
friend class LocalPtr; | |||
template <typename U> | |||
friend class LocalWeakPtr; | |||
}; | |||
template <typename T> | |||
class LocalWeakPtr; | |||
/** | |||
* \brief thread-unsafe smart pointer | |||
* | |||
* \tparam T type of value | |||
*/ | |||
template <typename T> | |||
class LocalPtr { | |||
public: | |||
using storage_t = LocalPtrStorage<T>; | |||
using pool_t = MemPool<storage_t>; | |||
using weak_type = LocalWeakPtr<T>; | |||
private: | |||
storage_t* m_storage = nullptr; | |||
void emplace(storage_t* ptr) { | |||
if (ptr) { | |||
ptr->inc_ref(); | |||
m_storage = ptr; | |||
} | |||
} | |||
LocalPtr(storage_t* ptr) { emplace(ptr); } | |||
public: | |||
LocalPtr() = default; | |||
LocalPtr(const LocalPtr& rhs) { (*this) = rhs; } | |||
LocalPtr(LocalPtr&& rhs) { (*this) = std::move(rhs); } | |||
LocalPtr& operator=(const LocalPtr& rhs) { | |||
if (this == &rhs) { | |||
return *this; | |||
} | |||
auto storage = rhs.m_storage; | |||
if (storage) { | |||
storage->inc_ref(); | |||
} | |||
if (m_storage) { | |||
m_storage->dec_ref(); | |||
// rhs.m_storage may be invalid here | |||
} | |||
m_storage = storage; | |||
return *this; | |||
} | |||
LocalPtr& operator=(LocalPtr&& rhs) { | |||
if (this == &rhs) { | |||
return *this; | |||
} | |||
std::swap(m_storage, rhs.m_storage); | |||
rhs.reset(); | |||
return *this; | |||
} | |||
bool operator==(const LocalPtr& rhs) const { return m_storage == rhs.m_storage; } | |||
bool operator!=(const LocalPtr& rhs) const { return m_storage != rhs.m_storage; } | |||
size_t hash() const { return reinterpret_cast<uintptr_t>(m_storage); } | |||
~LocalPtr() { reset(); } | |||
/** | |||
* \brief Construct an instance of TDerived and return an LocalPtr | |||
* | |||
* There is an memory pool for each (T, TDerived) pair | |||
* | |||
* \tparam TDerived type of concrete instance, should be subclass of T | |||
* \tparam TArgs | |||
* \param args constructor arguments | |||
* \return LocalPtr points to the instance | |||
*/ | |||
template <typename TDerived = T, typename... TArgs> | |||
static LocalPtr make(TArgs&&... args) { | |||
static_assert(std::is_base_of_v<T, TDerived>); | |||
using storage_impl_t = LocalPtrStorgeImpl<T, TDerived>; | |||
constexpr auto normalize_size = [](size_t size) { | |||
size_t normalized_size = 64; | |||
while (normalized_size < size) { | |||
normalized_size *= 2; | |||
} | |||
return normalized_size; | |||
}; | |||
using raw_storage_t = | |||
std::aligned_storage_t<normalize_size(sizeof(storage_impl_t))>; | |||
static_assert(alignof(raw_storage_t) % alignof(storage_impl_t) == 0); | |||
static_assert(sizeof(raw_storage_t) >= sizeof(storage_impl_t)); | |||
using pool_t = MemPool<raw_storage_t>; | |||
pool_t& pool = MemPoolUtils<raw_storage_t>::get_thread_local(); | |||
auto* raw_storage = pool.alloc_raw(); | |||
auto* storage = reinterpret_cast<storage_impl_t*>(raw_storage); | |||
new (storage) storage_impl_t(); | |||
storage->m_value.emplace(std::forward<TArgs&&>(args)...); | |||
storage->m_pointer = &*storage->m_value; | |||
storage->reset = [](storage_t* storage) { | |||
auto* storage_impl = static_cast<storage_impl_t*>(storage); | |||
storage_impl->m_value.reset(); | |||
storage_impl->m_pointer = nullptr; | |||
}; | |||
storage->free = [](storage_t* storage_base) { | |||
auto* storage = static_cast<storage_impl_t*>(storage_base); | |||
auto* pool = reinterpret_cast<pool_t*>(storage->m_pool); | |||
storage->m_pool = nullptr; | |||
storage->~storage_impl_t(); | |||
auto* raw_storage = reinterpret_cast<raw_storage_t*>(storage); | |||
pool->free_raw(raw_storage); | |||
}; | |||
storage->m_pool = &pool; | |||
return {(storage_t*)storage}; | |||
} | |||
T& operator*() const { return *get(); } | |||
T* get() const { | |||
if ((!m_storage) || !m_storage->m_pointer) { | |||
return nullptr; | |||
} | |||
return m_storage->m_pointer; | |||
} | |||
T* operator->() const { return get(); } | |||
size_t ref_count() const { return m_storage->m_ref_count; } | |||
bool unique() const { return ref_count() == 1; } | |||
void reset() { | |||
if (m_storage) { | |||
m_storage->dec_ref(); | |||
m_storage = nullptr; | |||
} | |||
} | |||
operator bool() const { return bool(m_storage); } | |||
bool operator==(std::nullptr_t nptr) const { return m_storage == nullptr; } | |||
bool operator!=(std::nullptr_t nptr) const { return m_storage != nullptr; } | |||
template <typename U> | |||
friend class LocalWeakPtr; | |||
}; | |||
template <typename T> | |||
class LocalWeakPtr { | |||
public: | |||
using storage_t = LocalPtrStorage<T>; | |||
private: | |||
storage_t* m_storage = nullptr; | |||
void emplace(storage_t* ptr) { | |||
if (ptr) { | |||
ptr->inc_weak_ref(); | |||
m_storage = ptr; | |||
} | |||
} | |||
public: | |||
LocalWeakPtr() = default; | |||
LocalWeakPtr(const LocalPtr<T>& rhs) { emplace(rhs.m_storage); } | |||
LocalWeakPtr(const LocalWeakPtr& rhs) { (*this) = rhs; } | |||
LocalWeakPtr(LocalWeakPtr&& rhs) { (*this) = std::move(rhs); } | |||
LocalWeakPtr& operator=(const LocalWeakPtr& rhs) { | |||
if (this == &rhs) { | |||
return *this; | |||
} | |||
reset(); | |||
emplace(rhs.m_storage); | |||
return *this; | |||
} | |||
LocalWeakPtr& operator=(LocalWeakPtr&& rhs) { | |||
if (this == &rhs) { | |||
return *this; | |||
} | |||
std::swap(m_storage, rhs.m_storage); | |||
rhs.reset(); | |||
return *this; | |||
} | |||
~LocalWeakPtr() { reset(); } | |||
void reset() { | |||
if (m_storage) { | |||
m_storage->dec_weak_ref(); | |||
m_storage = nullptr; | |||
} | |||
} | |||
LocalPtr<T> lock() const { | |||
if (m_storage && m_storage->m_ref_count) { | |||
return {m_storage}; | |||
} | |||
return {}; | |||
} | |||
bool operator==(const LocalWeakPtr& rhs) const { | |||
return m_storage == rhs.m_storage; | |||
} | |||
bool operator!=(const LocalWeakPtr& rhs) const { | |||
return m_storage != rhs.m_storage; | |||
} | |||
size_t hash() const { return reinterpret_cast<uintptr_t>(m_storage); } | |||
}; | |||
template <typename T, typename TDerived, typename... TArgs> | |||
LocalPtr<T> make_local(TArgs&&... args) { | |||
return LocalPtr<T>::template make<TDerived>(std::forward<TArgs&&>(args)...); | |||
} | |||
} // namespace mgb::imperative |
@@ -0,0 +1,157 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/utils/map.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <optional> | |||
#include "megbrain/utils/metahelper.h" | |||
namespace mgb::imperative { | |||
/** | |||
* \brief an hash map optimized for weak pointer as key | |||
* | |||
* Keys were scanned automatically, so values referenced by invalid keys whould be | |||
* released soon | |||
* | |||
* \tparam TKey key type, requires(bool(key.lock())) | |||
* \tparam TValue value type | |||
*/ | |||
template <typename TKey, typename TValue> | |||
class WeakKeyMap : public NonCopyableObj { | |||
public: | |||
using storage_t = std::unordered_map<TKey, TValue>; | |||
private: | |||
storage_t m_storage; | |||
typename storage_t::iterator m_cursor = m_storage.begin(); | |||
/** | |||
* \brief select a key and verify that whether it is invalid. If yes, erase it | |||
* | |||
*/ | |||
void _step() { | |||
if (m_cursor == m_storage.end()) { | |||
m_cursor = m_storage.begin(); | |||
return; | |||
} | |||
auto key = m_cursor->first; | |||
if (!key.lock()) { | |||
m_cursor = m_storage.erase(m_cursor); | |||
} else { | |||
++m_cursor; | |||
} | |||
} | |||
public: | |||
size_t count(TKey key) { | |||
_step(); | |||
_step(); | |||
return m_storage.count(key); | |||
} | |||
TValue& at(TKey key) const { return m_storage.at(key); } | |||
TValue& at(TKey key) { | |||
_step(); | |||
_step(); | |||
return m_storage.at(key); | |||
} | |||
TValue& operator[](TKey key) { | |||
_step(); | |||
_step(); | |||
if (m_storage.count(key)) { | |||
return m_storage.at(key); | |||
} else { | |||
size_t bucket_count = m_storage.bucket_count(); | |||
TValue& result = m_storage[key]; | |||
if (bucket_count != m_storage.bucket_count()) { | |||
m_cursor = m_storage.begin(); | |||
} | |||
return result; | |||
} | |||
} | |||
std::optional<TValue> try_get(TKey key) const { | |||
auto iter = m_storage.find(key); | |||
if (iter == m_storage.end()) { | |||
return {}; | |||
} | |||
return {iter->second}; | |||
} | |||
std::optional<TValue> try_get(TKey key) { | |||
_step(); | |||
_step(); | |||
return ((const WeakKeyMap*)this)->try_get(std::move(key)); | |||
} | |||
}; | |||
template <typename TKey, typename TValue> | |||
class WeakValueMap : public NonCopyableObj { | |||
public: | |||
using storage_t = std::unordered_map<TKey, TValue>; | |||
private: | |||
storage_t m_storage; | |||
typename storage_t::iterator m_cursor = m_storage.begin(); | |||
/** | |||
* \brief select a key and verify that whether it is invalid. If yes, erase it | |||
* | |||
*/ | |||
void _step() { | |||
if (m_cursor == m_storage.end()) { | |||
m_cursor = m_storage.begin(); | |||
return; | |||
} | |||
auto value = m_cursor->second; | |||
if (!value.lock()) { | |||
m_cursor = m_storage.erase(m_cursor); | |||
} else { | |||
++m_cursor; | |||
} | |||
} | |||
public: | |||
size_t count(TKey key) { | |||
_step(); | |||
_step(); | |||
return m_storage.count(key); | |||
} | |||
TValue& at(TKey key) const { return m_storage.at(key); } | |||
TValue& at(TKey key) { | |||
_step(); | |||
_step(); | |||
return m_storage.at(key); | |||
} | |||
TValue& operator[](TKey key) { | |||
_step(); | |||
_step(); | |||
if (m_storage.count(key)) { | |||
return m_storage.at(key); | |||
} else { | |||
size_t bucket_count = m_storage.bucket_count(); | |||
TValue& result = m_storage[key]; | |||
if (bucket_count != m_storage.bucket_count()) { | |||
m_cursor = m_storage.begin(); | |||
} | |||
return result; | |||
} | |||
} | |||
}; | |||
} // namespace mgb::imperative |
@@ -0,0 +1,70 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/utils/mempool.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <mutex> | |||
#include <thread> | |||
#include <unordered_map> | |||
#include "megbrain/utils/mempool.h" | |||
#include "megbrain/utils/metahelper.h" | |||
namespace mgb::imperative { | |||
template <typename T> | |||
class MemPoolUtils { | |||
private: | |||
static std::mutex sm_mutex; | |||
static std::unordered_map<std::thread::id, std::unique_ptr<MemPool<T>>> | |||
sm_instances; | |||
static thread_local MemPool<T>* tm_instance; | |||
static MemPool<T>* sm_instance; | |||
public: | |||
static MemPool<T>& get_thread_local() { | |||
if (!tm_instance) { | |||
MGB_LOCK_GUARD(sm_mutex); | |||
auto& instance = sm_instances[std::this_thread::get_id()]; | |||
if (!instance) { // thread id may be duplicated | |||
instance = std::make_unique<MemPool<T>>(); | |||
} | |||
tm_instance = instance.get(); | |||
} | |||
return *tm_instance; | |||
} | |||
static MemPool<T>& get_static() { | |||
if (!sm_instance) { | |||
MGB_LOCK_GUARD(sm_mutex); | |||
auto& instance = sm_instances[{}]; | |||
if (!instance) { // double check | |||
instance = std::make_unique<MemPool<T>>(); | |||
sm_instance = instance.get(); | |||
} | |||
mgb_assert(sm_instance); | |||
} | |||
} | |||
}; | |||
template <typename T> | |||
std::mutex MemPoolUtils<T>::sm_mutex; | |||
template <typename T> | |||
std::unordered_map<std::thread::id, std::unique_ptr<MemPool<T>>> | |||
MemPoolUtils<T>::sm_instances; | |||
template <typename T> | |||
thread_local MemPool<T>* MemPoolUtils<T>::tm_instance; | |||
template <typename T> | |||
MemPool<T>* MemPoolUtils<T>::sm_instance; | |||
} // namespace mgb::imperative |
@@ -0,0 +1,69 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/utils/span.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <array> | |||
#include <vector> | |||
#include "megbrain/utils/small_vector.h" | |||
namespace mgb::imperative { | |||
/** | |||
* \brief wrapper for c-style array | |||
* | |||
* \tparam T value type | |||
*/ | |||
template <typename T> | |||
class Span { | |||
private: | |||
const T* m_begin = nullptr; | |||
const T* m_end = nullptr; | |||
public: | |||
Span() {} | |||
Span(const T* begin, const T* end) : m_begin{begin}, m_end{end} {} | |||
Span(const T* begin, size_t size) : Span(begin, begin + size) {} | |||
template <typename TContainer> | |||
Span(TContainer& container) : Span(container.data(), container.size()) {} | |||
const T* begin() const { return m_begin; } | |||
const T* end() const { return m_end; } | |||
const T* data() const { return m_begin; } | |||
size_t size() const { return m_end - m_begin; } | |||
template <typename TContainer> | |||
TContainer copy_into() { | |||
return TContainer(m_begin, m_end); | |||
} | |||
const T& operator[](size_t idx) const { return m_begin[idx]; } | |||
const T& at(size_t idx) const { return m_begin[idx]; } | |||
const T& item() const { | |||
mgb_assert( | |||
m_end - m_begin == 1, "size mismatch: %zu vs %zu", (m_end - m_begin), | |||
(size_t)1); | |||
return m_begin[0]; | |||
} | |||
template <size_t N> | |||
const std::array<T, N>& as_array() { | |||
mgb_assert( | |||
m_end - m_begin == N, "size mismatch: %zu vs %zu", (m_end - m_begin), | |||
N); | |||
return *reinterpret_cast<const std::array<T, N>*>(m_begin); | |||
} | |||
Span sub(size_t begin, size_t length) { | |||
mgb_assert(begin + length <= m_end - m_begin); | |||
return {m_begin + begin, length}; | |||
} | |||
}; | |||
} // namespace mgb::imperative |
@@ -16,6 +16,7 @@ | |||
#include <tuple> | |||
#include <type_traits> | |||
#include "megbrain/imperative/utils/span.h" | |||
#include "megbrain/tensor.h" | |||
#include "megbrain/utils/small_vector.h" | |||
@@ -60,6 +61,22 @@ struct ToStringTrait<SmallVector<T, N>> { | |||
}; | |||
template <typename T> | |||
struct ToStringTrait<std::vector<T>> { | |||
std::string operator()(const std::vector<T>& v) const { | |||
if (v.empty()) { | |||
return "[]"; | |||
} | |||
std::string result = "["; | |||
result += to_string(v[0]); | |||
for (size_t i = 1; i < v.size(); ++i) { | |||
result += ", "; | |||
result += to_string(v[i]); | |||
} | |||
return result + "]"; | |||
} | |||
}; | |||
template <typename T> | |||
struct ToStringTrait<std::shared_ptr<T>> { | |||
std::string operator()(const std::shared_ptr<T>& sp) const { | |||
return to_string(sp.get()); | |||
@@ -115,4 +132,36 @@ struct ToStringTrait<CompNode> { | |||
std::string operator()(CompNode device) const { return device.to_string(); } | |||
}; | |||
inline std::string string_join(Span<std::string> span, char delimiter = ',') { | |||
std::string buffer = "["; | |||
for (size_t i = 1; i < span.size(); ++i) { | |||
if (i) { | |||
buffer.push_back(delimiter); | |||
} | |||
buffer.append(span[0]); | |||
} | |||
return buffer + "]"; | |||
} | |||
template <typename T> | |||
struct ToStringTrait<Span<T>> { | |||
std::string operator()(Span<T> span) const { | |||
if (span.size() == 0) { | |||
return "[]"; | |||
} | |||
std::string result = "["; | |||
result += to_string(span[0]); | |||
for (size_t i = 1; i < span.size(); ++i) { | |||
result += ", "; | |||
result += to_string(span[i]); | |||
} | |||
return result + "]"; | |||
} | |||
}; | |||
template <> | |||
struct ToStringTrait<std::type_info> { | |||
std::string operator()(const std::type_info& info) const { return info.name(); } | |||
}; | |||
} // namespace mgb::imperative |
@@ -0,0 +1,104 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/utils/visit.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <vector> | |||
#include "megbrain/imperative/utils/span.h" | |||
#include "megbrain/tensor.h" | |||
namespace mgb::imperative { | |||
/** | |||
* \brief like TensorShape, but allow real scalar shape. | |||
* | |||
*/ | |||
struct ValueShape { | |||
size_t shape[TensorShape::MAX_NDIM]; | |||
int ndim = 0; | |||
ValueShape() = default; | |||
ValueShape(std::initializer_list<size_t> dims) { | |||
for (auto&& dim : dims) { | |||
shape[ndim++] = dim; | |||
} | |||
} | |||
ValueShape(Span<size_t> dims) { | |||
for (auto&& dim : dims) { | |||
shape[ndim++] = dim; | |||
} | |||
} | |||
size_t& operator[](int axis) { return shape[axis]; } | |||
size_t operator[](int axis) const { return shape[axis]; } | |||
size_t at(int axis) const { | |||
mgb_assert(axis < ndim); | |||
return shape[axis]; | |||
} | |||
size_t total_nr_elems() const { | |||
size_t prod = 1; | |||
for (int i = 0; i < ndim; ++i) { | |||
prod *= shape[i]; | |||
} | |||
return prod; | |||
} | |||
bool is_scalar() const { return ndim == 0; } | |||
std::string to_string() const { | |||
std::string buffer = "{"; | |||
for (size_t i = 0; i < ndim; ++i) { | |||
if (i) { | |||
buffer.append(","); | |||
} | |||
buffer.append(std::to_string(shape[i])); | |||
} | |||
buffer.append("}"); | |||
return buffer; | |||
} | |||
static ValueShape from(TensorShape tensor_shape) { | |||
mgb_assert(tensor_shape.ndim); | |||
return Span<size_t>{tensor_shape.shape, tensor_shape.ndim}; | |||
} | |||
TensorShape as_tensor_shape() const { | |||
mgb_assert(ndim != 0); | |||
TensorShape ret; | |||
for (size_t i = 0; i < ndim; ++i) { | |||
ret.shape[i] = shape[i]; | |||
} | |||
ret.ndim = ndim; | |||
return ret; | |||
} | |||
bool operator==(const ValueShape& rhs) const { | |||
if (ndim != rhs.ndim) { | |||
return false; | |||
} | |||
for (size_t i = 0; i < ndim; ++i) { | |||
if (shape[i] != rhs.shape[i]) { | |||
return false; | |||
} | |||
} | |||
return true; | |||
} | |||
}; | |||
static_assert(sizeof(size_t) >= sizeof(int)); | |||
static_assert(TensorShape::MAX_NDIM == 7); | |||
static_assert(sizeof(ValueShape) <= sizeof(size_t) * 8); | |||
} // namespace mgb::imperative |
@@ -0,0 +1,26 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/utils/visit.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <vector> | |||
#include "megbrain/utils/small_vector.h" | |||
namespace mgb::imperative { | |||
template <typename... TVisitors> | |||
class Visitor : public TVisitors... { | |||
public: | |||
using TVisitors::operator()...; | |||
}; | |||
} // namespace mgb::imperative |
@@ -28,10 +28,10 @@ TEST(TestProfiler, ImperativeLogProfile) { | |||
auto results = imperative::Profiler::collect(); | |||
imperative::Profiler::stop_profile(); | |||
mgb_assert(results.entries.size() == 2); | |||
auto* event_start = results.entries[0].data.as<profiler::CustomEvent>(); | |||
auto* event_finish = results.entries[1].data.as<profiler::CustomFinishEvent>(); | |||
mgb_assert(event_start && event_start->title == "XXX"); | |||
mgb_assert(event_finish && event_finish->title == "XXX"); | |||
auto& event_start = results.entries[0].data.cast<profiler::CustomEvent>(); | |||
auto& event_finish = results.entries[1].data.cast<profiler::CustomFinishEvent>(); | |||
mgb_assert(event_start.title == "XXX"); | |||
mgb_assert(event_finish.title == "XXX"); | |||
mgb_assert(results.entries[0].time < results.entries[1].time); | |||
mgb_assert(results.entries[0].id < results.entries[1].id); | |||
} |