diff --git a/dnn/src/common/api_cache.h b/dnn/src/common/api_cache.h index 9009f5e1..c39589bc 100644 --- a/dnn/src/common/api_cache.h +++ b/dnn/src/common/api_cache.h @@ -12,79 +12,19 @@ #pragma once -#include #include #include -#include #include #include #include "megdnn/thin/function.h" -#include "./utils.h" - namespace megdnn { - -// https://jfdube.wordpress.com/2014/01/03/implementing-a-recursive-read-write-spinlock/ -class RWSpin { -public: - class Lock { - private: - RWSpin* m_spin; - void (RWSpin::*m_lock)(void); - void (RWSpin::*m_unlock)(void); - - public: - Lock(RWSpin* spin, decltype(m_lock) lock, decltype(m_unlock) unlock) - : m_spin{spin}, m_lock{lock}, m_unlock{unlock} {} - void lock() { (m_spin->*m_lock)(); } - void unlock() { (m_spin->*m_unlock)(); } - }; - -private: - std::atomic m_atomic{0}; - - static constexpr uint32_t sm_reader_mask = 0x7FFFFFFF; - static constexpr uint32_t sm_writer_mask = 0x80000000; - - void _reader_lock() { - uint32_t expected = m_atomic; - do { - expected &= sm_reader_mask; - } while (!m_atomic.compare_exchange_strong(expected, expected + 1)); - } - void _reader_unlock() { m_atomic--; } - void _writer_lock() { - uint32_t expected = m_atomic; - do { - expected &= sm_reader_mask; - } while (!m_atomic.compare_exchange_strong(expected, - expected | sm_writer_mask)); - while (m_atomic.load() != sm_writer_mask) - ; - } - void _writer_unlock() { - // assert m_atomic == sm_writer_mask - m_atomic = 0; - } - -public: - Lock reader() { - return {this, &RWSpin::_reader_lock, &RWSpin::_reader_unlock}; - } - Lock writer() { - return {this, &RWSpin::_writer_lock, &RWSpin::_writer_unlock}; - } -}; - -template -class FunctionCache; - -template -class FunctionCache { +template +class FunctionCache { public: using key_t = std::string; - using value_t = TRet; + using value_t = std::string; using key_mapper_t = thin_function; using value_mapper_t = thin_function; using storage_t = std::unordered_map; @@ -93,30 +33,12 @@ public: key_mapper_t key_mapper; value_mapper_t value_mapper; - RWSpin spin; - -public: - TRet operator()(TArgs... args) { + value_t operator()(TArgs... args) { key_t key = key_mapper(args...); - auto reader_lock = spin.reader(); - auto writer_lock = spin.writer(); - { - MEGDNN_LOCK_GUARD(reader_lock); - auto iter = storage.find(key); - if (iter != storage.end()) { - return iter->second; - } - } - // RWSpin doesn't support upgrade - { - MEGDNN_LOCK_GUARD(writer_lock); - if (storage.count(key) != 0) { - return storage[key]; - } - value_t ret = value_mapper(std::forward(args)...); - storage[key] = ret; - return ret; + if (storage.count(key) == 0) { + storage[key] = value_mapper(std::forward(args)...); } + return storage[key]; } }; @@ -129,8 +51,8 @@ private: public: template T read_plain() { - static_assert(std::is_trivially_copyable::value, "invalid type"); - T ret; + static_assert(std::is_trivially_copyable::value, "invalid type"); + T ret; memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); m_cursor += sizeof(T); return ret; @@ -141,8 +63,10 @@ public: "type should be trivially copyable"); m_buffer.append(reinterpret_cast(&value), sizeof(T)); } - std::string take() { return std::move(m_buffer); } - void reset(std::string new_buf) { + std::string take() { + return std::move(m_buffer); + } + void set(std::string new_buf) { m_cursor = 0; m_buffer = new_buf; } @@ -150,32 +74,26 @@ public: struct Empty {}; -// in: seq[1, 2, ..., m] -// out: seq[N+1, N+2, ... N+m] -template -static std::index_sequence inc_index_sequence( - std::index_sequence) { - return {}; -} - template class ParamBundle { private: - // out: Min, Min+1, ..., Max + template + static std::index_sequence add_all( + std::index_sequence) { + return {}; + } + template - using make_index_range = decltype( - inc_index_sequence(std::make_index_sequence())); + using make_index_range = + decltype(add_all(std::make_index_sequence())); - // store params in a tuple using storage_t = std::tuple...>; storage_t m_storage; - // deconstruct tuple and call functor template auto call_helper(TFunctor functor, std::index_sequence) { return functor(std::get(m_storage).value...); } - template auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence) { @@ -183,11 +101,9 @@ private: std::get(m_storage).serialize(ser, prev), std::index_sequence()); } - template auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} - template auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence) { @@ -195,11 +111,9 @@ private: ser, std::get(m_storage).deserialize(ser, prev), std::index_sequence()); } - template auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} - template void set_values_helper(std::index_sequence, TArg&& arg, TArgs&&... args) { @@ -207,7 +121,6 @@ private: set_values_helper(std::index_sequence(), std::forward(args)...); } - template void set_values_helper(std::index_sequence) { static_assert(sizeof...(Indices) == 0, "redundant indices"); @@ -219,26 +132,25 @@ public: return call_helper(std::forward(functor), std::make_index_sequence()); } - - // recursively store params into ser template void serialize_params(StringSerializer& ser) { static_assert(NEnd >= NBegin, "invalid range"); - serialize_helper(ser, Empty{}, make_index_range()); + serialize_helper( + ser, Empty{}, + add_all(std::make_index_sequence())); } - - // recursively load params from ser template void deserialize_params(StringSerializer& ser) { static_assert(NEnd >= NBegin, "invalid range"); - deserialize_helper(ser, Empty{}, make_index_range()); + deserialize_helper( + ser, Empty{}, + add_all(std::make_index_sequence())); } - - // recursively set params into m_storage template void set_values(TArgs&&... args) { - set_values_helper(make_index_range(), - std::forward(args)...); + set_values_helper( + add_all(std::make_index_sequence()), + std::forward(args)...); } }; @@ -246,12 +158,10 @@ template class Param { public: T value; - Empty serialize(StringSerializer& ser, Empty) { ser.write_plain(value); return Empty{}; } - Empty deserialize(StringSerializer& ser, Empty) { value = ser.read_plain(); return Empty{}; @@ -262,54 +172,42 @@ template , typename TInputs = std::tuple<>, typename TOutputs = std::tuple<>> class FunctionCacheBuilder { private: - // decl value with type of tuple-of-args static auto declargs() -> decltype(std::tuple_cat(std::declval(), std::declval())) { return {}; } - template static auto declfunction_helper(std::index_sequence) -> thin_function().value)( decltype(std::get(declargs()).value)...)> { return {}; } - - // decl value with type of original function static auto declfunction() { return declfunction_helper( std::make_index_sequence::value + std::tuple_size::value>()); } - template static auto declbundle_helper(std::index_sequence) -> ParamBundle(declargs()))...> { return {}; } - - // decl value with type of bundle-of-args static auto declbundle() { return declbundle_helper( std::make_index_sequence::value + std::tuple_size::value>()); } - - // type of original function using function_t = decltype(declfunction()); - // type of bundle-of-args using bundle_t = decltype(declbundle()); public: - // declare new return type, cannot be override template auto ret() { static_assert(std::is_same>::value, "return value redefinition"); return FunctionCacheBuilder{}; } - // declare new input template auto input() { using TNewInputs = decltype( @@ -317,7 +215,6 @@ public: std::make_tuple(std::declval()))); return FunctionCacheBuilder{}; } - // declare new output template auto output() { using TNewOutputs = decltype( @@ -325,20 +222,17 @@ public: std::make_tuple(std::declval()))); return FunctionCacheBuilder{}; } - // summary template function_t build(TFunctor func) { - auto cache = std::make_shared>(); - // bundle -> ser(in args) - cache->key_mapper = [](bundle_t bundle) { + FunctionCache cache; + cache.key_mapper = [](bundle_t bundle) { StringSerializer ser; bundle.template serialize_params<0, std::tuple_size::value>( ser); return ser.take(); }; - // bundle -> ser(out args) - cache->value_mapper = [=](bundle_t bundle) { + cache.value_mapper = [=](bundle_t bundle) { StringSerializer ser; TRet ret; ret.value = bundle.call_by(func); @@ -359,7 +253,7 @@ public: "args count mismatch"); bundle.template set_values<0, sizeof...(args)>( std::forward(args)...); - ser.reset((*cache)(bundle)); + ser.set(cache(bundle)); ret.deserialize(ser, Empty{}); constexpr size_t n_inputs = std::tuple_size::value; constexpr size_t n_outputs = std::tuple_size::value; @@ -384,7 +278,6 @@ public: } }; -// like RefParam but return *value while ser and deser. Working with ArrayParam template class RefArraySizeParam { public: @@ -398,7 +291,6 @@ public: } }; -// accept array length from previous param. Working with RefArraySizeParam template class ArrayParam { public: diff --git a/dnn/src/cuda/api_cache.h b/dnn/src/cuda/api_cache.h index f58f6d75..f6f51b75 100644 --- a/dnn/src/cuda/api_cache.h +++ b/dnn/src/cuda/api_cache.h @@ -20,7 +20,7 @@ class CudnnConvDescParam { public: cudnnConvolutionDescriptor_t value; Empty serialize(StringSerializer& ser, Empty) { - int nbDims = MEGDNN_MAX_NDIM; + constexpr int nbDims = MEGDNN_MAX_NDIM; int padA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM]; int dilationA[MEGDNN_MAX_NDIM]; @@ -59,7 +59,7 @@ class CudnnTensorDescParam { public: cudnnTensorDescriptor_t value; Empty serialize(StringSerializer& ser, Empty) { - int nbDims = MEGDNN_MAX_NDIM; + constexpr int nbDims = MEGDNN_MAX_NDIM; cudnnDataType_t dataType; int dimA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM]; @@ -74,7 +74,7 @@ public: return Empty{}; } Empty deserialize(StringSerializer& ser, Empty) { - int nbDims = MEGDNN_MAX_NDIM; + constexpr int nbDims = MEGDNN_MAX_NDIM; cudnnDataType_t dataType; int dimA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM]; @@ -92,7 +92,7 @@ class CudnnFilterDescParam { public: cudnnFilterDescriptor_t value; Empty serialize(StringSerializer& ser, Empty) { - int nbDims = MEGDNN_MAX_NDIM; + constexpr int nbDims = MEGDNN_MAX_NDIM; cudnnDataType_t dataType; cudnnTensorFormat_t format; int filterDimA[MEGDNN_MAX_NDIM]; @@ -107,7 +107,7 @@ public: return Empty{}; } Empty deserialize(StringSerializer& ser, Empty) { - int nbDims = MEGDNN_MAX_NDIM; + constexpr int nbDims = MEGDNN_MAX_NDIM; cudnnDataType_t dataType; cudnnTensorFormat_t format; int filterDimA[MEGDNN_MAX_NDIM];