@@ -160,7 +160,7 @@ private: | |||||
template <typename TItem> | template <typename TItem> | ||||
void register_converter() { | void register_converter() { | ||||
m_table[typeid(TItem)] = [](const any_t& input) { | m_table[typeid(TItem)] = [](const any_t& input) { | ||||
return variant_t(*input.as<TItem>()); | |||||
return variant_t(input.cast<TItem>()); | |||||
}; | }; | ||||
} | } | ||||
@@ -11,7 +11,6 @@ | |||||
#pragma once | #pragma once | ||||
#include <any> | |||||
#include <bitset> | #include <bitset> | ||||
#include <chrono> | #include <chrono> | ||||
#include <deque> | #include <deque> | ||||
@@ -28,6 +27,7 @@ | |||||
#include "megbrain/imperative/op_def.h" | #include "megbrain/imperative/op_def.h" | ||||
#include "megbrain/imperative/physical_tensor.h" | #include "megbrain/imperative/physical_tensor.h" | ||||
#include "megbrain/imperative/utils/any.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
@@ -51,48 +51,6 @@ public: | |||||
static std::shared_ptr<CompNode::Event> record_device(CompNode device); | static std::shared_ptr<CompNode::Event> record_device(CompNode device); | ||||
}; | }; | ||||
class AnyPtr { | |||||
public: | |||||
struct Deleter { | |||||
void* object; | |||||
void (*method)(void*, void*); | |||||
void operator()(void* ptr) { method(object, ptr); } | |||||
}; | |||||
private: | |||||
using holder_t = std::unique_ptr<void, Deleter>; | |||||
const std::type_info* m_type = nullptr; | |||||
holder_t m_holder = nullptr; | |||||
public: | |||||
AnyPtr() = default; | |||||
template < | |||||
typename T, | |||||
typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, AnyPtr>>> | |||||
explicit AnyPtr(T* value, Deleter deleter) { | |||||
m_type = &typeid(T); | |||||
m_holder = {value, deleter}; | |||||
} | |||||
template <typename T> | |||||
T* as() { | |||||
mgb_assert(is_exactly<T>(), "type mismatch"); | |||||
return reinterpret_cast<T*>(m_holder.get()); | |||||
} | |||||
template <typename T> | |||||
const T* as() const { | |||||
mgb_assert(is_exactly<T>(), "type mismatch"); | |||||
return reinterpret_cast<const T*>(m_holder.get()); | |||||
} | |||||
template <typename T> | |||||
bool is_exactly() const { | |||||
return std::type_index{typeid(T)} == std::type_index{*m_type}; | |||||
} | |||||
const std::type_info& type() const { return *m_type; } | |||||
bool operator==(std::nullptr_t nptr) const { return m_holder == nullptr; } | |||||
operator bool() const { return m_holder != nullptr; } | |||||
}; | |||||
class Profiler { | class Profiler { | ||||
public: | public: | ||||
struct Record { | struct Record { | ||||
@@ -128,7 +86,6 @@ private: | |||||
std::thread::id m_thread_id; | std::thread::id m_thread_id; | ||||
std::vector<Record> m_records; | std::vector<Record> m_records; | ||||
std::atomic<Status> m_status = Running; | std::atomic<Status> m_status = Running; | ||||
std::unordered_map<std::type_index, AnyPtr> m_mem_pools; | |||||
static std::vector<entry_t> sm_records; | static std::vector<entry_t> sm_records; | ||||
static options_t sm_profile_options; | static options_t sm_profile_options; | ||||
@@ -161,42 +118,21 @@ public: | |||||
return *tm_profiler; | return *tm_profiler; | ||||
} | } | ||||
template <typename T> | |||||
static MemPool<T>& get_mem_pool() { | |||||
thread_local MemPool<T>* t_pool = nullptr; | |||||
if (t_pool == nullptr) { | |||||
auto& pool = get_instance().m_mem_pools[typeid(MemPool<T>)]; | |||||
if (pool == nullptr) { | |||||
pool = | |||||
AnyPtr(new MemPool<T>(), | |||||
{nullptr, [](void*, void* ptr) { | |||||
delete reinterpret_cast<MemPool<T>*>(ptr); | |||||
}}); | |||||
} | |||||
t_pool = pool.as<MemPool<T>>(); | |||||
} | |||||
return *t_pool; | |||||
} | |||||
static uint64_t next_id() { return sm_last_id++; } | static uint64_t next_id() { return sm_last_id++; } | ||||
template <typename T, typename... TArgs> | template <typename T, typename... TArgs> | ||||
static uint64_t record(TArgs&&... args) { | static uint64_t record(TArgs&&... args) { | ||||
auto& profiler = get_instance(); | auto& profiler = get_instance(); | ||||
auto& mem_pool = get_mem_pool<T>(); | |||||
// auto& mem_pool = get_mem_pool<T>(); | |||||
if constexpr (sm_debug) { | if constexpr (sm_debug) { | ||||
Status expected = Running; | Status expected = Running; | ||||
mgb_assert(profiler.m_status.compare_exchange_strong(expected, Recording)); | mgb_assert(profiler.m_status.compare_exchange_strong(expected, Recording)); | ||||
} | } | ||||
uint64_t id = next_id(); | uint64_t id = next_id(); | ||||
profiler::Time time = sm_timer.record_host(); | profiler::Time time = sm_timer.record_host(); | ||||
auto deleter = [](void* obj, void* ptr) { | |||||
reinterpret_cast<MemPool<T>*>(obj)->free(reinterpret_cast<T*>(ptr)); | |||||
}; | |||||
profiler.m_records.emplace_back( | profiler.m_records.emplace_back( | ||||
id, profiler.m_thread_id, time, | id, profiler.m_thread_id, time, | ||||
AnyPtr{mem_pool.alloc(T{std::forward<TArgs>(args)...}), | |||||
{&mem_pool, deleter}}); | |||||
AnyPtr::make<T>(T{std::forward<TArgs&&>(args)...})); | |||||
if constexpr (sm_debug) { | if constexpr (sm_debug) { | ||||
Status expected = Recording; | Status expected = Recording; | ||||
mgb_assert(profiler.m_status.compare_exchange_strong(expected, Running)); | mgb_assert(profiler.m_status.compare_exchange_strong(expected, Running)); | ||||
@@ -241,7 +177,7 @@ public: | |||||
bundle.options = get_options(); | bundle.options = get_options(); | ||||
bundle.start_at = sm_start_at; | bundle.start_at = sm_start_at; | ||||
bundle.thread_dict = get_thread_dict(); | bundle.thread_dict = get_thread_dict(); | ||||
return std::move(bundle); | |||||
return bundle; | |||||
} | } | ||||
static option_t get_option(std::string key, option_t default_val) { | static option_t get_option(std::string key, option_t default_val) { | ||||
@@ -0,0 +1,71 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/utils/allocator.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <typeindex> | |||||
#include "megbrain/utils/mempool.h" | |||||
#include "megbrain/utils/metahelper.h" | |||||
namespace mgb::imperative { | |||||
template <typename T> | |||||
class Allocator { | |||||
public: | |||||
using pointer = T*; | |||||
using const_pointer = const T*; | |||||
using void_pointer = void*; | |||||
using const_void_pointer = const void*; | |||||
using value_type = T; | |||||
using size_type = std::size_t; | |||||
using diffenence_type = std::ptrdiff_t; | |||||
using pool_type = MemPoolStorage; | |||||
private: | |||||
pool_type* m_pool = nullptr; | |||||
public: | |||||
Allocator(pool_type* pool) : m_pool(pool) {} | |||||
T* allocate(size_type n) { | |||||
mgb_assert(n == 1); | |||||
return m_pool->alloc(sizeof(T)); | |||||
} | |||||
void deallocate(pointer* p, size_type n) { | |||||
mgb_assert(n == 1); | |||||
m_pool->free(p); | |||||
} | |||||
bool operator==(const Allocator& rhs) const { return m_pool == rhs.m_pool; } | |||||
bool operator!=(const Allocator& rhs) const { return m_pool != rhs.m_pool; } | |||||
}; | |||||
template <typename T> | |||||
class ThreadLocalAllocatorAdapter { | |||||
public: | |||||
using value_type = T; | |||||
using size_type = std::size_t; | |||||
using pointer = T*; | |||||
public: | |||||
T* allocate(size_type n) { mgb_assert(false); } | |||||
void deallocate(pointer* p, size_type n) { mgb_assert(false); } | |||||
bool operator==(const ThreadLocalAllocatorAdapter& rhs) const { return true; } | |||||
bool operator!=(const ThreadLocalAllocatorAdapter& rhs) const { return false; } | |||||
}; | |||||
} // namespace mgb::imperative |
@@ -0,0 +1,70 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/utils/any.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <typeindex> | |||||
#include "megbrain/imperative/utils/local_ptr.h" | |||||
namespace mgb::imperative { | |||||
class AnyMixinBase { | |||||
private: | |||||
const std::type_info* m_type = nullptr; | |||||
public: | |||||
AnyMixinBase() = default; | |||||
const std::type_info& type() const { return *m_type; } | |||||
friend class AnyPtr; | |||||
}; | |||||
template <typename T> | |||||
class AnyMixin : public AnyMixinBase, public T { | |||||
public: | |||||
AnyMixin(T&& val) : T(std::move(val)) {} | |||||
}; | |||||
class AnyPtr { | |||||
public: | |||||
using storage_t = LocalPtr<AnyMixinBase>; | |||||
private: | |||||
storage_t m_storage; | |||||
public: | |||||
const std::type_info& type() const { return m_storage->type(); } | |||||
template <typename T> | |||||
const T& cast() const { | |||||
mgb_assert(is_exactly<T>(), "type mismatch"); | |||||
return *static_cast<const AnyMixin<T>*>(m_storage.get()); | |||||
} | |||||
template <typename T> | |||||
bool is_exactly() const { | |||||
return std::type_index{typeid(T)} == std::type_index{type()}; | |||||
} | |||||
bool operator==(std::nullptr_t nptr) const { return m_storage == nullptr; } | |||||
bool operator!=(std::nullptr_t nptr) const { return m_storage != nullptr; } | |||||
operator bool() const { return m_storage != nullptr; } | |||||
template <typename T, typename... TArgs> | |||||
static AnyPtr make(TArgs&&... args) { | |||||
AnyPtr ret; | |||||
ret.m_storage = LocalPtr<AnyMixinBase>::make<AnyMixin<T>>( | |||||
std::forward<TArgs&&>(args)...); | |||||
ret.m_storage->m_type = &typeid(T); | |||||
return ret; | |||||
} | |||||
}; | |||||
} // namespace mgb::imperative |
@@ -0,0 +1,96 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/utils/visit.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <chrono> | |||||
#include <future> | |||||
#include <vector> | |||||
#include "megbrain/utils/metahelper.h" | |||||
#include "megbrain/utils/small_vector.h" | |||||
namespace mgb::imperative { | |||||
class BoxBase : public NonCopyableObj { | |||||
public: | |||||
virtual void reset() = 0; | |||||
virtual void set_exception(std::exception_ptr exc) = 0; | |||||
virtual bool try_set_exception(std::exception_ptr exc) = 0; | |||||
}; | |||||
/** | |||||
* \brief An reusable promise | |||||
* | |||||
* \tparam T type of value | |||||
*/ | |||||
template <typename T> | |||||
class Box final : public BoxBase { | |||||
private: | |||||
std::promise<T> m_promise; | |||||
std::shared_future<T> m_future; | |||||
std::mutex m_mutex; | |||||
bool m_value_set; | |||||
bool m_exception_set; | |||||
public: | |||||
Box() { reset(); } | |||||
const T& get_value() { return m_future.get(); } | |||||
T take_value() { | |||||
T value = m_future.get(); | |||||
reset(); | |||||
return value; | |||||
} | |||||
void set_value(T value) { | |||||
MGB_LOCK_GUARD(m_mutex); | |||||
m_promise.set_value(std::move(value)); | |||||
m_value_set = true; | |||||
} | |||||
bool try_set_value(T value) { | |||||
MGB_LOCK_GUARD(m_mutex); | |||||
if (m_exception_set) { | |||||
return false; | |||||
} | |||||
m_promise.set_value(std::move(value)); | |||||
m_value_set = true; | |||||
return true; | |||||
} | |||||
void set_exception(std::exception_ptr exc) override { | |||||
MGB_LOCK_GUARD(m_mutex); | |||||
m_promise.set_exception(exc); | |||||
m_exception_set = true; | |||||
} | |||||
bool try_set_exception(std::exception_ptr exc) override { | |||||
MGB_LOCK_GUARD(m_mutex); | |||||
if (m_value_set) { | |||||
return false; | |||||
} | |||||
m_promise.set_exception(exc); | |||||
m_exception_set = true; | |||||
return true; | |||||
} | |||||
void reset() override { | |||||
MGB_LOCK_GUARD(m_mutex); | |||||
m_promise = {}; | |||||
m_future = m_promise.get_future(); | |||||
m_value_set = false; | |||||
m_exception_set = false; | |||||
} | |||||
/** | |||||
* \brief make an empty box | |||||
* | |||||
* \return std::shared_ptr<Box> | |||||
*/ | |||||
static std::shared_ptr<Box> make() { return std::make_shared<Box>(); } | |||||
}; | |||||
} // namespace mgb::imperative |
@@ -0,0 +1,40 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/utils/span.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <iomanip> | |||||
#include <memory> | |||||
#include <sstream> | |||||
namespace mgb { | |||||
namespace imperative { | |||||
template <typename T> | |||||
class CleanupGuard { | |||||
private: | |||||
T m_callback; | |||||
public: | |||||
explicit CleanupGuard(T cb) : m_callback{std::move(cb)} {} | |||||
~CleanupGuard() { m_callback(); } | |||||
}; | |||||
inline std::string quoted(std::string str) { | |||||
std::stringstream ss; | |||||
ss << std::quoted(str); | |||||
return ss.str(); | |||||
} | |||||
} // namespace imperative | |||||
} // namespace mgb |
@@ -0,0 +1,245 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/utils/intrusive_list.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/utils/metahelper.h" | |||||
namespace mgb::imperative::utils::intrusive_list { | |||||
// copy policy | |||||
struct after_t {}; | |||||
struct before_t {}; | |||||
struct disable_t {}; | |||||
template <typename T> | |||||
struct Tail; | |||||
// invariant: next->prev == this | |||||
template <typename T> | |||||
struct Head { | |||||
Tail<T>* next; | |||||
Head(Tail<T>* node = nullptr) : next(node) {} | |||||
Head(const Head<T>&) = delete; | |||||
Head<T>& operator=(const Head<T>&) = delete; | |||||
Head(Head<T>&& rhs) : next(rhs.next) { | |||||
rhs.next = nullptr; | |||||
if (next) { | |||||
next->prev = this; | |||||
} | |||||
} | |||||
Head<T>& operator=(Head<T>&& rhs) { | |||||
mgb_assert(!next); | |||||
next = rhs.next; | |||||
rhs.next = nullptr; | |||||
if (next) { | |||||
next->prev = this; | |||||
} | |||||
return *this; | |||||
} | |||||
~Head() { | |||||
if (next) { | |||||
next->prev = nullptr; | |||||
} | |||||
} | |||||
}; | |||||
// invariant: prev->next == this | |||||
template <typename T> | |||||
struct Tail { | |||||
Head<T>* prev; | |||||
Tail(Head<T>* node = nullptr) : prev(node) {} | |||||
Tail(const Tail<T>&) = delete; | |||||
Tail<T>& operator=(const Tail<T>&) = delete; | |||||
Tail(Tail<T>&& rhs) : prev(rhs.prev) { | |||||
rhs.prev = nullptr; | |||||
if (prev) { | |||||
prev->next = this; | |||||
} | |||||
} | |||||
Tail<T>& operator=(Tail<T>&& rhs) { | |||||
mgb_assert(!prev); | |||||
prev = rhs.prev; | |||||
rhs.prev = nullptr; | |||||
if (prev) { | |||||
prev->next = this; | |||||
} | |||||
return *this; | |||||
} | |||||
~Tail() { | |||||
if (prev) { | |||||
prev->next = nullptr; | |||||
} | |||||
} | |||||
}; | |||||
template <typename T, typename policy> | |||||
struct Node; | |||||
template <typename T> | |||||
class Iterator { | |||||
T* ptr; | |||||
void inc() { ptr = static_cast<T*>(ptr->Head<T>::next); } | |||||
void dec() { ptr = static_cast<T*>(ptr->Head<T>::prev); } | |||||
public: | |||||
Iterator(Head<T>& head) : ptr(static_cast<T*>(head.next)) {} | |||||
Iterator(Tail<T>& tail) : ptr(static_cast<T*>(tail.prev)) {} | |||||
template <typename policy> | |||||
Iterator(Node<T, policy>& node) : ptr(static_cast<T*>(&node)) {} | |||||
T& operator*() { return *static_cast<T*>(ptr); } | |||||
T* operator->() { return static_cast<T*>(ptr); } | |||||
operator bool() { return ptr; } | |||||
bool operator==(const Iterator<T>& rhs) { return ptr == rhs.ptr; } | |||||
Iterator& operator++() { | |||||
inc(); | |||||
return *this; | |||||
} | |||||
Iterator& operator--() { | |||||
dec(); | |||||
return *this; | |||||
} | |||||
Iterator operator++(int) { | |||||
auto ret = *this; | |||||
inc(); | |||||
return ret; | |||||
} | |||||
Iterator operator--(int) { | |||||
auto ret = *this; | |||||
dec(); | |||||
return ret; | |||||
} | |||||
}; | |||||
// Node in a doubly linked list. Unlike std::list, nodes are not owned by a container. | |||||
// Instead, nodes may join or leave a list freely. | |||||
// NOTE: Derived classes have to explicitly declare copy / assignment as default, | |||||
// otherwise the compiler generated version would use the const T& signature, | |||||
// which is deleted. | |||||
template <typename T = void, typename policy = disable_t> | |||||
struct Node : Tail<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>>, | |||||
Head<std::conditional_t<std::is_same_v<T, void>, Node<T, policy>, T>> { | |||||
private: | |||||
using this_t = Node<T, policy>; | |||||
using U = std::conditional_t<std::is_same_v<T, void>, this_t, T>; | |||||
public: | |||||
using head_t = Head<U>; | |||||
using tail_t = Tail<U>; | |||||
using head_t::next; | |||||
using tail_t::prev; | |||||
Node() = default; | |||||
Node(const this_t&) = delete; | |||||
this_t& operator=(const this_t&) = delete; | |||||
//! constructed node is inserted after the input node | |||||
Node(after_t, head_t& node) : tail_t(&node), head_t(node.next) { | |||||
node.next = this; | |||||
if (next) { | |||||
next->prev = this; | |||||
} | |||||
} | |||||
//! constructed node is inserted before the input node | |||||
Node(before_t, tail_t& node) : head_t(&node), tail_t(node.prev) { | |||||
node.prev = this; | |||||
if (prev) { | |||||
prev->next = this; | |||||
} | |||||
} | |||||
Node(this_t&& rhs) : tail_t(rhs.prev), head_t(rhs.next) { | |||||
rhs.prev = nullptr; | |||||
rhs.next = nullptr; | |||||
if (prev) { | |||||
prev->next = this; | |||||
} | |||||
if (next) { | |||||
next->prev = this; | |||||
} | |||||
} | |||||
Node& operator=(this_t&& rhs) { | |||||
unlink(); | |||||
prev = rhs.prev; | |||||
next = rhs.next; | |||||
rhs.prev = nullptr; | |||||
rhs.next = nullptr; | |||||
if (prev) { | |||||
prev->next = this; | |||||
} | |||||
if (next) { | |||||
next->prev = this; | |||||
} | |||||
return *this; | |||||
} | |||||
template < | |||||
typename p = policy, | |||||
typename = std::enable_if_t< | |||||
std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||||
Node(this_t& rhs) : Node(policy{}, rhs) {} | |||||
template < | |||||
typename p = policy, | |||||
typename = std::enable_if_t< | |||||
std::is_same_v<p, before_t> || std::is_same_v<p, after_t>, void>> | |||||
this_t& operator=(this_t& rhs) { | |||||
insert(policy{}, rhs); | |||||
return *this; | |||||
} | |||||
void unlink() { | |||||
if (prev) { | |||||
prev->next = next; | |||||
} | |||||
if (next) { | |||||
next->prev = prev; | |||||
} | |||||
prev = nullptr; | |||||
next = nullptr; | |||||
} | |||||
//! this node is unlinked from its list and inserted after the input node | |||||
void insert(after_t, head_t& node) { | |||||
unlink(); | |||||
prev = &node; | |||||
next = node.next; | |||||
node.next = this; | |||||
if (next) { | |||||
next->prev = this; | |||||
} | |||||
} | |||||
//! this node is unlinked from its list and inserted before the input node | |||||
void insert(before_t, tail_t& node) { | |||||
unlink(); | |||||
next = &node; | |||||
prev = node.prev; | |||||
node.prev = this; | |||||
if (prev) { | |||||
prev->next = this; | |||||
} | |||||
} | |||||
void insert_before(tail_t& node) { insert(before_t{}, node); } | |||||
void insert_after(head_t& node) { insert(after_t{}, node); } | |||||
~Node() { unlink(); } | |||||
}; | |||||
} // namespace mgb::imperative::utils::intrusive_list |
@@ -0,0 +1,285 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/utils/local_ptr.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <optional> | |||||
#include "megbrain/imperative/utils/mempool.h" | |||||
#include "megbrain/utils/metahelper.h" | |||||
namespace mgb::imperative { | |||||
template <typename T> | |||||
class LocalPtrStorage : public NonCopyableObj { | |||||
private: | |||||
size_t m_ref_count = 0; | |||||
size_t m_weak_count = 0; | |||||
T* m_pointer = nullptr; | |||||
void (*reset)(LocalPtrStorage*) = nullptr; | |||||
void (*free)(LocalPtrStorage*) = nullptr; | |||||
void inc_ref() { m_ref_count++; } | |||||
void dec_ref() { | |||||
m_ref_count--; | |||||
if (m_ref_count == 0) { | |||||
reset(this); | |||||
m_pointer = nullptr; | |||||
reset = nullptr; | |||||
if (m_weak_count == 0) { | |||||
free(this); | |||||
// dead | |||||
} | |||||
} | |||||
} | |||||
void inc_weak_ref() { m_weak_count++; } | |||||
void dec_weak_ref() { | |||||
m_weak_count--; | |||||
if ((m_weak_count + m_ref_count) == 0) { | |||||
free(this); | |||||
// dead | |||||
} | |||||
} | |||||
template <typename U> | |||||
friend class LocalPtr; | |||||
template <typename U> | |||||
friend class LocalWeakPtr; | |||||
public: | |||||
}; | |||||
template <typename T, typename TDerived> | |||||
class LocalPtrStorgeImpl : public LocalPtrStorage<T> { | |||||
private: | |||||
std::optional<TDerived> m_value; | |||||
void* m_pool = nullptr; | |||||
template <typename U> | |||||
friend class LocalPtr; | |||||
template <typename U> | |||||
friend class LocalWeakPtr; | |||||
}; | |||||
template <typename T> | |||||
class LocalWeakPtr; | |||||
/** | |||||
* \brief thread-unsafe smart pointer | |||||
* | |||||
* \tparam T type of value | |||||
*/ | |||||
template <typename T> | |||||
class LocalPtr { | |||||
public: | |||||
using storage_t = LocalPtrStorage<T>; | |||||
using pool_t = MemPool<storage_t>; | |||||
using weak_type = LocalWeakPtr<T>; | |||||
private: | |||||
storage_t* m_storage = nullptr; | |||||
void emplace(storage_t* ptr) { | |||||
if (ptr) { | |||||
ptr->inc_ref(); | |||||
m_storage = ptr; | |||||
} | |||||
} | |||||
LocalPtr(storage_t* ptr) { emplace(ptr); } | |||||
public: | |||||
LocalPtr() = default; | |||||
LocalPtr(const LocalPtr& rhs) { (*this) = rhs; } | |||||
LocalPtr(LocalPtr&& rhs) { (*this) = std::move(rhs); } | |||||
LocalPtr& operator=(const LocalPtr& rhs) { | |||||
if (this == &rhs) { | |||||
return *this; | |||||
} | |||||
auto storage = rhs.m_storage; | |||||
if (storage) { | |||||
storage->inc_ref(); | |||||
} | |||||
if (m_storage) { | |||||
m_storage->dec_ref(); | |||||
// rhs.m_storage may be invalid here | |||||
} | |||||
m_storage = storage; | |||||
return *this; | |||||
} | |||||
LocalPtr& operator=(LocalPtr&& rhs) { | |||||
if (this == &rhs) { | |||||
return *this; | |||||
} | |||||
std::swap(m_storage, rhs.m_storage); | |||||
rhs.reset(); | |||||
return *this; | |||||
} | |||||
bool operator==(const LocalPtr& rhs) const { return m_storage == rhs.m_storage; } | |||||
bool operator!=(const LocalPtr& rhs) const { return m_storage != rhs.m_storage; } | |||||
size_t hash() const { return reinterpret_cast<uintptr_t>(m_storage); } | |||||
~LocalPtr() { reset(); } | |||||
/** | |||||
* \brief Construct an instance of TDerived and return an LocalPtr | |||||
* | |||||
* There is an memory pool for each (T, TDerived) pair | |||||
* | |||||
* \tparam TDerived type of concrete instance, should be subclass of T | |||||
* \tparam TArgs | |||||
* \param args constructor arguments | |||||
* \return LocalPtr points to the instance | |||||
*/ | |||||
template <typename TDerived = T, typename... TArgs> | |||||
static LocalPtr make(TArgs&&... args) { | |||||
static_assert(std::is_base_of_v<T, TDerived>); | |||||
using storage_impl_t = LocalPtrStorgeImpl<T, TDerived>; | |||||
constexpr auto normalize_size = [](size_t size) { | |||||
size_t normalized_size = 64; | |||||
while (normalized_size < size) { | |||||
normalized_size *= 2; | |||||
} | |||||
return normalized_size; | |||||
}; | |||||
using raw_storage_t = | |||||
std::aligned_storage_t<normalize_size(sizeof(storage_impl_t))>; | |||||
static_assert(alignof(raw_storage_t) % alignof(storage_impl_t) == 0); | |||||
static_assert(sizeof(raw_storage_t) >= sizeof(storage_impl_t)); | |||||
using pool_t = MemPool<raw_storage_t>; | |||||
pool_t& pool = MemPoolUtils<raw_storage_t>::get_thread_local(); | |||||
auto* raw_storage = pool.alloc_raw(); | |||||
auto* storage = reinterpret_cast<storage_impl_t*>(raw_storage); | |||||
new (storage) storage_impl_t(); | |||||
storage->m_value.emplace(std::forward<TArgs&&>(args)...); | |||||
storage->m_pointer = &*storage->m_value; | |||||
storage->reset = [](storage_t* storage) { | |||||
auto* storage_impl = static_cast<storage_impl_t*>(storage); | |||||
storage_impl->m_value.reset(); | |||||
storage_impl->m_pointer = nullptr; | |||||
}; | |||||
storage->free = [](storage_t* storage_base) { | |||||
auto* storage = static_cast<storage_impl_t*>(storage_base); | |||||
auto* pool = reinterpret_cast<pool_t*>(storage->m_pool); | |||||
storage->m_pool = nullptr; | |||||
storage->~storage_impl_t(); | |||||
auto* raw_storage = reinterpret_cast<raw_storage_t*>(storage); | |||||
pool->free_raw(raw_storage); | |||||
}; | |||||
storage->m_pool = &pool; | |||||
return {(storage_t*)storage}; | |||||
} | |||||
T& operator*() const { return *get(); } | |||||
T* get() const { | |||||
if ((!m_storage) || !m_storage->m_pointer) { | |||||
return nullptr; | |||||
} | |||||
return m_storage->m_pointer; | |||||
} | |||||
T* operator->() const { return get(); } | |||||
size_t ref_count() const { return m_storage->m_ref_count; } | |||||
bool unique() const { return ref_count() == 1; } | |||||
void reset() { | |||||
if (m_storage) { | |||||
m_storage->dec_ref(); | |||||
m_storage = nullptr; | |||||
} | |||||
} | |||||
operator bool() const { return bool(m_storage); } | |||||
bool operator==(std::nullptr_t nptr) const { return m_storage == nullptr; } | |||||
bool operator!=(std::nullptr_t nptr) const { return m_storage != nullptr; } | |||||
template <typename U> | |||||
friend class LocalWeakPtr; | |||||
}; | |||||
template <typename T> | |||||
class LocalWeakPtr { | |||||
public: | |||||
using storage_t = LocalPtrStorage<T>; | |||||
private: | |||||
storage_t* m_storage = nullptr; | |||||
void emplace(storage_t* ptr) { | |||||
if (ptr) { | |||||
ptr->inc_weak_ref(); | |||||
m_storage = ptr; | |||||
} | |||||
} | |||||
public: | |||||
LocalWeakPtr() = default; | |||||
LocalWeakPtr(const LocalPtr<T>& rhs) { emplace(rhs.m_storage); } | |||||
LocalWeakPtr(const LocalWeakPtr& rhs) { (*this) = rhs; } | |||||
LocalWeakPtr(LocalWeakPtr&& rhs) { (*this) = std::move(rhs); } | |||||
LocalWeakPtr& operator=(const LocalWeakPtr& rhs) { | |||||
if (this == &rhs) { | |||||
return *this; | |||||
} | |||||
reset(); | |||||
emplace(rhs.m_storage); | |||||
return *this; | |||||
} | |||||
LocalWeakPtr& operator=(LocalWeakPtr&& rhs) { | |||||
if (this == &rhs) { | |||||
return *this; | |||||
} | |||||
std::swap(m_storage, rhs.m_storage); | |||||
rhs.reset(); | |||||
return *this; | |||||
} | |||||
~LocalWeakPtr() { reset(); } | |||||
void reset() { | |||||
if (m_storage) { | |||||
m_storage->dec_weak_ref(); | |||||
m_storage = nullptr; | |||||
} | |||||
} | |||||
LocalPtr<T> lock() const { | |||||
if (m_storage && m_storage->m_ref_count) { | |||||
return {m_storage}; | |||||
} | |||||
return {}; | |||||
} | |||||
bool operator==(const LocalWeakPtr& rhs) const { | |||||
return m_storage == rhs.m_storage; | |||||
} | |||||
bool operator!=(const LocalWeakPtr& rhs) const { | |||||
return m_storage != rhs.m_storage; | |||||
} | |||||
size_t hash() const { return reinterpret_cast<uintptr_t>(m_storage); } | |||||
}; | |||||
template <typename T, typename TDerived, typename... TArgs> | |||||
LocalPtr<T> make_local(TArgs&&... args) { | |||||
return LocalPtr<T>::template make<TDerived>(std::forward<TArgs&&>(args)...); | |||||
} | |||||
} // namespace mgb::imperative |
@@ -0,0 +1,157 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/utils/map.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <optional> | |||||
#include "megbrain/utils/metahelper.h" | |||||
namespace mgb::imperative { | |||||
/** | |||||
* \brief an hash map optimized for weak pointer as key | |||||
* | |||||
* Keys were scanned automatically, so values referenced by invalid keys whould be | |||||
* released soon | |||||
* | |||||
* \tparam TKey key type, requires(bool(key.lock())) | |||||
* \tparam TValue value type | |||||
*/ | |||||
template <typename TKey, typename TValue> | |||||
class WeakKeyMap : public NonCopyableObj { | |||||
public: | |||||
using storage_t = std::unordered_map<TKey, TValue>; | |||||
private: | |||||
storage_t m_storage; | |||||
typename storage_t::iterator m_cursor = m_storage.begin(); | |||||
/** | |||||
* \brief select a key and verify that whether it is invalid. If yes, erase it | |||||
* | |||||
*/ | |||||
void _step() { | |||||
if (m_cursor == m_storage.end()) { | |||||
m_cursor = m_storage.begin(); | |||||
return; | |||||
} | |||||
auto key = m_cursor->first; | |||||
if (!key.lock()) { | |||||
m_cursor = m_storage.erase(m_cursor); | |||||
} else { | |||||
++m_cursor; | |||||
} | |||||
} | |||||
public: | |||||
size_t count(TKey key) { | |||||
_step(); | |||||
_step(); | |||||
return m_storage.count(key); | |||||
} | |||||
TValue& at(TKey key) const { return m_storage.at(key); } | |||||
TValue& at(TKey key) { | |||||
_step(); | |||||
_step(); | |||||
return m_storage.at(key); | |||||
} | |||||
TValue& operator[](TKey key) { | |||||
_step(); | |||||
_step(); | |||||
if (m_storage.count(key)) { | |||||
return m_storage.at(key); | |||||
} else { | |||||
size_t bucket_count = m_storage.bucket_count(); | |||||
TValue& result = m_storage[key]; | |||||
if (bucket_count != m_storage.bucket_count()) { | |||||
m_cursor = m_storage.begin(); | |||||
} | |||||
return result; | |||||
} | |||||
} | |||||
std::optional<TValue> try_get(TKey key) const { | |||||
auto iter = m_storage.find(key); | |||||
if (iter == m_storage.end()) { | |||||
return {}; | |||||
} | |||||
return {iter->second}; | |||||
} | |||||
std::optional<TValue> try_get(TKey key) { | |||||
_step(); | |||||
_step(); | |||||
return ((const WeakKeyMap*)this)->try_get(std::move(key)); | |||||
} | |||||
}; | |||||
template <typename TKey, typename TValue> | |||||
class WeakValueMap : public NonCopyableObj { | |||||
public: | |||||
using storage_t = std::unordered_map<TKey, TValue>; | |||||
private: | |||||
storage_t m_storage; | |||||
typename storage_t::iterator m_cursor = m_storage.begin(); | |||||
/** | |||||
* \brief select a key and verify that whether it is invalid. If yes, erase it | |||||
* | |||||
*/ | |||||
void _step() { | |||||
if (m_cursor == m_storage.end()) { | |||||
m_cursor = m_storage.begin(); | |||||
return; | |||||
} | |||||
auto value = m_cursor->second; | |||||
if (!value.lock()) { | |||||
m_cursor = m_storage.erase(m_cursor); | |||||
} else { | |||||
++m_cursor; | |||||
} | |||||
} | |||||
public: | |||||
size_t count(TKey key) { | |||||
_step(); | |||||
_step(); | |||||
return m_storage.count(key); | |||||
} | |||||
TValue& at(TKey key) const { return m_storage.at(key); } | |||||
TValue& at(TKey key) { | |||||
_step(); | |||||
_step(); | |||||
return m_storage.at(key); | |||||
} | |||||
TValue& operator[](TKey key) { | |||||
_step(); | |||||
_step(); | |||||
if (m_storage.count(key)) { | |||||
return m_storage.at(key); | |||||
} else { | |||||
size_t bucket_count = m_storage.bucket_count(); | |||||
TValue& result = m_storage[key]; | |||||
if (bucket_count != m_storage.bucket_count()) { | |||||
m_cursor = m_storage.begin(); | |||||
} | |||||
return result; | |||||
} | |||||
} | |||||
}; | |||||
} // namespace mgb::imperative |
@@ -0,0 +1,70 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/utils/mempool.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <mutex> | |||||
#include <thread> | |||||
#include <unordered_map> | |||||
#include "megbrain/utils/mempool.h" | |||||
#include "megbrain/utils/metahelper.h" | |||||
namespace mgb::imperative { | |||||
template <typename T> | |||||
class MemPoolUtils { | |||||
private: | |||||
static std::mutex sm_mutex; | |||||
static std::unordered_map<std::thread::id, std::unique_ptr<MemPool<T>>> | |||||
sm_instances; | |||||
static thread_local MemPool<T>* tm_instance; | |||||
static MemPool<T>* sm_instance; | |||||
public: | |||||
static MemPool<T>& get_thread_local() { | |||||
if (!tm_instance) { | |||||
MGB_LOCK_GUARD(sm_mutex); | |||||
auto& instance = sm_instances[std::this_thread::get_id()]; | |||||
if (!instance) { // thread id may be duplicated | |||||
instance = std::make_unique<MemPool<T>>(); | |||||
} | |||||
tm_instance = instance.get(); | |||||
} | |||||
return *tm_instance; | |||||
} | |||||
static MemPool<T>& get_static() { | |||||
if (!sm_instance) { | |||||
MGB_LOCK_GUARD(sm_mutex); | |||||
auto& instance = sm_instances[{}]; | |||||
if (!instance) { // double check | |||||
instance = std::make_unique<MemPool<T>>(); | |||||
sm_instance = instance.get(); | |||||
} | |||||
mgb_assert(sm_instance); | |||||
} | |||||
} | |||||
}; | |||||
template <typename T> | |||||
std::mutex MemPoolUtils<T>::sm_mutex; | |||||
template <typename T> | |||||
std::unordered_map<std::thread::id, std::unique_ptr<MemPool<T>>> | |||||
MemPoolUtils<T>::sm_instances; | |||||
template <typename T> | |||||
thread_local MemPool<T>* MemPoolUtils<T>::tm_instance; | |||||
template <typename T> | |||||
MemPool<T>* MemPoolUtils<T>::sm_instance; | |||||
} // namespace mgb::imperative |
@@ -0,0 +1,69 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/utils/span.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <array> | |||||
#include <vector> | |||||
#include "megbrain/utils/small_vector.h" | |||||
namespace mgb::imperative { | |||||
/** | |||||
* \brief wrapper for c-style array | |||||
* | |||||
* \tparam T value type | |||||
*/ | |||||
template <typename T> | |||||
class Span { | |||||
private: | |||||
const T* m_begin = nullptr; | |||||
const T* m_end = nullptr; | |||||
public: | |||||
Span() {} | |||||
Span(const T* begin, const T* end) : m_begin{begin}, m_end{end} {} | |||||
Span(const T* begin, size_t size) : Span(begin, begin + size) {} | |||||
template <typename TContainer> | |||||
Span(TContainer& container) : Span(container.data(), container.size()) {} | |||||
const T* begin() const { return m_begin; } | |||||
const T* end() const { return m_end; } | |||||
const T* data() const { return m_begin; } | |||||
size_t size() const { return m_end - m_begin; } | |||||
template <typename TContainer> | |||||
TContainer copy_into() { | |||||
return TContainer(m_begin, m_end); | |||||
} | |||||
const T& operator[](size_t idx) const { return m_begin[idx]; } | |||||
const T& at(size_t idx) const { return m_begin[idx]; } | |||||
const T& item() const { | |||||
mgb_assert( | |||||
m_end - m_begin == 1, "size mismatch: %zu vs %zu", (m_end - m_begin), | |||||
(size_t)1); | |||||
return m_begin[0]; | |||||
} | |||||
template <size_t N> | |||||
const std::array<T, N>& as_array() { | |||||
mgb_assert( | |||||
m_end - m_begin == N, "size mismatch: %zu vs %zu", (m_end - m_begin), | |||||
N); | |||||
return *reinterpret_cast<const std::array<T, N>*>(m_begin); | |||||
} | |||||
Span sub(size_t begin, size_t length) { | |||||
mgb_assert(begin + length <= m_end - m_begin); | |||||
return {m_begin + begin, length}; | |||||
} | |||||
}; | |||||
} // namespace mgb::imperative |
@@ -16,6 +16,7 @@ | |||||
#include <tuple> | #include <tuple> | ||||
#include <type_traits> | #include <type_traits> | ||||
#include "megbrain/imperative/utils/span.h" | |||||
#include "megbrain/tensor.h" | #include "megbrain/tensor.h" | ||||
#include "megbrain/utils/small_vector.h" | #include "megbrain/utils/small_vector.h" | ||||
@@ -60,6 +61,22 @@ struct ToStringTrait<SmallVector<T, N>> { | |||||
}; | }; | ||||
template <typename T> | template <typename T> | ||||
struct ToStringTrait<std::vector<T>> { | |||||
std::string operator()(const std::vector<T>& v) const { | |||||
if (v.empty()) { | |||||
return "[]"; | |||||
} | |||||
std::string result = "["; | |||||
result += to_string(v[0]); | |||||
for (size_t i = 1; i < v.size(); ++i) { | |||||
result += ", "; | |||||
result += to_string(v[i]); | |||||
} | |||||
return result + "]"; | |||||
} | |||||
}; | |||||
template <typename T> | |||||
struct ToStringTrait<std::shared_ptr<T>> { | struct ToStringTrait<std::shared_ptr<T>> { | ||||
std::string operator()(const std::shared_ptr<T>& sp) const { | std::string operator()(const std::shared_ptr<T>& sp) const { | ||||
return to_string(sp.get()); | return to_string(sp.get()); | ||||
@@ -115,4 +132,36 @@ struct ToStringTrait<CompNode> { | |||||
std::string operator()(CompNode device) const { return device.to_string(); } | std::string operator()(CompNode device) const { return device.to_string(); } | ||||
}; | }; | ||||
inline std::string string_join(Span<std::string> span, char delimiter = ',') { | |||||
std::string buffer = "["; | |||||
for (size_t i = 1; i < span.size(); ++i) { | |||||
if (i) { | |||||
buffer.push_back(delimiter); | |||||
} | |||||
buffer.append(span[0]); | |||||
} | |||||
return buffer + "]"; | |||||
} | |||||
template <typename T> | |||||
struct ToStringTrait<Span<T>> { | |||||
std::string operator()(Span<T> span) const { | |||||
if (span.size() == 0) { | |||||
return "[]"; | |||||
} | |||||
std::string result = "["; | |||||
result += to_string(span[0]); | |||||
for (size_t i = 1; i < span.size(); ++i) { | |||||
result += ", "; | |||||
result += to_string(span[i]); | |||||
} | |||||
return result + "]"; | |||||
} | |||||
}; | |||||
template <> | |||||
struct ToStringTrait<std::type_info> { | |||||
std::string operator()(const std::type_info& info) const { return info.name(); } | |||||
}; | |||||
} // namespace mgb::imperative | } // namespace mgb::imperative |
@@ -0,0 +1,104 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/utils/visit.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <vector> | |||||
#include "megbrain/imperative/utils/span.h" | |||||
#include "megbrain/tensor.h" | |||||
namespace mgb::imperative { | |||||
/** | |||||
* \brief like TensorShape, but allow real scalar shape. | |||||
* | |||||
*/ | |||||
struct ValueShape { | |||||
size_t shape[TensorShape::MAX_NDIM]; | |||||
int ndim = 0; | |||||
ValueShape() = default; | |||||
ValueShape(std::initializer_list<size_t> dims) { | |||||
for (auto&& dim : dims) { | |||||
shape[ndim++] = dim; | |||||
} | |||||
} | |||||
ValueShape(Span<size_t> dims) { | |||||
for (auto&& dim : dims) { | |||||
shape[ndim++] = dim; | |||||
} | |||||
} | |||||
size_t& operator[](int axis) { return shape[axis]; } | |||||
size_t operator[](int axis) const { return shape[axis]; } | |||||
size_t at(int axis) const { | |||||
mgb_assert(axis < ndim); | |||||
return shape[axis]; | |||||
} | |||||
size_t total_nr_elems() const { | |||||
size_t prod = 1; | |||||
for (int i = 0; i < ndim; ++i) { | |||||
prod *= shape[i]; | |||||
} | |||||
return prod; | |||||
} | |||||
bool is_scalar() const { return ndim == 0; } | |||||
std::string to_string() const { | |||||
std::string buffer = "{"; | |||||
for (size_t i = 0; i < ndim; ++i) { | |||||
if (i) { | |||||
buffer.append(","); | |||||
} | |||||
buffer.append(std::to_string(shape[i])); | |||||
} | |||||
buffer.append("}"); | |||||
return buffer; | |||||
} | |||||
static ValueShape from(TensorShape tensor_shape) { | |||||
mgb_assert(tensor_shape.ndim); | |||||
return Span<size_t>{tensor_shape.shape, tensor_shape.ndim}; | |||||
} | |||||
TensorShape as_tensor_shape() const { | |||||
mgb_assert(ndim != 0); | |||||
TensorShape ret; | |||||
for (size_t i = 0; i < ndim; ++i) { | |||||
ret.shape[i] = shape[i]; | |||||
} | |||||
ret.ndim = ndim; | |||||
return ret; | |||||
} | |||||
bool operator==(const ValueShape& rhs) const { | |||||
if (ndim != rhs.ndim) { | |||||
return false; | |||||
} | |||||
for (size_t i = 0; i < ndim; ++i) { | |||||
if (shape[i] != rhs.shape[i]) { | |||||
return false; | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
}; | |||||
static_assert(sizeof(size_t) >= sizeof(int)); | |||||
static_assert(TensorShape::MAX_NDIM == 7); | |||||
static_assert(sizeof(ValueShape) <= sizeof(size_t) * 8); | |||||
} // namespace mgb::imperative |
@@ -0,0 +1,26 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/utils/visit.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <vector> | |||||
#include "megbrain/utils/small_vector.h" | |||||
namespace mgb::imperative { | |||||
template <typename... TVisitors> | |||||
class Visitor : public TVisitors... { | |||||
public: | |||||
using TVisitors::operator()...; | |||||
}; | |||||
} // namespace mgb::imperative |
@@ -28,10 +28,10 @@ TEST(TestProfiler, ImperativeLogProfile) { | |||||
auto results = imperative::Profiler::collect(); | auto results = imperative::Profiler::collect(); | ||||
imperative::Profiler::stop_profile(); | imperative::Profiler::stop_profile(); | ||||
mgb_assert(results.entries.size() == 2); | mgb_assert(results.entries.size() == 2); | ||||
auto* event_start = results.entries[0].data.as<profiler::CustomEvent>(); | |||||
auto* event_finish = results.entries[1].data.as<profiler::CustomFinishEvent>(); | |||||
mgb_assert(event_start && event_start->title == "XXX"); | |||||
mgb_assert(event_finish && event_finish->title == "XXX"); | |||||
auto& event_start = results.entries[0].data.cast<profiler::CustomEvent>(); | |||||
auto& event_finish = results.entries[1].data.cast<profiler::CustomFinishEvent>(); | |||||
mgb_assert(event_start.title == "XXX"); | |||||
mgb_assert(event_finish.title == "XXX"); | |||||
mgb_assert(results.entries[0].time < results.entries[1].time); | mgb_assert(results.entries[0].time < results.entries[1].time); | ||||
mgb_assert(results.entries[0].id < results.entries[1].id); | mgb_assert(results.entries[0].id < results.entries[1].id); | ||||
} | } |