|
|
@@ -12,19 +12,79 @@ |
|
|
|
|
|
|
|
#pragma once |
|
|
|
|
|
|
|
#include <atomic> |
|
|
|
#include <cstring> |
|
|
|
#include <memory> |
|
|
|
#include <mutex> |
|
|
|
#include <tuple> |
|
|
|
#include <unordered_map> |
|
|
|
|
|
|
|
#include "megdnn/thin/function.h" |
|
|
|
|
|
|
|
#include "./utils.h" |
|
|
|
|
|
|
|
namespace megdnn { |
|
|
|
template <typename... TArgs> |
|
|
|
class FunctionCache { |
|
|
|
|
|
|
|
// https://jfdube.wordpress.com/2014/01/03/implementing-a-recursive-read-write-spinlock/ |
|
|
|
class RWSpin { |
|
|
|
public: |
|
|
|
class Lock { |
|
|
|
private: |
|
|
|
RWSpin* m_spin; |
|
|
|
void (RWSpin::*m_lock)(void); |
|
|
|
void (RWSpin::*m_unlock)(void); |
|
|
|
|
|
|
|
public: |
|
|
|
Lock(RWSpin* spin, decltype(m_lock) lock, decltype(m_unlock) unlock) |
|
|
|
: m_spin{spin}, m_lock{lock}, m_unlock{unlock} {} |
|
|
|
void lock() { (m_spin->*m_lock)(); } |
|
|
|
void unlock() { (m_spin->*m_unlock)(); } |
|
|
|
}; |
|
|
|
|
|
|
|
private: |
|
|
|
std::atomic<uint32_t> m_atomic{0}; |
|
|
|
|
|
|
|
static constexpr uint32_t sm_reader_mask = 0x7FFFFFFF; |
|
|
|
static constexpr uint32_t sm_writer_mask = 0x80000000; |
|
|
|
|
|
|
|
void _reader_lock() { |
|
|
|
uint32_t expected = m_atomic; |
|
|
|
do { |
|
|
|
expected &= sm_reader_mask; |
|
|
|
} while (!m_atomic.compare_exchange_strong(expected, expected + 1)); |
|
|
|
} |
|
|
|
void _reader_unlock() { m_atomic--; } |
|
|
|
void _writer_lock() { |
|
|
|
uint32_t expected = m_atomic; |
|
|
|
do { |
|
|
|
expected &= sm_reader_mask; |
|
|
|
} while (!m_atomic.compare_exchange_strong(expected, |
|
|
|
expected | sm_writer_mask)); |
|
|
|
while (m_atomic.load() != sm_writer_mask) |
|
|
|
; |
|
|
|
} |
|
|
|
void _writer_unlock() { |
|
|
|
// assert m_atomic == sm_writer_mask |
|
|
|
m_atomic = 0; |
|
|
|
} |
|
|
|
|
|
|
|
public: |
|
|
|
Lock reader() { |
|
|
|
return {this, &RWSpin::_reader_lock, &RWSpin::_reader_unlock}; |
|
|
|
} |
|
|
|
Lock writer() { |
|
|
|
return {this, &RWSpin::_writer_lock, &RWSpin::_writer_unlock}; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
template <typename TSignature> |
|
|
|
class FunctionCache; |
|
|
|
|
|
|
|
template <typename TRet, typename... TArgs> |
|
|
|
class FunctionCache<TRet(TArgs...)> { |
|
|
|
public: |
|
|
|
using key_t = std::string; |
|
|
|
using value_t = std::string; |
|
|
|
using value_t = TRet; |
|
|
|
using key_mapper_t = thin_function<key_t(TArgs...)>; |
|
|
|
using value_mapper_t = thin_function<value_t(TArgs...)>; |
|
|
|
using storage_t = std::unordered_map<key_t, value_t>; |
|
|
@@ -33,12 +93,30 @@ public: |
|
|
|
key_mapper_t key_mapper; |
|
|
|
value_mapper_t value_mapper; |
|
|
|
|
|
|
|
value_t operator()(TArgs... args) { |
|
|
|
RWSpin spin; |
|
|
|
|
|
|
|
public: |
|
|
|
TRet operator()(TArgs... args) { |
|
|
|
key_t key = key_mapper(args...); |
|
|
|
if (storage.count(key) == 0) { |
|
|
|
storage[key] = value_mapper(std::forward<TArgs>(args)...); |
|
|
|
auto reader_lock = spin.reader(); |
|
|
|
auto writer_lock = spin.writer(); |
|
|
|
{ |
|
|
|
MEGDNN_LOCK_GUARD(reader_lock); |
|
|
|
auto iter = storage.find(key); |
|
|
|
if (iter != storage.end()) { |
|
|
|
return iter->second; |
|
|
|
} |
|
|
|
} |
|
|
|
// RWSpin doesn't support upgrade |
|
|
|
{ |
|
|
|
MEGDNN_LOCK_GUARD(writer_lock); |
|
|
|
if (storage.count(key) != 0) { |
|
|
|
return storage[key]; |
|
|
|
} |
|
|
|
value_t ret = value_mapper(std::forward<TArgs>(args)...); |
|
|
|
storage[key] = ret; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
return storage[key]; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
@@ -51,8 +129,8 @@ private: |
|
|
|
public: |
|
|
|
template <typename T> |
|
|
|
T read_plain() { |
|
|
|
static_assert(std::is_trivially_copyable<T>::value, "invalid type"); |
|
|
|
T ret; |
|
|
|
static_assert(std::is_trivially_copyable<T>::value, "invalid type"); |
|
|
|
T ret; |
|
|
|
memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); |
|
|
|
m_cursor += sizeof(T); |
|
|
|
return ret; |
|
|
@@ -63,10 +141,8 @@ public: |
|
|
|
"type should be trivially copyable"); |
|
|
|
m_buffer.append(reinterpret_cast<const char*>(&value), sizeof(T)); |
|
|
|
} |
|
|
|
std::string take() { |
|
|
|
return std::move(m_buffer); |
|
|
|
} |
|
|
|
void set(std::string new_buf) { |
|
|
|
std::string take() { return std::move(m_buffer); } |
|
|
|
void reset(std::string new_buf) { |
|
|
|
m_cursor = 0; |
|
|
|
m_buffer = new_buf; |
|
|
|
} |
|
|
@@ -74,26 +150,32 @@ public: |
|
|
|
|
|
|
|
struct Empty {}; |
|
|
|
|
|
|
|
// in: seq[1, 2, ..., m] |
|
|
|
// out: seq[N+1, N+2, ... N+m] |
|
|
|
template <std::size_t N, std::size_t... Seq> |
|
|
|
static std::index_sequence<N + Seq...> inc_index_sequence( |
|
|
|
std::index_sequence<Seq...>) { |
|
|
|
return {}; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename... TParams> |
|
|
|
class ParamBundle { |
|
|
|
private: |
|
|
|
template <std::size_t N, std::size_t... Seq> |
|
|
|
static std::index_sequence<N + Seq...> add_all( |
|
|
|
std::index_sequence<Seq...>) { |
|
|
|
return {}; |
|
|
|
} |
|
|
|
|
|
|
|
// out: Min, Min+1, ..., Max |
|
|
|
template <std::size_t Min, std::size_t Max> |
|
|
|
using make_index_range = |
|
|
|
decltype(add_all<Min>(std::make_index_sequence<Max - Min>())); |
|
|
|
using make_index_range = decltype( |
|
|
|
inc_index_sequence<Min>(std::make_index_sequence<Max - Min>())); |
|
|
|
|
|
|
|
// store params in a tuple |
|
|
|
using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>; |
|
|
|
storage_t m_storage; |
|
|
|
|
|
|
|
// deconstruct tuple and call functor |
|
|
|
template <typename TFunctor, size_t... Indices> |
|
|
|
auto call_helper(TFunctor functor, std::index_sequence<Indices...>) { |
|
|
|
return functor(std::get<Indices>(m_storage).value...); |
|
|
|
} |
|
|
|
|
|
|
|
template <size_t Index, size_t... Indices, typename TPrev> |
|
|
|
auto serialize_helper(StringSerializer& ser, TPrev&& prev, |
|
|
|
std::index_sequence<Index, Indices...>) { |
|
|
@@ -101,9 +183,11 @@ private: |
|
|
|
std::get<Index>(m_storage).serialize(ser, prev), |
|
|
|
std::index_sequence<Indices...>()); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename TPrev> |
|
|
|
auto serialize_helper(StringSerializer& ser, TPrev&& prev, |
|
|
|
std::index_sequence<>) {} |
|
|
|
|
|
|
|
template <size_t Index, size_t... Indices, typename TPrev> |
|
|
|
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, |
|
|
|
std::index_sequence<Index, Indices...>) { |
|
|
@@ -111,9 +195,11 @@ private: |
|
|
|
ser, std::get<Index>(m_storage).deserialize(ser, prev), |
|
|
|
std::index_sequence<Indices...>()); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename TPrev> |
|
|
|
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, |
|
|
|
std::index_sequence<>) {} |
|
|
|
|
|
|
|
template <size_t Index, size_t... Indices, typename TArg, typename... TArgs> |
|
|
|
void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, |
|
|
|
TArgs&&... args) { |
|
|
@@ -121,6 +207,7 @@ private: |
|
|
|
set_values_helper(std::index_sequence<Indices...>(), |
|
|
|
std::forward<TArgs>(args)...); |
|
|
|
} |
|
|
|
|
|
|
|
template <size_t... Indices> |
|
|
|
void set_values_helper(std::index_sequence<Indices...>) { |
|
|
|
static_assert(sizeof...(Indices) == 0, "redundant indices"); |
|
|
@@ -132,25 +219,26 @@ public: |
|
|
|
return call_helper(std::forward<TFunctor>(functor), |
|
|
|
std::make_index_sequence<sizeof...(TParams)>()); |
|
|
|
} |
|
|
|
|
|
|
|
// recursively store params into ser |
|
|
|
template <size_t NBegin, size_t NEnd> |
|
|
|
void serialize_params(StringSerializer& ser) { |
|
|
|
static_assert(NEnd >= NBegin, "invalid range"); |
|
|
|
serialize_helper( |
|
|
|
ser, Empty{}, |
|
|
|
add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>())); |
|
|
|
serialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>()); |
|
|
|
} |
|
|
|
|
|
|
|
// recursively load params from ser |
|
|
|
template <size_t NBegin, size_t NEnd> |
|
|
|
void deserialize_params(StringSerializer& ser) { |
|
|
|
static_assert(NEnd >= NBegin, "invalid range"); |
|
|
|
deserialize_helper( |
|
|
|
ser, Empty{}, |
|
|
|
add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>())); |
|
|
|
deserialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>()); |
|
|
|
} |
|
|
|
|
|
|
|
// recursively set params into m_storage |
|
|
|
template <size_t NBegin, size_t NEnd, typename... TArgs> |
|
|
|
void set_values(TArgs&&... args) { |
|
|
|
set_values_helper( |
|
|
|
add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>()), |
|
|
|
std::forward<TArgs>(args)...); |
|
|
|
set_values_helper(make_index_range<NBegin, NEnd>(), |
|
|
|
std::forward<TArgs>(args)...); |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
@@ -158,10 +246,12 @@ template <typename T> |
|
|
|
class Param { |
|
|
|
public: |
|
|
|
T value; |
|
|
|
|
|
|
|
Empty serialize(StringSerializer& ser, Empty) { |
|
|
|
ser.write_plain(value); |
|
|
|
return Empty{}; |
|
|
|
} |
|
|
|
|
|
|
|
Empty deserialize(StringSerializer& ser, Empty) { |
|
|
|
value = ser.read_plain<T>(); |
|
|
|
return Empty{}; |
|
|
@@ -172,42 +262,54 @@ template <typename TRet = Param<Empty>, typename TInputs = std::tuple<>, |
|
|
|
typename TOutputs = std::tuple<>> |
|
|
|
class FunctionCacheBuilder { |
|
|
|
private: |
|
|
|
// decl value with type of tuple-of-args |
|
|
|
static auto declargs() |
|
|
|
-> decltype(std::tuple_cat(std::declval<TInputs>(), |
|
|
|
std::declval<TOutputs>())) { |
|
|
|
return {}; |
|
|
|
} |
|
|
|
|
|
|
|
template <size_t... Indices> |
|
|
|
static auto declfunction_helper(std::index_sequence<Indices...>) |
|
|
|
-> thin_function<decltype(std::declval<TRet>().value)( |
|
|
|
decltype(std::get<Indices>(declargs()).value)...)> { |
|
|
|
return {}; |
|
|
|
} |
|
|
|
|
|
|
|
// decl value with type of original function |
|
|
|
static auto declfunction() { |
|
|
|
return declfunction_helper( |
|
|
|
std::make_index_sequence<std::tuple_size<TInputs>::value + |
|
|
|
std::tuple_size<TOutputs>::value>()); |
|
|
|
} |
|
|
|
|
|
|
|
template <size_t... Indices> |
|
|
|
static auto declbundle_helper(std::index_sequence<Indices...>) |
|
|
|
-> ParamBundle<decltype(std::get<Indices>(declargs()))...> { |
|
|
|
return {}; |
|
|
|
} |
|
|
|
|
|
|
|
// decl value with type of bundle-of-args |
|
|
|
static auto declbundle() { |
|
|
|
return declbundle_helper( |
|
|
|
std::make_index_sequence<std::tuple_size<TInputs>::value + |
|
|
|
std::tuple_size<TOutputs>::value>()); |
|
|
|
} |
|
|
|
|
|
|
|
// type of original function |
|
|
|
using function_t = decltype(declfunction()); |
|
|
|
// type of bundle-of-args |
|
|
|
using bundle_t = decltype(declbundle()); |
|
|
|
|
|
|
|
public: |
|
|
|
// declare new return type, cannot be override |
|
|
|
template <typename TNewRet> |
|
|
|
auto ret() { |
|
|
|
static_assert(std::is_same<TRet, Param<Empty>>::value, |
|
|
|
"return value redefinition"); |
|
|
|
return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{}; |
|
|
|
} |
|
|
|
// declare new input |
|
|
|
template <typename TNewInput> |
|
|
|
auto input() { |
|
|
|
using TNewInputs = decltype( |
|
|
@@ -215,6 +317,7 @@ public: |
|
|
|
std::make_tuple(std::declval<TNewInput>()))); |
|
|
|
return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{}; |
|
|
|
} |
|
|
|
// declare new output |
|
|
|
template <typename TNewOutput> |
|
|
|
auto output() { |
|
|
|
using TNewOutputs = decltype( |
|
|
@@ -222,17 +325,20 @@ public: |
|
|
|
std::make_tuple(std::declval<TNewOutput>()))); |
|
|
|
return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{}; |
|
|
|
} |
|
|
|
// summary |
|
|
|
template <typename TFunctor> |
|
|
|
function_t build(TFunctor func) { |
|
|
|
FunctionCache<bundle_t> cache; |
|
|
|
cache.key_mapper = [](bundle_t bundle) { |
|
|
|
auto cache = std::make_shared<FunctionCache<std::string(bundle_t)>>(); |
|
|
|
// bundle -> ser(in args) |
|
|
|
cache->key_mapper = [](bundle_t bundle) { |
|
|
|
StringSerializer ser; |
|
|
|
bundle.template serialize_params<0, |
|
|
|
std::tuple_size<TInputs>::value>( |
|
|
|
ser); |
|
|
|
return ser.take(); |
|
|
|
}; |
|
|
|
cache.value_mapper = [=](bundle_t bundle) { |
|
|
|
// bundle -> ser(out args) |
|
|
|
cache->value_mapper = [=](bundle_t bundle) { |
|
|
|
StringSerializer ser; |
|
|
|
TRet ret; |
|
|
|
ret.value = bundle.call_by(func); |
|
|
@@ -253,7 +359,7 @@ public: |
|
|
|
"args count mismatch"); |
|
|
|
bundle.template set_values<0, sizeof...(args)>( |
|
|
|
std::forward<decltype(args)>(args)...); |
|
|
|
ser.set(cache(bundle)); |
|
|
|
ser.reset((*cache)(bundle)); |
|
|
|
ret.deserialize(ser, Empty{}); |
|
|
|
constexpr size_t n_inputs = std::tuple_size<TInputs>::value; |
|
|
|
constexpr size_t n_outputs = std::tuple_size<TOutputs>::value; |
|
|
@@ -278,6 +384,7 @@ public: |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
// like RefParam but return *value while ser and deser. Working with ArrayParam |
|
|
|
template <typename T> |
|
|
|
class RefArraySizeParam { |
|
|
|
public: |
|
|
@@ -291,6 +398,7 @@ public: |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
// accept array length from previous param. Working with RefArraySizeParam |
|
|
|
template <typename TSize, typename TItem> |
|
|
|
class ArrayParam { |
|
|
|
public: |
|
|
|