GitOrigin-RevId: 479718ac75
release-1.2
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
@@ -92,24 +93,72 @@ enum class AlgoDataType : uint32_t { | |||
/*! | |||
* \brief Abstract representation of an algorithm for implementing | |||
* the operator | |||
* | |||
* All pointers to Algorithm should be allocated globally and usable | |||
* across multiple megdnn handles, and they should not be freed by | |||
* the caller. | |||
*/ | |||
class Algorithm { | |||
public: | |||
static constexpr uint32_t INVALID_ALGO_TYPE = static_cast<uint32_t>(-1); | |||
/** | |||
* \brief Algorithm information, we can get real algo from | |||
* AlgorithmInfo::Info::Desc | |||
*/ | |||
struct Info { | |||
struct Desc { | |||
//! backend of the algo belonging to | |||
Handle::HandleType handle_type; | |||
//! indicate the real algo implementation | |||
uint32_t type = INVALID_ALGO_TYPE; | |||
//! serialized param of the algo type | |||
std::string param; | |||
bool valid() const { return type != INVALID_ALGO_TYPE; } | |||
void reset() { type = INVALID_ALGO_TYPE; } | |||
bool operator==(const Desc& rhs) const { | |||
return handle_type == rhs.handle_type && type == rhs.type && | |||
param == rhs.param; | |||
} | |||
} desc; | |||
//! algorithm name | |||
std::string name; | |||
bool is_reproducible; | |||
bool valid() const { return desc.valid(); } | |||
void reset() { desc.reset(); } | |||
//! desc donate the algo | |||
bool operator==(const Info& rhs) const { return desc == rhs.desc; } | |||
}; | |||
virtual ~Algorithm() = default; | |||
/** | |||
* \brief whether the execution result is | |||
* reproducible across multiple runs. | |||
*/ | |||
virtual bool is_reproducible() const = 0; | |||
virtual const char* name() const = 0; | |||
//! serialized param | |||
virtual std::string param() const { return {}; } | |||
virtual uint32_t type() const = 0; | |||
Handle::HandleType handle_type() const { return m_handle_type; } | |||
Info info() const { | |||
return {{handle_type(), type(), param()}, name(), is_reproducible()}; | |||
} | |||
template <typename T> | |||
static void serialize_write_pod(const T& val, std::string& result) { | |||
result.append(reinterpret_cast<const char*>(&val), sizeof(T)); | |||
} | |||
static void serialize_write_pod(const char* val, std::string& result) { | |||
result.append(val, strlen(val)); | |||
} | |||
template <typename T> | |||
static T deserialize_read_pod(const std::string& data, size_t offset = 0) { | |||
T ret = *reinterpret_cast<const T*>(&data[offset]); | |||
return ret; | |||
} | |||
protected: | |||
~Algorithm() = default; | |||
Handle::HandleType m_handle_type = Handle::HandleType::NAIVE; | |||
}; | |||
@@ -127,6 +176,8 @@ class MultiAlgoOpr; | |||
template <class Opr> | |||
class MultiAlgoOpr<Opr, -1> { | |||
public: | |||
using AlgorithmInfo = detail::Algorithm::Info; | |||
using AlgorithmDesc = detail::Algorithm::Info::Desc; | |||
using Algorithm = detail::Algorithm; | |||
/*! | |||
* \brief get a string representation for current algorithm set; | |||
@@ -139,8 +190,8 @@ public: | |||
//! policy for executing the operator | |||
struct ExecutionPolicy { | |||
//! nullptr means using heuristic | |||
Algorithm* algorithm = nullptr; | |||
//! INVALID_ALGO_TYPE algo_type means using heuristic | |||
AlgorithmInfo algo; | |||
}; | |||
ExecutionPolicy& execution_policy() { return m_execution_policy; } | |||
@@ -161,6 +212,39 @@ template <class Opr> | |||
class MultiAlgoOpr<Opr, 3> : public MultiAlgoOpr<Opr, -1> { | |||
public: | |||
using Algorithm = detail::Algorithm; | |||
using AlgorithmInfo = detail::Algorithm::Info; | |||
//! get all possible algorithm decriptions for the specified layouts | |||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
const TensorLayout& p1, | |||
const TensorLayout& p2) { | |||
std::vector<AlgorithmInfo> ret; | |||
for (auto&& algo : get_all_algorithms(p0, p1, p2)) { | |||
ret.emplace_back(algo->info()); | |||
} | |||
return ret; | |||
} | |||
/** | |||
* \brief Returns the best algorithm information which indicate the | |||
* algorithm by heuristic. | |||
* | |||
* The selected algorithm should not use workspace more than | |||
* \p workspace_limit_in_bytes. | |||
*/ | |||
AlgorithmInfo get_algorithm_info_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
bool reproducible = false) { | |||
return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes, | |||
reproducible) | |||
->info(); | |||
} | |||
protected: | |||
~MultiAlgoOpr() = default; | |||
//! get all possible algorithms for the specified layouts | |||
virtual std::vector<Algorithm*> get_all_algorithms( | |||
@@ -179,9 +263,6 @@ public: | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
bool reproducible = false) = 0; | |||
protected: | |||
~MultiAlgoOpr() = default; | |||
}; | |||
//! specializae for nargs == 4 | |||
@@ -189,6 +270,40 @@ template <class Opr> | |||
class MultiAlgoOpr<Opr, 4> : public MultiAlgoOpr<Opr, -1> { | |||
public: | |||
using Algorithm = detail::Algorithm; | |||
using AlgorithmInfo = detail::Algorithm::Info; | |||
//! get all possible algorithm decriptions for the specified layouts | |||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
const TensorLayout& p1, | |||
const TensorLayout& p2, | |||
const TensorLayout& p3) { | |||
std::vector<AlgorithmInfo> ret; | |||
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) { | |||
ret.emplace_back(algo->info()); | |||
} | |||
return ret; | |||
} | |||
/** | |||
* \brief Returns the best algorithm information which indicate the | |||
* algorithm by heuristic. | |||
* | |||
* The selected algorithm should not use workspace more than | |||
* \p workspace_limit_in_bytes. | |||
*/ | |||
AlgorithmInfo get_algorithm_info_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
bool reproducible = false) { | |||
return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes, | |||
reproducible) | |||
->info(); | |||
} | |||
protected: | |||
~MultiAlgoOpr() = default; | |||
//! get all possible algorithms for the specified layouts | |||
virtual std::vector<Algorithm*> get_all_algorithms( | |||
@@ -207,9 +322,6 @@ public: | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
bool reproducible = false) = 0; | |||
protected: | |||
~MultiAlgoOpr() = default; | |||
}; | |||
//! specializae for nargs == 5 | |||
@@ -217,6 +329,42 @@ template <class Opr> | |||
class MultiAlgoOpr<Opr, 5> : public MultiAlgoOpr<Opr, -1> { | |||
public: | |||
using Algorithm = detail::Algorithm; | |||
using AlgorithmInfo = detail::Algorithm::Info; | |||
//! get all possible algorithm decriptions for the specified layouts | |||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
const TensorLayout& p1, | |||
const TensorLayout& p2, | |||
const TensorLayout& p3, | |||
const TensorLayout& p4) { | |||
std::vector<AlgorithmInfo> ret; | |||
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) { | |||
ret.emplace_back(algo->info()); | |||
} | |||
return ret; | |||
} | |||
/** | |||
* \brief Returns the best algorithm information which indicate the | |||
* algorithm by heuristic. | |||
* | |||
* The selected algorithm should not use workspace more than | |||
* \p workspace_limit_in_bytes. | |||
*/ | |||
AlgorithmInfo get_algorithm_info_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
bool reproducible = false) { | |||
return get_algorithm_heuristic(p0, p1, p2, p3, p4, | |||
workspace_limit_in_bytes, reproducible) | |||
->info(); | |||
} | |||
protected: | |||
~MultiAlgoOpr() = default; | |||
//! get all possible algorithms for the specified layouts | |||
virtual std::vector<Algorithm*> get_all_algorithms( | |||
@@ -237,9 +385,6 @@ public: | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
bool reproducible = false) = 0; | |||
protected: | |||
~MultiAlgoOpr() = default; | |||
}; | |||
//! specializae for nargs == 8 | |||
@@ -247,6 +392,42 @@ template <class Opr> | |||
class MultiAlgoOpr<Opr, 8> : public MultiAlgoOpr<Opr, -1> { | |||
public: | |||
using Algorithm = detail::Algorithm; | |||
using AlgorithmInfo = detail::Algorithm::Info; | |||
//! get all possible algorithm decriptions for the specified layouts | |||
std::vector<AlgorithmInfo> get_all_algorithms_info( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4, const TensorLayout& p5, | |||
const TensorLayout& p6, const TensorLayout& p7) { | |||
std::vector<AlgorithmInfo> ret; | |||
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) { | |||
ret.emplace_back(algo->info()); | |||
} | |||
return ret; | |||
} | |||
/** | |||
* \brief Returns the best algorithm information which indicate the | |||
* algorithm by heuristic. | |||
* | |||
* The selected algorithm should not use workspace more than | |||
*/ | |||
AlgorithmInfo get_algorithm_info_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4, const TensorLayout& p5, | |||
const TensorLayout& p6, const TensorLayout& p7, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
bool reproducible = false) { | |||
return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7, | |||
workspace_limit_in_bytes, reproducible) | |||
->info(); | |||
} | |||
protected: | |||
~MultiAlgoOpr() = default; | |||
//! get all possible algorithms for the specified layouts | |||
virtual std::vector<Algorithm*> get_all_algorithms( | |||
@@ -269,9 +450,6 @@ public: | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
bool reproducible = false) = 0; | |||
protected: | |||
~MultiAlgoOpr() = default; | |||
}; | |||
} // namespace detail | |||
} // namespace megdnn | |||
@@ -31,6 +31,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_DIRECT_STRD2_FP16) | |||
}; | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
@@ -36,6 +36,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_DIRECT_STRD2_FP32) | |||
}; | |||
} // namespace aarch64 | |||
@@ -48,6 +48,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_MATMUL_S8) | |||
}; | |||
} // namespace aarch64 | |||
@@ -32,28 +32,54 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
AlgoF16DirectStride2 f16_direct_stride2; | |||
#endif | |||
fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map; | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_direct_algos; | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_matmul_algos; | |||
public: | |||
AlgoPack() { | |||
matmul_algos.emplace_back(&qu8_matrix_mul); | |||
matmul_algos.emplace_back(&s8_matrix_mul); | |||
m_matmul_algos.emplace_back(&qu8_matrix_mul); | |||
m_matmul_algos.emplace_back(&s8_matrix_mul); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
direct_algos.emplace_back(&f16_direct_stride2); | |||
m_direct_algos.emplace_back(&f16_direct_stride2); | |||
#endif | |||
direct_algos.emplace_back(&f32_direct_stride2); | |||
m_direct_algos.emplace_back(&f32_direct_stride2); | |||
for (auto&& algo : m_direct_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
for (auto&& algo : m_matmul_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const { | |||
return m_direct_algos; | |||
} | |||
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos() | |||
const { | |||
return m_matmul_algos; | |||
} | |||
SmallVector<AlgoBase*> direct_algos; | |||
SmallVector<AlgoBase*> matmul_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||
static AlgoPack sl_algo_pack; | |||
auto&& algos = arm_common::ConvBiasImpl::algo_pack(); | |||
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), | |||
sl_algo_pack.direct_algos.end()); | |||
const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||
static AlgoPack algo_pack; | |||
return algo_pack; | |||
} | |||
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> | |||
ConvBiasImpl::get_all_packed_algo() { | |||
auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); | |||
algos.insert(algos.begin(), algo_pack().direct_algos().begin(), | |||
algo_pack().direct_algos().end()); | |||
//! We put matmul algos at the begin. Because matmul will get privilege when | |||
//! prefer return true. See | |||
algos.insert(algos.begin(), sl_algo_pack.matmul_algos.begin(), | |||
sl_algo_pack.matmul_algos.end()); | |||
algos.insert(algos.begin(), algo_pack().matmul_algos().begin(), | |||
algo_pack().matmul_algos().end()); | |||
return std::move(algos); | |||
} | |||
@@ -25,7 +25,9 @@ public: | |||
} | |||
}; | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override; | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override; | |||
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl); | |||
protected: | |||
const char* get_algorithm_set_name() const override; | |||
@@ -38,6 +40,7 @@ private: | |||
class AlgoF16DirectStride2; | |||
#endif | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack(); | |||
}; | |||
} // namespace aarch64 | |||
@@ -48,6 +48,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_MATMUL_QU8) | |||
}; | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
@@ -27,6 +27,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_K8X12X1) | |||
}; | |||
class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | |||
@@ -37,6 +38,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_K8X12X1) | |||
}; | |||
class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { | |||
@@ -47,6 +49,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_K4X16X1) | |||
}; | |||
class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { | |||
@@ -58,10 +61,17 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4) | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16) | |||
}; | |||
class MatrixMulImpl::AlgoF32Gemv final | |||
: public arm_common::MatrixMulImpl::AlgoF32Gemv {}; | |||
: public arm_common::MatrixMulImpl::AlgoF32Gemv { | |||
public: | |||
AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { | |||
m_handle_type = Handle::HandleType::AARCH64; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_GEMV) | |||
}; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { | |||
@@ -72,6 +82,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_K8X24X1) | |||
}; | |||
class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { | |||
@@ -83,6 +94,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_8X8) | |||
}; | |||
#endif | |||
@@ -98,6 +110,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X12X4_DOTPROD) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | |||
@@ -110,6 +123,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD) | |||
}; | |||
#else | |||
@@ -124,6 +138,7 @@ public: | |||
PackMode packmode() const override { return PackMode::DEFAULT; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_4X4X16) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { | |||
@@ -136,6 +151,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K4X4X16) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { | |||
@@ -147,6 +163,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X8X8) | |||
}; | |||
#endif | |||
@@ -160,6 +177,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K8X8X8) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { | |||
@@ -171,6 +189,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K4X4X16) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | |||
@@ -186,6 +205,7 @@ public: | |||
PackMode packmode() const override { return PackMode::DEFAULT; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_16X12X4) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | |||
@@ -201,6 +221,7 @@ public: | |||
PackMode packmode() const override { return PackMode::DEFAULT; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_K8X8X8) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | |||
@@ -214,6 +235,7 @@ public: | |||
PackMode packmode() const override { return PackMode::DEFAULT; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_4X4X8) | |||
}; | |||
class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | |||
@@ -225,6 +247,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_K12X8X1) | |||
}; | |||
class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { | |||
@@ -236,6 +259,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_MK8_8X8) | |||
}; | |||
#if __ARM_FEATURE_DOTPROD | |||
@@ -249,6 +273,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X4_DOTPROD) | |||
}; | |||
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | |||
@@ -262,6 +287,7 @@ public: | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_GEMV_DOTPROD) | |||
}; | |||
#else | |||
@@ -273,6 +299,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X8) | |||
}; | |||
#endif | |||
@@ -51,49 +51,66 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
AlgoQuint8K8x8x8 quint8_k8x8x8; | |||
#endif | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | |||
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos; | |||
AlgoPack() { | |||
all_algos.emplace_back(&f32_gemv); | |||
all_algos.emplace_back(&f32K8x12x1); | |||
all_algos.emplace_back(&f32_mk4_8x12x1); | |||
all_algos.emplace_back(&f32k4x16x1); | |||
all_algos.emplace_back(&f32mk4_4x16); | |||
m_all_algos.emplace_back(&f32_gemv); | |||
m_all_algos.emplace_back(&f32K8x12x1); | |||
m_all_algos.emplace_back(&f32_mk4_8x12x1); | |||
m_all_algos.emplace_back(&f32k4x16x1); | |||
m_all_algos.emplace_back(&f32mk4_4x16); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
all_algos.emplace_back(&f16_k8x24x1); | |||
all_algos.emplace_back(&f16_mk8_8x8); | |||
m_all_algos.emplace_back(&f16_k8x24x1); | |||
m_all_algos.emplace_back(&f16_mk8_8x8); | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | |||
all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); | |||
m_all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | |||
m_all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); | |||
#else | |||
all_algos.emplace_back(&int8x8x32_k4x4x16); | |||
all_algos.emplace_back(&int8x8x32_k8x8x8); | |||
all_algos.emplace_back(&int8x8x32_mk4_4x4x16); | |||
m_all_algos.emplace_back(&int8x8x32_k4x4x16); | |||
m_all_algos.emplace_back(&int8x8x32_k8x8x8); | |||
m_all_algos.emplace_back(&int8x8x32_mk4_4x4x16); | |||
#endif | |||
all_algos.emplace_back(&int8x8x16_k4x4x16); | |||
all_algos.emplace_back(&int8x8x16_k8x8x8); | |||
all_algos.emplace_back(&int8x8x16_mk4_k8x8x8); | |||
all_algos.emplace_back(&int8x8x16_mk4_4x4x8); | |||
all_algos.emplace_back(&int8x8x16_mk4_16x12x4); | |||
m_all_algos.emplace_back(&int8x8x16_k4x4x16); | |||
m_all_algos.emplace_back(&int8x8x16_k8x8x8); | |||
m_all_algos.emplace_back(&int8x8x16_mk4_k8x8x8); | |||
m_all_algos.emplace_back(&int8x8x16_mk4_4x4x8); | |||
m_all_algos.emplace_back(&int8x8x16_mk4_16x12x4); | |||
all_algos.emplace_back(&int16x16x32_k12x8x1); | |||
all_algos.emplace_back(&int16x16x32_mk8_8x8); | |||
m_all_algos.emplace_back(&int16x16x32_k12x8x1); | |||
m_all_algos.emplace_back(&int16x16x32_mk8_8x8); | |||
#if __ARM_FEATURE_DOTPROD | |||
all_algos.emplace_back(&quint8_gemv_dotprod); | |||
all_algos.emplace_back(&quint8_k8x8x4_dotprod); | |||
m_all_algos.emplace_back(&quint8_gemv_dotprod); | |||
m_all_algos.emplace_back(&quint8_k8x8x4_dotprod); | |||
#else | |||
all_algos.emplace_back(&quint8_k8x8x8); | |||
m_all_algos.emplace_back(&quint8_k8x8x8); | |||
#endif | |||
for (auto&& algo : m_all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
const SmallVector<fallback::MatrixMulImpl::AlgoBase*>& all_algos() const { | |||
return m_all_algos; | |||
} | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||
static AlgoPack s_algo_pack; | |||
auto&& algos = arm_common::MatrixMulImpl::algo_pack(); | |||
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | |||
s_algo_pack.all_algos.end()); | |||
const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { | |||
static AlgoPack algo_pack; | |||
return algo_pack; | |||
} | |||
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl) | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> | |||
MatrixMulImpl::get_all_packed_algo() { | |||
auto&& algos = arm_common::MatrixMulImpl::get_all_packed_algo(); | |||
algos.insert(algos.begin(), algo_pack().all_algos().begin(), | |||
algo_pack().all_algos().end()); | |||
return std::move(algos); | |||
} | |||
@@ -25,7 +25,10 @@ public: | |||
} | |||
}; | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override; | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo() | |||
override; | |||
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl); | |||
private: | |||
class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 | |||
@@ -66,6 +69,8 @@ private: | |||
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 | |||
class AlgoPack; | |||
public: | |||
static const AlgoPack& algo_pack(); | |||
}; | |||
} // namespace aarch64 | |||
@@ -30,6 +30,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_FP16) | |||
}; | |||
class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase { | |||
@@ -45,7 +46,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP16) | |||
}; | |||
class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase { | |||
public: | |||
@@ -61,6 +62,7 @@ public: | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP16) | |||
}; | |||
class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase { | |||
public: | |||
@@ -75,6 +77,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_FP16) | |||
}; | |||
class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { | |||
@@ -94,6 +97,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override{ | |||
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP16) | |||
}; | |||
class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { | |||
@@ -110,6 +114,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP16) | |||
}; | |||
} // namespace arm_common | |||
@@ -30,6 +30,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { | |||
@@ -45,6 +46,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { | |||
@@ -60,6 +62,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { | |||
@@ -75,6 +78,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F54_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { | |||
@@ -90,6 +94,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP32) | |||
}; | |||
//===================== NCHW44 Winograd Support =====================// | |||
@@ -107,6 +112,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32) | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { | |||
@@ -123,6 +129,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32) | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { | |||
@@ -139,6 +146,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32) | |||
}; | |||
// ================================================================= // | |||
@@ -157,6 +165,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | |||
@@ -174,6 +183,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | |||
@@ -191,6 +201,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { | |||
@@ -209,6 +220,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { | |||
@@ -227,6 +239,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { | |||
@@ -244,6 +257,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW44_F32) | |||
}; | |||
} // namespace arm_common | |||
@@ -33,6 +33,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_S8) | |||
}; | |||
class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | |||
@@ -49,6 +50,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_S8) | |||
}; | |||
class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { | |||
@@ -65,6 +67,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44) | |||
}; | |||
class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { | |||
@@ -81,6 +84,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_S8) | |||
}; | |||
class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase { | |||
@@ -95,6 +99,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD1_NCHW44_S8) | |||
}; | |||
class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase { | |||
@@ -109,6 +114,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD2_NCHW44_S8) | |||
}; | |||
#if __ARM_FEATURE_DOTPROD | |||
@@ -126,6 +132,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8) | |||
}; | |||
class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | |||
@@ -142,6 +149,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_S8) | |||
}; | |||
class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { | |||
@@ -159,6 +167,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_S8) | |||
}; | |||
class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase { | |||
@@ -180,6 +189,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_DOT_S8) | |||
}; | |||
#endif | |||
@@ -196,6 +206,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_S8) | |||
}; | |||
//=======================input int8 compute fp32 output int8============ | |||
@@ -213,6 +224,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32) | |||
}; | |||
//=======================input int8 compute int16 output int8============ | |||
@@ -231,6 +243,7 @@ public: | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8) | |||
}; | |||
} // namespace arm_common | |||
@@ -39,6 +39,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_INT8X8X16) | |||
}; | |||
class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { | |||
@@ -54,6 +55,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_INT8X8X16) | |||
}; | |||
class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { | |||
@@ -80,6 +82,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_INT8X8X16) | |||
}; | |||
class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase { | |||
@@ -96,12 +99,16 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_F2_INT8X8X16) | |||
}; | |||
class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase { | |||
class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final | |||
: public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; } | |||
const char* name() const override { | |||
return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; | |||
} | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace( | |||
@@ -111,6 +118,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_STRD1_STRD2_NCHW44_INT8X8X16) | |||
}; | |||
class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase { | |||
@@ -129,6 +137,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_INT8X8X16) | |||
}; | |||
} // namespace arm_common | |||
@@ -88,46 +88,50 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
#endif | |||
SmallVector<std::unique_ptr<AlgoBase>> refhold; | |||
fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map; | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_direct_algos; | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_winograd_algos; | |||
public: | |||
AlgoPack() { | |||
#if __ARM_FEATURE_DOTPROD | |||
direct_algos.emplace_back(&ds8_direct_stride1); | |||
direct_algos.emplace_back(&ds8_direct_stride2); | |||
direct_algos.emplace_back(&du8_direct_stride1); | |||
direct_algos.emplace_back(&du8_direct_stride2); | |||
m_direct_algos.emplace_back(&ds8_direct_stride1); | |||
m_direct_algos.emplace_back(&ds8_direct_stride2); | |||
m_direct_algos.emplace_back(&du8_direct_stride1); | |||
m_direct_algos.emplace_back(&du8_direct_stride2); | |||
direct_algos.emplace_back(&ds8_direct_nchw44); | |||
direct_algos.emplace_back(&ds8_direct_nchw_nchw44); | |||
m_direct_algos.emplace_back(&ds8_direct_nchw44); | |||
m_direct_algos.emplace_back(&ds8_direct_nchw_nchw44); | |||
#endif | |||
direct_algos.emplace_back(&qu8_direct_stride2); | |||
direct_algos.emplace_back(&qu8_direct_stride1); | |||
direct_algos.emplace_back(&s8_direct_stride2); | |||
direct_algos.emplace_back(&s8_direct_nchw44); | |||
direct_algos.emplace_back(&s8x8x16_direct_nchw44); | |||
direct_algos.emplace_back(&s8_direct_nchw_nchw44); | |||
direct_algos.emplace_back(&s8_direct_stride1); | |||
direct_algos.emplace_back(&s8x8x16_channel_wise_stride1_stride2_nchw44); | |||
direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); | |||
direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); | |||
m_direct_algos.emplace_back(&qu8_direct_stride2); | |||
m_direct_algos.emplace_back(&qu8_direct_stride1); | |||
m_direct_algos.emplace_back(&s8_direct_stride2); | |||
m_direct_algos.emplace_back(&s8_direct_nchw44); | |||
m_direct_algos.emplace_back(&s8x8x16_direct_nchw44); | |||
m_direct_algos.emplace_back(&s8_direct_nchw_nchw44); | |||
m_direct_algos.emplace_back(&s8_direct_stride1); | |||
m_direct_algos.emplace_back( | |||
&s8x8x16_channel_wise_stride1_stride2_nchw44); | |||
m_direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); | |||
m_direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
direct_algos.emplace_back(&f16_direct_stride1); | |||
direct_algos.emplace_back(&f16_direct); | |||
m_direct_algos.emplace_back(&f16_direct_stride1); | |||
m_direct_algos.emplace_back(&f16_direct); | |||
#endif | |||
direct_algos.emplace_back(&i8x8x16_direct); | |||
direct_algos.emplace_back(&i8x8x16_stride2_filter2); | |||
direct_algos.emplace_back(&i8x8x16_stride2); | |||
direct_algos.emplace_back(&i8x8x16_nchw_nchw44); | |||
m_direct_algos.emplace_back(&i8x8x16_direct); | |||
m_direct_algos.emplace_back(&i8x8x16_stride2_filter2); | |||
m_direct_algos.emplace_back(&i8x8x16_stride2); | |||
m_direct_algos.emplace_back(&i8x8x16_nchw_nchw44); | |||
direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); | |||
direct_algos.emplace_back(&f32_chanel_wise_nchw44); | |||
direct_algos.emplace_back(&f32_direct_nchw44); | |||
m_direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); | |||
m_direct_algos.emplace_back(&f32_chanel_wise_nchw44); | |||
m_direct_algos.emplace_back(&f32_direct_nchw44); | |||
direct_algos.emplace_back(&f32_direct_stride1); | |||
direct_algos.emplace_back(&f32_direct_stride2); | |||
direct_algos.emplace_back(&f32_direct); | |||
m_direct_algos.emplace_back(&f32_direct_stride1); | |||
m_direct_algos.emplace_back(&f32_direct_stride2); | |||
m_direct_algos.emplace_back(&f32_direct); | |||
static CpuOprDelegationStorage<2> storage; | |||
auto matmul_opr = storage.get<MatrixMul, 0>(); | |||
@@ -143,31 +147,31 @@ public: | |||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
//! uncomment this when low precision mode is done | |||
#if 0 | |||
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
#endif | |||
//! Qint8x8x32 winograd compute with fp32 | |||
refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
} | |||
} | |||
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||
@@ -180,15 +184,15 @@ public: | |||
refhold.emplace_back(new AlgoFP32WinogradF63( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF54( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF45( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
} | |||
} | |||
@@ -203,15 +207,15 @@ public: | |||
refhold.emplace_back(new AlgoFP16WinogradF23( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP16WinogradF45( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP16WinogradF63( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
} | |||
} | |||
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||
@@ -224,7 +228,7 @@ public: | |||
refhold.emplace_back(new AlgoFP16WinogradF23_8x8( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
} | |||
} | |||
#endif | |||
@@ -238,25 +242,48 @@ public: | |||
refhold.emplace_back(new AlgoS8WinogradF23_8x8( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
} | |||
} | |||
for (auto&& algo : m_direct_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
for (auto&& algo : m_winograd_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
SmallVector<AlgoBase*> direct_algos; | |||
SmallVector<AlgoBase*> winograd_algos; | |||
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() | |||
const { | |||
return m_direct_algos; | |||
} | |||
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& winograd_algos() | |||
const { | |||
return m_winograd_algos; | |||
} | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||
static AlgoPack sl_algo_pack; | |||
auto&& algos = fallback::ConvBiasImpl::algo_pack(); | |||
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), | |||
sl_algo_pack.direct_algos.end()); | |||
algos.insert(algos.end(), sl_algo_pack.winograd_algos.begin(), | |||
sl_algo_pack.winograd_algos.end()); | |||
const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||
static AlgoPack algo_pack; | |||
return algo_pack; | |||
} | |||
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> | |||
ConvBiasImpl::get_all_packed_algo() { | |||
auto&& algos = fallback::ConvBiasImpl::get_all_packed_algo(); | |||
algos.insert(algos.begin(), algo_pack().direct_algos().begin(), | |||
algo_pack().direct_algos().end()); | |||
algos.insert(algos.end(), algo_pack().winograd_algos().begin(), | |||
algo_pack().winograd_algos().end()); | |||
return std::move(algos); | |||
} | |||
@@ -12,6 +12,7 @@ | |||
#pragma once | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#include "src/common/algo_base.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
@@ -27,7 +28,7 @@ public: | |||
} | |||
}; | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override; | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override; | |||
bool is_matmul_quantized_prefer( | |||
const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param) | |||
@@ -35,7 +36,8 @@ public: | |||
SmallVector<AlgoCategory> suggest_algo_category_order( | |||
const NCBKernSizeParam& param) const override; | |||
class AlgoPack; | |||
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl); | |||
protected: | |||
const char* get_algorithm_set_name() const override; | |||
@@ -95,6 +97,9 @@ private: | |||
class AlgoF16Direct; | |||
class AlgoF16DirectStride1; | |||
#endif | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack(); | |||
}; | |||
} // namespace arm_common | |||
@@ -32,6 +32,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_QU8) | |||
}; | |||
class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { | |||
@@ -48,6 +49,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_QU8) | |||
}; | |||
#if __ARM_FEATURE_DOTPROD | |||
class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { | |||
@@ -65,6 +67,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_QU8) | |||
}; | |||
class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { | |||
@@ -81,6 +84,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_QU8) | |||
}; | |||
#endif | |||
} // namespace arm_common | |||
@@ -36,6 +36,7 @@ public: | |||
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | |||
const NCBKernSizeParam&) const override; | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_INT8X8X32) | |||
}; | |||
class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final | |||
@@ -54,6 +55,7 @@ public: | |||
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | |||
const NCBKernSizeParam&) const override; | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_INT8X8X32) | |||
}; | |||
#endif | |||
@@ -1086,6 +1086,10 @@ bool deconv::can_stride1_int8x8x32_dot(const NCBKernSizeParam& param) { | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && | |||
FH >= PH + 1 && FW >= PW + 1; | |||
avaiable &= (param.filter_type.enumv() == DTypeEnum::QuantizedS8 || | |||
param.filter_type.enumv() == DTypeEnum::Int8) && | |||
(param.grad_type.enumv() == DTypeEnum::QuantizedS32 || | |||
param.grad_type.enumv() == DTypeEnum::Int32); | |||
return avaiable && | |||
((FH == 2 && OC <= 8) || | |||
((FH == 3 || FH == 5 || FH == 7) && (IC < 32 && OC <= 16))); | |||
@@ -1180,6 +1180,10 @@ bool deconv::can_stride2_int8x8x32_dot(const NCBKernSizeParam& param) { | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && | |||
FH >= PH + 1 && FW >= PW + 1; | |||
avaiable &= (param.filter_type.enumv() == DTypeEnum::QuantizedS8 || | |||
param.filter_type.enumv() == DTypeEnum::Int8) && | |||
(param.grad_type.enumv() == DTypeEnum::QuantizedS32 || | |||
param.grad_type.enumv() == DTypeEnum::Int32); | |||
return avaiable && ((FH == 2 && OC <= 4) || (FH == 3 && OC <= 8) || | |||
(FH == 5 && OC <= 16) || (FH == 7 && OC < 32)); | |||
} | |||
@@ -23,15 +23,54 @@ using namespace arm_common; | |||
/* ===================== ConvolutionBackwardData ===================== */ | |||
struct ConvolutionBackwardDataImpl::AlgoPack { | |||
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | |||
#if __ARM_FEATURE_DOTPROD | |||
AlgoSdot8DirectStride1 i8x8x32_direct_stride1_sdot; | |||
AlgoSdot8DirectStride2 i8x8x32_direct_stride2_sdot; | |||
AlgoUdot8DirectStride1 quint8_direct_stride1_udot; | |||
AlgoUdot8DirectStride2 quint8_direct_stride2_udot; | |||
#endif | |||
fallback::ConvolutionBackwardDataImpl::AlgoBase::Mapper m_all_algos_map; | |||
SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*> | |||
m_all_algos; | |||
public: | |||
AlgoPack() { | |||
#if __ARM_FEATURE_DOTPROD | |||
m_all_algos.emplace_back(&i8x8x32_direct_stride1_sdot); | |||
m_all_algos.emplace_back(&i8x8x32_direct_stride2_sdot); | |||
m_all_algos.emplace_back(&quint8_direct_stride1_udot); | |||
m_all_algos.emplace_back(&quint8_direct_stride2_udot); | |||
#endif | |||
for (auto&& algo : m_all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
const SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*>& | |||
all_algos() const { | |||
return m_all_algos; | |||
} | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack; | |||
const ConvolutionBackwardDataImpl::AlgoPack& | |||
ConvolutionBackwardDataImpl::algo_pack() { | |||
static AlgoPack algo_pack; | |||
return algo_pack; | |||
} | |||
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl) | |||
SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*> | |||
ConvolutionBackwardDataImpl::get_all_packed_algo() { | |||
auto&& algos = fallback::ConvolutionBackwardDataImpl::get_all_packed_algo(); | |||
algos.insert(algos.begin(), algo_pack().all_algos().begin(), | |||
algo_pack().all_algos().end()); | |||
return std::move(algos); | |||
} | |||
ConvolutionBackwardDataImpl::ncb_kern_t | |||
ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern( | |||
@@ -52,35 +91,6 @@ size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( | |||
param); | |||
} | |||
std::vector<ConvolutionBackwardDataImpl::Algorithm*> | |||
ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( | |||
const NCBKernSizeParam& param) { | |||
auto ret = fallback::ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( | |||
param); | |||
#if __ARM_FEATURE_DOTPROD | |||
if ((param.filter_type.enumv() == DTypeEnum::QuantizedS8 || | |||
param.filter_type.enumv() == DTypeEnum::Int8) && | |||
(param.grad_type.enumv() == DTypeEnum::QuantizedS32 || | |||
param.grad_type.enumv() == DTypeEnum::Int32)) { | |||
if (sm_algo_pack.i8x8x32_direct_stride1_sdot.usable(this, param)) { | |||
ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride1_sdot); | |||
} | |||
if (sm_algo_pack.i8x8x32_direct_stride2_sdot.usable(this, param)) { | |||
ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride2_sdot); | |||
} | |||
} else if (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
param.grad_type.enumv() == DTypeEnum::QuantizedS32) { | |||
if (sm_algo_pack.quint8_direct_stride1_udot.usable(this, param)) { | |||
ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride1_udot); | |||
} | |||
if (sm_algo_pack.quint8_direct_stride2_udot.usable(this, param)) { | |||
ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride2_udot); | |||
} | |||
} | |||
#endif | |||
return ret; | |||
} | |||
const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const { | |||
// arm common version 0 | |||
return "DeconvAC0"; | |||
@@ -47,11 +47,14 @@ protected: | |||
size_t ncb_1g_get_workspace(Algorithm* algo, | |||
const NCBKernSizeParam& param) override; | |||
std::vector<Algorithm*> ncb_1g_get_all_algorithms( | |||
const NCBKernSizeParam& param) override; | |||
const char* get_algorithm_set_name() const override; | |||
SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*> | |||
get_all_packed_algo() override; | |||
public: | |||
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl); | |||
private: | |||
#if __ARM_FEATURE_DOTPROD | |||
class AlgoSdot8DirectStride1; | |||
@@ -59,8 +62,8 @@ private: | |||
class AlgoUdot8DirectStride1; | |||
class AlgoUdot8DirectStride2; | |||
#endif | |||
struct AlgoPack; | |||
static AlgoPack sm_algo_pack; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack(); | |||
}; | |||
} // namespace arm_common | |||
@@ -36,6 +36,7 @@ public: | |||
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | |||
const NCBKernSizeParam&) const override; | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_QU8) | |||
}; | |||
class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final | |||
@@ -55,6 +56,7 @@ public: | |||
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | |||
const NCBKernSizeParam&) const override; | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_QU8) | |||
}; | |||
#endif | |||
} // namespace arm_common | |||
@@ -1236,6 +1236,9 @@ bool deconv::can_stride1_quint8_dot(const NCBKernSizeParam& param) { | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && | |||
FH >= PH + 1 && FW >= PW + 1; | |||
avaiable &= (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm || | |||
param.grad_type.enumv() == DTypeEnum::Int32); | |||
/** | |||
* \note In the kernel, we use int32_t to calc the value, in order | |||
* not generate negative number, we first initialize SHIFT and sub | |||
@@ -1337,6 +1337,9 @@ bool deconv::can_stride2_quint8_dot(const NCBKernSizeParam& param) { | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && | |||
FH >= PH + 1 && FW >= PW + 1; | |||
avaiable &= (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm || | |||
param.grad_type.enumv() == DTypeEnum::Int32); | |||
/** | |||
* \note In the kernel, we use uint32_t to calc the value, in order | |||
* not generate negative number, we first initialize SHIFT and sub | |||
@@ -59,6 +59,7 @@ public: | |||
virtual bool is_available(const KernParam&) const = 0; | |||
virtual void exec(const KernParam&) const = 0; | |||
virtual ~AlgoBase() = default; | |||
uint32_t type() const override { return INVALID_ALGO_TYPE; }; | |||
}; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
@@ -26,6 +26,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT) | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X16) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | |||
@@ -39,6 +40,7 @@ public: | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | |||
@@ -52,6 +54,7 @@ public: | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4) | |||
}; | |||
#if __ARM_FEATURE_DOTPROD | |||
@@ -66,6 +69,7 @@ public: | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT) | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4_DOT) | |||
}; | |||
#endif | |||
@@ -96,6 +100,7 @@ public: | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F32_GEMV_MK4) | |||
}; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
@@ -110,6 +115,7 @@ public: | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT) | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F16_GEMV) | |||
}; | |||
#endif | |||
@@ -130,6 +136,7 @@ public: | |||
static_cast<uint32_t>(AlgoDataType::FLOAT32) | | |||
static_cast<uint32_t>(AlgoDataType::QINT8X8X32)), | |||
DEFAULT) | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_GEVM) | |||
}; | |||
} // namespace arm_common | |||
@@ -28,28 +28,47 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
AlgoGevm gevm; | |||
AlgoF32GemvMK4 f32_gemv_mk4; | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | |||
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack() { | |||
all_algos.emplace_back(&int8x8x16); | |||
m_all_algos.emplace_back(&int8x8x16); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
all_algos.emplace_back(&f16gemv); | |||
m_all_algos.emplace_back(&f16gemv); | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); | |||
m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); | |||
#endif | |||
all_algos.emplace_back(&int8x8x32_gemv); | |||
all_algos.emplace_back(&int8x8x32_gemv_mk4); | |||
all_algos.emplace_back(&f32_gemv_mk4); | |||
all_algos.emplace_back(&gevm); | |||
m_all_algos.emplace_back(&int8x8x32_gemv); | |||
m_all_algos.emplace_back(&int8x8x32_gemv_mk4); | |||
m_all_algos.emplace_back(&f32_gemv_mk4); | |||
m_all_algos.emplace_back(&gevm); | |||
for (auto&& algo : m_all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
const SmallVector<fallback::MatrixMulImpl::AlgoBase*>& all_algos() const { | |||
return m_all_algos; | |||
} | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||
const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { | |||
static AlgoPack algo_pack; | |||
return algo_pack; | |||
} | |||
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl) | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> | |||
MatrixMulImpl::get_all_packed_algo() { | |||
static AlgoPack s_algo_pack; | |||
auto&& algos = fallback::MatrixMulImpl::algo_pack(); | |||
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | |||
s_algo_pack.all_algos.end()); | |||
auto&& algos = fallback::MatrixMulImpl::get_all_packed_algo(); | |||
algos.insert(algos.begin(), algo_pack().all_algos().begin(), | |||
algo_pack().all_algos().end()); | |||
return std::move(algos); | |||
} | |||
@@ -11,6 +11,7 @@ | |||
#pragma once | |||
#include "src/common/utils.h" | |||
#include "src/fallback/matrix_mul/opr_impl.h" | |||
#include "src/common/algo_base.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
@@ -27,7 +28,10 @@ public: | |||
} | |||
}; | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override; | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo() | |||
override; | |||
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl); | |||
protected: | |||
class AlgoF32Gemv; // Arm_common F32 Gemv | |||
@@ -43,6 +47,9 @@ protected: | |||
#endif | |||
class AlgoInt8x8x16; // Arm_common Int 8x8x16 | |||
class AlgoPack; | |||
public: | |||
static const AlgoPack& algo_pack(); | |||
}; | |||
} // namespace arm_common | |||
@@ -10,6 +10,7 @@ | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs/base.h" | |||
#include "src/fallback/pooling/opr_impl.h" | |||
namespace megdnn { | |||
@@ -72,6 +73,8 @@ public: | |||
virtual ~AlgoBase() = default; | |||
virtual bool usable(const PoolingKernSizeParam& param) const = 0; | |||
virtual void exec(const PoolingKernParam& param) const = 0; | |||
uint32_t type() const override { return INVALID_ALGO_TYPE; }; | |||
}; | |||
private: | |||
@@ -40,6 +40,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_MATMUL_S8) | |||
}; | |||
} // namespace armv7 | |||
@@ -24,22 +24,40 @@ using namespace armv7; | |||
class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
AlgoS8MatrixMul s8_matrix_mul; | |||
AlgoQU8MatrixMul qu8_matrix_mul; | |||
fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map; | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_all_algos; | |||
public: | |||
AlgoPack() { | |||
all_algos.emplace_back(&qu8_matrix_mul); | |||
all_algos.emplace_back(&s8_matrix_mul); | |||
m_all_algos.emplace_back(&qu8_matrix_mul); | |||
m_all_algos.emplace_back(&s8_matrix_mul); | |||
for (auto&& algo : m_all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& all_algos() | |||
const { | |||
return m_all_algos; | |||
} | |||
SmallVector<AlgoBase*> all_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||
static AlgoPack sl_algo_pack; | |||
auto&& algos = arm_common::ConvBiasImpl::algo_pack(); | |||
const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||
static AlgoPack algo_pack; | |||
return algo_pack; | |||
} | |||
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> | |||
ConvBiasImpl::get_all_packed_algo() { | |||
auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); | |||
//! TODO fused matmul bias is slower than matmul + elemwise in armv7 now, | |||
//! and nearly equal in aarch64, because of the waste of register in | |||
//! postprocess | |||
algos.insert(algos.end(), sl_algo_pack.all_algos.begin(), | |||
sl_algo_pack.all_algos.end()); | |||
algos.insert(algos.end(), algo_pack().all_algos().begin(), | |||
algo_pack().all_algos().end()); | |||
return std::move(algos); | |||
} | |||
@@ -25,7 +25,9 @@ public: | |||
} | |||
}; | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override; | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override; | |||
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl); | |||
protected: | |||
const char* get_algorithm_set_name() const override; | |||
@@ -34,6 +36,7 @@ private: | |||
class AlgoS8MatrixMul; | |||
class AlgoQU8MatrixMul; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack(); | |||
}; | |||
} // namespace armv7 | |||
@@ -42,6 +42,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_MATMUL_QU8) | |||
}; | |||
} // namespace armv7 | |||
@@ -27,6 +27,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_F32) | |||
}; | |||
class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { | |||
@@ -37,6 +38,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_MK4_PACK_4X12) | |||
}; | |||
class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase { | |||
@@ -48,6 +50,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_MK4_4x8) | |||
}; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
@@ -59,6 +62,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_K4X16X1) | |||
}; | |||
class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase { | |||
public: | |||
@@ -69,6 +73,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_MK8_4X8) | |||
}; | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
@@ -80,6 +85,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8_K6X8X4) | |||
}; | |||
class MatrixMulImpl::AlgoQuint8DotK4x8x4 final : public AlgoBase { | |||
@@ -90,6 +96,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_QUINT8_K4X8X4) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd final : public AlgoBase { | |||
@@ -102,11 +109,18 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8_MK4_8X4X4_DOTPROD) | |||
}; | |||
#endif | |||
class MatrixMulImpl::AlgoF32Gemv final | |||
: public arm_common::MatrixMulImpl::AlgoF32Gemv {}; | |||
: public arm_common::MatrixMulImpl::AlgoF32Gemv { | |||
public: | |||
AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { | |||
m_handle_type = Handle::HandleType::ARMV7; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_GEMV) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32K4x2x16 final : public AlgoBase { | |||
public: | |||
@@ -117,6 +131,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_K4X2X16) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32K4x8x8 final : public AlgoBase { | |||
@@ -128,6 +143,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_K4X8X8) | |||
}; | |||
class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { | |||
@@ -138,6 +154,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_QUINT8_K4X8X8) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x16K4x2x16 final : public AlgoBase { | |||
@@ -149,6 +166,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K4X2X16) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x16K4x8x8 final : public AlgoBase { | |||
@@ -160,6 +178,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K4X8X8) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | |||
@@ -171,6 +190,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_MK4_K8X8X4) | |||
}; | |||
class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase { | |||
@@ -182,6 +202,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT16X16X32_K12X4X1) | |||
}; | |||
class MatrixMulImpl::AlgoInt16x16x32MK8_4x8 final : public AlgoBase { | |||
@@ -193,6 +214,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT16X16X32_MK8_4X8) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | |||
@@ -204,6 +226,7 @@ public: | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_MK4_4X2X16) | |||
}; | |||
} // namespace armv7 | |||
@@ -43,42 +43,60 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1; | |||
AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | |||
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos; | |||
AlgoPack() { | |||
all_algos.emplace_back(&f32_gemv); | |||
all_algos.emplace_back(&f32); | |||
all_algos.emplace_back(&f32_mk4_pack_4x12); | |||
all_algos.emplace_back(&f32_mk4_4x8); | |||
m_all_algos.emplace_back(&f32_gemv); | |||
m_all_algos.emplace_back(&f32); | |||
m_all_algos.emplace_back(&f32_mk4_pack_4x12); | |||
m_all_algos.emplace_back(&f32_mk4_4x8); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
all_algos.emplace_back(&f16_k4x16x1); | |||
all_algos.emplace_back(&f16_mk8_4x8); | |||
m_all_algos.emplace_back(&f16_k4x16x1); | |||
m_all_algos.emplace_back(&f16_mk8_4x8); | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); | |||
all_algos.emplace_back(&int8_k6x8x4); | |||
all_algos.emplace_back(&quint8_k4x8x4); | |||
m_all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); | |||
m_all_algos.emplace_back(&int8_k6x8x4); | |||
m_all_algos.emplace_back(&quint8_k4x8x4); | |||
#endif | |||
all_algos.emplace_back(&int8x8x32_mk4_4x2x16); | |||
all_algos.emplace_back(&int8x8x32_k4x2x16); | |||
all_algos.emplace_back(&int8x8x32_k4x8x8); | |||
all_algos.emplace_back(&quint8_k4x8x8); | |||
all_algos.emplace_back(&int8x8x16_mk4_8x8x4); | |||
all_algos.emplace_back(&int8x8x16_k4x2x16); | |||
all_algos.emplace_back(&int8x8x16_k4x8x8); | |||
m_all_algos.emplace_back(&int8x8x32_mk4_4x2x16); | |||
m_all_algos.emplace_back(&int8x8x32_k4x2x16); | |||
m_all_algos.emplace_back(&int8x8x32_k4x8x8); | |||
m_all_algos.emplace_back(&quint8_k4x8x8); | |||
m_all_algos.emplace_back(&int8x8x16_mk4_8x8x4); | |||
m_all_algos.emplace_back(&int8x8x16_k4x2x16); | |||
m_all_algos.emplace_back(&int8x8x16_k4x8x8); | |||
m_all_algos.emplace_back(&int16x16x32_k12x4x1); | |||
m_all_algos.emplace_back(&int16x16x32_mk8_4x8); | |||
for (auto&& algo : m_all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
all_algos.emplace_back(&int16x16x32_k12x4x1); | |||
all_algos.emplace_back(&int16x16x32_mk8_4x8); | |||
const SmallVector<fallback::MatrixMulImpl::AlgoBase*>& all_algos() const { | |||
return m_all_algos; | |||
} | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||
static AlgoPack s_algo_pack; | |||
auto algos = arm_common::MatrixMulImpl::algo_pack(); | |||
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | |||
s_algo_pack.all_algos.end()); | |||
const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { | |||
static AlgoPack algo_pack; | |||
return algo_pack; | |||
} | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> | |||
MatrixMulImpl::get_all_packed_algo() { | |||
auto algos = arm_common::MatrixMulImpl::get_all_packed_algo(); | |||
algos.insert(algos.begin(), algo_pack().all_algos().begin(), | |||
algo_pack().all_algos().end()); | |||
return algos; | |||
} | |||
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl) | |||
// vim: syntax=cpp.doxygen |
@@ -25,7 +25,10 @@ public: | |||
} | |||
}; | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override; | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo() | |||
override; | |||
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl); | |||
private: | |||
class AlgoF32; // Armv7 F32 | |||
@@ -52,6 +55,9 @@ private: | |||
// DotProduct | |||
#endif | |||
class AlgoPack; | |||
public: | |||
static const AlgoPack& algo_pack(); | |||
}; | |||
} // namespace armv7 | |||
@@ -0,0 +1,101 @@ | |||
/** | |||
* \file dnn/src/common/algo_base.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include <functional> | |||
#include <string> | |||
#include "megdnn/oprs/base.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
#define MEGDNN_DECL_ALGO_TYPE(_type) \ | |||
uint32_t type() const override { \ | |||
return static_cast<std::underlying_type<AlgoType>::type>( \ | |||
AlgoType::_type); \ | |||
} | |||
#define MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(_opr) \ | |||
static fallback::_opr::AlgoBase* get_algo_from_desc( \ | |||
const AlgorithmDesc& desc) | |||
#define MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(_opr) \ | |||
fallback::_opr::AlgoBase* _opr::get_algo_from_desc( \ | |||
const AlgorithmDesc& desc) { \ | |||
megdnn_assert(algo_pack().all_algos_map().find(desc) != \ | |||
algo_pack().all_algos_map().end()); \ | |||
return algo_pack().all_algos_map().at(desc); \ | |||
} | |||
#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \ | |||
_opr::AlgoBase* _opr::get_algo_from_desc(const AlgorithmDesc& desc) { \ | |||
megdnn_assert(algo_pack().all_algos_map().find(desc) != \ | |||
algo_pack().all_algos_map().end()); \ | |||
return algo_pack().all_algos_map().at(desc); \ | |||
} | |||
/** | |||
* \brief construct algo from AlgorithmDesc | |||
*/ | |||
template <typename AlgoBase> | |||
class AlgoConstructMixin { | |||
private: | |||
std::vector<std::unique_ptr<AlgoBase>> m_refhold; | |||
protected: | |||
typename AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
//! construct the algo which described by desc, and return the instance | |||
AlgoBase* construct_and_get_algo( | |||
const detail::Algorithm::Info::Desc& desc) { | |||
auto iter = m_all_algos_map.find(desc); | |||
if (iter != m_all_algos_map.end()) { | |||
return m_all_algos_map.at(desc); | |||
} | |||
std::string serialized_bin; | |||
AlgoBase::serialize_write_pod(desc.type, serialized_bin); | |||
serialized_bin += desc.param; | |||
m_refhold.emplace_back(AlgoBase::deserialize(serialized_bin)); | |||
m_all_algos_map.emplace(desc, m_refhold.back().get()); | |||
return m_refhold.back().get(); | |||
} | |||
void clear() { | |||
m_all_algos_map.clear(); | |||
m_refhold.clear(); | |||
} | |||
const typename AlgoBase::Mapper& all_algos_map() const { | |||
return m_all_algos_map; | |||
} | |||
}; | |||
} // namespace megdnn | |||
namespace std { | |||
template <> | |||
struct hash<megdnn::detail::Algorithm::Info::Desc> { | |||
std::size_t operator()( | |||
const megdnn::detail::Algorithm::Info::Desc& desc) const { | |||
return megdnn::hash_combine<size_t>( | |||
megdnn::hash_combine<size_t>( | |||
std::hash<std::string>()(desc.param), | |||
std::hash<uint32_t>()(desc.type)), | |||
std::hash<uint32_t>()(static_cast<uint32_t>(desc.handle_type))); | |||
} | |||
}; | |||
} // namespace std | |||
// vim: syntax=cpp.doxygen |
@@ -25,15 +25,34 @@ namespace megdnn { | |||
*/ | |||
template <class Opr, typename... Args> | |||
typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||
typename Opr::Algorithm* ret; | |||
if (auto set = opr->execution_policy().algorithm) { | |||
typename Opr::AlgorithmInfo ret; | |||
auto set = opr->execution_policy().algo; | |||
if (set.valid()) { | |||
ret = set; | |||
} else { | |||
ret = opr->get_algorithm_heuristic(std::forward<Args>(args)..., | |||
std::numeric_limits<size_t>::max(), | |||
false); | |||
ret = opr->get_algorithm_info_heuristic( | |||
std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), | |||
false); | |||
} | |||
return opr->get_algo_from_desc(ret.desc); | |||
} | |||
/*! | |||
* \brief get user-configured algorithm, or heuristic algorithm. used in opencl | |||
* whose algo need to be constructed each time. | |||
*/ | |||
template <class Opr, typename... Args> | |||
typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { | |||
typename Opr::AlgorithmInfo ret; | |||
auto set = opr->execution_policy().algo; | |||
if (set.valid()) { | |||
return opr->algo_pack().construct_and_get_algo(set.desc); | |||
} else { | |||
ret = opr->get_algorithm_info_heuristic( | |||
std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), | |||
false); | |||
return opr->get_algo_from_desc(ret.desc); | |||
} | |||
return static_cast<typename Opr::AlgoBase*>(ret); | |||
} | |||
/*! | |||
@@ -9,6 +9,32 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
/** | |||
* Boost Software License - Version 1.0 - August 17th, 2003 | |||
* | |||
* Permission is hereby granted, free of charge, to any person or organization | |||
* obtaining a copy of the software and accompanying documentation covered by | |||
* this license (the "Software") to use, reproduce, display, distribute, | |||
* execute, and transmit the Software, and to prepare derivative works of the | |||
* Software, and to permit third-parties to whom the Software is furnished to | |||
* do so, all subject to the following: | |||
* | |||
* The copyright notices in the Software and this entire statement, including | |||
* the above license grant, this restriction and the following disclaimer, | |||
* must be included in all copies of the Software, in whole or in part, and | |||
* all derivative works of the Software, unless such copies or derivative | |||
* works are solely in the form of machine-executable object code generated by | |||
* a source language processor. | |||
* | |||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |||
* FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT | |||
* SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE | |||
* FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, | |||
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | |||
* DEALINGS IN THE SOFTWARE. | |||
*/ | |||
#pragma once | |||
#include "megdnn/arch.h" | |||
@@ -263,6 +289,13 @@ constexpr uint32_t operator"" _hash(char const* str, size_t count) { | |||
return XXHash64CT::hash(str, count, 20160701); | |||
} | |||
// refer to https://www.boost.org/doc/libs/1_64_0/boost/functional/hash/hash.hpp | |||
template <typename T> | |||
inline T hash_combine(T seed, T value) { | |||
seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2); | |||
return seed; | |||
} | |||
template <typename Vec> | |||
std::string vec2str(Vec&& vec) { | |||
std::string res; | |||
@@ -18,8 +18,14 @@ using namespace cuda; | |||
BatchConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(&int8_nchw4_gemm_dotprod); | |||
all_algos.push_back(&int8_nchw4_implicit_gemm_dotprod); | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchConvBiasForwardImpl) | |||
BatchConvBiasForwardImpl::AlgoPack BatchConvBiasForwardImpl::sm_algo_pack; | |||
BatchConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( | |||
@@ -11,13 +11,16 @@ | |||
#pragma once | |||
#include <csetjmp> | |||
#include <unordered_map> | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
#include "src/cuda/batch_conv_bias/opr_impl.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
@@ -26,6 +29,12 @@ protected: | |||
~AlgoBase() = default; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
CUDA_GEMM_NCHW4_DOTPROD_INT8, | |||
CUDA_IMPLICIT_GEMM_PRECOMP_NCHW4_DOTPROD_INT8, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
BatchConvBiasForwardImpl* opr; | |||
@@ -85,6 +94,7 @@ public: | |||
const char* name() const override { | |||
return "BATCH_CONV_BIAS_INT8_NCHW4_GEMM_DOTPROD"; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_GEMM_NCHW4_DOTPROD_INT8) | |||
}; | |||
class BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp final | |||
@@ -99,15 +109,16 @@ public: | |||
const char* name() const override { | |||
return "BATCH_CONV_BIAS_INT8_NCHW4_IMPLICIT_GEMM_PRECOMP_DOTPROD"; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_PRECOMP_NCHW4_DOTPROD_INT8) | |||
private: | |||
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
const SizeArgs& args) const; | |||
}; | |||
class BatchConvBiasForwardImpl::AlgoPack { | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator=(const AlgoPack&) = delete; | |||
class BatchConvBiasForwardImpl::AlgoPack : NonCopyableObj { | |||
private: | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
@@ -116,6 +127,8 @@ public: | |||
AlgoInt8NCHW4DotProdImplicitGemmPrecomp int8_nchw4_implicit_gemm_dotprod; | |||
std::vector<AlgoBase*> all_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace cuda | |||
@@ -26,6 +26,18 @@ public: | |||
const TensorLayout& bias, | |||
const TensorLayout& z, | |||
const TensorLayout& dst) override; | |||
const char* get_algorithm_set_name() const override; | |||
class AlgoBase; | |||
class AlgoInt8NCHW4DotProdGemm; | |||
class AlgoInt8NCHW4DotProdImplicitGemmPrecomp; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& src, const TensorLayout& filter, | |||
const TensorLayout& bias, const TensorLayout& z, | |||
@@ -37,15 +49,6 @@ public: | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
const char* get_algorithm_set_name() const override; | |||
class AlgoBase; | |||
class AlgoInt8NCHW4DotProdGemm; | |||
class AlgoInt8NCHW4DotProdImplicitGemmPrecomp; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
private: | |||
static AlgoPack sm_algo_pack; | |||
@@ -60,4 +60,12 @@ BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
for (auto& algo : brute_force_algos) { | |||
all_algos.push_back(&algo); | |||
} | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl) | |||
// vim: syntax=cpp.doxygen |
@@ -16,6 +16,8 @@ | |||
#include "src/common/utils.h" | |||
#include "src/cuda/batched_matrix_mul/opr_impl.h" | |||
#include "src/cuda/matrix_mul/cublasLt_wrapper.h" | |||
#include "src/common/metahelper.h" | |||
#if CUDA_VERSION >= 10010 | |||
#include <cublasLt.h> | |||
#endif | |||
@@ -28,6 +30,14 @@ protected: | |||
~AlgoBase() = default; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
CUDA_BRUTE_FORCE, | |||
CUDA_CUBLAS, | |||
CUDA_CUBLASLT, | |||
CUDA_INT8X8X32, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
BatchedMatrixMulForwardImpl* opr; | |||
@@ -90,6 +100,13 @@ public: | |||
void exec(const ExecArgs& args) const final; | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return m_name.c_str(); } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_BRUTE_FORCE) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_algorithm, ret); | |||
return ret; | |||
} | |||
}; | |||
class BatchedMatrixMulForwardImpl::AlgoCublas final | |||
: public BatchedMatrixMulForwardImpl::AlgoBase { | |||
@@ -100,6 +117,7 @@ public: | |||
void exec(const ExecArgs& args) const final; | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "CUBLAS"; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | |||
}; | |||
#if CUDA_VERSION >= 10010 | |||
class BatchedMatrixMulForwardImpl::AlgoCublasLt final : public AlgoBase { | |||
@@ -110,6 +128,7 @@ public: | |||
void exec(const ExecArgs& args) const final; | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "CUBLAS_LT"; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | |||
}; | |||
#endif | |||
class BatchedMatrixMulForwardImpl::AlgoInt8x8x32 final | |||
@@ -121,11 +140,13 @@ public: | |||
void exec(const ExecArgs& args) const final; | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "INT8x8x32"; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_INT8X8X32) | |||
}; | |||
class BatchedMatrixMulForwardImpl::AlgoPack { | |||
class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||
private: | |||
AlgoBase::Mapper m_all_algos_map; | |||
MatrixMulForwardImpl::AlgoPack mm_pack; | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator=(const AlgoPack&) = delete; | |||
public: | |||
AlgoPack(); | |||
@@ -137,6 +158,8 @@ public: | |||
AlgoInt8x8x32 int8x8x32; | |||
std::vector<AlgoBase*> all_algos; | |||
std::vector<AlgoBruteForce> brute_force_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -24,7 +24,7 @@ bool BatchedMatrixMulForwardImpl::AlgoBruteForce::is_available( | |||
const SizeArgs& args) const { | |||
MatrixMulForwardImpl mm{args.opr->handle()}; | |||
mm.param() = {args.opr->param().transposeA, args.opr->param().transposeB}; | |||
mm.execution_policy() = {m_algorithm}; | |||
mm.execution_policy() = {m_algorithm->info()}; | |||
auto mm_layout_a = args.layout_a.remove_axis(0); | |||
auto mm_layout_b = args.layout_b.remove_axis(0); | |||
@@ -39,7 +39,7 @@ size_t BatchedMatrixMulForwardImpl::AlgoBruteForce::get_workspace_in_bytes( | |||
auto mm_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | |||
mm_opr->param() = {args.opr->param().transposeA, | |||
args.opr->param().transposeB}; | |||
mm_opr->execution_policy() = {m_algorithm}; | |||
mm_opr->execution_policy() = {m_algorithm->info()}; | |||
return mm_opr->get_workspace_in_bytes(args.layout_a, args.layout_b, | |||
args.layout_c); | |||
@@ -50,7 +50,7 @@ void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec( | |||
auto&& mm_opr = args.opr->handle()->create_operator<MatrixMulForward>(); | |||
mm_opr->param() = {args.opr->param().transposeA, | |||
args.opr->param().transposeB}; | |||
mm_opr->execution_policy() = {m_algorithm}; | |||
mm_opr->execution_policy() = {m_algorithm->info()}; | |||
rep(n, N) { | |||
TensorND A_, B_, C_; | |||
auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) { | |||
@@ -32,6 +32,16 @@ public: | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout& A, const TensorLayout& B, | |||
const TensorLayout& C) override; | |||
const char* get_algorithm_set_name() const override { | |||
return "BATCHED_MATMUL"; | |||
} | |||
bool is_thread_safe() const override { return true; } | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C) override; | |||
@@ -40,12 +50,6 @@ public: | |||
const TensorLayout& C, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
const char* get_algorithm_set_name() const override { | |||
return "BATCHED_MATMUL"; | |||
} | |||
bool is_thread_safe() const override { return true; } | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
private: | |||
static AlgoPack sm_algo_pack; | |||
@@ -100,10 +100,16 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||
for (size_t i = all_algo_size; i < all_algos.size(); ++i) { | |||
non_cudnn_algos.push_back(all_algos[i]); | |||
} | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
ConvBiasForwardImpl::AlgoPack ConvBiasForwardImpl::sm_algo_pack; | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvBiasForwardImpl) | |||
ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( | |||
ConvBiasForwardImpl* o, const TensorLayout& src, | |||
const TensorLayout& filter, const TensorLayout& bias, | |||
@@ -172,43 +178,10 @@ std::string ConvBiasForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||
} | |||
void ConvBiasForwardImpl::AlgoPack::fill_cudnn_algos() { | |||
#define V1(v) #v | |||
#define V(v) V1(v) | |||
#define DEF_ALGO(NAME, REPROD) \ | |||
cudnn_conv_bias_activations.push_back( \ | |||
{REPROD, \ | |||
"CUDNN:ConvBiasActivation:" #NAME \ | |||
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL), \ | |||
NAME}); \ | |||
cudnn_convs.push_back( \ | |||
{REPROD, \ | |||
"CUDNN:Convolution:" #NAME \ | |||
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL), \ | |||
NAME}) | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true); | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true); | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true); | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true); | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true); | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true); | |||
#if CUDNN_MAJOR >= 5 | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true); | |||
#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true); | |||
#endif | |||
#endif | |||
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||
#pragma message "not latest cudnn" | |||
#endif | |||
#undef DEF_ALGO | |||
#undef V | |||
#undef V1 | |||
for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) { | |||
cudnn_conv_bias_activations.push_back(algo.first); | |||
cudnn_convs.push_back(algo.first); | |||
} | |||
} | |||
#if CUDA_VERSION >= 10000 | |||
@@ -6,19 +6,23 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/utils.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/cuda/conv_bias/conv_bias_int8.cuh" | |||
#include "src/cuda/conv_bias/helper.h" | |||
#include "src/cuda/conv_bias/opr_impl.h" | |||
#include "src/cuda/convolution_helper/parameter.cuh" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/cudnn_wrapper.h" | |||
#include <cuda.h> | |||
#include <memory> | |||
@@ -38,11 +42,39 @@ protected: | |||
~AlgoBase() = default; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
CUDA_CUDNN_CONVBIAS, | |||
CUDA_CHANWISE, | |||
CUDA_CHANWISE_SMALL, | |||
CUDA_CHANWISE_INT8X8X32, | |||
CUDA_CUDNN_CONV, | |||
CUDA_INPLACE_MATMUL, | |||
CUDA_MATMUL, | |||
CUDA_MATMUL_INT8X8X32, | |||
CUDA_1X1, | |||
CUDA_BATCHED_MATMUL, | |||
CUDA_GROUP_CONV_GENERAL, | |||
CUDA_WMMA_UINT4X4X32, | |||
CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8, | |||
CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8, | |||
CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8, | |||
CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8, | |||
CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8, | |||
CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8, | |||
CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8, | |||
CUDA_BFLOAT16, | |||
CUDA_IMPLICIT_GEMM_SASS_NCHW4_DOTPROD_INT8, | |||
CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8, | |||
CUDA_IMPLICIT_GEMM_SASS_NCHW32_IMMA_INT8, | |||
CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs : public conv_bias::BiasForwardSizeArgs { | |||
ConvBiasForwardImpl* opr; | |||
const PreprocessedFilter* preprocessed_filter; | |||
std::string to_string() const; | |||
SizeArgs(ConvBiasForwardImpl* opr, const TensorLayout& src, | |||
const TensorLayout& filter, const TensorLayout& bias, | |||
@@ -80,13 +112,17 @@ public: | |||
virtual void exec(const ExecArgs& args) const = 0; | |||
virtual size_t get_preprocess_workspace_in_bytes( | |||
const SizeArgs& args) const { | |||
MEGDNN_MARK_USED_VAR(args); | |||
return 0; | |||
} | |||
virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | |||
const SizeArgs& args) const { | |||
MEGDNN_MARK_USED_VAR(args); | |||
return {}; | |||
} | |||
virtual void exec_preprocess(const ExecArgs& args) const {} | |||
virtual void exec_preprocess(const ExecArgs& args) const { | |||
MEGDNN_MARK_USED_VAR(args); | |||
} | |||
bool is_available_wk(const SizeArgs& args, size_t limit) { | |||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
@@ -114,11 +150,14 @@ public: | |||
class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final : public AlgoBase { | |||
public: | |||
AlgoCUDNNConvBiasActivation(bool is_reproducible, const char* name, | |||
cudnnConvolutionFwdAlgo_t cudnn_enum) | |||
: m_is_reproducible(is_reproducible), | |||
m_name(ConvBiasForward::algo_name<DefaultParam>(name, {})), | |||
m_cudnn_enum(cudnn_enum) {} | |||
AlgoCUDNNConvBiasActivation(cudnnConvolutionFwdAlgo_t cudnn_enum) | |||
: m_cudnn_enum(cudnn_enum) { | |||
megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != | |||
CudnnAlgoPack::conv_fwd_algos().end()); | |||
m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum); | |||
m_name = ConvBiasForward::algo_name<DefaultParam>( | |||
"CUDNN:ConvBiasActivation:" + m_attr.name, {}); | |||
} | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
@@ -127,16 +166,24 @@ public: | |||
const char* name() const override { return m_name.c_str(); } | |||
bool is_reproducible() const override { return m_is_reproducible; } | |||
bool is_reproducible() const override { return m_attr.is_reproducible; } | |||
cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; } | |||
bool is_cudnn() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONVBIAS) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_cudnn_enum, ret); | |||
return ret; | |||
} | |||
private: | |||
bool m_is_reproducible; | |||
std::string m_name; | |||
cudnnConvolutionFwdAlgo_t m_cudnn_enum; | |||
CudnnAlgoPack::Attr m_attr; | |||
}; | |||
class ConvBiasForwardImpl::AlgoChanwise final : public AlgoBase { | |||
@@ -154,6 +201,8 @@ public: | |||
} | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||
private: | |||
mutable std::string m_name; | |||
}; | |||
@@ -172,6 +221,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | |||
private: | |||
mutable std::string m_name; | |||
@@ -190,6 +240,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_INT8X8X32) | |||
private: | |||
mutable std::string m_name; | |||
@@ -197,27 +248,39 @@ private: | |||
class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoBase { | |||
public: | |||
AlgoCUDNNConv(bool is_reproducible, const char* name, | |||
cudnnConvolutionFwdAlgo_t cudnn_enum) | |||
: m_is_reproducible(is_reproducible), | |||
m_name(ConvBiasForward::algo_name<DefaultParam>(name, {})), | |||
m_cudnn_enum(cudnn_enum) {} | |||
AlgoCUDNNConv(cudnnConvolutionFwdAlgo_t cudnn_enum) | |||
: m_cudnn_enum(cudnn_enum) { | |||
megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != | |||
CudnnAlgoPack::conv_fwd_algos().end()); | |||
m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum); | |||
m_name = ConvBiasForward::algo_name<DefaultParam>( | |||
"CUDNN:Convolution:" + m_attr.name, {}); | |||
} | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
bool is_reproducible() const override { return m_is_reproducible; } | |||
bool is_reproducible() const override { return m_attr.is_reproducible; } | |||
const char* name() const override { return m_name.c_str(); } | |||
cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||
bool is_cudnn() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONV) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_cudnn_enum, ret); | |||
return ret; | |||
} | |||
private: | |||
bool m_is_reproducible; | |||
std::string m_name; | |||
cudnnConvolutionFwdAlgo_t m_cudnn_enum; | |||
CudnnAlgoPack::Attr m_attr; | |||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
}; | |||
@@ -237,6 +300,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) | |||
private: | |||
mutable std::string m_name; | |||
@@ -261,6 +325,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||
private: | |||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
@@ -281,6 +346,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL_INT8X8X32) | |||
private: | |||
bool need_src_unroll(const SizeArgs& args) const; | |||
@@ -310,6 +376,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_1X1) | |||
private: | |||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
@@ -333,6 +400,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | |||
private: | |||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
@@ -354,6 +422,13 @@ public: | |||
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | |||
TensorLayout& dst_pg, TensorLayout& bias_pg); | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_impl, ret); | |||
return ret; | |||
} | |||
private: | |||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
@@ -370,10 +445,13 @@ public: | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { return "QUINT4x4x32_WMMA"; } | |||
bool is_reproducible() const override { return true; } | |||
private: | |||
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; | |||
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
const SizeArgs& args) const; | |||
bool use_kernel_fhxfw(const SizeArgs& args) const; | |||
size_t get_workspace_in_bytes_do_conv(const SizeArgs& args) const; | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32) | |||
}; | |||
#endif | |||
@@ -395,6 +473,7 @@ public: | |||
const convolution::ConvParam& param, float alpha, float beta, | |||
float gamma, float scale, cudaStream_t stream, | |||
param::ConvBias::NonlineMode nonlinear_mode); | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8) | |||
}; | |||
class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final | |||
@@ -415,8 +494,9 @@ public: | |||
warp_k == 32 && stage == 2) { | |||
return ""; | |||
} | |||
return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n, | |||
threadblock_k, warp_m, warp_n, warp_k, stage); | |||
return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, | |||
threadblock_n, threadblock_k, warp_m, warp_n, | |||
warp_k, stage); | |||
} | |||
}; | |||
AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) | |||
@@ -433,6 +513,13 @@ public: | |||
SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | |||
const SizeArgs& args) const override; | |||
void exec_preprocess(const ExecArgs& args) const override; | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_algo_param, ret); | |||
return ret; | |||
} | |||
private: | |||
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
@@ -457,9 +544,7 @@ public: | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { | |||
return m_name.c_str(); | |||
} | |||
const char* name() const override { return m_name.c_str(); } | |||
bool is_reproducible() const override { return true; } | |||
template <typename BiasVisitor> | |||
static void dispatch_nonlinear_mode( | |||
@@ -471,6 +556,14 @@ public: | |||
MMATileSize mma_tile_size); | |||
static std::string to_string(MMATileSize mma_tile_size); | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_mma_tile_size, ret); | |||
return ret; | |||
} | |||
private: | |||
MMATileSize m_mma_tile_size; | |||
std::string m_name; | |||
@@ -488,10 +581,16 @@ public: | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { | |||
return m_name.c_str(); | |||
} | |||
const char* name() const override { return m_name.c_str(); } | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_mma_tile_size, ret); | |||
return ret; | |||
} | |||
private: | |||
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
const SizeArgs& args) const; | |||
@@ -513,6 +612,13 @@ public: | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { return m_name.c_str(); } | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_mma_tile_size, ret); | |||
return ret; | |||
} | |||
private: | |||
MMATileSize m_mma_tile_size; | |||
@@ -533,6 +639,13 @@ public: | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { return m_name.c_str(); } | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_mma_tile_size, ret); | |||
return ret; | |||
} | |||
private: | |||
MMATileSize m_mma_tile_size; | |||
@@ -570,6 +683,13 @@ public: | |||
SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | |||
const SizeArgs& args) const override; | |||
void exec_preprocess(const ExecArgs& args) const override; | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_algo_param, ret); | |||
return ret; | |||
} | |||
private: | |||
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
@@ -592,6 +712,14 @@ public: | |||
bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_impl, ret); | |||
return ret; | |||
} | |||
private: | |||
SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr, | |||
TensorLayout& fsrc, TensorLayout& ffilter, | |||
@@ -603,17 +731,16 @@ private: | |||
}; | |||
class ConvBiasForwardImpl::AlgoPack { | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator=(const AlgoPack&) = delete; | |||
class ConvBiasForwardImpl::AlgoPack : NonCopyableObj { | |||
private: | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
std::vector<AlgoBase*> all_algos, | |||
//! non-cudnn algos, used for heuristic if cudnn is not supported | |||
non_cudnn_algos, | |||
bfloat16_algos; | |||
non_cudnn_algos, bfloat16_algos; | |||
std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations; | |||
std::vector<AlgoCUDNNConv> cudnn_convs; | |||
AlgoChanwise chanwise; | |||
@@ -646,6 +773,8 @@ public: | |||
AlgoBase* cudnn_conv_from_enum(cudnnConvolutionFwdAlgo_t algo); | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
private: | |||
#if CUDA_VERSION >= 10000 | |||
void fill_imma_algos(); | |||
@@ -47,7 +47,7 @@ ConvBiasForwardImpl::AlgoBFloat16::float_args( | |||
change_dtype(fdst); | |||
opr->param() = args.opr->param(); | |||
opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
opr->execution_policy() = {m_impl}; | |||
opr->execution_policy() = {m_impl->info()}; | |||
return SizeArgs(opr, fsrc, ffilter, fbias, fz, fdst); | |||
} | |||
@@ -110,7 +110,7 @@ void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { | |||
auto convbias_opr = args.handle->create_operator<ConvBias>(); | |||
convbias_opr->param() = args.opr->param(); | |||
convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
convbias_opr->execution_policy() = {m_impl}; | |||
convbias_opr->execution_policy() = {m_impl->info()}; | |||
convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor, | |||
fdst_tensor, nullptr, cvter.workspace()); | |||
} | |||
@@ -63,12 +63,12 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||
auto conv_args = args; | |||
auto cudnn_conv_bias_act_from_enum_wrapper = | |||
[this](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { | |||
[](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { | |||
return sm_algo_pack.cudnn_conv_bias_act_from_enum(algo); | |||
}; | |||
auto cudnn_conv_from_enum_wrapper = | |||
[this](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { | |||
[](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* { | |||
return sm_algo_pack.cudnn_conv_from_enum(algo); | |||
}; | |||
@@ -24,17 +24,6 @@ public: | |||
_megdnn_tensor_out dst, | |||
const PreprocessedFilter* preprocessed_filter, | |||
_megdnn_workspace workspace) override; | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& src, const TensorLayout& filter, | |||
const TensorLayout& bias, const TensorLayout& z, | |||
const TensorLayout& dst) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& filter, | |||
const TensorLayout& bias, | |||
const TensorLayout& z, | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, | |||
@@ -80,6 +69,20 @@ public: | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& src, const TensorLayout& filter, | |||
const TensorLayout& bias, const TensorLayout& z, | |||
const TensorLayout& dst) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& filter, | |||
const TensorLayout& bias, | |||
const TensorLayout& z, | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
private: | |||
static AlgoPack sm_algo_pack; | |||
}; | |||
@@ -52,8 +52,14 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(bfloat16_refhold.back().get()); | |||
bfloat16_algos.push_back(bfloat16_refhold.back().get()); | |||
} | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl) | |||
ConvolutionBackwardDataImpl::AlgoCUDNN* | |||
ConvolutionBackwardDataImpl::AlgoPack::cudnn_from_enum( | |||
cudnnConvolutionBwdDataAlgo_t algo) { | |||
@@ -11,8 +11,11 @@ | |||
#pragma once | |||
#include "src/cuda/convolution/helper.h" | |||
#include <unordered_map> | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/cuda/convolution/helper.h" | |||
#include "src/cuda/cudnn_wrapper.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
@@ -23,154 +26,146 @@ namespace cuda { | |||
* All the algo impls should try to support non-contiguous batch dim, for group | |||
* conv execution. | |||
*/ | |||
class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm { | |||
protected: | |||
~AlgoBase() = default; | |||
public: | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
HandleImpl *handle; | |||
CanonizedFilterMeta filter_meta; | |||
const TensorLayout *diff_layout, *grad_layout, *filter_layout; | |||
ConvolutionBackwardDataImpl *opr; | |||
std::string to_string() const; | |||
void init_desc(convolution::CUDNNBwdDataDescs &desc) const { | |||
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | |||
} | |||
SizeArgs(ConvolutionBackwardDataImpl* opr, | |||
const TensorLayout& filter, const TensorLayout& diff, | |||
const TensorLayout& grad); | |||
SizeArgs(ConvolutionBackwardDataImpl* opr, | |||
const TensorLayout& filter, | |||
const CanonizedFilterMeta& filter_meta, | |||
const TensorLayout& diff, const TensorLayout& grad); | |||
convolution::ForwardSizeArgs as_fwd_args() const { | |||
return {handle, grad_layout, filter_layout, filter_meta, | |||
diff_layout}; | |||
} | |||
}; | |||
struct ExecArgs: public SizeArgs { | |||
const TensorND *filter_tensor, *diff_tensor, *grad_tensor; | |||
Workspace workspace; | |||
ExecArgs(ConvolutionBackwardDataImpl *opr, | |||
_megdnn_tensor_in filter, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace); | |||
}; | |||
virtual bool is_available(const SizeArgs &args) const = 0; | |||
virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; | |||
virtual void exec(const ExecArgs &args) const = 0; | |||
bool is_available_wk(const SizeArgs &args, size_t limit) { | |||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
} | |||
class ConvolutionBackwardDataImpl::AlgoBase : public Algorithm { | |||
protected: | |||
~AlgoBase() = default; | |||
bool is_available_reproducible( | |||
const SizeArgs& args, bool reproducible = true, | |||
size_t limit = std::numeric_limits<size_t>::max()) { | |||
return (!reproducible || is_reproducible()) && | |||
is_available_wk(args, limit); | |||
} | |||
AlgoBase& check_workspace( | |||
const SizeArgs &args, const Workspace &workspace) { | |||
auto req = get_workspace_in_bytes(args); | |||
megdnn_assert(req <= workspace.size, | |||
"conv bwd data algo %s: " | |||
"required workspace %zu bytes, got %zu", | |||
name(), req, workspace.size); | |||
return *this; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
CUDA_CUDNN, | |||
CUDA_MATMUL, | |||
CUDA_CHANWISE, | |||
CUDA_CHANWISE_SMALL, | |||
CUDA_BFLOAT16, | |||
CUDA_GROUP_CONV_GENERAL, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
HandleImpl* handle; | |||
CanonizedFilterMeta filter_meta; | |||
const TensorLayout *diff_layout, *grad_layout, *filter_layout; | |||
ConvolutionBackwardDataImpl* opr; | |||
std::string to_string() const; | |||
void init_desc(convolution::CUDNNBwdDataDescs& desc) const { | |||
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | |||
} | |||
virtual bool is_cudnn() const { | |||
return false; | |||
SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter, | |||
const TensorLayout& diff, const TensorLayout& grad); | |||
SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter, | |||
const CanonizedFilterMeta& filter_meta, | |||
const TensorLayout& diff, const TensorLayout& grad); | |||
convolution::ForwardSizeArgs as_fwd_args() const { | |||
return {handle, grad_layout, filter_layout, filter_meta, | |||
diff_layout}; | |||
} | |||
}; | |||
struct ExecArgs : public SizeArgs { | |||
const TensorND *filter_tensor, *diff_tensor, *grad_tensor; | |||
Workspace workspace; | |||
ExecArgs(ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter, | |||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
_megdnn_workspace workspace); | |||
}; | |||
virtual bool is_available(const SizeArgs& args) const = 0; | |||
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
virtual void exec(const ExecArgs& args) const = 0; | |||
bool is_available_wk(const SizeArgs& args, size_t limit) { | |||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
} | |||
bool is_available_reproducible( | |||
const SizeArgs& args, bool reproducible = true, | |||
size_t limit = std::numeric_limits<size_t>::max()) { | |||
return (!reproducible || is_reproducible()) && | |||
is_available_wk(args, limit); | |||
} | |||
AlgoBase& check_workspace(const SizeArgs& args, | |||
const Workspace& workspace) { | |||
auto req = get_workspace_in_bytes(args); | |||
megdnn_assert(req <= workspace.size, | |||
"conv bwd data algo %s: " | |||
"required workspace %zu bytes, got %zu", | |||
name(), req, workspace.size); | |||
return *this; | |||
} | |||
virtual bool is_cudnn() const { return false; } | |||
}; | |||
class ConvolutionBackwardDataImpl::AlgoCUDNN final : public AlgoBase { | |||
bool m_is_reproducible; | |||
const char *m_name; | |||
cudnnConvolutionBwdDataAlgo_t m_cudnn_enum; | |||
CudnnAlgoPack::Attr m_attr; | |||
public: | |||
public: | |||
AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum) | |||
: m_cudnn_enum(cudnn_enum) { | |||
megdnn_assert(CudnnAlgoPack::conv_bwd_data_algos().find(cudnn_enum) != | |||
CudnnAlgoPack::conv_bwd_data_algos().end()); | |||
m_attr = CudnnAlgoPack::conv_bwd_data_algos().at(cudnn_enum); | |||
} | |||
AlgoCUDNN(bool is_reproducible, const char *name, | |||
cudnnConvolutionBwdDataAlgo_t cudnn_enum): | |||
m_is_reproducible(is_reproducible), | |||
m_name(name), | |||
m_cudnn_enum(cudnn_enum) | |||
{} | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
bool is_reproducible() const override { return m_attr.is_reproducible; } | |||
bool is_reproducible() const override { | |||
return m_is_reproducible; | |||
} | |||
const char* name() const override { return m_attr.name.c_str(); } | |||
const char* name() const override { | |||
return m_name; | |||
} | |||
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { | |||
return m_cudnn_enum; | |||
} | |||
bool is_cudnn() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) | |||
bool is_cudnn() const override { | |||
return true; | |||
} | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_cudnn_enum, ret); | |||
return ret; | |||
} | |||
}; | |||
//! im2col and matmul, with dilation | |||
class ConvolutionBackwardDataImpl::AlgoMatmul final: public AlgoBase { | |||
template<typename T> | |||
static void exec_internal(const ExecArgs &args); | |||
class ConvolutionBackwardDataImpl::AlgoMatmul final : public AlgoBase { | |||
template <typename T> | |||
static void exec_internal(const ExecArgs& args); | |||
public: | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
public: | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { | |||
return "MATMUL"; | |||
} | |||
bool is_reproducible() const override { | |||
return true; | |||
} | |||
const char* name() const override { return "MATMUL"; } | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||
}; | |||
class ConvolutionBackwardDataImpl::AlgoChanwise final: public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { | |||
return "CHANNEL_WISE"; | |||
} | |||
bool is_reproducible() const override { | |||
return true; | |||
} | |||
const char* name() const override { return "CHANNEL_WISE"; } | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||
}; | |||
class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final: public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final : public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { | |||
return "CHANNEL_WISE_SMALL"; | |||
} | |||
bool is_reproducible() const override { | |||
return true; | |||
} | |||
const char* name() const override { return "CHANNEL_WISE_SMALL"; } | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | |||
}; | |||
class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase { | |||
@@ -190,61 +185,72 @@ private: | |||
TensorLayout& fsrc, TensorLayout& ffilter, | |||
TensorLayout& fdst) const; | |||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_algorithm, ret); | |||
return ret; | |||
} | |||
}; | |||
//! implement group conv by another algo | |||
class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||
AlgoBase *m_impl; | |||
class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final | |||
: public AlgoBase { | |||
AlgoBase* m_impl; | |||
std::string m_name; | |||
public: | |||
AlgoGroupConvGeneral(AlgoBase *impl); | |||
public: | |||
AlgoGroupConvGeneral(AlgoBase* impl); | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { | |||
return m_name.c_str(); | |||
} | |||
const char* name() const override { return m_name.c_str(); } | |||
bool is_reproducible() const override { | |||
return m_impl->is_reproducible(); | |||
} | |||
bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||
static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg, | |||
TensorLayout& grad_pg); | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||
static void modify_size_args(SizeArgs &args, | |||
TensorLayout &diff_pg, TensorLayout &grad_pg); | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_impl, ret); | |||
return ret; | |||
} | |||
}; | |||
class ConvolutionBackwardDataImpl::AlgoPack { | |||
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | |||
// defined in cudnn.cpp | |||
void fill_cudnn_algos(); | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator = (const AlgoPack &) = delete; | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
public: | |||
AlgoPack(); | |||
std::vector<AlgoCUDNN> cudnn; | |||
AlgoMatmul matmul; | |||
AlgoChanwise chanwise; | |||
AlgoChanwiseSmall chanwise_small; | |||
std::vector<AlgoGroupConvGeneral> gconv; | |||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||
std::vector<AlgoCUDNN> cudnn; | |||
AlgoMatmul matmul; | |||
AlgoChanwise chanwise; | |||
AlgoChanwiseSmall chanwise_small; | |||
std::vector<AlgoGroupConvGeneral> gconv; | |||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||
std::vector<AlgoBase*> | |||
std::vector<AlgoBase*> | |||
//! all algorithms | |||
all_algos, | |||
//! non-cudnn algos, used for heuristic if cudnn is not supported | |||
non_cudnn_algos, | |||
bfloat16_algos; | |||
non_cudnn_algos, bfloat16_algos; | |||
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | |||
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -42,7 +42,7 @@ ConvolutionBackwardDataImpl::AlgoBFloat16::float_args( | |||
change_dtype(fgrad); | |||
opr->param() = args.opr->param(); | |||
opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
opr->execution_policy() = {m_algorithm}; | |||
opr->execution_policy() = {m_algorithm->info()}; | |||
return SizeArgs(opr, ffilter, fdiff, fgrad); | |||
} | |||
@@ -105,7 +105,7 @@ void ConvolutionBackwardDataImpl::AlgoBFloat16::exec( | |||
args.handle->create_operator<ConvolutionBackwardData>(); | |||
conv_back_data_opr->param() = args.opr->param(); | |||
conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
conv_back_data_opr->execution_policy() = {m_algorithm}; | |||
conv_back_data_opr->execution_policy() = {m_algorithm->info()}; | |||
conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor, | |||
cvter.workspace()); | |||
} | |||
@@ -98,35 +98,9 @@ void ConvolutionBackwardDataImpl::AlgoCUDNN::exec( | |||
} | |||
void ConvolutionBackwardDataImpl::AlgoPack::fill_cudnn_algos() { | |||
#define V1(v) #v | |||
#define V(v) V1(v) | |||
#define DEF_ALGO(NAME, REPROD) \ | |||
cudnn.push_back({ \ | |||
REPROD, #NAME \ | |||
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ | |||
"." V(CUDNN_PATCHLEVEL), \ | |||
NAME}) | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false); | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true); | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true); | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true); | |||
#if CUDNN_MAJOR >= 5 | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true); | |||
#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, true); | |||
#endif | |||
#endif | |||
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||
#pragma message "not latest cudnn" | |||
#endif | |||
#undef DEF_ALGO | |||
#undef V | |||
#undef V1 | |||
for (auto&& algo : CudnnAlgoPack::conv_bwd_data_algos()) { | |||
cudnn.push_back(algo.first); | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -49,8 +49,14 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(bfloat16_refhold.back().get()); | |||
bfloat16_algos.push_back(bfloat16_refhold.back().get()); | |||
} | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardFilterImpl) | |||
ConvolutionBackwardFilterImpl::AlgoCUDNN* | |||
ConvolutionBackwardFilterImpl::AlgoPack::cudnn_from_enum( | |||
cudnnConvolutionBwdFilterAlgo_t algo) { | |||
@@ -6,13 +6,16 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/cuda/convolution/helper.h" | |||
#include <unordered_map> | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/cuda/convolution/helper.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
@@ -23,141 +26,134 @@ namespace cuda { | |||
* All the algo impls should try to support non-contiguous batch dim, for group | |||
* conv execution. | |||
*/ | |||
class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm { | |||
protected: | |||
~AlgoBase() = default; | |||
public: | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
HandleImpl *handle; | |||
const TensorLayout *src_layout, *diff_layout, *grad_layout; | |||
CanonizedFilterMeta grad_filter_meta; | |||
ConvolutionBackwardFilterImpl *opr; | |||
std::string to_string() const; | |||
void init_desc(convolution::CUDNNBwdFilterDescs &desc) const { | |||
desc.set(*src_layout, *diff_layout, grad_filter_meta, | |||
opr->param()); | |||
} | |||
SizeArgs(ConvolutionBackwardFilterImpl *opr, | |||
const TensorLayout &src, const TensorLayout &diff, | |||
const TensorLayout &grad); | |||
SizeArgs(ConvolutionBackwardFilterImpl* opr, | |||
const TensorLayout& src, const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
const CanonizedFilterMeta& grad_meta); | |||
convolution::ForwardSizeArgs as_fwd_args() const { | |||
return {handle, src_layout, grad_layout, grad_filter_meta, | |||
diff_layout}; | |||
} | |||
}; | |||
struct ExecArgs: public SizeArgs { | |||
const TensorND *src_tensor, *diff_tensor, *grad_tensor; | |||
Workspace workspace; | |||
ExecArgs(ConvolutionBackwardFilterImpl *opr, | |||
_megdnn_tensor_in src, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace); | |||
}; | |||
virtual bool is_available(const SizeArgs &args) const = 0; | |||
virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; | |||
virtual void exec(const ExecArgs &args) const = 0; | |||
bool is_available_wk(const SizeArgs &args, size_t limit) { | |||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
} | |||
class ConvolutionBackwardFilterImpl::AlgoBase : public Algorithm { | |||
protected: | |||
~AlgoBase() = default; | |||
bool is_available_reproducible( | |||
const SizeArgs& args, bool reproducible = true, | |||
size_t limit = std::numeric_limits<size_t>::max()) { | |||
return (!reproducible || is_reproducible()) && | |||
is_available_wk(args, limit); | |||
public: | |||
enum class AlgoType : uint32_t { | |||
CUDA_CUDNN, | |||
CUDA_MATMUL, | |||
CUDA_CHANWISE, | |||
CUDA_BFLOAT16, | |||
CUDA_GROUP_CONV_GENERAL, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
HandleImpl* handle; | |||
const TensorLayout *src_layout, *diff_layout, *grad_layout; | |||
CanonizedFilterMeta grad_filter_meta; | |||
ConvolutionBackwardFilterImpl* opr; | |||
std::string to_string() const; | |||
void init_desc(convolution::CUDNNBwdFilterDescs& desc) const { | |||
desc.set(*src_layout, *diff_layout, grad_filter_meta, opr->param()); | |||
} | |||
AlgoBase& check_workspace( | |||
const SizeArgs &args, const Workspace &workspace) { | |||
auto req = get_workspace_in_bytes(args); | |||
megdnn_assert(req <= workspace.size, | |||
"conv bwd filter algo %s: " | |||
"required workspace %zu bytes, got %zu", | |||
name(), req, workspace.size); | |||
return *this; | |||
} | |||
virtual bool is_cudnn() const { | |||
return false; | |||
SizeArgs(ConvolutionBackwardFilterImpl* opr, const TensorLayout& src, | |||
const TensorLayout& diff, const TensorLayout& grad); | |||
SizeArgs(ConvolutionBackwardFilterImpl* opr, const TensorLayout& src, | |||
const TensorLayout& diff, const TensorLayout& grad, | |||
const CanonizedFilterMeta& grad_meta); | |||
convolution::ForwardSizeArgs as_fwd_args() const { | |||
return {handle, src_layout, grad_layout, grad_filter_meta, | |||
diff_layout}; | |||
} | |||
}; | |||
struct ExecArgs : public SizeArgs { | |||
const TensorND *src_tensor, *diff_tensor, *grad_tensor; | |||
Workspace workspace; | |||
ExecArgs(ConvolutionBackwardFilterImpl* opr, _megdnn_tensor_in src, | |||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
_megdnn_workspace workspace); | |||
}; | |||
virtual bool is_available(const SizeArgs& args) const = 0; | |||
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
virtual void exec(const ExecArgs& args) const = 0; | |||
bool is_available_wk(const SizeArgs& args, size_t limit) { | |||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
} | |||
bool is_available_reproducible( | |||
const SizeArgs& args, bool reproducible = true, | |||
size_t limit = std::numeric_limits<size_t>::max()) { | |||
return (!reproducible || is_reproducible()) && | |||
is_available_wk(args, limit); | |||
} | |||
AlgoBase& check_workspace(const SizeArgs& args, | |||
const Workspace& workspace) { | |||
auto req = get_workspace_in_bytes(args); | |||
megdnn_assert(req <= workspace.size, | |||
"conv bwd filter algo %s: " | |||
"required workspace %zu bytes, got %zu", | |||
name(), req, workspace.size); | |||
return *this; | |||
} | |||
virtual bool is_cudnn() const { return false; } | |||
}; | |||
class ConvolutionBackwardFilterImpl::AlgoCUDNN final : public AlgoBase { | |||
bool m_is_reproducible; | |||
const char *m_name; | |||
cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum; | |||
CudnnAlgoPack::Attr m_attr; | |||
public: | |||
public: | |||
AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum) | |||
: m_cudnn_enum(cudnn_enum) { | |||
megdnn_assert(CudnnAlgoPack::conv_bwd_flt_algos().find(cudnn_enum) != | |||
CudnnAlgoPack::conv_bwd_flt_algos().end()); | |||
m_attr = CudnnAlgoPack::conv_bwd_flt_algos().at(cudnn_enum); | |||
} | |||
AlgoCUDNN(bool is_reproducible, const char *name, | |||
cudnnConvolutionBwdFilterAlgo_t cudnn_enum): | |||
m_is_reproducible(is_reproducible), | |||
m_name(name), | |||
m_cudnn_enum(cudnn_enum) | |||
{} | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
bool is_reproducible() const override { return m_attr.is_reproducible; } | |||
bool is_reproducible() const override { | |||
return m_is_reproducible; | |||
} | |||
const char* name() const override { return m_attr.name.c_str(); } | |||
const char* name() const override { | |||
return m_name; | |||
} | |||
cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||
cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { | |||
return m_cudnn_enum; | |||
} | |||
bool is_cudnn() const override { return true; } | |||
bool is_cudnn() const override { | |||
return true; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_cudnn_enum, ret); | |||
return ret; | |||
} | |||
}; | |||
//! im2col and matmul, with dilation | |||
class ConvolutionBackwardFilterImpl::AlgoMatmul final: public AlgoBase { | |||
template<typename T> | |||
static void exec_internal(const ExecArgs &args); | |||
class ConvolutionBackwardFilterImpl::AlgoMatmul final : public AlgoBase { | |||
template <typename T> | |||
static void exec_internal(const ExecArgs& args); | |||
public: | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
public: | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { | |||
return "MATMUL"; | |||
} | |||
bool is_reproducible() const override { | |||
return true; | |||
} | |||
const char* name() const override { return "MATMUL"; } | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||
}; | |||
class ConvolutionBackwardFilterImpl::AlgoChanwise final: public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
class ConvolutionBackwardFilterImpl::AlgoChanwise final : public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { | |||
return "CHANNEL_WISE"; | |||
} | |||
bool is_reproducible() const override { | |||
return true; | |||
} | |||
const char* name() const override { return "CHANNEL_WISE"; } | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||
}; | |||
class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase { | |||
@@ -169,6 +165,13 @@ public: | |||
const char* name() const override { return m_name.c_str(); } | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_algorithm, ret); | |||
return ret; | |||
} | |||
private: | |||
std::string m_name; | |||
@@ -180,57 +183,62 @@ private: | |||
}; | |||
//! implement group conv by another algo | |||
class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||
AlgoBase *m_impl; | |||
class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final | |||
: public AlgoBase { | |||
AlgoBase* m_impl; | |||
std::string m_name; | |||
public: | |||
AlgoGroupConvGeneral(AlgoBase *impl); | |||
public: | |||
AlgoGroupConvGeneral(AlgoBase* impl); | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { return m_name.c_str(); } | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||
const char* name() const override { | |||
return m_name.c_str(); | |||
} | |||
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | |||
TensorLayout& diff_pg); | |||
bool is_reproducible() const override { | |||
return m_impl->is_reproducible(); | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||
static void modify_size_args(SizeArgs &args, | |||
TensorLayout &src_pg, TensorLayout &diff_pg); | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_impl, ret); | |||
return ret; | |||
} | |||
}; | |||
class ConvolutionBackwardFilterImpl::AlgoPack { | |||
class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj { | |||
// defined in cudnn.cpp | |||
void fill_cudnn_algos(); | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator = (const AlgoPack &) = delete; | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
public: | |||
AlgoPack(); | |||
std::vector<AlgoCUDNN> cudnn; | |||
AlgoMatmul matmul; | |||
AlgoChanwise chanwise; | |||
std::vector<AlgoGroupConvGeneral> gconv; | |||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||
std::vector<AlgoCUDNN> cudnn; | |||
AlgoMatmul matmul; | |||
AlgoChanwise chanwise; | |||
std::vector<AlgoGroupConvGeneral> gconv; | |||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||
std::vector<AlgoBase*> | |||
std::vector<AlgoBase*> | |||
//! all algorithms | |||
all_algos, | |||
//! non-cudnn algos, used for heuristic if cudnn is not supported | |||
non_cudnn_algos, | |||
bfloat16_algos; | |||
non_cudnn_algos, bfloat16_algos; | |||
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | |||
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -42,7 +42,7 @@ ConvolutionBackwardFilterImpl::AlgoBFloat16::float_args( | |||
change_dtype(fgrad); | |||
opr->param() = args.opr->param(); | |||
opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
opr->execution_policy() = {m_algorithm}; | |||
opr->execution_policy() = {m_algorithm->info()}; | |||
return SizeArgs(opr, fsrc, fdiff, fgrad); | |||
} | |||
@@ -107,7 +107,7 @@ void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec( | |||
conv_back_filter_opr->param() = args.opr->param(); | |||
conv_back_filter_opr->param().compute_mode = | |||
Param::ComputeMode::DEFAULT; | |||
conv_back_filter_opr->execution_policy() = {m_algorithm}; | |||
conv_back_filter_opr->execution_policy() = {m_algorithm->info()}; | |||
conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor, | |||
cvter.workspace()); | |||
} | |||
@@ -80,35 +80,9 @@ void ConvolutionBackwardFilterImpl::AlgoCUDNN::exec( | |||
} | |||
void ConvolutionBackwardFilterImpl::AlgoPack::fill_cudnn_algos() { | |||
#define V1(v) #v | |||
#define V(v) V1(v) | |||
#define DEF_ALGO(NAME, REPROD) \ | |||
cudnn.push_back({ \ | |||
REPROD, #NAME \ | |||
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ | |||
"." V(CUDNN_PATCHLEVEL), \ | |||
NAME}) | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false); | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true); | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true); | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false); | |||
#if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1) | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, true); | |||
#if CUDNN_MAJOR >= 6 | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true); | |||
#endif | |||
#endif | |||
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||
#pragma message "not latest cudnn" | |||
#endif | |||
#undef DEF_ALGO | |||
#undef V | |||
#undef V1 | |||
for(auto&& algo : CudnnAlgoPack::conv_bwd_flt_algos()) { | |||
cudnn.push_back(algo.first); | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -70,7 +70,7 @@ ConvolutionForwardImpl::conv_bias_extra_data(const TensorLayout& src, | |||
conv_param.dilate_w, | |||
0, | |||
conv_param.compute_mode}; | |||
ret.convbias_opr->execution_policy() = {this->execution_policy().algorithm}; | |||
ret.convbias_opr->execution_policy() = {this->execution_policy().algo}; | |||
return ret; | |||
} | |||
@@ -183,15 +183,6 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, | |||
CUDNNBwdDataDescs desc; | |||
args.init_desc(desc); | |||
//disable, segfault in megbrain, need further investigate. | |||
#if 0 | |||
bool is_heuristic_success= convolution:: | |||
PerformanceModelBackwardData::get_algo_backward_data_success( | |||
args, desc, workspace_limit_in_bytes, &algo); | |||
if (is_heuristic_success) { | |||
return sm_algo_pack.cudnn_from_enum(algo); | |||
} | |||
#endif | |||
#if CUDNN_MAJOR >= 7 | |||
int max_count = 0; | |||
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( | |||
@@ -24,14 +24,6 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||
const PreprocessedFilter* preprocessed_filter, | |||
_megdnn_workspace workspace) override; | |||
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||
const TensorLayout &filter, | |||
const TensorLayout &dst) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& filter, | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& filter, | |||
const TensorLayout& dst, | |||
@@ -60,99 +52,129 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||
TensorLayout bias_layout; | |||
TensorLayout z_layout; | |||
}; | |||
private: | |||
ConvBiasExtraData conv_bias_extra_data(const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&); | |||
}; | |||
class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||
public: | |||
using ConvolutionBackwardData::ConvolutionBackwardData; | |||
void exec(_megdnn_tensor_in filter, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) override; | |||
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& src, const TensorLayout& filter, | |||
const TensorLayout& dst) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& filter, | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
Algorithm* get_algorithm_heuristic( | |||
const TensorLayout& filter, | |||
const CanonizedFilterMeta& filter_meta, | |||
const TensorLayout& diff, const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, bool reproducible); | |||
size_t get_workspace_in_bytes(const TensorLayout& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
const char* get_algorithm_set_name() const override; | |||
class AlgoBase; | |||
class AlgoCUDNN; | |||
class AlgoMatmul; | |||
class AlgoChanwise; | |||
class AlgoChanwiseSmall; | |||
class AlgoGroupConvGeneral; | |||
class AlgoBFloat16; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { | |||
return sm_algo_pack; | |||
} | |||
private: | |||
static AlgoPack sm_algo_pack; | |||
ConvBiasExtraData conv_bias_extra_data(const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&); | |||
}; | |||
class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||
public: | |||
using ConvolutionBackwardFilter::ConvolutionBackwardFilter; | |||
void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) override; | |||
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& gradk, | |||
const CanonizedFilterMeta& grad_meta, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible); | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
const char* get_algorithm_set_name() const override; | |||
class AlgoBase; | |||
class AlgoCUDNN; | |||
class AlgoMatmul; | |||
class AlgoChanwise; | |||
class AlgoGroupConvGeneral; | |||
class AlgoBFloat16; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { | |||
return sm_algo_pack; | |||
} | |||
class ConvolutionBackwardDataImpl : public ConvolutionBackwardData { | |||
public: | |||
using ConvolutionBackwardData::ConvolutionBackwardData; | |||
void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
AlgorithmInfo get_algorithm_info_heuristic( | |||
const TensorLayout& filter, const CanonizedFilterMeta& filter_meta, | |||
const TensorLayout& diff, const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, bool reproducible) { | |||
return get_algorithm_heuristic(filter, filter_meta, diff, grad, | |||
workspace_limit_in_bytes, reproducible) | |||
->info(); | |||
} | |||
size_t get_workspace_in_bytes(const TensorLayout& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
const char* get_algorithm_set_name() const override; | |||
class AlgoBase; | |||
class AlgoCUDNN; | |||
class AlgoMatmul; | |||
class AlgoChanwise; | |||
class AlgoChanwiseSmall; | |||
class AlgoGroupConvGeneral; | |||
class AlgoBFloat16; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& filter, const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
private: | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
const CanonizedFilterMeta& filter_meta, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible); | |||
static AlgoPack sm_algo_pack; | |||
}; | |||
private: | |||
static AlgoPack sm_algo_pack; | |||
class ConvolutionBackwardFilterImpl : public ConvolutionBackwardFilter { | |||
public: | |||
using ConvolutionBackwardFilter::ConvolutionBackwardFilter; | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
const CanonizedFilterMeta& grad_meta, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) { | |||
return get_algorithm_heuristic(src, diff, grad, grad_meta, | |||
workspace_limit_in_bytes, reproducible) | |||
->info(); | |||
} | |||
const char* get_algorithm_set_name() const override; | |||
class AlgoBase; | |||
class AlgoCUDNN; | |||
class AlgoMatmul; | |||
class AlgoChanwise; | |||
class AlgoGroupConvGeneral; | |||
class AlgoBFloat16; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& src, const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
private: | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
const CanonizedFilterMeta& grad_meta, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible); | |||
static AlgoPack sm_algo_pack; | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -39,8 +39,14 @@ Convolution3DBackwardDataImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(&i); | |||
} | |||
megdnn_assert(all_algos_data == all_algos.data()); | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DBackwardDataImpl) | |||
Convolution3DBackwardDataImpl::AlgoCUDNN* | |||
Convolution3DBackwardDataImpl::AlgoPack::cudnn_from_enum( | |||
cudnnConvolutionBwdDataAlgo_t algo) { | |||
@@ -96,7 +102,7 @@ std::string Convolution3DBackwardDataImpl::AlgoBase::SizeArgs::to_string() const | |||
fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], fm.spatial[2], | |||
diff_layout->to_string().c_str(), | |||
grad_layout->to_string().c_str(), | |||
fm.padding[0], fm.padding[1], fm.padding[2], | |||
fm.padding[0], fm.padding[1], fm.padding[2], | |||
fm.stride[0], fm.stride[1], fm.stride[2], | |||
fm.dilation[0], fm.dilation[1] ,fm.dilation[2], | |||
!fm.should_flip, | |||
@@ -6,13 +6,16 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/cuda/convolution3d/helper.h" | |||
#include <unordered_map> | |||
#include "src/cuda/convolution3d/helper.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
@@ -23,170 +26,174 @@ namespace cuda { | |||
* All the algo impls should try to support non-contiguous batch dim, for group | |||
* conv execution. | |||
*/ | |||
class Convolution3DBackwardDataImpl::AlgoBase: public Algorithm { | |||
protected: | |||
~AlgoBase() = default; | |||
public: | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
HandleImpl *handle; | |||
CanonizedFilterMeta filter_meta; | |||
const TensorLayout *diff_layout, *grad_layout; | |||
Convolution3DBackwardDataImpl *opr; | |||
std::string to_string() const; | |||
void init_desc(convolution3d::CUDNNBwdDataDescs &desc) const { | |||
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | |||
} | |||
SizeArgs(Convolution3DBackwardDataImpl *opr, | |||
const TensorLayout &filter, const TensorLayout &diff, | |||
const TensorLayout &grad); | |||
SizeArgs(Convolution3DBackwardDataImpl *opr, | |||
const CanonizedFilterMeta &filter, const TensorLayout &diff, | |||
const TensorLayout &grad); | |||
convolution3d::ForwardSizeArgs as_fwd_args() const { | |||
return {handle, grad_layout, filter_meta, diff_layout, | |||
opr->param().data_type}; | |||
} | |||
}; | |||
struct ExecArgs: public SizeArgs { | |||
const TensorND *filter_tensor, *diff_tensor, *grad_tensor; | |||
Workspace workspace; | |||
ExecArgs(Convolution3DBackwardDataImpl *opr, | |||
_megdnn_tensor_in filter, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace); | |||
}; | |||
virtual bool is_available(const SizeArgs &args) const = 0; | |||
virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; | |||
virtual void exec(const ExecArgs &args) const = 0; | |||
bool is_available_wk(const SizeArgs &args, size_t limit) { | |||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
} | |||
bool is_available_reproducible( | |||
const SizeArgs& args, bool reproducible = true, | |||
size_t limit = std::numeric_limits<size_t>::max()) { | |||
return (!reproducible || is_reproducible()) && | |||
is_available_wk(args, limit); | |||
} | |||
AlgoBase& check_workspace( | |||
const SizeArgs &args, const Workspace &workspace) { | |||
auto req = get_workspace_in_bytes(args); | |||
megdnn_assert(req <= workspace.size, | |||
"conv bwd data algo %s: " | |||
"required workspace %zu bytes, got %zu", | |||
name(), req, workspace.size); | |||
return *this; | |||
class Convolution3DBackwardDataImpl::AlgoBase : public Algorithm { | |||
protected: | |||
~AlgoBase() = default; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
CUDA_GROUP_CONV_GENERAL, | |||
CUDA_CUDNN, | |||
CUDA_CHANWISE, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
HandleImpl* handle; | |||
CanonizedFilterMeta filter_meta; | |||
const TensorLayout *diff_layout, *grad_layout; | |||
Convolution3DBackwardDataImpl* opr; | |||
std::string to_string() const; | |||
void init_desc(convolution3d::CUDNNBwdDataDescs& desc) const { | |||
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | |||
} | |||
virtual bool is_cudnn() const { | |||
return false; | |||
SizeArgs(Convolution3DBackwardDataImpl* opr, const TensorLayout& filter, | |||
const TensorLayout& diff, const TensorLayout& grad); | |||
SizeArgs(Convolution3DBackwardDataImpl* opr, | |||
const CanonizedFilterMeta& filter, const TensorLayout& diff, | |||
const TensorLayout& grad); | |||
convolution3d::ForwardSizeArgs as_fwd_args() const { | |||
return {handle, grad_layout, filter_meta, diff_layout, | |||
opr->param().data_type}; | |||
} | |||
}; | |||
struct ExecArgs : public SizeArgs { | |||
const TensorND *filter_tensor, *diff_tensor, *grad_tensor; | |||
Workspace workspace; | |||
ExecArgs(Convolution3DBackwardDataImpl* opr, _megdnn_tensor_in filter, | |||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
_megdnn_workspace workspace); | |||
}; | |||
virtual bool is_available(const SizeArgs& args) const = 0; | |||
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
virtual void exec(const ExecArgs& args) const = 0; | |||
bool is_available_wk(const SizeArgs& args, size_t limit) { | |||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
} | |||
bool is_available_reproducible( | |||
const SizeArgs& args, bool reproducible = true, | |||
size_t limit = std::numeric_limits<size_t>::max()) { | |||
return (!reproducible || is_reproducible()) && | |||
is_available_wk(args, limit); | |||
} | |||
AlgoBase& check_workspace(const SizeArgs& args, | |||
const Workspace& workspace) { | |||
auto req = get_workspace_in_bytes(args); | |||
megdnn_assert(req <= workspace.size, | |||
"conv bwd data algo %s: " | |||
"required workspace %zu bytes, got %zu", | |||
name(), req, workspace.size); | |||
return *this; | |||
} | |||
virtual bool is_cudnn() const { return false; } | |||
}; | |||
class Convolution3DBackwardDataImpl::AlgoCUDNN final : public AlgoBase { | |||
bool m_is_reproducible; | |||
const char *m_name; | |||
cudnnConvolutionBwdDataAlgo_t m_cudnn_enum; | |||
CudnnAlgoPack::Attr m_attr; | |||
public: | |||
public: | |||
AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum) | |||
: m_cudnn_enum(cudnn_enum) { | |||
megdnn_assert(CudnnAlgoPack::conv3d_bwd_data_algos().find(cudnn_enum) != | |||
CudnnAlgoPack::conv3d_bwd_data_algos().end()); | |||
m_attr = CudnnAlgoPack::conv3d_bwd_data_algos().at(cudnn_enum); | |||
} | |||
AlgoCUDNN(bool is_reproducible, const char *name, | |||
cudnnConvolutionBwdDataAlgo_t cudnn_enum): | |||
m_is_reproducible(is_reproducible), | |||
m_name(name), | |||
m_cudnn_enum(cudnn_enum) | |||
{} | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
bool is_reproducible() const override { return m_attr.is_reproducible; } | |||
bool is_reproducible() const override { | |||
return m_is_reproducible; | |||
} | |||
const char* name() const override { return m_attr.name.c_str(); } | |||
const char* name() const override { | |||
return m_name; | |||
} | |||
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { | |||
return m_cudnn_enum; | |||
} | |||
bool is_cudnn() const override { return true; } | |||
bool is_cudnn() const override { | |||
return true; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_cudnn_enum, ret); | |||
return ret; | |||
} | |||
}; | |||
class Convolution3DBackwardDataImpl::AlgoChanwise final: public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
class Convolution3DBackwardDataImpl::AlgoChanwise final : public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { | |||
return "CHANNEL_WISE"; | |||
} | |||
bool is_reproducible() const override { | |||
return true; | |||
} | |||
const char* name() const override { return "CHANNEL_WISE"; } | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||
}; | |||
//! implement group conv by another algo | |||
class Convolution3DBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||
AlgoBase *m_impl; | |||
class Convolution3DBackwardDataImpl::AlgoGroupConvGeneral final | |||
: public AlgoBase { | |||
AlgoBase* m_impl; | |||
std::string m_name; | |||
public: | |||
AlgoGroupConvGeneral(AlgoBase *impl); | |||
public: | |||
AlgoGroupConvGeneral(AlgoBase* impl); | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { | |||
return m_name.c_str(); | |||
} | |||
const char* name() const override { return m_name.c_str(); } | |||
bool is_reproducible() const override { | |||
return m_impl->is_reproducible(); | |||
} | |||
bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||
static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg, | |||
TensorLayout& grad_pg); | |||
static void modify_size_args(SizeArgs &args, | |||
TensorLayout &diff_pg, TensorLayout &grad_pg); | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_impl, ret); | |||
return ret; | |||
} | |||
}; | |||
class Convolution3DBackwardDataImpl::AlgoPack { | |||
class Convolution3DBackwardDataImpl::AlgoPack : NonCopyableObj { | |||
// defined in cudnn.cpp | |||
void fill_cudnn_algos(); | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator = (const AlgoPack &) = delete; | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
public: | |||
AlgoPack(); | |||
std::vector<AlgoCUDNN> cudnn; | |||
AlgoChanwise chanwise; | |||
std::vector<AlgoGroupConvGeneral> gconv; | |||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
std::vector<AlgoCUDNN> cudnn; | |||
AlgoChanwise chanwise; | |||
std::vector<AlgoGroupConvGeneral> gconv; | |||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
std::vector<AlgoBase*> | |||
std::vector<AlgoBase*> | |||
//! all algorithms | |||
all_algos, | |||
//! non-cudnn algos, used for heuristic if cudnn is not supported | |||
non_cudnn_algos; | |||
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | |||
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -80,27 +80,9 @@ void Convolution3DBackwardDataImpl::AlgoCUDNN::exec( | |||
} | |||
void Convolution3DBackwardDataImpl::AlgoPack::fill_cudnn_algos() { | |||
#define V1(v) #v | |||
#define V(v) V1(v) | |||
#define DEF_ALGO(NAME, REPROD) \ | |||
cudnn.push_back({ \ | |||
REPROD, #NAME \ | |||
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ | |||
"." V(CUDNN_PATCHLEVEL), \ | |||
NAME}) | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false); | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true); | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true); | |||
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||
#pragma message "not latest cudnn" | |||
#endif | |||
#undef DEF_ALGO | |||
#undef V | |||
#undef V1 | |||
for (auto&& algo : CudnnAlgoPack::conv3d_bwd_data_algos()) { | |||
cudnn.push_back(algo.first); | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -17,7 +17,7 @@ using namespace cuda; | |||
Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() { | |||
non_cudnn_algos.push_back(&chanwise); | |||
non_cudnn_algos.push_back(&inplace_matmul); | |||
non_cudnn_algos.push_back(&inplace_matmul); | |||
all_algos.push_back(&chanwise); // prefer chanwise | |||
fill_cudnn_algos(); | |||
@@ -41,8 +41,14 @@ Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() { | |||
} | |||
megdnn_assert(all_algos_data == all_algos.data()); | |||
non_cudnn_algos.push_back(all_algos.rbegin()[0]); //group inplace_matmul | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DBackwardFilterImpl) | |||
Convolution3DBackwardFilterImpl::AlgoCUDNN* | |||
Convolution3DBackwardFilterImpl::AlgoPack::cudnn_from_enum( | |||
cudnnConvolutionBwdFilterAlgo_t algo) { | |||
@@ -99,9 +105,9 @@ Convolution3DBackwardFilterImpl::AlgoBase::SizeArgs::to_string() const { | |||
"pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s", | |||
src_layout->to_string().c_str(), | |||
diff_layout->to_string().c_str(), | |||
fm.group, fm.ocpg, fm.icpg, | |||
fm.group, fm.ocpg, fm.icpg, | |||
fm.spatial[0], fm.spatial[1], fm.spatial[2], | |||
fm.padding[0], fm.padding[1], fm.padding[2], | |||
fm.padding[0], fm.padding[1], fm.padding[2], | |||
fm.stride[0], fm.stride[1], fm.stride[2], | |||
fm.dilation[0], fm.dilation[1], fm.dilation[2], | |||
!fm.should_flip, | |||
@@ -6,198 +6,198 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/cuda/convolution3d/helper.h" | |||
#include <unordered_map> | |||
#include "src/cuda/convolution3d/helper.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
class Convolution3DBackwardFilterImpl::AlgoBase: public Algorithm { | |||
protected: | |||
~AlgoBase() = default; | |||
public: | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
HandleImpl *handle; | |||
const TensorLayout *src_layout, *diff_layout; | |||
CanonizedFilterMeta grad_filter_meta; | |||
Convolution3DBackwardFilterImpl *opr; | |||
std::string to_string() const; | |||
void init_desc(convolution3d::CUDNNBwdFilterDescs &desc) const { | |||
desc.set(*src_layout, *diff_layout, grad_filter_meta, | |||
opr->param()); | |||
} | |||
SizeArgs(Convolution3DBackwardFilterImpl *opr, | |||
const TensorLayout &src, const TensorLayout &diff, | |||
const TensorLayout &grad); | |||
SizeArgs(Convolution3DBackwardFilterImpl *opr, | |||
const TensorLayout &src, const TensorLayout &diff, | |||
const CanonizedFilterMeta &grad); | |||
convolution3d::ForwardSizeArgs as_fwd_args() const { | |||
return {handle, src_layout, grad_filter_meta, diff_layout, | |||
opr->param().data_type}; | |||
} | |||
}; | |||
struct ExecArgs: public SizeArgs { | |||
const TensorND *src_tensor, *diff_tensor, *grad_tensor; | |||
Workspace workspace; | |||
ExecArgs(Convolution3DBackwardFilterImpl *opr, | |||
_megdnn_tensor_in src, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace); | |||
}; | |||
virtual bool is_available(const SizeArgs &args) const = 0; | |||
virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; | |||
virtual void exec(const ExecArgs &args) const = 0; | |||
bool is_available_wk(const SizeArgs &args, size_t limit) { | |||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
} | |||
bool is_available_reproducible( | |||
const SizeArgs& args, bool reproducible = true, | |||
size_t limit = std::numeric_limits<size_t>::max()) { | |||
return (!reproducible || is_reproducible()) && | |||
is_available_wk(args, limit); | |||
} | |||
AlgoBase& check_workspace(const SizeArgs& args, | |||
const Workspace& workspace) { | |||
auto req = get_workspace_in_bytes(args); | |||
megdnn_assert(req <= workspace.size, | |||
"conv bwd filter algo %s: " | |||
"required workspace %zu bytes, got %zu", | |||
name(), req, workspace.size); | |||
return *this; | |||
class Convolution3DBackwardFilterImpl::AlgoBase : public Algorithm { | |||
protected: | |||
~AlgoBase() = default; | |||
public: | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
enum class AlgoType : uint32_t { | |||
CUDA_GROUP_CONV_GENERAL, | |||
CUDA_CUDNN, | |||
CUDA_INPLACE_MATMUL, | |||
CUDA_CHANWISE, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
struct SizeArgs { | |||
HandleImpl* handle; | |||
const TensorLayout *src_layout, *diff_layout; | |||
CanonizedFilterMeta grad_filter_meta; | |||
Convolution3DBackwardFilterImpl* opr; | |||
std::string to_string() const; | |||
void init_desc(convolution3d::CUDNNBwdFilterDescs& desc) const { | |||
desc.set(*src_layout, *diff_layout, grad_filter_meta, opr->param()); | |||
} | |||
SizeArgs(Convolution3DBackwardFilterImpl* opr, const TensorLayout& src, | |||
const TensorLayout& diff, const TensorLayout& grad); | |||
SizeArgs(Convolution3DBackwardFilterImpl* opr, const TensorLayout& src, | |||
const TensorLayout& diff, const CanonizedFilterMeta& grad); | |||
virtual bool is_cudnn() const { | |||
return false; | |||
convolution3d::ForwardSizeArgs as_fwd_args() const { | |||
return {handle, src_layout, grad_filter_meta, diff_layout, | |||
opr->param().data_type}; | |||
} | |||
}; | |||
struct ExecArgs : public SizeArgs { | |||
const TensorND *src_tensor, *diff_tensor, *grad_tensor; | |||
Workspace workspace; | |||
ExecArgs(Convolution3DBackwardFilterImpl* opr, _megdnn_tensor_in src, | |||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
_megdnn_workspace workspace); | |||
}; | |||
virtual bool is_available(const SizeArgs& args) const = 0; | |||
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
virtual void exec(const ExecArgs& args) const = 0; | |||
bool is_available_wk(const SizeArgs& args, size_t limit) { | |||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
} | |||
bool is_available_reproducible( | |||
const SizeArgs& args, bool reproducible = true, | |||
size_t limit = std::numeric_limits<size_t>::max()) { | |||
return (!reproducible || is_reproducible()) && | |||
is_available_wk(args, limit); | |||
} | |||
AlgoBase& check_workspace(const SizeArgs& args, | |||
const Workspace& workspace) { | |||
auto req = get_workspace_in_bytes(args); | |||
megdnn_assert(req <= workspace.size, | |||
"conv bwd filter algo %s: " | |||
"required workspace %zu bytes, got %zu", | |||
name(), req, workspace.size); | |||
return *this; | |||
} | |||
virtual bool is_cudnn() const { return false; } | |||
}; | |||
class Convolution3DBackwardFilterImpl::AlgoCUDNN final : public AlgoBase { | |||
bool m_is_reproducible; | |||
const char *m_name; | |||
cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum; | |||
CudnnAlgoPack::Attr m_attr; | |||
public: | |||
public: | |||
AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum) | |||
: m_cudnn_enum(cudnn_enum) { | |||
megdnn_assert(CudnnAlgoPack::conv3d_bwd_flt_algos().find(cudnn_enum) != | |||
CudnnAlgoPack::conv3d_bwd_flt_algos().end()); | |||
m_attr = CudnnAlgoPack::conv3d_bwd_flt_algos().at(cudnn_enum); | |||
} | |||
AlgoCUDNN(bool is_reproducible, const char *name, | |||
cudnnConvolutionBwdFilterAlgo_t cudnn_enum): | |||
m_is_reproducible(is_reproducible), | |||
m_name(name), | |||
m_cudnn_enum(cudnn_enum) | |||
{} | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
bool is_reproducible() const override { return m_attr.is_reproducible; } | |||
bool is_reproducible() const override { | |||
return m_is_reproducible; | |||
} | |||
const char* name() const override { return m_attr.name.c_str(); } | |||
const char* name() const override { | |||
return m_name; | |||
} | |||
cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||
cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { | |||
return m_cudnn_enum; | |||
} | |||
bool is_cudnn() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) | |||
bool is_cudnn() const override { | |||
return true; | |||
} | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_cudnn_enum, ret); | |||
return ret; | |||
} | |||
}; | |||
class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final | |||
: public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final: public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
const char* name() const override { | |||
return "INPLACE_MATMUL"; | |||
} | |||
bool is_reproducible() const override { | |||
return false; | |||
} | |||
const char* name() const override { return "INPLACE_MATMUL"; } | |||
bool is_reproducible() const override { return false; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) | |||
}; | |||
class Convolution3DBackwardFilterImpl::AlgoChanwise final: public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
class Convolution3DBackwardFilterImpl::AlgoChanwise final : public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { | |||
return "CHANNEL_WISE"; | |||
} | |||
bool is_reproducible() const override { | |||
return true; | |||
} | |||
const char* name() const override { return "CHANNEL_WISE"; } | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||
}; | |||
//! implement group conv by another algo | |||
class Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||
AlgoBase *m_impl; | |||
class Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral final | |||
: public AlgoBase { | |||
AlgoBase* m_impl; | |||
std::string m_name; | |||
public: | |||
AlgoGroupConvGeneral(AlgoBase *impl); | |||
public: | |||
AlgoGroupConvGeneral(AlgoBase* impl); | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { | |||
return m_name.c_str(); | |||
} | |||
const char* name() const override { return m_name.c_str(); } | |||
bool is_reproducible() const override { | |||
return m_impl->is_reproducible(); | |||
} | |||
bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||
static void modify_size_args(SizeArgs &args, | |||
TensorLayout &src_pg, TensorLayout &diff_pg); | |||
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | |||
TensorLayout& diff_pg); | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_impl, ret); | |||
return ret; | |||
} | |||
}; | |||
class Convolution3DBackwardFilterImpl::AlgoPack { | |||
class Convolution3DBackwardFilterImpl::AlgoPack : NonCopyableObj { | |||
// defined in cudnn.cpp | |||
void fill_cudnn_algos(); | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator = (const AlgoPack &) = delete; | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
public: | |||
AlgoPack(); | |||
std::vector<AlgoCUDNN> cudnn; | |||
AlgoInplaceMatmul inplace_matmul; | |||
AlgoChanwise chanwise; | |||
std::vector<AlgoGroupConvGeneral> gconv; | |||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
std::vector<AlgoCUDNN> cudnn; | |||
AlgoInplaceMatmul inplace_matmul; | |||
AlgoChanwise chanwise; | |||
std::vector<AlgoGroupConvGeneral> gconv; | |||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
std::vector<AlgoBase*> | |||
std::vector<AlgoBase*> | |||
//! all algorithms | |||
all_algos, | |||
//! non-cudnn algos, used for heuristic if cudnn is not supported | |||
non_cudnn_algos; | |||
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | |||
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -66,29 +66,9 @@ void Convolution3DBackwardFilterImpl::AlgoCUDNN::exec( | |||
} | |||
void Convolution3DBackwardFilterImpl::AlgoPack::fill_cudnn_algos() { | |||
#define V1(v) #v | |||
#define V(v) V1(v) | |||
#define DEF_ALGO(NAME, REPROD) \ | |||
cudnn.push_back({REPROD, \ | |||
#NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V( \ | |||
CUDNN_PATCHLEVEL), \ | |||
NAME}) | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false); | |||
#pragma message \ | |||
"fp16 dilated conv with odd size filter, only algo_1 works, need focus on doc" | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true); | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false); | |||
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||
#pragma message "not latest cudnn" | |||
#endif | |||
#undef DEF_ALGO | |||
#undef V | |||
#undef V1 | |||
for (auto&& algo : CudnnAlgoPack::conv3d_bwd_flt_algos()) { | |||
cudnn.push_back(algo.first); | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -21,13 +21,13 @@ Convolution3DForwardImpl::AlgoPack::AlgoPack() { | |||
non_cudnn_algos.push_back(&a1x1x1); | |||
all_algos.push_back(&chanwise); | |||
fill_cudnn_algos(); | |||
for (auto &&i: cudnn) { | |||
all_algos.push_back(&i); | |||
all_algos.push_back(&i); | |||
} | |||
all_algos.push_back(&inplace_matmul); | |||
all_algos.push_back(&a1x1x1); | |||
all_algos.push_back(&a1x1x1); | |||
all_algos.reserve(all_algos.size() * 2); | |||
// add gconv algos by AlgoGroupConvGeneral | |||
@@ -42,10 +42,16 @@ Convolution3DForwardImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(&i); | |||
} | |||
megdnn_assert(all_algos_data == all_algos.data()); | |||
non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group inplace_matmul | |||
non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group inplace_matmul | |||
non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1x1 | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DForwardImpl) | |||
Convolution3DForwardImpl::AlgoCUDNN* | |||
Convolution3DForwardImpl::AlgoPack::cudnn_from_enum( | |||
cudnnConvolutionFwdAlgo_t algo) { | |||
@@ -99,7 +105,7 @@ std::string Convolution3DForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||
"src=%s, filter=%u{%u,%u,%u,%u,%u}, dst=%s, " | |||
"pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s", | |||
src_layout->to_string().c_str(), | |||
fm.group, fm.ocpg, fm.icpg, | |||
fm.group, fm.ocpg, fm.icpg, | |||
fm.spatial[0], fm.spatial[1], fm.spatial[2], | |||
dst_layout->to_string().c_str(), | |||
fm.padding[0], fm.padding[1], fm.padding[2], | |||
@@ -6,17 +6,20 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
#include "src/cuda/convolution3d/helper.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/convolution3d/opr_impl.h" | |||
#include "src/common/utils.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
#include <unordered_map> | |||
@@ -29,195 +32,189 @@ namespace cuda { | |||
* All the algo impls should try to support non-contiguous batch dim, for group | |||
* conv execution. | |||
*/ | |||
class Convolution3DForwardImpl::AlgoBase: public Algorithm { | |||
protected: | |||
~AlgoBase() = default; | |||
public: | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs: public convolution3d::ForwardSizeArgs { | |||
Convolution3DForwardImpl *opr; | |||
std::string to_string() const; | |||
void init_desc(convolution3d::CUDNNForwardDescs &desc) const { | |||
desc.set(*src_layout, filter_meta, *dst_layout, opr->param()); | |||
} | |||
SizeArgs(Convolution3DForwardImpl *opr, | |||
const TensorLayout &src, | |||
const TensorLayout &filter, | |||
const TensorLayout &dst); | |||
SizeArgs(Convolution3DForwardImpl *opr, | |||
const TensorLayout &src, | |||
const CanonizedFilterMeta &filter, | |||
const TensorLayout &dst); | |||
}; | |||
struct ExecArgs : public SizeArgs { | |||
const TensorND *src_tensor, *filter_tensor, *dst_tensor; | |||
Workspace workspace; | |||
ExecArgs(Convolution3DForwardImpl *opr, | |||
_megdnn_tensor_in src, | |||
_megdnn_tensor_in filter, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace); | |||
}; | |||
virtual bool is_available(const SizeArgs &args) const = 0; | |||
virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0; | |||
virtual void exec(const ExecArgs &args) const = 0; | |||
bool is_available_wk(const SizeArgs &args, size_t limit) { | |||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
} | |||
bool is_available_reproducible( | |||
const SizeArgs& args, bool reproducible = true, | |||
size_t limit = std::numeric_limits<size_t>::max()) { | |||
return (!reproducible || is_reproducible()) && | |||
is_available_wk(args, limit); | |||
} | |||
AlgoBase& check_workspace(const SizeArgs& args, | |||
const Workspace& workspace) { | |||
auto req = get_workspace_in_bytes(args); | |||
megdnn_assert(req <= workspace.size, | |||
"conv3d fwd algo %s: required workspace %zu bytes, got %zu", | |||
name(), req, workspace.size); | |||
return *this; | |||
} | |||
virtual bool is_cudnn() const { | |||
return false; | |||
} | |||
class Convolution3DForwardImpl::AlgoBase : public Algorithm { | |||
protected: | |||
~AlgoBase() = default; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
CUDA_1X1X1, | |||
CUDA_GROUP_CONV_GENERAL, | |||
CUDA_CUDNN, | |||
CUDA_INPLACE_MATMUL, | |||
CUDA_CHANWISE, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs : public convolution3d::ForwardSizeArgs { | |||
Convolution3DForwardImpl* opr; | |||
std::string to_string() const; | |||
void init_desc(convolution3d::CUDNNForwardDescs& desc) const { | |||
desc.set(*src_layout, filter_meta, *dst_layout, opr->param()); | |||
} | |||
SizeArgs(Convolution3DForwardImpl* opr, const TensorLayout& src, | |||
const TensorLayout& filter, const TensorLayout& dst); | |||
SizeArgs(Convolution3DForwardImpl* opr, const TensorLayout& src, | |||
const CanonizedFilterMeta& filter, const TensorLayout& dst); | |||
}; | |||
struct ExecArgs : public SizeArgs { | |||
const TensorND *src_tensor, *filter_tensor, *dst_tensor; | |||
Workspace workspace; | |||
ExecArgs(Convolution3DForwardImpl* opr, _megdnn_tensor_in src, | |||
_megdnn_tensor_in filter, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace); | |||
}; | |||
virtual bool is_available(const SizeArgs& args) const = 0; | |||
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
virtual void exec(const ExecArgs& args) const = 0; | |||
bool is_available_wk(const SizeArgs& args, size_t limit) { | |||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
} | |||
bool is_available_reproducible( | |||
const SizeArgs& args, bool reproducible = true, | |||
size_t limit = std::numeric_limits<size_t>::max()) { | |||
return (!reproducible || is_reproducible()) && | |||
is_available_wk(args, limit); | |||
} | |||
AlgoBase& check_workspace(const SizeArgs& args, | |||
const Workspace& workspace) { | |||
auto req = get_workspace_in_bytes(args); | |||
megdnn_assert( | |||
req <= workspace.size, | |||
"conv3d fwd algo %s: required workspace %zu bytes, got %zu", | |||
name(), req, workspace.size); | |||
return *this; | |||
} | |||
virtual bool is_cudnn() const { return false; } | |||
}; | |||
class Convolution3DForwardImpl::Algo1x1x1 final: public AlgoBase { | |||
static void extract_matmul_layouts(const SizeArgs &args, | |||
TensorLayout &A, TensorLayout &B, TensorLayout &C); | |||
public: | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
const char* name() const override { | |||
return "1x1x1"; | |||
} | |||
bool is_reproducible() const override { | |||
return true; | |||
} | |||
class Convolution3DForwardImpl::Algo1x1x1 final : public AlgoBase { | |||
static void extract_matmul_layouts(const SizeArgs& args, TensorLayout& A, | |||
TensorLayout& B, TensorLayout& C); | |||
public: | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { return "1x1x1"; } | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_1X1X1) | |||
}; | |||
//! implement group conv by another algo | |||
class Convolution3DForwardImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||
AlgoBase *m_impl; | |||
class Convolution3DForwardImpl::AlgoGroupConvGeneral final : public AlgoBase { | |||
AlgoBase* m_impl; | |||
std::string m_name; | |||
public: | |||
AlgoGroupConvGeneral(AlgoBase *impl); | |||
public: | |||
AlgoGroupConvGeneral(AlgoBase* impl); | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { | |||
return m_name.c_str(); | |||
} | |||
const char* name() const override { return m_name.c_str(); } | |||
bool is_reproducible() const override { | |||
return m_impl->is_reproducible(); | |||
} | |||
bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||
static void modify_size_args(SizeArgs &args, | |||
TensorLayout &src_pg, TensorLayout &dst_pg); | |||
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | |||
TensorLayout& dst_pg); | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_impl, ret); | |||
return ret; | |||
} | |||
}; | |||
class Convolution3DForwardImpl::AlgoCUDNN final : public AlgoBase { | |||
bool m_is_reproducible; | |||
const char *m_name; | |||
cudnnConvolutionFwdAlgo_t m_cudnn_enum; | |||
CudnnAlgoPack::Attr m_attr; | |||
public: | |||
public: | |||
AlgoCUDNN(cudnnConvolutionFwdAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) { | |||
megdnn_assert(CudnnAlgoPack::conv3d_fwd_algos().find(cudnn_enum) != | |||
CudnnAlgoPack::conv3d_fwd_algos().end()); | |||
m_attr = CudnnAlgoPack::conv3d_fwd_algos().at(cudnn_enum); | |||
} | |||
AlgoCUDNN(bool is_reproducible, const char *name, | |||
cudnnConvolutionFwdAlgo_t cudnn_enum): | |||
m_is_reproducible(is_reproducible), | |||
m_name(name), | |||
m_cudnn_enum(cudnn_enum) | |||
{} | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
bool is_reproducible() const override { return m_attr.is_reproducible; } | |||
bool is_reproducible() const override { | |||
return m_is_reproducible; | |||
} | |||
const char* name() const override { return m_attr.name.c_str(); } | |||
const char* name() const override { | |||
return m_name; | |||
} | |||
cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; } | |||
cudnnConvolutionFwdAlgo_t cudnn_enum() const { | |||
return m_cudnn_enum; | |||
} | |||
bool is_cudnn() const override { return true; } | |||
bool is_cudnn() const override { | |||
return true; | |||
} | |||
}; | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) | |||
class Convolution3DForwardImpl::AlgoInplaceMatmul final: public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_cudnn_enum, ret); | |||
return ret; | |||
} | |||
const char* name() const override { | |||
return "INPLACE_MATMUL"; | |||
} | |||
bool is_reproducible() const override { | |||
return true; | |||
} | |||
}; | |||
class Convolution3DForwardImpl::AlgoInplaceMatmul final : public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
class Convolution3DForwardImpl::AlgoChanwise final: public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs &args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs &args) const override; | |||
void exec(const ExecArgs &args) const override; | |||
const char* name() const override { return "INPLACE_MATMUL"; } | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) | |||
}; | |||
const char* name() const override { | |||
return "CHANNEL_WISE"; | |||
} | |||
bool is_reproducible() const override { | |||
return true; | |||
} | |||
class Convolution3DForwardImpl::AlgoChanwise final : public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { return "CHANNEL_WISE"; } | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | |||
}; | |||
class Convolution3DForwardImpl::AlgoPack { | |||
class Convolution3DForwardImpl::AlgoPack : NonCopyableObj { | |||
// defined in cudnn.cpp | |||
void fill_cudnn_algos(); | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator = (const AlgoPack &) = delete; | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
public: | |||
AlgoPack(); | |||
std::vector<AlgoCUDNN> cudnn; | |||
Algo1x1x1 a1x1x1; | |||
AlgoInplaceMatmul inplace_matmul; | |||
AlgoChanwise chanwise; | |||
std::vector<AlgoGroupConvGeneral> gconv; | |||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
std::vector<AlgoCUDNN> cudnn; | |||
Algo1x1x1 a1x1x1; | |||
AlgoInplaceMatmul inplace_matmul; | |||
AlgoChanwise chanwise; | |||
std::vector<AlgoGroupConvGeneral> gconv; | |||
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
std::vector<AlgoBase*> | |||
std::vector<AlgoBase*> | |||
//! all algorithms | |||
all_algos, | |||
//! non-cudnn algos, used for heuristic if cudnn is not supported | |||
non_cudnn_algos; | |||
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionFwdAlgo_t algo); | |||
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionFwdAlgo_t algo); | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -78,30 +78,10 @@ void Convolution3DForwardImpl::AlgoCUDNN::exec( | |||
cudnnGetErrorString(status), args.to_string().c_str()); | |||
} | |||
void Convolution3DForwardImpl::AlgoPack::fill_cudnn_algos() { | |||
#define V1(v) #v | |||
#define V(v) V1(v) | |||
#define DEF_ALGO(NAME, REPROD) \ | |||
cudnn.push_back({ \ | |||
REPROD, #NAME \ | |||
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \ | |||
"." V(CUDNN_PATCHLEVEL), \ | |||
NAME}) | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true); | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true); | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true); | |||
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||
#pragma message "not latest cudnn" | |||
#endif | |||
#undef DEF_ALGO | |||
#undef V | |||
#undef V1 | |||
for (auto&& algo : CudnnAlgoPack::conv3d_fwd_algos()) { | |||
cudnn.push_back(algo.first); | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
@@ -15,126 +16,155 @@ | |||
namespace megdnn { | |||
namespace cuda { | |||
class Convolution3DForwardImpl: public Convolution3DForward { | |||
public: | |||
using Convolution3DForward::Convolution3DForward; | |||
void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in filter, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||
const TensorLayout &filter, | |||
const TensorLayout &dst) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& filter, | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const CanonizedFilterMeta& filter, | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible); | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& filter, | |||
const TensorLayout& dst) override; | |||
const char* get_algorithm_set_name() const override; | |||
class AlgoBase; | |||
class AlgoCUDNN; | |||
class Algo1x1x1; | |||
class AlgoInplaceMatmul; | |||
class AlgoChanwise; | |||
class AlgoGroupConvGeneral; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { | |||
return sm_algo_pack; | |||
} | |||
private: | |||
static AlgoPack sm_algo_pack; | |||
class Convolution3DForwardImpl : public Convolution3DForward { | |||
public: | |||
using Convolution3DForward::Convolution3DForward; | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src, | |||
const CanonizedFilterMeta& filter, | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) { | |||
return get_algorithm_heuristic(src, filter, dst, | |||
workspace_limit_in_bytes, reproducible) | |||
->info(); | |||
} | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& filter, | |||
const TensorLayout& dst) override; | |||
const char* get_algorithm_set_name() const override; | |||
class AlgoBase; | |||
class AlgoCUDNN; | |||
class Algo1x1x1; | |||
class AlgoInplaceMatmul; | |||
class AlgoChanwise; | |||
class AlgoGroupConvGeneral; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& src, const TensorLayout& filter, | |||
const TensorLayout& dst) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& filter, | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
private: | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const CanonizedFilterMeta& filter, | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible); | |||
static AlgoPack sm_algo_pack; | |||
}; | |||
class Convolution3DBackwardDataImpl: public Convolution3DBackwardData { | |||
public: | |||
using Convolution3DBackwardData::Convolution3DBackwardData; | |||
void exec(_megdnn_tensor_in filter, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) override; | |||
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible); | |||
size_t get_workspace_in_bytes(const TensorLayout& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
const char* get_algorithm_set_name() const override; | |||
class AlgoBase; | |||
class AlgoCUDNN; | |||
class AlgoInplaceMatmul; | |||
class AlgoChanwise; | |||
class AlgoGroupConvGeneral; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { | |||
return sm_algo_pack; | |||
} | |||
private: | |||
static AlgoPack sm_algo_pack; | |||
class Convolution3DBackwardDataImpl : public Convolution3DBackwardData { | |||
public: | |||
using Convolution3DBackwardData::Convolution3DBackwardData; | |||
void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
AlgorithmInfo get_algorithm_info_heuristic( | |||
const CanonizedFilterMeta& filter, const TensorLayout& diff, | |||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
bool reproducible) { | |||
return get_algorithm_heuristic(filter, diff, grad, | |||
workspace_limit_in_bytes, reproducible) | |||
->info(); | |||
} | |||
size_t get_workspace_in_bytes(const TensorLayout& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
const char* get_algorithm_set_name() const override; | |||
class AlgoBase; | |||
class AlgoCUDNN; | |||
class AlgoInplaceMatmul; | |||
class AlgoChanwise; | |||
class AlgoGroupConvGeneral; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& filter, const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
private: | |||
Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible); | |||
static AlgoPack sm_algo_pack; | |||
}; | |||
class Convolution3DBackwardFilterImpl: public Convolution3DBackwardFilter { | |||
public: | |||
using Convolution3DBackwardFilter::Convolution3DBackwardFilter; | |||
void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) override; | |||
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const CanonizedFilterMeta& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible); | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
const char* get_algorithm_set_name() const override; | |||
class AlgoBase; | |||
class AlgoCUDNN; | |||
class AlgoInplaceMatmul; | |||
class AlgoChanwise; | |||
class AlgoGroupConvGeneral; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { | |||
return sm_algo_pack; | |||
} | |||
private: | |||
static AlgoPack sm_algo_pack; | |||
class Convolution3DBackwardFilterImpl : public Convolution3DBackwardFilter { | |||
public: | |||
using Convolution3DBackwardFilter::Convolution3DBackwardFilter; | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const CanonizedFilterMeta& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) { | |||
return get_algorithm_heuristic(src, diff, grad, | |||
workspace_limit_in_bytes, reproducible) | |||
->info(); | |||
} | |||
const char* get_algorithm_set_name() const override; | |||
class AlgoBase; | |||
class AlgoCUDNN; | |||
class AlgoInplaceMatmul; | |||
class AlgoChanwise; | |||
class AlgoGroupConvGeneral; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& src, const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
private: | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const CanonizedFilterMeta& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible); | |||
static AlgoPack sm_algo_pack; | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -433,6 +433,137 @@ void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) { | |||
desc, 3, padA, filterStrideA, dilationA, mode, CUDNN_DATA_FLOAT)); | |||
} | |||
////////////////////////// CudnnAlgoPack ////////////////////////// | |||
#define V1(v) #v | |||
#define V(v) V1(v) | |||
#define DEF_NAME(NAME) \ | |||
#NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) | |||
#define DEF_ALGO(NAME, PROD) \ | |||
{ \ | |||
NAME, { DEF_NAME(NAME), PROD } \ | |||
} | |||
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | |||
#pragma message "not latest cudnn" | |||
#endif | |||
const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, CudnnAlgoPack::Attr> | |||
CudnnAlgoPack::conv_bwd_data_algos() { | |||
static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, | |||
CudnnAlgoPack::Attr> | |||
algos = { | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false), | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true), | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true), | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true), | |||
#if CUDNN_MAJOR >= 5 | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true), | |||
#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, | |||
true), | |||
#endif | |||
#endif | |||
}; | |||
return algos; | |||
} | |||
const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, CudnnAlgoPack::Attr> | |||
CudnnAlgoPack::conv_bwd_flt_algos() { | |||
static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, | |||
CudnnAlgoPack::Attr> | |||
algos = { | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false), | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true), | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true), | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false), | |||
#if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1) | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, | |||
true), | |||
#if CUDNN_MAJOR >= 6 | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true), | |||
#endif | |||
#endif | |||
}; | |||
return algos; | |||
} | |||
const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> | |||
CudnnAlgoPack::conv_fwd_algos() { | |||
static const std::unordered_map<cudnnConvolutionFwdAlgo_t, | |||
CudnnAlgoPack::Attr> | |||
algos = { | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true), | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, | |||
true), | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true), | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true), | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true), | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true), | |||
#if CUDNN_MAJOR >= 5 | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true), | |||
#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true), | |||
#endif | |||
#endif | |||
}; | |||
return algos; | |||
} | |||
const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, CudnnAlgoPack::Attr> | |||
CudnnAlgoPack::conv3d_bwd_data_algos() { | |||
static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, | |||
CudnnAlgoPack::Attr> | |||
algos = { | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false), | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true), | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true), | |||
}; | |||
return algos; | |||
} // namespace cuda | |||
const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, CudnnAlgoPack::Attr> | |||
CudnnAlgoPack::conv3d_bwd_flt_algos() { | |||
#pragma message \ | |||
"fp16 dilated conv with odd size filter, only algo_1 works, need focus on doc" | |||
static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, | |||
CudnnAlgoPack::Attr> | |||
algos = { | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false), | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true), | |||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false), | |||
}; | |||
return algos; | |||
} | |||
const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> | |||
CudnnAlgoPack::conv3d_fwd_algos() { | |||
static const std::unordered_map<cudnnConvolutionFwdAlgo_t, | |||
CudnnAlgoPack::Attr> | |||
algos = { | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true), | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, | |||
true), | |||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true), | |||
}; | |||
return algos; | |||
} | |||
#undef DEF_ALGO | |||
#undef DEF_NAME | |||
#undef V | |||
#undef V1 | |||
} // namespace cuda | |||
} // namespace megdnn | |||
@@ -10,6 +10,7 @@ | |||
*/ | |||
#pragma once | |||
#include <unordered_map> | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/oprs/nn.h" | |||
#include "src/cuda/cudnn_with_check.h" | |||
@@ -27,7 +28,7 @@ class TensorDesc { | |||
public: | |||
TensorDesc(); | |||
//! default layout is nchw | |||
void set(const TensorLayout& layout, const param::Convolution::Format = | |||
void set(const TensorLayout& layout, const param::Convolution::Format = | |||
param::Convolution::Format::NCHW); | |||
~TensorDesc(); | |||
cudnnTensorDescriptor_t desc; | |||
@@ -103,9 +104,52 @@ class Conv3DDesc { | |||
cudnnConvolutionDescriptor_t desc; | |||
}; | |||
class CudnnAlgoPack { | |||
public: | |||
//! algorithm attr | |||
struct Attr { | |||
std::string name; | |||
bool is_reproducible; | |||
}; | |||
static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, Attr> | |||
conv_bwd_data_algos(); | |||
} // namespace cuda | |||
} // namespace megdnn | |||
static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, Attr> | |||
conv_bwd_flt_algos(); | |||
static const std::unordered_map<cudnnConvolutionFwdAlgo_t, Attr> | |||
conv_fwd_algos(); | |||
static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, Attr> | |||
conv3d_bwd_data_algos(); | |||
static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, Attr> | |||
conv3d_bwd_flt_algos(); | |||
static const std::unordered_map<cudnnConvolutionFwdAlgo_t, Attr> | |||
conv3d_fwd_algos(); | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
namespace std { | |||
#define DEF_HASH(_type) \ | |||
template <> \ | |||
struct hash<_type> { \ | |||
std::size_t operator()(const _type& algo) const { \ | |||
return std::hash<uint32_t>()(static_cast<uint32_t>(algo)); \ | |||
} \ | |||
} | |||
DEF_HASH(cudnnConvolutionBwdDataAlgo_t); | |||
DEF_HASH(cudnnConvolutionBwdFilterAlgo_t); | |||
DEF_HASH(cudnnConvolutionFwdAlgo_t); | |||
#undef DEF_HASH | |||
} // namespace std | |||
// vim: syntax=cpp.doxygen |
@@ -19,7 +19,12 @@ using OprImpl = DeformableConvBackwardDataImpl; | |||
OprImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(&algo_matmul); | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvBackwardDataImpl) | |||
OprImpl::AlgoPack OprImpl::sm_algo_pack; | |||
@@ -13,11 +13,15 @@ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/common/utils.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/deformable_conv/opr_impl.h" | |||
#include <unordered_map> | |||
namespace megdnn { | |||
namespace cuda { | |||
@@ -26,6 +30,10 @@ protected: | |||
~AlgoBase() = default; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
CUDA_MATMUL, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
DeformableConvBackwardDataImpl* opr; | |||
@@ -107,17 +115,18 @@ public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AlgoMatmul"; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||
}; | |||
class DeformableConvBackwardDataImpl::AlgoPack { | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator=(const AlgoPack&) = delete; | |||
class DeformableConvBackwardDataImpl::AlgoPack : NonCopyableObj { | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
AlgoMatmul algo_matmul; | |||
//! all algorithms | |||
std::vector<AlgoBase*> all_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace cuda | |||
@@ -20,7 +20,11 @@ using OprImpl = DeformableConvBackwardFilterImpl; | |||
OprImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(&algo_matmul); | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvBackwardFilterImpl) | |||
OprImpl::AlgoPack OprImpl::sm_algo_pack; | |||
@@ -13,11 +13,15 @@ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/common/utils.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/deformable_conv/opr_impl.h" | |||
#include <unordered_map> | |||
namespace megdnn { | |||
namespace cuda { | |||
@@ -26,6 +30,11 @@ protected: | |||
~AlgoBase() = default; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
CUDA_MATMUL, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
DeformableConvBackwardFilterImpl* opr; | |||
@@ -97,18 +106,18 @@ public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AlgoMatmul"; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||
}; | |||
class DeformableConvBackwardFilterImpl::AlgoPack { | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator=(const AlgoPack&) = delete; | |||
class DeformableConvBackwardFilterImpl::AlgoPack : NonCopyableObj { | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
AlgoMatmul algo_matmul; | |||
//! all algorithms | |||
std::vector<AlgoBase*> all_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace cuda | |||
@@ -22,8 +22,14 @@ using OprImpl = DeformableConvForwardImpl; | |||
OprImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(&algo_matmul); | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvForwardImpl) | |||
OprImpl::AlgoPack OprImpl::sm_algo_pack; | |||
OprImpl::AlgoBase::SizeArgs::SizeArgs(OprImpl* o, const TensorLayout& im, | |||
@@ -13,9 +13,13 @@ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/cuda/deformable_conv/opr_impl.h" | |||
#include "src/cuda/utils.h" | |||
#include <unordered_map> | |||
namespace megdnn { | |||
namespace cuda { | |||
@@ -24,6 +28,11 @@ protected: | |||
~AlgoBase() = default; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
CUDA_MATMUL, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
DeformableConvForwardImpl* opr; | |||
@@ -92,17 +101,17 @@ public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "AlgoMatmul"; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | |||
}; | |||
class DeformableConvForwardImpl::AlgoPack { | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator=(const AlgoPack&) = delete; | |||
class DeformableConvForwardImpl::AlgoPack : NonCopyableObj { | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
AlgoMatmul algo_matmul; | |||
//! all algorithms | |||
std::vector<AlgoBase*> all_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace cuda | |||
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
@@ -29,19 +30,6 @@ public: | |||
const TensorLayout& mask, | |||
const TensorLayout& dst) override; | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& im, const TensorLayout& filter, | |||
const TensorLayout& offset, const TensorLayout& mask, | |||
const TensorLayout& dst) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||
const TensorLayout& filter, | |||
const TensorLayout& offset, | |||
const TensorLayout& mask, | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||
const CanonizedFilterMeta& filter, | |||
const TensorLayout& offset, | |||
@@ -58,31 +46,35 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& im, const TensorLayout& filter, | |||
const TensorLayout& offset, const TensorLayout& mask, | |||
const TensorLayout& dst) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||
const TensorLayout& filter, | |||
const TensorLayout& offset, | |||
const TensorLayout& mask, | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
private: | |||
static AlgoPack sm_algo_pack; | |||
}; | |||
class DeformableConvBackwardFilterImpl: public DeformableConvBackwardFilter { | |||
class DeformableConvBackwardFilterImpl : public DeformableConvBackwardFilter { | |||
public: | |||
using DeformableConvBackwardFilter::DeformableConvBackwardFilter; | |||
void exec(_megdnn_tensor_in im,_megdnn_tensor_in offset, _megdnn_tensor_in mask, | |||
_megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad, | |||
void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset, | |||
_megdnn_tensor_in mask, _megdnn_tensor_in out_grad, | |||
_megdnn_tensor_out filter_grad, | |||
_megdnn_workspace workspace) override; | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, | |||
const TensorLayout& out_grad, const TensorLayout& filter_grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||
const TensorLayout& offset, | |||
const TensorLayout& mask, | |||
const TensorLayout& out_grad, | |||
const TensorLayout& filter_grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||
const TensorLayout& offset, | |||
const TensorLayout& mask, | |||
@@ -91,9 +83,11 @@ public: | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible); | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, | |||
const TensorLayout& out_grad, const TensorLayout& filter_grad) override; | |||
size_t get_workspace_in_bytes(const TensorLayout& im, | |||
const TensorLayout& offset, | |||
const TensorLayout& mask, | |||
const TensorLayout& out_grad, | |||
const TensorLayout& filter_grad) override; | |||
const char* get_algorithm_set_name() const override; | |||
@@ -103,6 +97,21 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& im, const TensorLayout& offset, | |||
const TensorLayout& mask, const TensorLayout& out_grad, | |||
const TensorLayout& filter_grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||
const TensorLayout& offset, | |||
const TensorLayout& mask, | |||
const TensorLayout& out_grad, | |||
const TensorLayout& filter_grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
private: | |||
static AlgoPack sm_algo_pack; | |||
@@ -118,19 +127,6 @@ public: | |||
_megdnn_tensor_out offset_grad, _megdnn_tensor_out mask_grad, | |||
_megdnn_workspace workspace) override; | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& im, const TensorLayout& filter, | |||
const TensorLayout& offset, const TensorLayout& mask, | |||
const TensorLayout& out_grad, const TensorLayout& im_grad, | |||
const TensorLayout& offset_grad, const TensorLayout& mask_grad) override; | |||
Algorithm* get_algorithm_heuristic( | |||
const TensorLayout& im, const TensorLayout& filter, | |||
const TensorLayout& offset, const TensorLayout& mask, | |||
const TensorLayout& out_grad, const TensorLayout& im_grad, | |||
const TensorLayout& offset_grad, const TensorLayout& mask_grad, | |||
size_t workspace_limit_in_bytes, bool reproducible) override; | |||
Algorithm* get_algorithm_heuristic( | |||
const TensorLayout& im, const CanonizedFilterMeta& filter, | |||
const TensorLayout& offset, const TensorLayout& mask, | |||
@@ -138,11 +134,14 @@ public: | |||
const TensorLayout& offset_grad, const TensorLayout& mask_grad, | |||
size_t workspace_limit_in_bytes, bool reproducible); | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& im, const TensorLayout& filter, | |||
const TensorLayout& offset, const TensorLayout& mask, | |||
const TensorLayout& out_grad, const TensorLayout& im_grad, | |||
const TensorLayout& offset_grad, const TensorLayout& mask_grad) override; | |||
size_t get_workspace_in_bytes(const TensorLayout& im, | |||
const TensorLayout& filter, | |||
const TensorLayout& offset, | |||
const TensorLayout& mask, | |||
const TensorLayout& out_grad, | |||
const TensorLayout& im_grad, | |||
const TensorLayout& offset_grad, | |||
const TensorLayout& mask_grad) override; | |||
const char* get_algorithm_set_name() const override; | |||
@@ -152,6 +151,22 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& im, const TensorLayout& filter, | |||
const TensorLayout& offset, const TensorLayout& mask, | |||
const TensorLayout& out_grad, const TensorLayout& im_grad, | |||
const TensorLayout& offset_grad, | |||
const TensorLayout& mask_grad) override; | |||
Algorithm* get_algorithm_heuristic( | |||
const TensorLayout& im, const TensorLayout& filter, | |||
const TensorLayout& offset, const TensorLayout& mask, | |||
const TensorLayout& out_grad, const TensorLayout& im_grad, | |||
const TensorLayout& offset_grad, const TensorLayout& mask_grad, | |||
size_t workspace_limit_in_bytes, bool reproducible) override; | |||
private: | |||
static AlgoPack sm_algo_pack; | |||
@@ -18,8 +18,14 @@ using namespace cuda; | |||
LocalShareBackwardDataImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(&implicit_gemm); | |||
all_algos.push_back(&batched_matmul); | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareBackwardDataImpl) | |||
LocalShareBackwardDataImpl::AlgoPack LocalShareBackwardDataImpl::sm_algo_pack; | |||
LocalShareBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( | |||
@@ -13,10 +13,14 @@ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/common/utils.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/local_share/opr_impl.h" | |||
#include <unordered_map> | |||
namespace megdnn { | |||
namespace cuda { | |||
@@ -25,6 +29,13 @@ protected: | |||
~AlgoBase() = default; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
CUDA_IMPLICIT_GEMM, | |||
CUDA_BATCHED_MATMUL, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
LocalShareBackwardDataImpl* opr; | |||
@@ -77,6 +88,7 @@ public: | |||
const char* name() const override { | |||
return "LOCAL_SHARE_IMPLICIT_GEMM"; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM) | |||
}; | |||
class LocalShareBackwardDataImpl::AlgoBatchedMatMul final | |||
@@ -93,11 +105,11 @@ public: | |||
const char* name() const override { | |||
return "LOCAL_SHARE_BATCHED_MATMUL"; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | |||
}; | |||
class LocalShareBackwardDataImpl::AlgoPack { | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator=(const AlgoPack&) = delete; | |||
class LocalShareBackwardDataImpl::AlgoPack : NonCopyableObj { | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
@@ -106,6 +118,7 @@ public: | |||
AlgoBatchedMatMul batched_matmul; | |||
std::vector<AlgoBase*> all_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace cuda | |||
@@ -18,8 +18,14 @@ using namespace cuda; | |||
LocalShareBackwardFilterImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(&implicit_gemm); | |||
all_algos.push_back(&batched_matmul); | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareBackwardFilterImpl) | |||
LocalShareBackwardFilterImpl::AlgoPack LocalShareBackwardFilterImpl::sm_algo_pack; | |||
LocalShareBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( | |||
@@ -13,10 +13,14 @@ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/common/utils.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/local_share/opr_impl.h" | |||
#include <unordered_map> | |||
namespace megdnn { | |||
namespace cuda { | |||
@@ -25,6 +29,12 @@ protected: | |||
~AlgoBase() = default; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
CUDA_IMPLICIT_GEMM, | |||
CUDA_BATCHED_MATMUL, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
LocalShareBackwardFilterImpl* opr; | |||
@@ -75,6 +85,7 @@ public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "LOCAL_SHARE_IMPLICIT_GEMM"; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM) | |||
}; | |||
class LocalShareBackwardFilterImpl::AlgoBatchedMatMul final : public AlgoBase { | |||
@@ -88,11 +99,11 @@ public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | |||
}; | |||
class LocalShareBackwardFilterImpl::AlgoPack { | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator=(const AlgoPack&) = delete; | |||
class LocalShareBackwardFilterImpl::AlgoPack : NonCopyableObj { | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
@@ -101,6 +112,8 @@ public: | |||
AlgoBatchedMatMul batched_matmul; | |||
std::vector<AlgoBase*> all_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace cuda | |||
@@ -19,8 +19,14 @@ LocalShareForwardImpl::AlgoPack::AlgoPack() { | |||
all_algos.push_back(&batch_size_aware_chwn_small_image); | |||
all_algos.push_back(&batch_size_aware_chwn); | |||
all_algos.push_back(&batched_matmul); | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareForwardImpl) | |||
LocalShareForwardImpl::AlgoPack LocalShareForwardImpl::sm_algo_pack; | |||
LocalShareForwardImpl::AlgoBase::SizeArgs::SizeArgs(LocalShareForwardImpl* o, | |||
@@ -14,9 +14,13 @@ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/cuda/handle.h" | |||
#include "src/cuda/local_share/opr_impl.h" | |||
#include <unordered_map> | |||
namespace megdnn { | |||
namespace cuda { | |||
@@ -25,6 +29,13 @@ protected: | |||
~AlgoBase() = default; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
CUDA_CHWN_BATCH_SIZE_AWARE, | |||
CUDA_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE, | |||
CUDA_BATCHED_MATMUL | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
LocalShareForwardImpl* opr; | |||
@@ -79,6 +90,7 @@ public: | |||
const char* name() const override { | |||
return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE"; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHWN_BATCH_SIZE_AWARE) | |||
}; | |||
class LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage final | |||
@@ -95,6 +107,7 @@ public: | |||
const char* name() const override { | |||
return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE"; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE) | |||
}; | |||
class LocalShareForwardImpl::AlgoBatchedMatMul final : public AlgoBase { | |||
@@ -108,11 +121,11 @@ public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | |||
}; | |||
class LocalShareForwardImpl::AlgoPack { | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator=(const AlgoPack&) = delete; | |||
class LocalShareForwardImpl::AlgoPack : NonCopyableObj { | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
@@ -122,6 +135,7 @@ public: | |||
AlgoBatchedMatMul batched_matmul; | |||
std::vector<AlgoBase*> all_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace cuda | |||
@@ -23,14 +23,6 @@ public: | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& filter, | |||
const TensorLayout& dst) override; | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& src, const TensorLayout& filter, | |||
const TensorLayout& dst) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& filter, | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
const char* get_algorithm_set_name() const override; | |||
class AlgoBase; | |||
@@ -41,7 +33,17 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& src, const TensorLayout& filter, | |||
const TensorLayout& dst) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& filter, | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
private: | |||
static AlgoPack sm_algo_pack; | |||
}; | |||
@@ -54,14 +56,6 @@ public: | |||
size_t get_workspace_in_bytes(const TensorLayout& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& filter, const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
const char* get_algorithm_set_name() const override; | |||
class AlgoBase; | |||
@@ -71,6 +65,17 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& filter, const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
private: | |||
static AlgoPack sm_algo_pack; | |||
@@ -84,14 +89,6 @@ public: | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& src, const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
const char* get_algorithm_set_name() const override; | |||
class AlgoBase; | |||
@@ -101,6 +98,17 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& src, const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
private: | |||
static AlgoPack sm_algo_pack; | |||
@@ -11,6 +11,7 @@ | |||
#include "./algos.h" | |||
#include "src/cuda/utils.h" | |||
#include "src/common/algo_base.h" | |||
#include <cuda.h> | |||
#if CUDA_VERSION >= 10010 | |||
@@ -33,10 +34,16 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
cublas_bfloat16 = std::make_unique<AlgoBFloat16>(&cublas); | |||
all_algos.push_back(cublas_bfloat16.get()); | |||
#endif | |||
for (auto&& algo : all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; | |||
MEGDNN_DEF_GET_ALGO_FROM_DESC(MatrixMulForwardImpl) | |||
MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(MatrixMulForwardImpl* o, | |||
const TensorLayout& A, | |||
const TensorLayout& B, | |||
@@ -67,4 +74,5 @@ std::string MatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||
m, k, k, n, m, n, param.transposeA, param.transposeB, | |||
layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -6,14 +6,18 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
#include "src/cuda/matrix_mul/opr_impl.h" | |||
#include "src/common/algo_base.h" | |||
#include "src/common/metahelper.h" | |||
#include <unordered_map> | |||
#include <cuda.h> | |||
#include <memory> | |||
#if CUDA_VERSION >= 10010 | |||
@@ -32,6 +36,15 @@ protected: | |||
~AlgoBase() = default; | |||
public: | |||
enum class AlgoType : uint32_t { | |||
CUDA_CUBLAS, | |||
CUDA_WMMA_UINT4X4X32, | |||
CUDA_CUBLASLT, | |||
CUDA_NAIVE, | |||
CUDA_BFLOAT16 | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||
struct SizeArgs { | |||
MatrixMulForwardImpl* opr; | |||
@@ -62,12 +75,12 @@ public: | |||
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||
virtual void exec(const ExecArgs& args) const = 0; | |||
bool is_available_wk(const SizeArgs& args, size_t limit) { | |||
bool is_available_wk(const SizeArgs& args, size_t limit) const { | |||
return is_available(args) && get_workspace_in_bytes(args) <= limit; | |||
} | |||
bool is_available_reproducible( | |||
const SizeArgs& args, bool reproducible = true, | |||
size_t limit = std::numeric_limits<size_t>::max()) { | |||
size_t limit = std::numeric_limits<size_t>::max()) const { | |||
return (!reproducible || is_reproducible()) && | |||
is_available_wk(args, limit); | |||
} | |||
@@ -80,8 +93,6 @@ public: | |||
name(), req, workspace.size); | |||
return *this; | |||
} | |||
}; | |||
class MatrixMulForwardImpl::AlgoCuBlas final : public AlgoBase { | |||
@@ -91,13 +102,10 @@ public: | |||
size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { | |||
return 0_z; | |||
} | |||
const char* name() const override { | |||
return "CUBLAS"; | |||
} | |||
const char* name() const override { return "CUBLAS"; } | |||
void exec(const ExecArgs& args) const override; | |||
bool is_reproducible() const override { | |||
return true; | |||
} | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | |||
}; | |||
#if CUDA_VERSION >= 10000 | |||
@@ -106,13 +114,10 @@ public: | |||
AlgoUInt4x4x32WMMA() = default; | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
const char* name() const override { | |||
return "UINT4x4x32_WMMA"; | |||
} | |||
const char* name() const override { return "UINT4x4x32_WMMA"; } | |||
void exec(const ExecArgs& args) const override; | |||
bool is_reproducible() const override { | |||
return true; | |||
} | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32) | |||
}; | |||
#endif | |||
#if CUDA_VERSION >= 10010 | |||
@@ -120,13 +125,10 @@ class MatrixMulForwardImpl::AlgoCuBlasLt final : public AlgoBase { | |||
public: | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
const char* name() const override { | |||
return "CUBLAS_LT"; | |||
} | |||
const char* name() const override { return "CUBLAS_LT"; } | |||
void exec(const ExecArgs& args) const override; | |||
bool is_reproducible() const override { | |||
return true; | |||
} | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | |||
}; | |||
#endif | |||
@@ -140,6 +142,7 @@ public: | |||
const char* name() const override { return "NAIVE"; } | |||
void exec(const ExecArgs& args) const override; | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE) | |||
}; | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
@@ -151,6 +154,13 @@ public: | |||
const char* name() const override { return m_name.c_str(); } | |||
void exec(const ExecArgs& args) const override; | |||
bool is_reproducible() const override { return true; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_algorithm, ret); | |||
return ret; | |||
} | |||
private: | |||
MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr; | |||
@@ -160,9 +170,9 @@ private: | |||
}; | |||
#endif | |||
class MatrixMulForwardImpl::AlgoPack { | |||
AlgoPack(const AlgoPack&) = delete; | |||
AlgoPack& operator=(const AlgoPack&) = delete; | |||
class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { | |||
private: | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack(); | |||
@@ -178,6 +188,8 @@ public: | |||
std::unique_ptr<AlgoBFloat16> cublas_bfloat16; | |||
#endif | |||
std::vector<AlgoBase*> all_algos; | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
} // namespace cuda | |||
@@ -82,7 +82,7 @@ void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { | |||
args.opr->handle()->create_operator<MatrixMulForward>(); | |||
matmul_opr->param() = args.opr->param(); | |||
matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
matmul_opr->execution_policy() = {m_algorithm}; | |||
matmul_opr->execution_policy() = {m_algorithm->info()}; | |||
matmul_opr->exec(a, b, c, ctypecvt.workspace()); | |||
} | |||
ctypecvt.comp_to_dst_type(c, args.tensor_c); | |||
@@ -25,15 +25,6 @@ public: | |||
bool is_thread_safe() const override { return true; } | |||
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
const char* get_algorithm_set_name() const override { | |||
return "CUDA MATMUL"; | |||
} | |||
@@ -55,6 +46,17 @@ public: | |||
static const AlgoPack& algo_pack() { | |||
return sm_algo_pack; | |||
} | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
private: | |||
static AlgoPack sm_algo_pack; | |||
@@ -10,10 +10,14 @@ | |||
*/ | |||
#include "src/fallback/conv_bias/algos.h" | |||
#include "src/fallback/conv_bias/conv1x1/algos.h" | |||
#include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h" | |||
#include "src/fallback/conv_bias/im2col/algos.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "src/fallback/conv_bias/winograd/strategy.h" | |||
#include "src/naive/convolution/helper.h" | |||
#include "src/common/algo_base.h" | |||
#include "midout.h" | |||
@@ -176,6 +180,7 @@ void kern_default(const ConvBiasImpl::NCBKernParam& p) { | |||
} // namespace | |||
MIDOUT_DECL(megdnn_fallback_naive) | |||
/* ======================= AlgoNaive ======================== */ | |||
bool ConvBiasImpl::AlgoNaive::usable( | |||
@@ -36,6 +36,7 @@ public: | |||
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); | |||
return {support_data_type, AlgoCategory::NAIVE}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(FB_NAIVE) | |||
}; | |||
class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase { | |||
@@ -59,6 +60,12 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_F32) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_matmul_algo, ret); | |||
return ret; | |||
} | |||
private: | |||
MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
@@ -87,6 +94,12 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_4X4_F32) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_matmul_algo, ret); | |||
return ret; | |||
} | |||
private: | |||
MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
@@ -115,6 +128,12 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_QS8) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_matmul_algo, ret); | |||
return ret; | |||
} | |||
private: | |||
MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
@@ -143,6 +162,12 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_matmul_algo, ret); | |||
return ret; | |||
} | |||
private: | |||
MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
@@ -156,6 +156,12 @@ using BiasMode = ConvBiasForward::BiasMode; | |||
ConvAlgoTypePack get_algo_type() const override { \ | |||
return {_algo_data_type, AlgoCategory::WINOGRAD}; \ | |||
} \ | |||
std::string param() const override { \ | |||
std::string ret; \ | |||
serialize_write_pod(m_matmul_algo, ret); \ | |||
serialize_write_pod(m_tile_size, ret); \ | |||
return ret; \ | |||
} \ | |||
\ | |||
private: \ | |||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; \ | |||
@@ -60,6 +60,13 @@ public: | |||
return {m_matmul_algo->matmul_description().algo_type.data_type, | |||
AlgoCategory::IM2COL}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_matmul_algo, ret); | |||
serialize_write_pod(m_oc_block_size, ret); | |||
return ret; | |||
} | |||
protected: | |||
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | |||
@@ -43,6 +43,7 @@ public: | |||
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); | |||
return {support_data_type, AlgoCategory::IM2COL}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(FB_CONV1x1_GEMV) | |||
protected: | |||
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | |||
@@ -68,6 +68,14 @@ public: | |||
return {m_matmul_algo->matmul_description().algo_type.data_type, | |||
AlgoCategory::IM2COL}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(FB_IM2COL) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_matmul_algo, ret); | |||
serialize_write_pod(m_ohw_tile_size, ret); | |||
return ret; | |||
} | |||
private: | |||
MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
@@ -22,6 +22,14 @@ | |||
#include "src/naive/convolution/algorithms.h" | |||
#include "src/naive/handle.h" | |||
#if MEGDNN_X86 | |||
#include "src/x86/conv_bias/opr_impl.h" | |||
#elif MEGDNN_AARCH64 | |||
#include "src/aarch64/conv_bias/opr_impl.h" | |||
#elif MEGDNN_ARMV7 | |||
#include "src/armv7/conv_bias/opr_impl.h" | |||
#endif | |||
#include <cstring> | |||
using namespace megdnn; | |||
@@ -65,17 +73,19 @@ void incr_ptr(T*& dst, ptrdiff_t delta) { | |||
class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
AlgoNaive algo_naive; | |||
SmallVector<std::unique_ptr<AlgoBase>> refhold; | |||
SmallVector<AlgoBase*> m_all_algos; | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack() { | |||
refhold.emplace_back(new AlgoConv1x1Gemv()); | |||
all_algos.emplace_back(refhold.back().get()); | |||
m_all_algos.emplace_back(refhold.back().get()); | |||
static CpuOprDelegationStorage<> storage; | |||
auto matmul_opr = storage.get<MatrixMul>(); | |||
auto&& matmul_algos = | |||
static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack(); | |||
auto&& matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr) | |||
->get_all_packed_algo(); | |||
for (auto&& algo : matmul_algos) { | |||
#if MEGDNN_X86 | |||
//! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may | |||
@@ -97,13 +107,13 @@ public: | |||
refhold.emplace_back(new AlgoIm2col( | |||
static_cast<MatrixMulImpl::AlgoBase*>(algo), | |||
ohw_tile_size)); | |||
all_algos.emplace_back(refhold.back().get()); | |||
m_all_algos.emplace_back(refhold.back().get()); | |||
} | |||
for (size_t oc_tile_size : {48, 24}) { | |||
refhold.emplace_back(new AlgoConv1x1( | |||
static_cast<MatrixMulImpl::AlgoBase*>(algo), | |||
oc_tile_size)); | |||
all_algos.emplace_back(refhold.back().get()); | |||
m_all_algos.emplace_back(refhold.back().get()); | |||
} | |||
#endif | |||
@@ -113,26 +123,35 @@ public: | |||
//! FIXME: I do not know a better way to do it. | |||
refhold.emplace_back(new AlgoWinogradF32( | |||
static_cast<MatrixMulImpl::AlgoBase*>(algo))); | |||
all_algos.emplace_back(refhold.back().get()); | |||
m_all_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoWinogradF32_4x4( | |||
static_cast<MatrixMulImpl::AlgoBase*>(algo))); | |||
all_algos.emplace_back(refhold.back().get()); | |||
m_all_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoWinogradQS8( | |||
static_cast<MatrixMulImpl::AlgoBase*>(algo))); | |||
all_algos.emplace_back(refhold.back().get()); | |||
m_all_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoWinogradQS8_8x8( | |||
static_cast<MatrixMulImpl::AlgoBase*>(algo))); | |||
all_algos.emplace_back(refhold.back().get()); | |||
m_all_algos.emplace_back(refhold.back().get()); | |||
#endif | |||
} | |||
all_algos.emplace_back(&algo_naive); | |||
m_all_algos.emplace_back(&algo_naive); | |||
for (auto&& algo : m_all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
} | |||
SmallVector<AlgoBase*> all_algos; | |||
const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; } | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||
static AlgoPack sl_algo_pack; | |||
return sl_algo_pack.all_algos; | |||
const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||
static AlgoPack algo_pack; | |||
return algo_pack; | |||
} | |||
SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::get_all_packed_algo() { | |||
return algo_pack().all_algos(); | |||
} | |||
SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type( | |||
@@ -140,7 +159,7 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type( | |||
megdnn_assert(nr_type_contain(target_type.data_type), | |||
"ConvBias algo selection only support one type"); | |||
SmallVector<ConvBiasImpl::AlgoBase*> algos; | |||
for (auto&& algo : algo_pack()) { | |||
for (auto&& algo : get_all_packed_algo()) { | |||
auto algo_type = algo->get_algo_type(); | |||
if (contain_data_type(algo_type.data_type, target_type.data_type) && | |||
algo_type.algo_category == target_type.algo_category) { | |||
@@ -166,7 +185,7 @@ void ConvBiasImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
workspace.size, preprocessed_filter); | |||
auto fparam = make_ncb_kern_param(src, filter, bias, dst, workspace, | |||
preprocessed_filter); | |||
ConvBiasImpl::Algorithm* algo = get_algorithm(fparam, workspace.size); | |||
auto&& algo = get_algorithm(fparam, workspace.size); | |||
if (!is_naive_algo(algo) && | |||
NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) { | |||
exec_with_ncb_kern(fparam, algo); | |||
@@ -189,9 +208,10 @@ void ConvBiasImpl::exec_preprocess(const TensorLayout& src_layout, | |||
auto fparam = make_ncb_kern_param(src, filter, bias, dst, workspace, | |||
preprocessed_filter); | |||
//! should not pass workspace_size limit otherwise can not find match algo | |||
ConvBiasImpl::Algorithm* algo = get_algorithm(fparam); | |||
if (!is_naive_algo(algo) && NCB_ALGO_FUNC(get_preprocess_workspace, algo, | |||
fparam) <= workspace.size) { | |||
auto&& algo = get_algorithm(fparam); | |||
if (!is_naive_algo(algo) && | |||
NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam) <= | |||
workspace.size) { | |||
exec_preprocess_with_ncb_kern(fparam, algo); | |||
} else { | |||
naive::ConvBiasForwardImpl::exec_preprocess( | |||
@@ -207,7 +227,7 @@ size_t ConvBiasImpl::get_workspace_in_bytes( | |||
const PreprocessedFilter* preprocessed_filter) { | |||
auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, | |||
preprocessed_filter); | |||
ConvBiasImpl::Algorithm* algo = get_algorithm(fparam); | |||
auto&& algo = get_algorithm(fparam); | |||
if (is_naive_algo(algo)) { | |||
return naive::ConvBiasForwardImpl::get_workspace_in_bytes( | |||
src, filter, bias, z, dst, preprocessed_filter); | |||
@@ -221,7 +241,7 @@ size_t ConvBiasImpl::get_preprocess_workspace_in_bytes( | |||
const TensorLayout& bias, const TensorLayout& z, | |||
const TensorLayout& dst) { | |||
auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | |||
Algorithm* algo = get_algorithm(fparam); | |||
auto&& algo = get_algorithm(fparam); | |||
if (is_naive_algo(algo)) { | |||
return naive::ConvBiasForwardImpl::get_preprocess_workspace_in_bytes( | |||
src, filter, bias, z, dst); | |||
@@ -235,7 +255,7 @@ SmallVector<TensorLayout> ConvBiasImpl::deduce_preprocessed_filter_layout( | |||
const TensorLayout& bias, const TensorLayout& z, | |||
const TensorLayout& dst) { | |||
auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | |||
Algorithm* algo = get_algorithm(fparam); | |||
auto&& algo = get_algorithm(fparam); | |||
if (is_naive_algo(algo)) { | |||
return naive::ConvBiasForwardImpl::deduce_preprocessed_filter_layout( | |||
src, filter, bias, z, dst); | |||
@@ -443,7 +463,7 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb( | |||
MEGDNN_MARK_USED_VAR(param); | |||
std::vector<Algorithm*> algos; | |||
std::vector<Algorithm*> prefer_algos; | |||
for (auto&& algo : algo_pack()) { | |||
for (auto&& algo : get_all_packed_algo()) { | |||
if (algo->usable(param, AlgoSelectionStrategy::FULL_RUN)) { | |||
if (algo->is_preferred(param)) { | |||
prefer_algos.push_back(algo); | |||
@@ -457,10 +477,49 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb( | |||
return algos; | |||
} | |||
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc( | |||
const AlgorithmDesc& desc) const { | |||
if (!desc.valid()) { | |||
return nullptr; | |||
} else { | |||
switch (desc.handle_type) { | |||
case Handle::HandleType::FALLBACK: { | |||
const auto& map = algo_pack().all_algos_map(); | |||
megdnn_assert(map.find(desc) != map.end()); | |||
return map.at(desc); | |||
}; | |||
#if MEGDNN_X86 | |||
case Handle::HandleType::X86: | |||
return x86::ConvBiasImpl::get_algo_from_desc(desc); | |||
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
case Handle::HandleType::ARM_COMMON: | |||
return arm_common::ConvBiasImpl::get_algo_from_desc(desc); | |||
#if MEGDNN_AARCH64 | |||
case Handle::HandleType::AARCH64: | |||
return aarch64::ConvBiasImpl::get_algo_from_desc(desc); | |||
#else | |||
case Handle::HandleType::ARMV7: | |||
return armv7::ConvBiasImpl::get_algo_from_desc(desc); | |||
#endif | |||
#endif | |||
case Handle::HandleType::NAIVE: { | |||
auto algo = static_cast<naive::HandleImpl*>(handle()) | |||
->default_conv_bias_fwd_algo(); | |||
megdnn_assert(algo->info().desc == desc); | |||
return algo; | |||
} | |||
default: | |||
megdnn_throw("Unknown handle type"); | |||
return nullptr; | |||
} | |||
} | |||
} | |||
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( | |||
const NCBKernSizeParam& param, size_t workspace_size) { | |||
if (auto set = execution_policy().algorithm) { | |||
return set; | |||
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { | |||
return algo; | |||
} | |||
if (!m_prev_selected_algo || | |||
memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | |||
@@ -216,6 +216,86 @@ public: | |||
AlgoBase() : Algorithm() { | |||
m_handle_type = Handle::HandleType::FALLBACK; | |||
} | |||
enum class AlgoType : uint32_t { | |||
//! fallback | |||
FB_NAIVE = 1 << 0, | |||
FB_WINOGRAD_F32, | |||
FB_WINOGRAD_4X4_F32, | |||
FB_WINOGRAD_QS8, | |||
FB_WINOGRAD_8X8_QS8, | |||
FB_CONV1x1, | |||
FB_CONV1x1_GEMV, | |||
FB_IM2COL, | |||
#if MEGDNN_X86 | |||
X86_DIRECT = 1 << 8, | |||
X86_DIRECT_STRD2, | |||
X86_WINOGRAD_F63_8x8_F32, | |||
X86_WINOGRAD_F23_8x8_F32, | |||
X86_MKLDNN, | |||
X86_CHANWISE_AVX2_STRD1_QINT8, | |||
X86_CHANWISE_AVX2_STRD2_QINT8, | |||
X86_DIRECT_AVX2_STRD1_INT8, | |||
X86_DIRECT_AVX2_STRD2_INT8, | |||
X86_MKLDNN_QINT8, | |||
X86_MKLDNN_MATMUL_QINT8, | |||
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
ARM_COMMON_WINOGRAD_F23_FP16 = 1 << 8, | |||
ARM_COMMON_WINOGRAD_F45_FP16, | |||
ARM_COMMON_WINOGRAD_F63_FP16, | |||
ARM_COMMON_WINOGRAD_F23_8X8_FP16, | |||
ARM_COMMON_DIRECT_FP16, | |||
ARM_COMMON_DIRECT_STRD1_FP16, | |||
ARM_COMMON_WINOGRAD_F23_4X4_FP32, | |||
ARM_COMMON_WINOGRAD_F63_FP32, | |||
ARM_COMMON_WINOGRAD_F63_4X4_FP32, | |||
ARM_COMMON_WINOGRAD_F54_FP32, | |||
ARM_COMMON_WINOGRAD_F45_FP32, | |||
ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32, | |||
ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32, | |||
ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32, | |||
ARM_COMMON_DIRECT_FP32, | |||
ARM_COMMON_DIRECT_STRD1_FP32, | |||
ARM_COMMON_DIRECT_STRD2_FP32, | |||
ARM_COMMON_DIRECT_NCHW44_FP32, | |||
ARM_COMMON_DIRECT_NCHW_NCHW44_FP32, | |||
ARM_COMMON_CHWNWISE_NCHW44_F32, | |||
ARM_COMMON_DIRECT_STRD1_S8, | |||
ARM_COMMON_DIRECT_STRD2_S8, | |||
ARM_COMMON_DIRECT_NCHW44, | |||
ARM_COMMON_DIRECT_NCHW_NCHW44_S8, | |||
ARM_COMMON_CHANWISE_STRD1_NCHW44_S8, | |||
ARM_COMMON_CHANWISE_STRD2_NCHW44_S8, | |||
ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8, | |||
ARM_COMMON_DIRECT_STRD1_DOT_S8, | |||
ARM_COMMON_DIRECT_STRD2_DOT_S8, | |||
ARM_COMMON_DIRECT_NCHW44_DOT_S8, | |||
ARM_COMMON_WINOGRAD_F23_8X8_S8, | |||
ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32, | |||
ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8, | |||
ARM_COMMON_DIRECT_INT8X8X16, | |||
ARM_COMMON_DIRECT_NCHW44_INT8X8X16, | |||
ARM_COMMON_DIRECT_STRD2_INT8X8X16, | |||
ARM_COMMON_DIRECT_STRD2_F2_INT8X8X16, | |||
ARM_COMMON_CHWNWISE_STRD1_STRD2_NCHW44_INT8X8X16, | |||
ARM_COMMON_DIRECT_NCHW_NCHW44_INT8X8X16, | |||
ARM_COMMON_DIRECT_STRD1_QU8, | |||
ARM_COMMON_DIRECT_STRD2_QU8, | |||
ARM_COMMON_DIRECT_STRD1_DOT_QU8, | |||
ARM_COMMON_DIRECT_STRD2_DOT_QU8, | |||
#if MEGDNN_AARCH64 | |||
AARCH64_DIRECT_STRD2_FP16, | |||
AARCH64_DIRECT_STRD2_FP32, | |||
AARCH64_MATMUL_S8, | |||
AARCH64_MATMUL_QU8, | |||
#else | |||
ARMV7_MATMUL_S8, | |||
ARMV7_MATMUL_QU8, | |||
#endif // MEGDNN_AARCH64 | |||
#endif | |||
}; | |||
virtual ~AlgoBase() = default; | |||
virtual bool usable( | |||
const NCBKernSizeParam& param, | |||
@@ -255,12 +335,14 @@ public: | |||
//! get the type of the algo | |||
virtual ConvAlgoTypePack get_algo_type() const = 0; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
}; | |||
using AlgoMapper = AlgoBase::Mapper; | |||
/** | |||
* \brief get all the algorithm for the opr. | |||
*/ | |||
virtual SmallVector<AlgoBase*> algo_pack(); | |||
virtual SmallVector<AlgoBase*> get_all_packed_algo(); | |||
/** | |||
* \brief select algo according to input algo type | |||
@@ -305,6 +387,8 @@ private: | |||
bool is_naive_algo(ConvBiasImpl::Algorithm* algo); | |||
Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; | |||
//! get algorithm set by user or by heuristic | |||
Algorithm* get_algorithm( | |||
const NCBKernSizeParam& param, | |||
@@ -320,6 +404,8 @@ private: | |||
_megdnn_tensor_in bias, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace, | |||
const PreprocessedFilter* preprocessed_filter); | |||
static const AlgoPack& algo_pack(); | |||
}; | |||
inline bool is_enable_filter_preprocess( | |||