@@ -12,32 +12,28 @@ | |||
#pragma once | |||
#include <unordered_map> | |||
#include <memory> | |||
#include <cstring> | |||
#include <memory> | |||
#include <tuple> | |||
#include <unordered_map> | |||
#include "megdnn/thin/function.h" | |||
namespace megdnn { | |||
template <typename TSignature> | |||
class FunctionCache; | |||
template <typename TRet, typename... TArgs> | |||
class FunctionCache<TRet(TArgs...)> { | |||
template <typename... TArgs> | |||
class FunctionCache { | |||
public: | |||
using key_t = std::string; | |||
using value_t = TRet; | |||
using value_t = std::string; | |||
using key_mapper_t = thin_function<key_t(TArgs...)>; | |||
using value_mapper_t = thin_function<value_t(TArgs...)>; | |||
using storage_t = std::unordered_map<key_t, value_t>; | |||
public: | |||
storage_t storage; | |||
key_mapper_t key_mapper; | |||
value_mapper_t value_mapper; | |||
public: | |||
TRet operator()(TArgs... args) { | |||
value_t operator()(TArgs... args) { | |||
key_t key = key_mapper(args...); | |||
if (storage.count(key) == 0) { | |||
storage[key] = value_mapper(std::forward<TArgs>(args)...); | |||
@@ -46,28 +42,28 @@ public: | |||
} | |||
}; | |||
// FIFO | |||
class StringSerializer { | |||
private: | |||
std::string m_buffer; | |||
size_t m_cursor = 0; | |||
public: | |||
template <typename T> | |||
T read_plain() { | |||
T result; | |||
std::memcpy(&result, m_buffer.data() + m_cursor, sizeof(T)); | |||
static_assert(std::is_trivially_copyable<T>::value, "invalid type"); | |||
T ret; | |||
memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); | |||
m_cursor += sizeof(T); | |||
return result; | |||
return ret; | |||
} | |||
template <typename T> | |||
void write_plain(T value) { | |||
m_buffer.resize(m_buffer.size() + sizeof(T)); | |||
std::memcpy(const_cast<char*>(m_buffer.data()) + (m_buffer.size() - sizeof(T)), &value, sizeof(T)); | |||
static_assert(std::is_trivially_copyable<T>::value, | |||
"type should be trivially copyable"); | |||
m_buffer.append(reinterpret_cast<const char*>(&value), sizeof(T)); | |||
} | |||
std::string take() { | |||
std::string result; | |||
m_buffer.erase(0, m_cursor); | |||
return std::move(m_buffer); | |||
} | |||
void set(std::string new_buf) { | |||
@@ -76,20 +72,20 @@ public: | |||
} | |||
}; | |||
struct Empty {}; | |||
template <typename... TParams> | |||
class ParamBundle { | |||
private: | |||
template<std::size_t N, std::size_t... Seq> | |||
static std::index_sequence<N + Seq ...> add_all(std::index_sequence<Seq...>){ | |||
template <std::size_t N, std::size_t... Seq> | |||
static std::index_sequence<N + Seq...> add_all( | |||
std::index_sequence<Seq...>) { | |||
return {}; | |||
} | |||
template<std::size_t Min, std::size_t Max> | |||
using make_index_range = decltype(add_all<Min>(std::make_index_sequence<Max-Min>())); | |||
template <std::size_t Min, std::size_t Max> | |||
using make_index_range = | |||
decltype(add_all<Min>(std::make_index_sequence<Max - Min>())); | |||
using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>; | |||
storage_t m_storage; | |||
@@ -99,21 +95,31 @@ private: | |||
return functor(std::get<Indices>(m_storage).value...); | |||
} | |||
template <size_t Index, size_t... Indices, typename TPrev> | |||
auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<Index, Indices...>) { | |||
return serialize_helper(ser, std::get<Index>(m_storage).serialize(ser, prev), std::index_sequence<Indices...>()); | |||
auto serialize_helper(StringSerializer& ser, TPrev&& prev, | |||
std::index_sequence<Index, Indices...>) { | |||
return serialize_helper(ser, | |||
std::get<Index>(m_storage).serialize(ser, prev), | |||
std::index_sequence<Indices...>()); | |||
} | |||
template <typename TPrev> | |||
auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} | |||
auto serialize_helper(StringSerializer& ser, TPrev&& prev, | |||
std::index_sequence<>) {} | |||
template <size_t Index, size_t... Indices, typename TPrev> | |||
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<Index, Indices...>) { | |||
return deserialize_helper(ser, std::get<Index>(m_storage).deserialize(ser, prev), std::index_sequence<Indices...>()); | |||
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, | |||
std::index_sequence<Index, Indices...>) { | |||
return deserialize_helper( | |||
ser, std::get<Index>(m_storage).deserialize(ser, prev), | |||
std::index_sequence<Indices...>()); | |||
} | |||
template <typename TPrev> | |||
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} | |||
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, | |||
std::index_sequence<>) {} | |||
template <size_t Index, size_t... Indices, typename TArg, typename... TArgs> | |||
void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, TArgs&&... args) { | |||
void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, | |||
TArgs&&... args) { | |||
std::get<Index>(m_storage).value = arg; | |||
set_values_helper(std::index_sequence<Indices...>(), std::forward<TArgs>(args)...); | |||
set_values_helper(std::index_sequence<Indices...>(), | |||
std::forward<TArgs>(args)...); | |||
} | |||
template <size_t... Indices> | |||
void set_values_helper(std::index_sequence<Indices...>) { | |||
@@ -123,27 +129,33 @@ private: | |||
public: | |||
template <typename TFunctor> | |||
auto call_by(TFunctor&& functor) { | |||
return call_helper(std::forward<TFunctor>(functor), std::make_index_sequence<sizeof...(TParams)>()); | |||
return call_helper(std::forward<TFunctor>(functor), | |||
std::make_index_sequence<sizeof...(TParams)>()); | |||
} | |||
template <size_t NBegin, size_t NEnd> | |||
void serialize_params(StringSerializer& ser) { | |||
static_assert(NEnd >= NBegin, "invalid range"); | |||
serialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>()); | |||
serialize_helper( | |||
ser, Empty{}, | |||
add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>())); | |||
} | |||
template <size_t NBegin, size_t NEnd> | |||
void deserialize_params(StringSerializer& ser) { | |||
static_assert(NEnd >= NBegin, "invalid range"); | |||
deserialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>()); | |||
deserialize_helper( | |||
ser, Empty{}, | |||
add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>())); | |||
} | |||
template <size_t NBegin, size_t NEnd, typename... TArgs> | |||
void set_values(TArgs&&... args) { | |||
set_values_helper(make_index_range<NBegin, NEnd>(), std::forward<TArgs>(args)...); | |||
set_values_helper( | |||
add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>()), | |||
std::forward<TArgs>(args)...); | |||
} | |||
}; | |||
template <typename T> | |||
class RetParam { | |||
class Param { | |||
public: | |||
T value; | |||
Empty serialize(StringSerializer& ser, Empty) { | |||
@@ -156,45 +168,68 @@ public: | |||
} | |||
}; | |||
template <typename TRet=RetParam<Empty>, typename TInputs=std::tuple<>, typename TOutputs=std::tuple<>> | |||
template <typename TRet = Param<Empty>, typename TInputs = std::tuple<>, | |||
typename TOutputs = std::tuple<>> | |||
class FunctionCacheBuilder { | |||
private: | |||
static auto declargs() -> decltype(std::tuple_cat(std::declval<TInputs>(), std::declval<TOutputs>())) { return {}; } | |||
static auto declargs() | |||
-> decltype(std::tuple_cat(std::declval<TInputs>(), | |||
std::declval<TOutputs>())) { | |||
return {}; | |||
} | |||
template <size_t... Indices> | |||
static auto declfunction_helper(std::index_sequence<Indices...>) -> thin_function<decltype(std::declval<TRet>().value)(decltype(std::get<Indices>(declargs()).value)...)> { return {}; } | |||
static auto declfunction_helper(std::index_sequence<Indices...>) | |||
-> thin_function<decltype(std::declval<TRet>().value)( | |||
decltype(std::get<Indices>(declargs()).value)...)> { | |||
return {}; | |||
} | |||
static auto declfunction() { | |||
return declfunction_helper(std::make_index_sequence<std::tuple_size<TInputs>::value + std::tuple_size<TOutputs>::value>()); | |||
return declfunction_helper( | |||
std::make_index_sequence<std::tuple_size<TInputs>::value + | |||
std::tuple_size<TOutputs>::value>()); | |||
} | |||
template <size_t... Indices> | |||
static auto declbundle_helper(std::index_sequence<Indices...>) -> ParamBundle<decltype(std::get<Indices>(declargs()))...> { return {}; } | |||
static auto declbundle_helper(std::index_sequence<Indices...>) | |||
-> ParamBundle<decltype(std::get<Indices>(declargs()))...> { | |||
return {}; | |||
} | |||
static auto declbundle() { | |||
return declbundle_helper(std::make_index_sequence<std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value>()); | |||
return declbundle_helper( | |||
std::make_index_sequence<std::tuple_size<TInputs>::value + | |||
std::tuple_size<TOutputs>::value>()); | |||
} | |||
using function_t = decltype(declfunction()); | |||
using bundle_t = decltype(declbundle()); | |||
public: | |||
template <typename TNewRet> | |||
auto ret() { | |||
static_assert(std::is_same<TRet, RetParam<Empty>>::value, "return value redefinition"); | |||
static_assert(std::is_same<TRet, Param<Empty>>::value, | |||
"return value redefinition"); | |||
return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{}; | |||
} | |||
template <typename TNewInput> | |||
auto input() { | |||
using TNewInputs = decltype(std::tuple_cat(std::declval<TInputs>(), std::make_tuple(std::declval<TNewInput>()))); | |||
using TNewInputs = decltype( | |||
std::tuple_cat(std::declval<TInputs>(), | |||
std::make_tuple(std::declval<TNewInput>()))); | |||
return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{}; | |||
} | |||
template <typename TNewOutput> | |||
auto output() { | |||
using TNewOutputs = decltype(std::tuple_cat(std::declval<TOutputs>(), std::make_tuple(std::declval<TNewOutput>()))); | |||
using TNewOutputs = decltype( | |||
std::tuple_cat(std::declval<TOutputs>(), | |||
std::make_tuple(std::declval<TNewOutput>()))); | |||
return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{}; | |||
} | |||
template <typename TFunctor> | |||
function_t build(TFunctor func) { | |||
FunctionCache<std::string(bundle_t)> cache; | |||
FunctionCache<bundle_t> cache; | |||
cache.key_mapper = [](bundle_t bundle) { | |||
StringSerializer ser; | |||
bundle.template serialize_params<0, std::tuple_size<TInputs>::value>(ser); | |||
bundle.template serialize_params<0, | |||
std::tuple_size<TInputs>::value>( | |||
ser); | |||
return ser.take(); | |||
}; | |||
cache.value_mapper = [=](bundle_t bundle) { | |||
@@ -202,42 +237,33 @@ public: | |||
TRet ret; | |||
ret.value = bundle.call_by(func); | |||
ret.serialize(ser, Empty{}); | |||
bundle.template serialize_params<std::tuple_size<TInputs>::value, std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value>(ser); | |||
bundle.template serialize_params< | |||
std::tuple_size<TInputs>::value, | |||
std::tuple_size<TInputs>::value + | |||
std::tuple_size<TOutputs>::value>(ser); | |||
return ser.take(); | |||
}; | |||
return [=](auto&&... args) mutable { | |||
bundle_t bundle; | |||
TRet ret; | |||
StringSerializer ser; | |||
static_assert(sizeof...(args) == std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value, | |||
"arg count mismatch"); | |||
bundle.template set_values<0, sizeof...(args)>(std::forward<decltype(args)>(args)...); | |||
static_assert( | |||
sizeof...(args) == std::tuple_size<TInputs>::value + | |||
std::tuple_size<TOutputs>::value, | |||
"args count mismatch"); | |||
bundle.template set_values<0, sizeof...(args)>( | |||
std::forward<decltype(args)>(args)...); | |||
ser.set(cache(bundle)); | |||
ret.deserialize(ser, Empty{}); | |||
constexpr size_t n_inputs = std::tuple_size<TInputs>::value; | |||
constexpr size_t n_outputs = std::tuple_size<TOutputs>::value; | |||
bundle.template deserialize_params<n_inputs, n_inputs+n_outputs>(ser); | |||
bundle.template deserialize_params<n_inputs, n_inputs + n_outputs>( | |||
ser); | |||
return ret.value; | |||
}; | |||
} | |||
}; | |||
template <typename T> | |||
class PlainParam { | |||
public: | |||
T value; | |||
Empty serialize(StringSerializer& ser, Empty) { | |||
ser.write_plain(value); | |||
return Empty{}; | |||
} | |||
Empty deserialize(StringSerializer& ser, Empty) { | |||
value = ser.read_plain<T>(); | |||
return Empty{}; | |||
} | |||
}; | |||
template <typename T> | |||
class RefParam { | |||
public: | |||
@@ -252,7 +278,6 @@ public: | |||
} | |||
}; | |||
template <typename T> | |||
class RefArraySizeParam { | |||
public: | |||
@@ -266,7 +291,6 @@ public: | |||
} | |||
}; | |||
template <typename TSize, typename TItem> | |||
class ArrayParam { | |||
public: | |||
@@ -285,4 +309,4 @@ public: | |||
} | |||
}; | |||
} | |||
} // namespace megdnn |
@@ -16,105 +16,109 @@ | |||
#include "src/cuda/cudnn_wrapper.h" | |||
namespace megdnn { | |||
class CudnnConvDescParam { | |||
public: | |||
cudnnConvolutionDescriptor_t value; | |||
Empty serialize(StringSerializer& ser, Empty) { | |||
int ndim = MEGDNN_MAX_NDIM; | |||
int padA[MEGDNN_MAX_NDIM]; | |||
int strideA[MEGDNN_MAX_NDIM]; | |||
int dilationA[MEGDNN_MAX_NDIM]; | |||
cudnnConvolutionMode_t mode; | |||
cudnnDataType_t computeType; | |||
cudnnGetConvolutionNdDescriptor(value, MEGDNN_MAX_NDIM, &ndim, padA, strideA, dilationA, &mode, &computeType); | |||
ser.write_plain(ndim); | |||
for (int i = 0; i < ndim; ++i) { | |||
ser.write_plain(padA[i]); | |||
ser.write_plain(strideA[i]); | |||
ser.write_plain(dilationA[i]); | |||
} | |||
ser.write_plain(mode); | |||
ser.write_plain(computeType); | |||
return Empty{}; | |||
class CudnnConvDescParam { | |||
public: | |||
cudnnConvolutionDescriptor_t value; | |||
Empty serialize(StringSerializer& ser, Empty) { | |||
constexpr int nbDims = MEGDNN_MAX_NDIM; | |||
int padA[MEGDNN_MAX_NDIM]; | |||
int strideA[MEGDNN_MAX_NDIM]; | |||
int dilationA[MEGDNN_MAX_NDIM]; | |||
cudnnConvolutionMode_t mode; | |||
cudnnDataType_t computeType; | |||
cudnnGetConvolutionNdDescriptor(value, nbDims, &nbDims, padA, strideA, | |||
dilationA, &mode, &computeType); | |||
ser.write_plain(nbDims); | |||
for (int i = 0; i < nbDims; ++i) { | |||
ser.write_plain(padA[i]); | |||
ser.write_plain(strideA[i]); | |||
ser.write_plain(dilationA[i]); | |||
} | |||
Empty deserialize(StringSerializer& ser, Empty) { | |||
int ndim = ser.read_plain<int>(); | |||
int padA[MEGDNN_MAX_NDIM]; | |||
int strideA[MEGDNN_MAX_NDIM]; | |||
int dilationA[MEGDNN_MAX_NDIM]; | |||
for (int i = 0; i < ndim; ++i) { | |||
padA[i] = ser.read_plain<int>(); | |||
strideA[i] = ser.read_plain<int>(); | |||
dilationA[i] = ser.read_plain<int>(); | |||
} | |||
cudnnConvolutionMode_t mode = ser.read_plain<cudnnConvolutionMode_t>(); | |||
cudnnDataType_t computeType = ser.read_plain<cudnnDataType_t>(); | |||
cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, mode, computeType); | |||
return Empty{}; | |||
ser.write_plain(mode); | |||
ser.write_plain(computeType); | |||
return Empty{}; | |||
} | |||
Empty deserialize(StringSerializer& ser, Empty) { | |||
int ndim = ser.read_plain<int>(); | |||
int padA[MEGDNN_MAX_NDIM]; | |||
int strideA[MEGDNN_MAX_NDIM]; | |||
int dilationA[MEGDNN_MAX_NDIM]; | |||
for (int i = 0; i < ndim; ++i) { | |||
padA[i] = ser.read_plain<int>(); | |||
strideA[i] = ser.read_plain<int>(); | |||
dilationA[i] = ser.read_plain<int>(); | |||
} | |||
}; | |||
class CudnnTensorDescParam { | |||
public: | |||
cudnnTensorDescriptor_t value; | |||
Empty serialize(StringSerializer& ser, Empty) { | |||
int nbDims = MEGDNN_MAX_NDIM; | |||
cudnnDataType_t dataType; | |||
int dimA[MEGDNN_MAX_NDIM]; | |||
int strideA[MEGDNN_MAX_NDIM]; | |||
cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, strideA); | |||
ser.write_plain(nbDims); | |||
for (int i = 0; i < nbDims; ++i) { | |||
ser.write_plain(dimA[i]); | |||
ser.write_plain(strideA[i]); | |||
} | |||
ser.write_plain(dataType); | |||
return Empty{}; | |||
cudnnConvolutionMode_t mode = ser.read_plain<cudnnConvolutionMode_t>(); | |||
cudnnDataType_t computeType = ser.read_plain<cudnnDataType_t>(); | |||
cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, | |||
mode, computeType); | |||
return Empty{}; | |||
} | |||
}; | |||
class CudnnTensorDescParam { | |||
public: | |||
cudnnTensorDescriptor_t value; | |||
Empty serialize(StringSerializer& ser, Empty) { | |||
constexpr int nbDims = MEGDNN_MAX_NDIM; | |||
cudnnDataType_t dataType; | |||
int dimA[MEGDNN_MAX_NDIM]; | |||
int strideA[MEGDNN_MAX_NDIM]; | |||
cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, | |||
strideA); | |||
ser.write_plain(nbDims); | |||
for (int i = 0; i < nbDims; ++i) { | |||
ser.write_plain(dimA[i]); | |||
ser.write_plain(strideA[i]); | |||
} | |||
Empty deserialize(StringSerializer& ser, Empty) { | |||
int nbDims = MEGDNN_MAX_NDIM; | |||
cudnnDataType_t dataType; | |||
int dimA[MEGDNN_MAX_NDIM]; | |||
int strideA[MEGDNN_MAX_NDIM]; | |||
nbDims = ser.read_plain<int>(); | |||
for (int i = 0; i < nbDims; ++i) { | |||
dimA[i] = ser.read_plain<int>(); | |||
strideA[i] = ser.read_plain<int>(); | |||
} | |||
dataType = ser.read_plain<cudnnDataType_t>(); | |||
cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA); | |||
return Empty{}; | |||
ser.write_plain(dataType); | |||
return Empty{}; | |||
} | |||
Empty deserialize(StringSerializer& ser, Empty) { | |||
constexpr int nbDims = MEGDNN_MAX_NDIM; | |||
cudnnDataType_t dataType; | |||
int dimA[MEGDNN_MAX_NDIM]; | |||
int strideA[MEGDNN_MAX_NDIM]; | |||
nbDims = ser.read_plain<int>(); | |||
for (int i = 0; i < nbDims; ++i) { | |||
dimA[i] = ser.read_plain<int>(); | |||
strideA[i] = ser.read_plain<int>(); | |||
} | |||
}; | |||
class CudnnFilterDescParam { | |||
public: | |||
cudnnFilterDescriptor_t value; | |||
Empty serialize(StringSerializer& ser, Empty) { | |||
int nbDims = MEGDNN_MAX_NDIM; | |||
cudnnDataType_t dataType; | |||
cudnnTensorFormat_t format; | |||
int filterDimA[MEGDNN_MAX_NDIM]; | |||
cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, filterDimA); | |||
ser.write_plain(nbDims); | |||
for (int i = 0; i < nbDims; ++i) { | |||
ser.write_plain(filterDimA[i]); | |||
} | |||
ser.write_plain(dataType); | |||
ser.write_plain(format); | |||
return Empty{}; | |||
dataType = ser.read_plain<cudnnDataType_t>(); | |||
cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA); | |||
return Empty{}; | |||
} | |||
}; | |||
class CudnnFilterDescParam { | |||
public: | |||
cudnnFilterDescriptor_t value; | |||
Empty serialize(StringSerializer& ser, Empty) { | |||
constexpr int nbDims = MEGDNN_MAX_NDIM; | |||
cudnnDataType_t dataType; | |||
cudnnTensorFormat_t format; | |||
int filterDimA[MEGDNN_MAX_NDIM]; | |||
cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, | |||
filterDimA); | |||
ser.write_plain(nbDims); | |||
for (int i = 0; i < nbDims; ++i) { | |||
ser.write_plain(filterDimA[i]); | |||
} | |||
Empty deserialize(StringSerializer& ser, Empty) { | |||
int nbDims = MEGDNN_MAX_NDIM; | |||
cudnnDataType_t dataType; | |||
cudnnTensorFormat_t format; | |||
int filterDimA[MEGDNN_MAX_NDIM]; | |||
nbDims = ser.read_plain<int>(); | |||
for (int i = 0; i < nbDims; ++i) { | |||
filterDimA[i] = ser.read_plain<int>(); | |||
} | |||
dataType = ser.read_plain<cudnnDataType_t>(); | |||
format = ser.read_plain<cudnnTensorFormat_t>(); | |||
cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA); | |||
return Empty{}; | |||
ser.write_plain(dataType); | |||
ser.write_plain(format); | |||
return Empty{}; | |||
} | |||
Empty deserialize(StringSerializer& ser, Empty) { | |||
constexpr int nbDims = MEGDNN_MAX_NDIM; | |||
cudnnDataType_t dataType; | |||
cudnnTensorFormat_t format; | |||
int filterDimA[MEGDNN_MAX_NDIM]; | |||
nbDims = ser.read_plain<int>(); | |||
for (int i = 0; i < nbDims; ++i) { | |||
filterDimA[i] = ser.read_plain<int>(); | |||
} | |||
}; | |||
} | |||
dataType = ser.read_plain<cudnnDataType_t>(); | |||
format = ser.read_plain<cudnnTensorFormat_t>(); | |||
cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA); | |||
return Empty{}; | |||
} | |||
}; | |||
} // namespace megdnn |
@@ -39,7 +39,8 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( | |||
conv_args.init_conv_desc(D); | |||
size_t workspace_size; | |||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||
auto& cudnn = conv_args.handle->cudnn(); | |||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||
conv_args.handle->cudnn_handle(), D.src_desc.desc, | |||
D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | |||
m_cudnn_enum, &workspace_size); | |||
@@ -65,7 +66,8 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_bundle( | |||
conv_args.init_conv_desc(D); | |||
size_t conv_workspace_size; | |||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||
auto& cudnn = conv_args.handle->cudnn(); | |||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||
conv_args.handle->cudnn_handle(), D.src_desc.desc, | |||
D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | |||
m_cudnn_enum, &conv_workspace_size); | |||
@@ -108,7 +108,8 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||
megdnn_throw("unsupported NonlineMode"); | |||
} | |||
size_t workspace_size; | |||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||
auto& cudnn = args.handle->cudnn(); | |||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||
args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | |||
D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | |||
&workspace_size); | |||
@@ -121,7 +122,8 @@ size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::get_workspace_in_bytes( | |||
args.init_conv_bias_desc(D); | |||
size_t workspace_size; | |||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||
auto& cudnn = args.handle->cudnn(); | |||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||
args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | |||
D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | |||
&workspace_size); | |||
@@ -83,12 +83,13 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||
CUDNNForwardDescs desc; | |||
conv_args.init_conv_desc(desc); | |||
#if CUDNN_MAJOR >= 7 | |||
auto& cudnn = static_cast<HandleImpl*>(this->handle())->cudnn(); | |||
int max_count = 0; | |||
cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, | |||
cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount(cudnn_handle, | |||
&max_count)); | |||
SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(max_count); | |||
int ret_count = 0; | |||
cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7( | |||
cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7( | |||
cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | |||
desc.conv_desc.conv_desc, desc.dst_desc.desc, max_count, | |||
&ret_count, algo_perf.data())); | |||
@@ -44,9 +44,10 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( | |||
} | |||
#endif | |||
auto& cudnn = args.handle->cudnn(); | |||
args.init_desc(D); | |||
size_t workspace_size; | |||
auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||
args.handle->cudnn_handle(), | |||
D.filter_desc.desc, | |||
D.diff_desc.desc, | |||
@@ -59,10 +60,11 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( | |||
size_t ConvolutionBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( | |||
const SizeArgs &args) const { | |||
auto& cudnn = args.handle->cudnn(); | |||
CUDNNBwdDataDescs D; | |||
args.init_desc(D); | |||
size_t workspace_size; | |||
auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||
args.handle->cudnn_handle(), | |||
D.filter_desc.desc, | |||
D.diff_desc.desc, | |||
@@ -21,6 +21,7 @@ using namespace convolution; | |||
bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | |||
const SizeArgs &args) const { | |||
auto& cudnn = args.handle->cudnn(); | |||
CUDNNBwdFilterDescs D; | |||
if (!is_cudnn_supported(args.as_fwd_args())) | |||
@@ -28,7 +29,7 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | |||
args.init_desc(D); | |||
size_t workspace_size; | |||
auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||
args.handle->cudnn_handle(), | |||
D.src_desc.desc, | |||
D.diff_desc.desc, | |||
@@ -41,10 +42,11 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | |||
size_t ConvolutionBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( | |||
const SizeArgs &args) const { | |||
auto& cudnn = args.handle->cudnn(); | |||
CUDNNBwdFilterDescs D; | |||
args.init_desc(D); | |||
size_t workspace_size; | |||
auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||
args.handle->cudnn_handle(), | |||
D.src_desc.desc, | |||
D.diff_desc.desc, | |||
@@ -141,12 +141,13 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, | |||
#if CUDNN_MAJOR >= 7 | |||
MEGDNN_MARK_USED_VAR(negative_attr); | |||
auto& cudnn = args.handle->cudnn(); | |||
int max_count = 0; | |||
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( | |||
cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithmMaxCount( | |||
cudnn_handle, &max_count)); | |||
SmallVector<cudnnConvolutionBwdDataAlgoPerf_t> algo_perf(max_count); | |||
int ret_count = 0; | |||
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm_v7( | |||
cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithm_v7( | |||
cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, | |||
desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | |||
algo_perf.data())); | |||
@@ -286,12 +287,13 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||
#endif | |||
#if CUDNN_MAJOR >= 7 | |||
MEGDNN_MARK_USED_VAR(negative_attr); | |||
auto& cudnn = args.handle->cudnn(); | |||
int max_count = 0; | |||
cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( | |||
cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithmMaxCount( | |||
cudnn_handle, &max_count)); | |||
SmallVector<cudnnConvolutionBwdFilterAlgoPerf_t> algo_perf(max_count); | |||
int ret_count = 0; | |||
cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm_v7( | |||
cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithm_v7( | |||
cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, | |||
desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | |||
algo_perf.data())); | |||
@@ -28,7 +28,8 @@ bool Convolution3DBackwardDataImpl::AlgoCUDNN::is_available( | |||
args.init_desc(D); | |||
size_t workspace_size; | |||
auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||
auto& cudnn = args.handle->cudnn(); | |||
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||
args.handle->cudnn_handle(), | |||
D.filter_desc.desc, | |||
D.diff_desc.desc, | |||
@@ -44,7 +45,8 @@ size_t Convolution3DBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( | |||
CUDNNBwdDataDescs D; | |||
args.init_desc(D); | |||
size_t workspace_size; | |||
auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||
auto& cudnn = args.handle->cudnn(); | |||
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||
args.handle->cudnn_handle(), | |||
D.filter_desc.desc, | |||
D.diff_desc.desc, | |||
@@ -28,7 +28,8 @@ bool Convolution3DBackwardFilterImpl::AlgoCUDNN::is_available( | |||
args.init_desc(D); | |||
size_t workspace_size; | |||
auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||
auto& cudnn = args.handle->cudnn(); | |||
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||
args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | |||
D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | |||
return status == CUDNN_STATUS_SUCCESS; | |||
@@ -40,7 +41,8 @@ size_t Convolution3DBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( | |||
args.init_desc(D); | |||
size_t workspace_size; | |||
auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||
auto& cudnn = args.handle->cudnn(); | |||
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||
args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | |||
D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | |||
megdnn_assert(status == CUDNN_STATUS_SUCCESS, | |||
@@ -27,7 +27,8 @@ bool Convolution3DForwardImpl::AlgoCUDNN::is_available( | |||
args.init_desc(D); | |||
size_t workspace_size; | |||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||
auto& cudnn = args.handle->cudnn(); | |||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||
args.handle->cudnn_handle(), | |||
D.src_desc.desc, | |||
D.filter_desc.desc, | |||
@@ -43,7 +44,8 @@ size_t Convolution3DForwardImpl::AlgoCUDNN::get_workspace_in_bytes( | |||
CUDNNForwardDescs D; | |||
args.init_desc(D); | |||
size_t workspace_size; | |||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||
auto& cudnn = args.handle->cudnn(); | |||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||
args.handle->cudnn_handle(), | |||
D.src_desc.desc, | |||
D.filter_desc.desc, | |||
@@ -92,7 +92,7 @@ namespace convolution3d { | |||
const Workspace &workspace, void *&raw_ptr); | |||
inline bool cudnn_get_convolution_fwd_algo_helper( | |||
cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc, | |||
Handle* handle, const cudnnTensorDescriptor_t x_desc, | |||
const cudnnFilterDescriptor_t w_desc, | |||
const cudnnConvolutionDescriptor_t conv_desc, | |||
const cudnnTensorDescriptor_t y_desc, | |||
@@ -102,13 +102,14 @@ namespace convolution3d { | |||
MEGDNN_MARK_USED_VAR(positive_attr); | |||
MEGDNN_MARK_USED_VAR(negative_attr); | |||
#if CUDNN_MAJOR >= 7 | |||
auto& cudnn = static_cast<HandleImpl*>(handle)->cudnn(); | |||
int algo_max_count = 0; | |||
cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount( | |||
cudnn_handle, &algo_max_count)); | |||
cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount( | |||
cuda::cudnn_handle(handle), &algo_max_count)); | |||
SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(algo_max_count); | |||
int algo_count = 0; | |||
cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7( | |||
cudnn_handle, x_desc, w_desc, conv_desc, y_desc, algo_max_count, | |||
cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7( | |||
cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, algo_max_count, | |||
&algo_count, algo_perf.data())); | |||
for (int i = 0; i < algo_count; ++i) { | |||
if (algo_perf[i].algo == | |||
@@ -116,8 +117,8 @@ namespace convolution3d { | |||
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) | |||
continue; | |||
size_t workspace_size = 0; | |||
cudnn_check(cudnnGetConvolutionForwardWorkspaceSize( | |||
cudnn_handle, x_desc, w_desc, conv_desc, y_desc, | |||
cudnn_check(cudnn.GetConvolutionForwardWorkspaceSize( | |||
cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, | |||
algo_perf[i].algo, &workspace_size)); | |||
if (workspace_size > workspace_limit_in_bytes) continue; | |||
if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { | |||
@@ -133,7 +134,7 @@ namespace convolution3d { | |||
return false; | |||
#else | |||
cudnn_check(cudnnGetConvolutionForwardAlgorithm( | |||
cudnn_handle, x_desc, w_desc, conv_desc, y_desc, | |||
cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, | |||
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, | |||
workspace_limit_in_bytes, algo)); | |||
return true; | |||
@@ -74,13 +74,12 @@ Convolution3DForwardImpl::get_algorithm_heuristic( | |||
auto get_cudnn_algo = | |||
[this, &args, workspace_limit_in_bytes, positive_attr, | |||
negative_attr]() -> Convolution3DForwardImpl::AlgoBase* { | |||
auto cudnn_handle = cuda::cudnn_handle(this->handle()); | |||
cudnnConvolutionFwdAlgo_t algo; | |||
CUDNNForwardDescs desc; | |||
args.init_desc(desc); | |||
bool got = cudnn_get_convolution_fwd_algo_helper( | |||
cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | |||
this->handle(), desc.src_desc.desc, desc.filter_desc.desc, | |||
desc.conv_desc.desc, desc.dst_desc.desc, | |||
workspace_limit_in_bytes, &algo, positive_attr, negative_attr); | |||
if (got) { | |||
@@ -56,7 +56,7 @@ namespace convolution { | |||
using KernLayout = _kern_layout; \ | |||
using OutputLayout = _output_layout; \ | |||
using Param = _conv_param; \ | |||
static constexpr bool check_bounds = check_bounds_; | |||
static constexpr bool check_bounds = check_bounds_ | |||
#define MEGDNN_COMMA , | |||
template <bool check_bounds_, typename src_ldg_dtype, typename filter_ldg_dtype, | |||
@@ -53,7 +53,7 @@ namespace convolution { | |||
using KernLayout = _kern_layout; \ | |||
using OutputLayout = _output_layout; \ | |||
using Param = _conv_param; \ | |||
static constexpr bool check_bounds = check_bounds_; | |||
static constexpr bool check_bounds = check_bounds_ | |||
#define MEGDNN_COMMA , | |||
template <bool check_bounds_, typename IMMAConfig_, typename WarpTileConfig_, | |||
@@ -53,7 +53,7 @@ namespace convolution { | |||
using KernLayout = _kern_layout; \ | |||
using OutputLayout = _output_layout; \ | |||
using Param = _conv_param; \ | |||
static constexpr bool check_bounds = check_bounds_; | |||
static constexpr bool check_bounds = check_bounds_ | |||
#define MEGDNN_COMMA , | |||
template <bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_, | |||
@@ -11,12 +11,15 @@ | |||
#include "src/common/handle_impl.h" | |||
#include "src/common/version_symbol.h" | |||
#include "src/common/api_cache.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/utils.h" | |||
#include "src/cuda/api_cache.h" | |||
#include <cuda.h> | |||
#include <cstring> | |||
#include <memory> | |||
#define STR_HELPER(x) #x | |||
#define STR(x) STR_HELPER(x) | |||
@@ -88,6 +91,8 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle): | |||
// check tk1 | |||
m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0); | |||
m_cusolver_handle = nullptr; | |||
m_cudnn_api_cache = std::make_unique<CUDNN>(m_cudnn_handle); | |||
} | |||
HandleImpl::~HandleImpl() noexcept { | |||
@@ -133,8 +138,111 @@ HandleImpl::HandleVendorType HandleImpl::vendor_type() const { | |||
return HandleVendorType::CUDA; | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn | |||
HandleImpl::CUDNN& HandleImpl::cudnn() { | |||
return *m_cudnn_api_cache; | |||
} | |||
HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) { | |||
m_handle = handle; | |||
GetConvolutionForwardWorkspaceSize = | |||
FunctionCacheBuilder<>() | |||
.input<Param<cudnnHandle_t>>() | |||
.input<CudnnTensorDescParam>() | |||
.input<CudnnFilterDescParam>() | |||
.input<CudnnConvDescParam>() | |||
.input<CudnnTensorDescParam>() | |||
.input<Param<cudnnConvolutionFwdAlgo_t>>() | |||
.output<RefParam<size_t>>() | |||
.ret<Param<cudnnStatus_t>>() | |||
.build(&cudnnGetConvolutionForwardWorkspaceSize); | |||
#if CUDNN_MAJOR >= 7 | |||
GetConvolutionForwardAlgorithm_v7 = | |||
FunctionCacheBuilder<>() | |||
.input<Param<cudnnHandle_t>>() | |||
.input<CudnnTensorDescParam>() | |||
.input<CudnnFilterDescParam>() | |||
.input<CudnnConvDescParam>() | |||
.input<CudnnTensorDescParam>() | |||
.input<Param<int>>() | |||
.output<RefArraySizeParam<int>>() | |||
.output<ArrayParam<int, cudnnConvolutionFwdAlgoPerf_t>>() | |||
.ret<Param<cudnnStatus_t>>() | |||
.build(&cudnnGetConvolutionForwardAlgorithm_v7); | |||
GetConvolutionForwardAlgorithmMaxCount = | |||
FunctionCacheBuilder<>() | |||
.input<Param<cudnnHandle_t>>() | |||
.output<RefParam<int>>() | |||
.ret<Param<cudnnStatus_t>>() | |||
.build(&cudnnGetConvolutionForwardAlgorithmMaxCount); | |||
#endif | |||
GetConvolutionBackwardDataWorkspaceSize = | |||
FunctionCacheBuilder<>() | |||
.input<Param<cudnnHandle_t>>() | |||
.input<CudnnFilterDescParam>() | |||
.input<CudnnTensorDescParam>() | |||
.input<CudnnConvDescParam>() | |||
.input<CudnnTensorDescParam>() | |||
.input<Param<cudnnConvolutionBwdDataAlgo_t>>() | |||
.output<RefParam<size_t>>() | |||
.ret<Param<cudnnStatus_t>>() | |||
.build(&cudnnGetConvolutionBackwardDataWorkspaceSize); | |||
#if CUDNN_MAJOR >= 7 | |||
GetConvolutionBackwardDataAlgorithm_v7 = | |||
FunctionCacheBuilder<>() | |||
.input<Param<cudnnHandle_t>>() | |||
.input<CudnnFilterDescParam>() | |||
.input<CudnnTensorDescParam>() | |||
.input<CudnnConvDescParam>() | |||
.input<CudnnTensorDescParam>() | |||
.input<Param<int>>() | |||
.output<RefArraySizeParam<int>>() | |||
.output<ArrayParam<int, | |||
cudnnConvolutionBwdDataAlgoPerf_t>>() | |||
.ret<Param<cudnnStatus_t>>() | |||
.build(&cudnnGetConvolutionBackwardDataAlgorithm_v7); | |||
GetConvolutionBackwardDataAlgorithmMaxCount = | |||
FunctionCacheBuilder<>() | |||
.input<Param<cudnnHandle_t>>() | |||
.output<RefParam<int>>() | |||
.ret<Param<cudnnStatus_t>>() | |||
.build(&cudnnGetConvolutionBackwardDataAlgorithmMaxCount); | |||
#endif | |||
GetConvolutionBackwardFilterWorkspaceSize = | |||
FunctionCacheBuilder<>() | |||
.input<Param<cudnnHandle_t>>() | |||
.input<CudnnTensorDescParam>() | |||
.input<CudnnTensorDescParam>() | |||
.input<CudnnConvDescParam>() | |||
.input<CudnnFilterDescParam>() | |||
.input<Param<cudnnConvolutionBwdFilterAlgo_t>>() | |||
.output<RefParam<size_t>>() | |||
.ret<Param<cudnnStatus_t>>() | |||
.build(&cudnnGetConvolutionBackwardFilterWorkspaceSize); | |||
#if CUDNN_MAJOR >= 7 | |||
GetConvolutionBackwardFilterAlgorithm_v7 = | |||
FunctionCacheBuilder<>() | |||
.input<Param<cudnnHandle_t>>() | |||
.input<CudnnTensorDescParam>() | |||
.input<CudnnTensorDescParam>() | |||
.input<CudnnConvDescParam>() | |||
.input<CudnnFilterDescParam>() | |||
.input<Param<int>>() | |||
.output<RefArraySizeParam<int>>() | |||
.output<ArrayParam<int, | |||
cudnnConvolutionBwdFilterAlgoPerf_t>>() | |||
.ret<Param<cudnnStatus_t>>() | |||
.build(&cudnnGetConvolutionBackwardFilterAlgorithm_v7); | |||
GetConvolutionBackwardFilterAlgorithmMaxCount = | |||
FunctionCacheBuilder<>() | |||
.input<Param<cudnnHandle_t>>() | |||
.output<RefParam<int>>() | |||
.ret<Param<cudnnStatus_t>>() | |||
.build(&cudnnGetConvolutionBackwardFilterAlgorithmMaxCount); | |||
#endif | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn | |||
MEGDNN_VERSION_SYMBOL(CUDA, CUDA_VERSION); | |||
MEGDNN_VERSION_SYMBOL3(CUDNN, CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL); | |||
@@ -124,6 +124,10 @@ class HandleImpl: public HandleImplHelper { | |||
size_t image2d_pitch_alignment() const override; | |||
HandleVendorType vendor_type() const override; | |||
class CUDNN; | |||
CUDNN& cudnn(); | |||
private: | |||
bool m_is_tegra_k1; | |||
int m_device_id; | |||
@@ -156,9 +160,34 @@ class HandleImpl: public HandleImplHelper { | |||
//! device ptr to const scalars | |||
ConstScalars* m_const_scalars; | |||
std::unique_ptr<CUDNN> m_cudnn_api_cache; | |||
void initialize_cusolver(); | |||
}; | |||
class HandleImpl::CUDNN { | |||
cudnnHandle_t m_handle; | |||
public: | |||
CUDNN(cudnnHandle_t handle); | |||
#define WRAP_CUDNN_API(NAME) thin_function<decltype(cudnn##NAME)> NAME; | |||
WRAP_CUDNN_API(GetConvolutionForwardWorkspaceSize); | |||
#if CUDNN_MAJOR >= 7 | |||
WRAP_CUDNN_API(GetConvolutionForwardAlgorithm_v7); | |||
WRAP_CUDNN_API(GetConvolutionForwardAlgorithmMaxCount); | |||
#endif | |||
#if CUDNN_MAJOR >= 7 | |||
WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithm_v7); | |||
WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithmMaxCount); | |||
#endif | |||
WRAP_CUDNN_API(GetConvolutionBackwardDataWorkspaceSize); | |||
#if CUDNN_MAJOR >= 7 | |||
WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithmMaxCount); | |||
WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithm_v7); | |||
#endif | |||
WRAP_CUDNN_API(GetConvolutionBackwardFilterWorkspaceSize); | |||
#undef WRAP_CUDNN_API | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||