This reverts commitrelease-1.7188c62cdd6
. GitOrigin-RevId:92a82b8cd9
@@ -12,28 +12,32 @@ | |||||
#pragma once | #pragma once | ||||
#include <cstring> | |||||
#include <unordered_map> | |||||
#include <memory> | #include <memory> | ||||
#include <cstring> | |||||
#include <tuple> | #include <tuple> | ||||
#include <unordered_map> | |||||
#include "megdnn/thin/function.h" | #include "megdnn/thin/function.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
template <typename... TArgs> | |||||
class FunctionCache { | |||||
template <typename TSignature> | |||||
class FunctionCache; | |||||
template <typename TRet, typename... TArgs> | |||||
class FunctionCache<TRet(TArgs...)> { | |||||
public: | public: | ||||
using key_t = std::string; | using key_t = std::string; | ||||
using value_t = std::string; | |||||
using value_t = TRet; | |||||
using key_mapper_t = thin_function<key_t(TArgs...)>; | using key_mapper_t = thin_function<key_t(TArgs...)>; | ||||
using value_mapper_t = thin_function<value_t(TArgs...)>; | using value_mapper_t = thin_function<value_t(TArgs...)>; | ||||
using storage_t = std::unordered_map<key_t, value_t>; | using storage_t = std::unordered_map<key_t, value_t>; | ||||
public: | |||||
storage_t storage; | storage_t storage; | ||||
key_mapper_t key_mapper; | key_mapper_t key_mapper; | ||||
value_mapper_t value_mapper; | value_mapper_t value_mapper; | ||||
value_t operator()(TArgs... args) { | |||||
public: | |||||
TRet operator()(TArgs... args) { | |||||
key_t key = key_mapper(args...); | key_t key = key_mapper(args...); | ||||
if (storage.count(key) == 0) { | if (storage.count(key) == 0) { | ||||
storage[key] = value_mapper(std::forward<TArgs>(args)...); | storage[key] = value_mapper(std::forward<TArgs>(args)...); | ||||
@@ -42,28 +46,28 @@ public: | |||||
} | } | ||||
}; | }; | ||||
// FIFO | // FIFO | ||||
class StringSerializer { | class StringSerializer { | ||||
private: | private: | ||||
std::string m_buffer; | std::string m_buffer; | ||||
size_t m_cursor = 0; | size_t m_cursor = 0; | ||||
public: | public: | ||||
template <typename T> | template <typename T> | ||||
T read_plain() { | T read_plain() { | ||||
static_assert(std::is_trivially_copyable<T>::value, "invalid type"); | |||||
T ret; | |||||
memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); | |||||
T result; | |||||
std::memcpy(&result, m_buffer.data() + m_cursor, sizeof(T)); | |||||
m_cursor += sizeof(T); | m_cursor += sizeof(T); | ||||
return ret; | |||||
return result; | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
void write_plain(T value) { | void write_plain(T value) { | ||||
static_assert(std::is_trivially_copyable<T>::value, | |||||
"type should be trivially copyable"); | |||||
m_buffer.append(reinterpret_cast<const char*>(&value), sizeof(T)); | |||||
m_buffer.resize(m_buffer.size() + sizeof(T)); | |||||
std::memcpy(const_cast<char*>(m_buffer.data()) + (m_buffer.size() - sizeof(T)), &value, sizeof(T)); | |||||
} | } | ||||
std::string take() { | std::string take() { | ||||
std::string result; | |||||
m_buffer.erase(0, m_cursor); | |||||
return std::move(m_buffer); | return std::move(m_buffer); | ||||
} | } | ||||
void set(std::string new_buf) { | void set(std::string new_buf) { | ||||
@@ -72,20 +76,20 @@ public: | |||||
} | } | ||||
}; | }; | ||||
struct Empty {}; | struct Empty {}; | ||||
template <typename... TParams> | template <typename... TParams> | ||||
class ParamBundle { | class ParamBundle { | ||||
private: | private: | ||||
template <std::size_t N, std::size_t... Seq> | |||||
static std::index_sequence<N + Seq...> add_all( | |||||
std::index_sequence<Seq...>) { | |||||
template<std::size_t N, std::size_t... Seq> | |||||
static std::index_sequence<N + Seq ...> add_all(std::index_sequence<Seq...>){ | |||||
return {}; | return {}; | ||||
} | } | ||||
template <std::size_t Min, std::size_t Max> | |||||
using make_index_range = | |||||
decltype(add_all<Min>(std::make_index_sequence<Max - Min>())); | |||||
template<std::size_t Min, std::size_t Max> | |||||
using make_index_range = decltype(add_all<Min>(std::make_index_sequence<Max-Min>())); | |||||
using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>; | using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>; | ||||
storage_t m_storage; | storage_t m_storage; | ||||
@@ -95,31 +99,21 @@ private: | |||||
return functor(std::get<Indices>(m_storage).value...); | return functor(std::get<Indices>(m_storage).value...); | ||||
} | } | ||||
template <size_t Index, size_t... Indices, typename TPrev> | template <size_t Index, size_t... Indices, typename TPrev> | ||||
auto serialize_helper(StringSerializer& ser, TPrev&& prev, | |||||
std::index_sequence<Index, Indices...>) { | |||||
return serialize_helper(ser, | |||||
std::get<Index>(m_storage).serialize(ser, prev), | |||||
std::index_sequence<Indices...>()); | |||||
auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<Index, Indices...>) { | |||||
return serialize_helper(ser, std::get<Index>(m_storage).serialize(ser, prev), std::index_sequence<Indices...>()); | |||||
} | } | ||||
template <typename TPrev> | template <typename TPrev> | ||||
auto serialize_helper(StringSerializer& ser, TPrev&& prev, | |||||
std::index_sequence<>) {} | |||||
auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} | |||||
template <size_t Index, size_t... Indices, typename TPrev> | template <size_t Index, size_t... Indices, typename TPrev> | ||||
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, | |||||
std::index_sequence<Index, Indices...>) { | |||||
return deserialize_helper( | |||||
ser, std::get<Index>(m_storage).deserialize(ser, prev), | |||||
std::index_sequence<Indices...>()); | |||||
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<Index, Indices...>) { | |||||
return deserialize_helper(ser, std::get<Index>(m_storage).deserialize(ser, prev), std::index_sequence<Indices...>()); | |||||
} | } | ||||
template <typename TPrev> | template <typename TPrev> | ||||
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, | |||||
std::index_sequence<>) {} | |||||
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} | |||||
template <size_t Index, size_t... Indices, typename TArg, typename... TArgs> | template <size_t Index, size_t... Indices, typename TArg, typename... TArgs> | ||||
void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, | |||||
TArgs&&... args) { | |||||
void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, TArgs&&... args) { | |||||
std::get<Index>(m_storage).value = arg; | std::get<Index>(m_storage).value = arg; | ||||
set_values_helper(std::index_sequence<Indices...>(), | |||||
std::forward<TArgs>(args)...); | |||||
set_values_helper(std::index_sequence<Indices...>(), std::forward<TArgs>(args)...); | |||||
} | } | ||||
template <size_t... Indices> | template <size_t... Indices> | ||||
void set_values_helper(std::index_sequence<Indices...>) { | void set_values_helper(std::index_sequence<Indices...>) { | ||||
@@ -129,33 +123,27 @@ private: | |||||
public: | public: | ||||
template <typename TFunctor> | template <typename TFunctor> | ||||
auto call_by(TFunctor&& functor) { | auto call_by(TFunctor&& functor) { | ||||
return call_helper(std::forward<TFunctor>(functor), | |||||
std::make_index_sequence<sizeof...(TParams)>()); | |||||
return call_helper(std::forward<TFunctor>(functor), std::make_index_sequence<sizeof...(TParams)>()); | |||||
} | } | ||||
template <size_t NBegin, size_t NEnd> | template <size_t NBegin, size_t NEnd> | ||||
void serialize_params(StringSerializer& ser) { | void serialize_params(StringSerializer& ser) { | ||||
static_assert(NEnd >= NBegin, "invalid range"); | static_assert(NEnd >= NBegin, "invalid range"); | ||||
serialize_helper( | |||||
ser, Empty{}, | |||||
add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>())); | |||||
serialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>()); | |||||
} | } | ||||
template <size_t NBegin, size_t NEnd> | template <size_t NBegin, size_t NEnd> | ||||
void deserialize_params(StringSerializer& ser) { | void deserialize_params(StringSerializer& ser) { | ||||
static_assert(NEnd >= NBegin, "invalid range"); | static_assert(NEnd >= NBegin, "invalid range"); | ||||
deserialize_helper( | |||||
ser, Empty{}, | |||||
add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>())); | |||||
deserialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>()); | |||||
} | } | ||||
template <size_t NBegin, size_t NEnd, typename... TArgs> | template <size_t NBegin, size_t NEnd, typename... TArgs> | ||||
void set_values(TArgs&&... args) { | void set_values(TArgs&&... args) { | ||||
set_values_helper( | |||||
add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>()), | |||||
std::forward<TArgs>(args)...); | |||||
set_values_helper(make_index_range<NBegin, NEnd>(), std::forward<TArgs>(args)...); | |||||
} | } | ||||
}; | }; | ||||
template <typename T> | template <typename T> | ||||
class Param { | |||||
class RetParam { | |||||
public: | public: | ||||
T value; | T value; | ||||
Empty serialize(StringSerializer& ser, Empty) { | Empty serialize(StringSerializer& ser, Empty) { | ||||
@@ -168,68 +156,45 @@ public: | |||||
} | } | ||||
}; | }; | ||||
template <typename TRet = Param<Empty>, typename TInputs = std::tuple<>, | |||||
typename TOutputs = std::tuple<>> | |||||
template <typename TRet=RetParam<Empty>, typename TInputs=std::tuple<>, typename TOutputs=std::tuple<>> | |||||
class FunctionCacheBuilder { | class FunctionCacheBuilder { | ||||
private: | private: | ||||
static auto declargs() | |||||
-> decltype(std::tuple_cat(std::declval<TInputs>(), | |||||
std::declval<TOutputs>())) { | |||||
return {}; | |||||
} | |||||
static auto declargs() -> decltype(std::tuple_cat(std::declval<TInputs>(), std::declval<TOutputs>())) { return {}; } | |||||
template <size_t... Indices> | template <size_t... Indices> | ||||
static auto declfunction_helper(std::index_sequence<Indices...>) | |||||
-> thin_function<decltype(std::declval<TRet>().value)( | |||||
decltype(std::get<Indices>(declargs()).value)...)> { | |||||
return {}; | |||||
} | |||||
static auto declfunction_helper(std::index_sequence<Indices...>) -> thin_function<decltype(std::declval<TRet>().value)(decltype(std::get<Indices>(declargs()).value)...)> { return {}; } | |||||
static auto declfunction() { | static auto declfunction() { | ||||
return declfunction_helper( | |||||
std::make_index_sequence<std::tuple_size<TInputs>::value + | |||||
std::tuple_size<TOutputs>::value>()); | |||||
return declfunction_helper(std::make_index_sequence<std::tuple_size<TInputs>::value + std::tuple_size<TOutputs>::value>()); | |||||
} | } | ||||
template <size_t... Indices> | template <size_t... Indices> | ||||
static auto declbundle_helper(std::index_sequence<Indices...>) | |||||
-> ParamBundle<decltype(std::get<Indices>(declargs()))...> { | |||||
return {}; | |||||
} | |||||
static auto declbundle_helper(std::index_sequence<Indices...>) -> ParamBundle<decltype(std::get<Indices>(declargs()))...> { return {}; } | |||||
static auto declbundle() { | static auto declbundle() { | ||||
return declbundle_helper( | |||||
std::make_index_sequence<std::tuple_size<TInputs>::value + | |||||
std::tuple_size<TOutputs>::value>()); | |||||
return declbundle_helper(std::make_index_sequence<std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value>()); | |||||
} | } | ||||
using function_t = decltype(declfunction()); | using function_t = decltype(declfunction()); | ||||
using bundle_t = decltype(declbundle()); | using bundle_t = decltype(declbundle()); | ||||
public: | public: | ||||
template <typename TNewRet> | template <typename TNewRet> | ||||
auto ret() { | auto ret() { | ||||
static_assert(std::is_same<TRet, Param<Empty>>::value, | |||||
"return value redefinition"); | |||||
static_assert(std::is_same<TRet, RetParam<Empty>>::value, "return value redefinition"); | |||||
return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{}; | return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{}; | ||||
} | } | ||||
template <typename TNewInput> | template <typename TNewInput> | ||||
auto input() { | auto input() { | ||||
using TNewInputs = decltype( | |||||
std::tuple_cat(std::declval<TInputs>(), | |||||
std::make_tuple(std::declval<TNewInput>()))); | |||||
using TNewInputs = decltype(std::tuple_cat(std::declval<TInputs>(), std::make_tuple(std::declval<TNewInput>()))); | |||||
return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{}; | return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{}; | ||||
} | } | ||||
template <typename TNewOutput> | template <typename TNewOutput> | ||||
auto output() { | auto output() { | ||||
using TNewOutputs = decltype( | |||||
std::tuple_cat(std::declval<TOutputs>(), | |||||
std::make_tuple(std::declval<TNewOutput>()))); | |||||
using TNewOutputs = decltype(std::tuple_cat(std::declval<TOutputs>(), std::make_tuple(std::declval<TNewOutput>()))); | |||||
return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{}; | return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{}; | ||||
} | } | ||||
template <typename TFunctor> | template <typename TFunctor> | ||||
function_t build(TFunctor func) { | function_t build(TFunctor func) { | ||||
FunctionCache<bundle_t> cache; | |||||
FunctionCache<std::string(bundle_t)> cache; | |||||
cache.key_mapper = [](bundle_t bundle) { | cache.key_mapper = [](bundle_t bundle) { | ||||
StringSerializer ser; | StringSerializer ser; | ||||
bundle.template serialize_params<0, | |||||
std::tuple_size<TInputs>::value>( | |||||
ser); | |||||
bundle.template serialize_params<0, std::tuple_size<TInputs>::value>(ser); | |||||
return ser.take(); | return ser.take(); | ||||
}; | }; | ||||
cache.value_mapper = [=](bundle_t bundle) { | cache.value_mapper = [=](bundle_t bundle) { | ||||
@@ -237,33 +202,42 @@ public: | |||||
TRet ret; | TRet ret; | ||||
ret.value = bundle.call_by(func); | ret.value = bundle.call_by(func); | ||||
ret.serialize(ser, Empty{}); | ret.serialize(ser, Empty{}); | ||||
bundle.template serialize_params< | |||||
std::tuple_size<TInputs>::value, | |||||
std::tuple_size<TInputs>::value + | |||||
std::tuple_size<TOutputs>::value>(ser); | |||||
bundle.template serialize_params<std::tuple_size<TInputs>::value, std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value>(ser); | |||||
return ser.take(); | return ser.take(); | ||||
}; | }; | ||||
return [=](auto&&... args) mutable { | return [=](auto&&... args) mutable { | ||||
bundle_t bundle; | bundle_t bundle; | ||||
TRet ret; | TRet ret; | ||||
StringSerializer ser; | StringSerializer ser; | ||||
static_assert( | |||||
sizeof...(args) == std::tuple_size<TInputs>::value + | |||||
std::tuple_size<TOutputs>::value, | |||||
"args count mismatch"); | |||||
bundle.template set_values<0, sizeof...(args)>( | |||||
std::forward<decltype(args)>(args)...); | |||||
static_assert(sizeof...(args) == std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value, | |||||
"arg count mismatch"); | |||||
bundle.template set_values<0, sizeof...(args)>(std::forward<decltype(args)>(args)...); | |||||
ser.set(cache(bundle)); | ser.set(cache(bundle)); | ||||
ret.deserialize(ser, Empty{}); | ret.deserialize(ser, Empty{}); | ||||
constexpr size_t n_inputs = std::tuple_size<TInputs>::value; | constexpr size_t n_inputs = std::tuple_size<TInputs>::value; | ||||
constexpr size_t n_outputs = std::tuple_size<TOutputs>::value; | constexpr size_t n_outputs = std::tuple_size<TOutputs>::value; | ||||
bundle.template deserialize_params<n_inputs, n_inputs + n_outputs>( | |||||
ser); | |||||
bundle.template deserialize_params<n_inputs, n_inputs+n_outputs>(ser); | |||||
return ret.value; | return ret.value; | ||||
}; | }; | ||||
} | } | ||||
}; | }; | ||||
template <typename T> | |||||
class PlainParam { | |||||
public: | |||||
T value; | |||||
Empty serialize(StringSerializer& ser, Empty) { | |||||
ser.write_plain(value); | |||||
return Empty{}; | |||||
} | |||||
Empty deserialize(StringSerializer& ser, Empty) { | |||||
value = ser.read_plain<T>(); | |||||
return Empty{}; | |||||
} | |||||
}; | |||||
template <typename T> | template <typename T> | ||||
class RefParam { | class RefParam { | ||||
public: | public: | ||||
@@ -278,6 +252,7 @@ public: | |||||
} | } | ||||
}; | }; | ||||
template <typename T> | template <typename T> | ||||
class RefArraySizeParam { | class RefArraySizeParam { | ||||
public: | public: | ||||
@@ -291,6 +266,7 @@ public: | |||||
} | } | ||||
}; | }; | ||||
template <typename TSize, typename TItem> | template <typename TSize, typename TItem> | ||||
class ArrayParam { | class ArrayParam { | ||||
public: | public: | ||||
@@ -309,4 +285,4 @@ public: | |||||
} | } | ||||
}; | }; | ||||
} // namespace megdnn | |||||
} |
@@ -16,109 +16,105 @@ | |||||
#include "src/cuda/cudnn_wrapper.h" | #include "src/cuda/cudnn_wrapper.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
class CudnnConvDescParam { | |||||
public: | |||||
cudnnConvolutionDescriptor_t value; | |||||
Empty serialize(StringSerializer& ser, Empty) { | |||||
constexpr int nbDims = MEGDNN_MAX_NDIM; | |||||
int padA[MEGDNN_MAX_NDIM]; | |||||
int strideA[MEGDNN_MAX_NDIM]; | |||||
int dilationA[MEGDNN_MAX_NDIM]; | |||||
cudnnConvolutionMode_t mode; | |||||
cudnnDataType_t computeType; | |||||
cudnnGetConvolutionNdDescriptor(value, nbDims, &nbDims, padA, strideA, | |||||
dilationA, &mode, &computeType); | |||||
ser.write_plain(nbDims); | |||||
for (int i = 0; i < nbDims; ++i) { | |||||
ser.write_plain(padA[i]); | |||||
ser.write_plain(strideA[i]); | |||||
ser.write_plain(dilationA[i]); | |||||
class CudnnConvDescParam { | |||||
public: | |||||
cudnnConvolutionDescriptor_t value; | |||||
Empty serialize(StringSerializer& ser, Empty) { | |||||
int ndim = MEGDNN_MAX_NDIM; | |||||
int padA[MEGDNN_MAX_NDIM]; | |||||
int strideA[MEGDNN_MAX_NDIM]; | |||||
int dilationA[MEGDNN_MAX_NDIM]; | |||||
cudnnConvolutionMode_t mode; | |||||
cudnnDataType_t computeType; | |||||
cudnnGetConvolutionNdDescriptor(value, MEGDNN_MAX_NDIM, &ndim, padA, strideA, dilationA, &mode, &computeType); | |||||
ser.write_plain(ndim); | |||||
for (int i = 0; i < ndim; ++i) { | |||||
ser.write_plain(padA[i]); | |||||
ser.write_plain(strideA[i]); | |||||
ser.write_plain(dilationA[i]); | |||||
} | |||||
ser.write_plain(mode); | |||||
ser.write_plain(computeType); | |||||
return Empty{}; | |||||
} | } | ||||
ser.write_plain(mode); | |||||
ser.write_plain(computeType); | |||||
return Empty{}; | |||||
} | |||||
Empty deserialize(StringSerializer& ser, Empty) { | |||||
int ndim = ser.read_plain<int>(); | |||||
int padA[MEGDNN_MAX_NDIM]; | |||||
int strideA[MEGDNN_MAX_NDIM]; | |||||
int dilationA[MEGDNN_MAX_NDIM]; | |||||
for (int i = 0; i < ndim; ++i) { | |||||
padA[i] = ser.read_plain<int>(); | |||||
strideA[i] = ser.read_plain<int>(); | |||||
dilationA[i] = ser.read_plain<int>(); | |||||
Empty deserialize(StringSerializer& ser, Empty) { | |||||
int ndim = ser.read_plain<int>(); | |||||
int padA[MEGDNN_MAX_NDIM]; | |||||
int strideA[MEGDNN_MAX_NDIM]; | |||||
int dilationA[MEGDNN_MAX_NDIM]; | |||||
for (int i = 0; i < ndim; ++i) { | |||||
padA[i] = ser.read_plain<int>(); | |||||
strideA[i] = ser.read_plain<int>(); | |||||
dilationA[i] = ser.read_plain<int>(); | |||||
} | |||||
cudnnConvolutionMode_t mode = ser.read_plain<cudnnConvolutionMode_t>(); | |||||
cudnnDataType_t computeType = ser.read_plain<cudnnDataType_t>(); | |||||
cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, mode, computeType); | |||||
return Empty{}; | |||||
} | } | ||||
cudnnConvolutionMode_t mode = ser.read_plain<cudnnConvolutionMode_t>(); | |||||
cudnnDataType_t computeType = ser.read_plain<cudnnDataType_t>(); | |||||
cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, | |||||
mode, computeType); | |||||
return Empty{}; | |||||
} | |||||
}; | |||||
class CudnnTensorDescParam { | |||||
public: | |||||
cudnnTensorDescriptor_t value; | |||||
Empty serialize(StringSerializer& ser, Empty) { | |||||
constexpr int nbDims = MEGDNN_MAX_NDIM; | |||||
cudnnDataType_t dataType; | |||||
int dimA[MEGDNN_MAX_NDIM]; | |||||
int strideA[MEGDNN_MAX_NDIM]; | |||||
cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, | |||||
strideA); | |||||
ser.write_plain(nbDims); | |||||
for (int i = 0; i < nbDims; ++i) { | |||||
ser.write_plain(dimA[i]); | |||||
ser.write_plain(strideA[i]); | |||||
}; | |||||
class CudnnTensorDescParam { | |||||
public: | |||||
cudnnTensorDescriptor_t value; | |||||
Empty serialize(StringSerializer& ser, Empty) { | |||||
int nbDims = MEGDNN_MAX_NDIM; | |||||
cudnnDataType_t dataType; | |||||
int dimA[MEGDNN_MAX_NDIM]; | |||||
int strideA[MEGDNN_MAX_NDIM]; | |||||
cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, strideA); | |||||
ser.write_plain(nbDims); | |||||
for (int i = 0; i < nbDims; ++i) { | |||||
ser.write_plain(dimA[i]); | |||||
ser.write_plain(strideA[i]); | |||||
} | |||||
ser.write_plain(dataType); | |||||
return Empty{}; | |||||
} | } | ||||
ser.write_plain(dataType); | |||||
return Empty{}; | |||||
} | |||||
Empty deserialize(StringSerializer& ser, Empty) { | |||||
constexpr int nbDims = MEGDNN_MAX_NDIM; | |||||
cudnnDataType_t dataType; | |||||
int dimA[MEGDNN_MAX_NDIM]; | |||||
int strideA[MEGDNN_MAX_NDIM]; | |||||
nbDims = ser.read_plain<int>(); | |||||
for (int i = 0; i < nbDims; ++i) { | |||||
dimA[i] = ser.read_plain<int>(); | |||||
strideA[i] = ser.read_plain<int>(); | |||||
Empty deserialize(StringSerializer& ser, Empty) { | |||||
int nbDims = MEGDNN_MAX_NDIM; | |||||
cudnnDataType_t dataType; | |||||
int dimA[MEGDNN_MAX_NDIM]; | |||||
int strideA[MEGDNN_MAX_NDIM]; | |||||
nbDims = ser.read_plain<int>(); | |||||
for (int i = 0; i < nbDims; ++i) { | |||||
dimA[i] = ser.read_plain<int>(); | |||||
strideA[i] = ser.read_plain<int>(); | |||||
} | |||||
dataType = ser.read_plain<cudnnDataType_t>(); | |||||
cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA); | |||||
return Empty{}; | |||||
} | } | ||||
dataType = ser.read_plain<cudnnDataType_t>(); | |||||
cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA); | |||||
return Empty{}; | |||||
} | |||||
}; | |||||
class CudnnFilterDescParam { | |||||
public: | |||||
cudnnFilterDescriptor_t value; | |||||
Empty serialize(StringSerializer& ser, Empty) { | |||||
constexpr int nbDims = MEGDNN_MAX_NDIM; | |||||
cudnnDataType_t dataType; | |||||
cudnnTensorFormat_t format; | |||||
int filterDimA[MEGDNN_MAX_NDIM]; | |||||
cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, | |||||
filterDimA); | |||||
ser.write_plain(nbDims); | |||||
for (int i = 0; i < nbDims; ++i) { | |||||
ser.write_plain(filterDimA[i]); | |||||
}; | |||||
class CudnnFilterDescParam { | |||||
public: | |||||
cudnnFilterDescriptor_t value; | |||||
Empty serialize(StringSerializer& ser, Empty) { | |||||
int nbDims = MEGDNN_MAX_NDIM; | |||||
cudnnDataType_t dataType; | |||||
cudnnTensorFormat_t format; | |||||
int filterDimA[MEGDNN_MAX_NDIM]; | |||||
cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, filterDimA); | |||||
ser.write_plain(nbDims); | |||||
for (int i = 0; i < nbDims; ++i) { | |||||
ser.write_plain(filterDimA[i]); | |||||
} | |||||
ser.write_plain(dataType); | |||||
ser.write_plain(format); | |||||
return Empty{}; | |||||
} | } | ||||
ser.write_plain(dataType); | |||||
ser.write_plain(format); | |||||
return Empty{}; | |||||
} | |||||
Empty deserialize(StringSerializer& ser, Empty) { | |||||
constexpr int nbDims = MEGDNN_MAX_NDIM; | |||||
cudnnDataType_t dataType; | |||||
cudnnTensorFormat_t format; | |||||
int filterDimA[MEGDNN_MAX_NDIM]; | |||||
nbDims = ser.read_plain<int>(); | |||||
for (int i = 0; i < nbDims; ++i) { | |||||
filterDimA[i] = ser.read_plain<int>(); | |||||
Empty deserialize(StringSerializer& ser, Empty) { | |||||
int nbDims = MEGDNN_MAX_NDIM; | |||||
cudnnDataType_t dataType; | |||||
cudnnTensorFormat_t format; | |||||
int filterDimA[MEGDNN_MAX_NDIM]; | |||||
nbDims = ser.read_plain<int>(); | |||||
for (int i = 0; i < nbDims; ++i) { | |||||
filterDimA[i] = ser.read_plain<int>(); | |||||
} | |||||
dataType = ser.read_plain<cudnnDataType_t>(); | |||||
format = ser.read_plain<cudnnTensorFormat_t>(); | |||||
cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA); | |||||
return Empty{}; | |||||
} | } | ||||
dataType = ser.read_plain<cudnnDataType_t>(); | |||||
format = ser.read_plain<cudnnTensorFormat_t>(); | |||||
cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA); | |||||
return Empty{}; | |||||
} | |||||
}; | |||||
} // namespace megdnn | |||||
}; | |||||
} |
@@ -56,8 +56,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( | |||||
conv_args.init_conv_desc(D); | conv_args.init_conv_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto& cudnn = conv_args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
conv_args.handle->cudnn_handle(), D.src_desc.desc, | conv_args.handle->cudnn_handle(), D.src_desc.desc, | ||||
D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | ||||
m_cudnn_enum, &workspace_size); | m_cudnn_enum, &workspace_size); | ||||
@@ -83,8 +82,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_bundle( | |||||
conv_args.init_conv_desc(D); | conv_args.init_conv_desc(D); | ||||
size_t conv_workspace_size; | size_t conv_workspace_size; | ||||
auto& cudnn = conv_args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
conv_args.handle->cudnn_handle(), D.src_desc.desc, | conv_args.handle->cudnn_handle(), D.src_desc.desc, | ||||
D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | ||||
m_cudnn_enum, &conv_workspace_size); | m_cudnn_enum, &conv_workspace_size); | ||||
@@ -149,8 +149,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||||
megdnn_throw("unsupported NonlineMode"); | megdnn_throw("unsupported NonlineMode"); | ||||
} | } | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | ||||
D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | ||||
&workspace_size); | &workspace_size); | ||||
@@ -163,8 +162,7 @@ size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::get_workspace_in_bytes( | |||||
args.init_conv_bias_desc(D); | args.init_conv_bias_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | ||||
D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | ||||
&workspace_size); | &workspace_size); | ||||
@@ -95,13 +95,12 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
CUDNNForwardDescs desc; | CUDNNForwardDescs desc; | ||||
conv_args.init_conv_desc(desc); | conv_args.init_conv_desc(desc); | ||||
#if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
auto& cudnn = static_cast<HandleImpl*>(this->handle())->cudnn(); | |||||
int max_count = 0; | int max_count = 0; | ||||
cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount(cudnn_handle, | |||||
cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, | |||||
&max_count)); | &max_count)); | ||||
SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(max_count); | SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(max_count); | ||||
int ret_count = 0; | int ret_count = 0; | ||||
cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7( | |||||
cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7( | |||||
cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | ||||
desc.conv_desc.conv_desc, desc.dst_desc.desc, max_count, | desc.conv_desc.conv_desc, desc.dst_desc.desc, max_count, | ||||
&ret_count, algo_perf.data())); | &ret_count, algo_perf.data())); | ||||
@@ -42,10 +42,9 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( | |||||
if (!conv_bias::is_cudnn_supported(bias_args)) | if (!conv_bias::is_cudnn_supported(bias_args)) | ||||
return false; | return false; | ||||
auto& cudnn = args.handle->cudnn(); | |||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||||
auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.filter_desc.desc, | D.filter_desc.desc, | ||||
D.diff_desc.desc, | D.diff_desc.desc, | ||||
@@ -58,11 +57,10 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( | |||||
size_t ConvolutionBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( | size_t ConvolutionBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( | ||||
const SizeArgs &args) const { | const SizeArgs &args) const { | ||||
auto& cudnn = args.handle->cudnn(); | |||||
CUDNNBwdDataDescs D; | CUDNNBwdDataDescs D; | ||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||||
auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.filter_desc.desc, | D.filter_desc.desc, | ||||
D.diff_desc.desc, | D.diff_desc.desc, | ||||
@@ -29,7 +29,6 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
auto& cudnn = args.handle->cudnn(); | |||||
CUDNNBwdFilterDescs D; | CUDNNBwdFilterDescs D; | ||||
TensorLayout bias_layout, z_layout; | TensorLayout bias_layout, z_layout; | ||||
@@ -44,7 +43,7 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | |||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||||
auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.src_desc.desc, | D.src_desc.desc, | ||||
D.diff_desc.desc, | D.diff_desc.desc, | ||||
@@ -57,11 +56,10 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | |||||
size_t ConvolutionBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( | size_t ConvolutionBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( | ||||
const SizeArgs &args) const { | const SizeArgs &args) const { | ||||
auto& cudnn = args.handle->cudnn(); | |||||
CUDNNBwdFilterDescs D; | CUDNNBwdFilterDescs D; | ||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||||
auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.src_desc.desc, | D.src_desc.desc, | ||||
D.diff_desc.desc, | D.diff_desc.desc, | ||||
@@ -144,13 +144,12 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||||
#if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
MEGDNN_MARK_USED_VAR(negative_attr); | MEGDNN_MARK_USED_VAR(negative_attr); | ||||
auto& cudnn = args.handle->cudnn(); | |||||
int max_count = 0; | int max_count = 0; | ||||
cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithmMaxCount( | |||||
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( | |||||
cudnn_handle, &max_count)); | cudnn_handle, &max_count)); | ||||
SmallVector<cudnnConvolutionBwdDataAlgoPerf_t> algo_perf(max_count); | SmallVector<cudnnConvolutionBwdDataAlgoPerf_t> algo_perf(max_count); | ||||
int ret_count = 0; | int ret_count = 0; | ||||
cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithm_v7( | |||||
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm_v7( | |||||
cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, | cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, | ||||
desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | ||||
algo_perf.data())); | algo_perf.data())); | ||||
@@ -280,13 +279,12 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||||
#endif | #endif | ||||
#if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
MEGDNN_MARK_USED_VAR(negative_attr); | MEGDNN_MARK_USED_VAR(negative_attr); | ||||
auto& cudnn = args.handle->cudnn(); | |||||
int max_count = 0; | int max_count = 0; | ||||
cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithmMaxCount( | |||||
cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( | |||||
cudnn_handle, &max_count)); | cudnn_handle, &max_count)); | ||||
SmallVector<cudnnConvolutionBwdFilterAlgoPerf_t> algo_perf(max_count); | SmallVector<cudnnConvolutionBwdFilterAlgoPerf_t> algo_perf(max_count); | ||||
int ret_count = 0; | int ret_count = 0; | ||||
cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithm_v7( | |||||
cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm_v7( | |||||
cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, | cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, | ||||
desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | ||||
algo_perf.data())); | algo_perf.data())); | ||||
@@ -28,8 +28,7 @@ bool Convolution3DBackwardDataImpl::AlgoCUDNN::is_available( | |||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||||
auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.filter_desc.desc, | D.filter_desc.desc, | ||||
D.diff_desc.desc, | D.diff_desc.desc, | ||||
@@ -45,8 +44,7 @@ size_t Convolution3DBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( | |||||
CUDNNBwdDataDescs D; | CUDNNBwdDataDescs D; | ||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||||
auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.filter_desc.desc, | D.filter_desc.desc, | ||||
D.diff_desc.desc, | D.diff_desc.desc, | ||||
@@ -28,8 +28,7 @@ bool Convolution3DBackwardFilterImpl::AlgoCUDNN::is_available( | |||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||||
auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||||
args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | ||||
D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | ||||
return status == CUDNN_STATUS_SUCCESS; | return status == CUDNN_STATUS_SUCCESS; | ||||
@@ -41,8 +40,7 @@ size_t Convolution3DBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( | |||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||||
auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||||
args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | ||||
D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | ||||
megdnn_assert(status == CUDNN_STATUS_SUCCESS, | megdnn_assert(status == CUDNN_STATUS_SUCCESS, | ||||
@@ -27,8 +27,7 @@ bool Convolution3DForwardImpl::AlgoCUDNN::is_available( | |||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.src_desc.desc, | D.src_desc.desc, | ||||
D.filter_desc.desc, | D.filter_desc.desc, | ||||
@@ -44,8 +43,7 @@ size_t Convolution3DForwardImpl::AlgoCUDNN::get_workspace_in_bytes( | |||||
CUDNNForwardDescs D; | CUDNNForwardDescs D; | ||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.src_desc.desc, | D.src_desc.desc, | ||||
D.filter_desc.desc, | D.filter_desc.desc, | ||||
@@ -93,7 +93,7 @@ namespace convolution3d { | |||||
const Workspace &workspace, void *&raw_ptr); | const Workspace &workspace, void *&raw_ptr); | ||||
inline bool cudnn_get_convolution_fwd_algo_helper( | inline bool cudnn_get_convolution_fwd_algo_helper( | ||||
Handle* handle, const cudnnTensorDescriptor_t x_desc, | |||||
cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc, | |||||
const cudnnFilterDescriptor_t w_desc, | const cudnnFilterDescriptor_t w_desc, | ||||
const cudnnConvolutionDescriptor_t conv_desc, | const cudnnConvolutionDescriptor_t conv_desc, | ||||
const cudnnTensorDescriptor_t y_desc, | const cudnnTensorDescriptor_t y_desc, | ||||
@@ -103,14 +103,13 @@ namespace convolution3d { | |||||
MEGDNN_MARK_USED_VAR(positive_attr); | MEGDNN_MARK_USED_VAR(positive_attr); | ||||
MEGDNN_MARK_USED_VAR(negative_attr); | MEGDNN_MARK_USED_VAR(negative_attr); | ||||
#if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
auto& cudnn = static_cast<HandleImpl*>(handle)->cudnn(); | |||||
int algo_max_count = 0; | int algo_max_count = 0; | ||||
cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount( | |||||
cuda::cudnn_handle(handle), &algo_max_count)); | |||||
cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount( | |||||
cudnn_handle, &algo_max_count)); | |||||
SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(algo_max_count); | SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(algo_max_count); | ||||
int algo_count = 0; | int algo_count = 0; | ||||
cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7( | |||||
cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, algo_max_count, | |||||
cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7( | |||||
cudnn_handle, x_desc, w_desc, conv_desc, y_desc, algo_max_count, | |||||
&algo_count, algo_perf.data())); | &algo_count, algo_perf.data())); | ||||
for (int i = 0; i < algo_count; ++i) { | for (int i = 0; i < algo_count; ++i) { | ||||
if (algo_perf[i].algo == | if (algo_perf[i].algo == | ||||
@@ -118,8 +117,8 @@ namespace convolution3d { | |||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) | CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) | ||||
continue; | continue; | ||||
size_t workspace_size = 0; | size_t workspace_size = 0; | ||||
cudnn_check(cudnn.GetConvolutionForwardWorkspaceSize( | |||||
cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, | |||||
cudnn_check(cudnnGetConvolutionForwardWorkspaceSize( | |||||
cudnn_handle, x_desc, w_desc, conv_desc, y_desc, | |||||
algo_perf[i].algo, &workspace_size)); | algo_perf[i].algo, &workspace_size)); | ||||
if (workspace_size > workspace_limit_in_bytes) continue; | if (workspace_size > workspace_limit_in_bytes) continue; | ||||
if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { | if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { | ||||
@@ -135,7 +134,7 @@ namespace convolution3d { | |||||
return false; | return false; | ||||
#else | #else | ||||
cudnn_check(cudnnGetConvolutionForwardAlgorithm( | cudnn_check(cudnnGetConvolutionForwardAlgorithm( | ||||
cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, | |||||
cudnn_handle, x_desc, w_desc, conv_desc, y_desc, | |||||
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, | CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, | ||||
workspace_limit_in_bytes, algo)); | workspace_limit_in_bytes, algo)); | ||||
return true; | return true; | ||||
@@ -64,12 +64,13 @@ Convolution3DForwardImpl::get_algorithm_heuristic( | |||||
auto get_cudnn_algo = | auto get_cudnn_algo = | ||||
[this, &args, workspace_limit_in_bytes, positive_attr, | [this, &args, workspace_limit_in_bytes, positive_attr, | ||||
negative_attr]() -> Convolution3DForwardImpl::AlgoBase* { | negative_attr]() -> Convolution3DForwardImpl::AlgoBase* { | ||||
auto cudnn_handle = cuda::cudnn_handle(this->handle()); | |||||
cudnnConvolutionFwdAlgo_t algo; | cudnnConvolutionFwdAlgo_t algo; | ||||
CUDNNForwardDescs desc; | CUDNNForwardDescs desc; | ||||
args.init_desc(desc); | args.init_desc(desc); | ||||
bool got = cudnn_get_convolution_fwd_algo_helper( | bool got = cudnn_get_convolution_fwd_algo_helper( | ||||
this->handle(), desc.src_desc.desc, desc.filter_desc.desc, | |||||
cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | |||||
desc.conv_desc.desc, desc.dst_desc.desc, | desc.conv_desc.desc, desc.dst_desc.desc, | ||||
workspace_limit_in_bytes, &algo, positive_attr, negative_attr); | workspace_limit_in_bytes, &algo, positive_attr, negative_attr); | ||||
if (got) { | if (got) { | ||||
@@ -56,7 +56,7 @@ namespace convolution { | |||||
using KernLayout = _kern_layout; \ | using KernLayout = _kern_layout; \ | ||||
using OutputLayout = _output_layout; \ | using OutputLayout = _output_layout; \ | ||||
using Param = _conv_param; \ | using Param = _conv_param; \ | ||||
static constexpr bool check_bounds = check_bounds_ | |||||
static constexpr bool check_bounds = check_bounds_; | |||||
#define MEGDNN_COMMA , | #define MEGDNN_COMMA , | ||||
template <bool check_bounds_, typename src_ldg_dtype, typename filter_ldg_dtype, | template <bool check_bounds_, typename src_ldg_dtype, typename filter_ldg_dtype, | ||||
@@ -53,7 +53,7 @@ namespace convolution { | |||||
using KernLayout = _kern_layout; \ | using KernLayout = _kern_layout; \ | ||||
using OutputLayout = _output_layout; \ | using OutputLayout = _output_layout; \ | ||||
using Param = _conv_param; \ | using Param = _conv_param; \ | ||||
static constexpr bool check_bounds = check_bounds_ | |||||
static constexpr bool check_bounds = check_bounds_; | |||||
#define MEGDNN_COMMA , | #define MEGDNN_COMMA , | ||||
template <bool check_bounds_, typename IMMAConfig_, typename WarpTileConfig_, | template <bool check_bounds_, typename IMMAConfig_, typename WarpTileConfig_, | ||||
@@ -53,7 +53,7 @@ namespace convolution { | |||||
using KernLayout = _kern_layout; \ | using KernLayout = _kern_layout; \ | ||||
using OutputLayout = _output_layout; \ | using OutputLayout = _output_layout; \ | ||||
using Param = _conv_param; \ | using Param = _conv_param; \ | ||||
static constexpr bool check_bounds = check_bounds_ | |||||
static constexpr bool check_bounds = check_bounds_; | |||||
#define MEGDNN_COMMA , | #define MEGDNN_COMMA , | ||||
template <bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_, | template <bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_, | ||||
@@ -11,16 +11,13 @@ | |||||
#include "src/common/handle_impl.h" | #include "src/common/handle_impl.h" | ||||
#include "src/common/version_symbol.h" | #include "src/common/version_symbol.h" | ||||
#include "src/common/api_cache.h" | |||||
#include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/cuda/api_cache.h" | |||||
#include "megdnn/common.h" | #include "megdnn/common.h" | ||||
#include <cuda.h> | #include <cuda.h> | ||||
#include <cstring> | #include <cstring> | ||||
#include <memory> | |||||
#define STR_HELPER(x) #x | #define STR_HELPER(x) #x | ||||
#define STR(x) STR_HELPER(x) | #define STR(x) STR_HELPER(x) | ||||
@@ -94,8 +91,6 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle): | |||||
// check tk1 | // check tk1 | ||||
m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0); | m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0); | ||||
m_cusolver_handle = nullptr; | m_cusolver_handle = nullptr; | ||||
m_cudnn_api_cache = std::make_unique<CUDNN>(m_cudnn_handle); | |||||
} | } | ||||
HandleImpl::~HandleImpl() noexcept { | HandleImpl::~HandleImpl() noexcept { | ||||
@@ -141,111 +136,8 @@ HandleImpl::HandleVendorType HandleImpl::vendor_type() const { | |||||
return HandleVendorType::CUDA; | return HandleVendorType::CUDA; | ||||
} | } | ||||
HandleImpl::CUDNN& HandleImpl::cudnn() { | |||||
return *m_cudnn_api_cache; | |||||
} | |||||
HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) { | |||||
m_handle = handle; | |||||
GetConvolutionForwardWorkspaceSize = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnFilterDescParam>() | |||||
.input<CudnnConvDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<Param<cudnnConvolutionFwdAlgo_t>>() | |||||
.output<RefParam<size_t>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionForwardWorkspaceSize); | |||||
#if CUDNN_MAJOR >= 7 | |||||
GetConvolutionForwardAlgorithm_v7 = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnFilterDescParam>() | |||||
.input<CudnnConvDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<Param<int>>() | |||||
.output<RefArraySizeParam<int>>() | |||||
.output<ArrayParam<int, cudnnConvolutionFwdAlgoPerf_t>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionForwardAlgorithm_v7); | |||||
GetConvolutionForwardAlgorithmMaxCount = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.output<RefParam<int>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionForwardAlgorithmMaxCount); | |||||
#endif | |||||
GetConvolutionBackwardDataWorkspaceSize = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.input<CudnnFilterDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnConvDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<Param<cudnnConvolutionBwdDataAlgo_t>>() | |||||
.output<RefParam<size_t>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionBackwardDataWorkspaceSize); | |||||
#if CUDNN_MAJOR >= 7 | |||||
GetConvolutionBackwardDataAlgorithm_v7 = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.input<CudnnFilterDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnConvDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<Param<int>>() | |||||
.output<RefArraySizeParam<int>>() | |||||
.output<ArrayParam<int, | |||||
cudnnConvolutionBwdDataAlgoPerf_t>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionBackwardDataAlgorithm_v7); | |||||
GetConvolutionBackwardDataAlgorithmMaxCount = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.output<RefParam<int>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionBackwardDataAlgorithmMaxCount); | |||||
#endif | |||||
GetConvolutionBackwardFilterWorkspaceSize = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnConvDescParam>() | |||||
.input<CudnnFilterDescParam>() | |||||
.input<Param<cudnnConvolutionBwdFilterAlgo_t>>() | |||||
.output<RefParam<size_t>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionBackwardFilterWorkspaceSize); | |||||
#if CUDNN_MAJOR >= 7 | |||||
GetConvolutionBackwardFilterAlgorithm_v7 = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnConvDescParam>() | |||||
.input<CudnnFilterDescParam>() | |||||
.input<Param<int>>() | |||||
.output<RefArraySizeParam<int>>() | |||||
.output<ArrayParam<int, | |||||
cudnnConvolutionBwdFilterAlgoPerf_t>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionBackwardFilterAlgorithm_v7); | |||||
GetConvolutionBackwardFilterAlgorithmMaxCount = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.output<RefParam<int>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionBackwardFilterAlgorithmMaxCount); | |||||
#endif | |||||
} | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
MEGDNN_VERSION_SYMBOL(CUDA, CUDA_VERSION); | MEGDNN_VERSION_SYMBOL(CUDA, CUDA_VERSION); | ||||
MEGDNN_VERSION_SYMBOL3(CUDNN, CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL); | MEGDNN_VERSION_SYMBOL3(CUDNN, CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL); | ||||
@@ -124,10 +124,6 @@ class HandleImpl: public HandleImplHelper { | |||||
size_t image2d_pitch_alignment() const override; | size_t image2d_pitch_alignment() const override; | ||||
HandleVendorType vendor_type() const override; | HandleVendorType vendor_type() const override; | ||||
class CUDNN; | |||||
CUDNN& cudnn(); | |||||
private: | private: | ||||
bool m_is_tegra_k1; | bool m_is_tegra_k1; | ||||
int m_device_id; | int m_device_id; | ||||
@@ -160,34 +156,9 @@ class HandleImpl: public HandleImplHelper { | |||||
//! device ptr to const scalars | //! device ptr to const scalars | ||||
ConstScalars* m_const_scalars; | ConstScalars* m_const_scalars; | ||||
std::unique_ptr<CUDNN> m_cudnn_api_cache; | |||||
void initialize_cusolver(); | void initialize_cusolver(); | ||||
}; | }; | ||||
class HandleImpl::CUDNN { | |||||
cudnnHandle_t m_handle; | |||||
public: | |||||
CUDNN(cudnnHandle_t handle); | |||||
#define WRAP_CUDNN_API(NAME) thin_function<decltype(cudnn##NAME)> NAME; | |||||
WRAP_CUDNN_API(GetConvolutionForwardWorkspaceSize); | |||||
#if CUDNN_MAJOR >= 7 | |||||
WRAP_CUDNN_API(GetConvolutionForwardAlgorithm_v7); | |||||
WRAP_CUDNN_API(GetConvolutionForwardAlgorithmMaxCount); | |||||
#endif | |||||
#if CUDNN_MAJOR >= 7 | |||||
WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithm_v7); | |||||
WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithmMaxCount); | |||||
#endif | |||||
WRAP_CUDNN_API(GetConvolutionBackwardDataWorkspaceSize); | |||||
#if CUDNN_MAJOR >= 7 | |||||
WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithmMaxCount); | |||||
WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithm_v7); | |||||
#endif | |||||
WRAP_CUDNN_API(GetConvolutionBackwardFilterWorkspaceSize); | |||||
#undef WRAP_CUDNN_API | |||||
}; | |||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||