Browse Source

feat(dnn/naive): add convolution weight preprocess interface

GitOrigin-RevId: d0fd6c75a6
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
b2f0ceb2fc
1 changed files with 49 additions and 10 deletions
  1. +49
    -10
      dnn/include/megdnn/oprs/nn.h

+ 49
- 10
dnn/include/megdnn/oprs/nn.h View File

@@ -131,6 +131,13 @@ public:
}
};

struct PreprocessedFilter {
//! user data; its lifetime should be bound to MegDNN Convolution
//! operator
void* algorithm_id;
TensorNDArray tensors;
};

protected:
// Check or deduce output DType
void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const;
@@ -200,13 +207,27 @@ public:
* \param[out] dst (n, oc, oh, ow)
*/
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
_megdnn_tensor_out dst,
const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) = 0;
virtual void exec_preprocess(const TensorLayout& src_layout,
_megdnn_tensor_in filter,
const TensorLayout& dst_layout,
PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) = 0;
void deduce_dtype(DType src, DType filter, DType& dst);
void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst,
PreprocessedFilter* preprocessed_filter) = 0;
virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) = 0;
virtual size_t get_preprocess_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) = 0;

protected:
CanonizedFilterMeta check_exec(const TensorLayout& src,
@@ -297,17 +318,35 @@ public:
*/
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
_megdnn_tensor_in bias, _megdnn_tensor_in z,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
_megdnn_tensor_out dst,
const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) = 0;
virtual void exec_preprocess(const TensorLayout& src_layout,
_megdnn_tensor_in filter,
const TensorLayout& bias_layout,
const TensorLayout& z_layout,
const TensorLayout& dst_layout,
PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) = 0;
void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst);
void deduce_layout(const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
TensorLayout& dst);

virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& bias,
const TensorLayout& z,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst,
PreprocessedFilter* preprocessed_filter) = 0;
virtual size_t get_preprocess_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) = 0;
virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) = 0;

enum class BiasMode : uint32_t {
NO_BIAS = 0, //!< no bias
BROADCAST_CHANNEL_BIAS, //!< broadcast channel bias, [1, c, 1, 1]


Loading…
Cancel
Save