|
@@ -131,6 +131,13 @@ public: |
|
|
} |
|
|
} |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
struct PreprocessedFilter { |
|
|
|
|
|
//! user data; its lifetime should be bound to MegDNN Convolution |
|
|
|
|
|
//! operator |
|
|
|
|
|
void* algorithm_id; |
|
|
|
|
|
TensorNDArray tensors; |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
protected: |
|
|
protected: |
|
|
// Check or deduce output DType |
|
|
// Check or deduce output DType |
|
|
void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const; |
|
|
void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const; |
|
@@ -200,13 +207,27 @@ public: |
|
|
* \param[out] dst (n, oc, oh, ow) |
|
|
* \param[out] dst (n, oc, oh, ow) |
|
|
*/ |
|
|
*/ |
|
|
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, |
|
|
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, |
|
|
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; |
|
|
|
|
|
|
|
|
_megdnn_tensor_out dst, |
|
|
|
|
|
const PreprocessedFilter* preprocessed_filter, |
|
|
|
|
|
_megdnn_workspace workspace) = 0; |
|
|
|
|
|
virtual void exec_preprocess(const TensorLayout& src_layout, |
|
|
|
|
|
_megdnn_tensor_in filter, |
|
|
|
|
|
const TensorLayout& dst_layout, |
|
|
|
|
|
PreprocessedFilter* preprocessed_filter, |
|
|
|
|
|
_megdnn_workspace workspace) = 0; |
|
|
void deduce_dtype(DType src, DType filter, DType& dst); |
|
|
void deduce_dtype(DType src, DType filter, DType& dst); |
|
|
void deduce_layout(const TensorLayout& src, const TensorLayout& filter, |
|
|
void deduce_layout(const TensorLayout& src, const TensorLayout& filter, |
|
|
TensorLayout& dst); |
|
|
TensorLayout& dst); |
|
|
virtual size_t get_workspace_in_bytes(const TensorLayout& src, |
|
|
|
|
|
const TensorLayout& filter, |
|
|
|
|
|
const TensorLayout& dst) = 0; |
|
|
|
|
|
|
|
|
virtual size_t get_workspace_in_bytes( |
|
|
|
|
|
const TensorLayout& src, const TensorLayout& filter, |
|
|
|
|
|
const TensorLayout& dst, |
|
|
|
|
|
PreprocessedFilter* preprocessed_filter) = 0; |
|
|
|
|
|
virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout( |
|
|
|
|
|
const TensorLayout& src, const TensorLayout& filter, |
|
|
|
|
|
const TensorLayout& dst) = 0; |
|
|
|
|
|
virtual size_t get_preprocess_workspace_in_bytes( |
|
|
|
|
|
const TensorLayout& src, const TensorLayout& filter, |
|
|
|
|
|
const TensorLayout& dst) = 0; |
|
|
|
|
|
|
|
|
protected: |
|
|
protected: |
|
|
CanonizedFilterMeta check_exec(const TensorLayout& src, |
|
|
CanonizedFilterMeta check_exec(const TensorLayout& src, |
|
@@ -297,17 +318,35 @@ public: |
|
|
*/ |
|
|
*/ |
|
|
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, |
|
|
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, |
|
|
_megdnn_tensor_in bias, _megdnn_tensor_in z, |
|
|
_megdnn_tensor_in bias, _megdnn_tensor_in z, |
|
|
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; |
|
|
|
|
|
|
|
|
_megdnn_tensor_out dst, |
|
|
|
|
|
const PreprocessedFilter* preprocessed_filter, |
|
|
|
|
|
_megdnn_workspace workspace) = 0; |
|
|
|
|
|
virtual void exec_preprocess(const TensorLayout& src_layout, |
|
|
|
|
|
_megdnn_tensor_in filter, |
|
|
|
|
|
const TensorLayout& bias_layout, |
|
|
|
|
|
const TensorLayout& z_layout, |
|
|
|
|
|
const TensorLayout& dst_layout, |
|
|
|
|
|
PreprocessedFilter* preprocessed_filter, |
|
|
|
|
|
_megdnn_workspace workspace) = 0; |
|
|
void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst); |
|
|
void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst); |
|
|
void deduce_layout(const TensorLayout& src, const TensorLayout& filter, |
|
|
void deduce_layout(const TensorLayout& src, const TensorLayout& filter, |
|
|
const TensorLayout& bias, const TensorLayout& z, |
|
|
const TensorLayout& bias, const TensorLayout& z, |
|
|
TensorLayout& dst); |
|
|
TensorLayout& dst); |
|
|
|
|
|
|
|
|
virtual size_t get_workspace_in_bytes(const TensorLayout& src, |
|
|
|
|
|
const TensorLayout& filter, |
|
|
|
|
|
const TensorLayout& bias, |
|
|
|
|
|
const TensorLayout& z, |
|
|
|
|
|
const TensorLayout& dst) = 0; |
|
|
|
|
|
|
|
|
virtual size_t get_workspace_in_bytes( |
|
|
|
|
|
const TensorLayout& src, const TensorLayout& filter, |
|
|
|
|
|
const TensorLayout& bias, const TensorLayout& z, |
|
|
|
|
|
const TensorLayout& dst, |
|
|
|
|
|
PreprocessedFilter* preprocessed_filter) = 0; |
|
|
|
|
|
virtual size_t get_preprocess_workspace_in_bytes( |
|
|
|
|
|
const TensorLayout& src, const TensorLayout& filter, |
|
|
|
|
|
const TensorLayout& bias, const TensorLayout& z, |
|
|
|
|
|
const TensorLayout& dst) = 0; |
|
|
|
|
|
virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout( |
|
|
|
|
|
const TensorLayout& src, const TensorLayout& filter, |
|
|
|
|
|
const TensorLayout& bias, const TensorLayout& z, |
|
|
|
|
|
const TensorLayout& dst) = 0; |
|
|
|
|
|
|
|
|
enum class BiasMode : uint32_t { |
|
|
enum class BiasMode : uint32_t { |
|
|
NO_BIAS = 0, //!< no bias |
|
|
NO_BIAS = 0, //!< no bias |
|
|
BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1] |
|
|
BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1] |
|
|