Browse Source

feat(dnn): add an get_all_algorithms_safe interface

GitOrigin-RevId: e3734e4531
release-1.6
Megvii Engine Team 3 years ago
parent
commit
d69b59035d
67 changed files with 648 additions and 35 deletions
  1. +75
    -4
      dnn/include/megdnn/oprs/base.h
  2. +6
    -1
      dnn/src/arm_common/pooling/opr_impl.cpp
  3. +2
    -0
      dnn/src/arm_common/pooling/opr_impl.h
  4. +8
    -2
      dnn/src/common/algo_chooser.h
  5. +9
    -0
      dnn/src/cuda/batch_conv_bias/opr_impl.cpp
  6. +4
    -0
      dnn/src/cuda/batch_conv_bias/opr_impl.h
  7. +6
    -0
      dnn/src/cuda/batched_matrix_mul/opr_impl.cpp
  8. +3
    -0
      dnn/src/cuda/batched_matrix_mul/opr_impl.h
  9. +10
    -0
      dnn/src/cuda/conv_bias/opr_impl.cpp
  10. +4
    -0
      dnn/src/cuda/conv_bias/opr_impl.h
  11. +24
    -0
      dnn/src/cuda/convolution/opr_impl.cpp
  12. +12
    -0
      dnn/src/cuda/convolution/opr_impl.h
  13. +24
    -0
      dnn/src/cuda/convolution3d/opr_impl.cpp
  14. +9
    -0
      dnn/src/cuda/convolution3d/opr_impl.h
  15. +25
    -0
      dnn/src/cuda/deformable_conv/opr_impl.cpp
  16. +15
    -0
      dnn/src/cuda/deformable_conv/opr_impl.h
  17. +24
    -1
      dnn/src/cuda/local_share/opr_impl.cpp
  18. +9
    -0
      dnn/src/cuda/local_share/opr_impl.h
  19. +8
    -0
      dnn/src/cuda/matrix_mul/opr_impl.cpp
  20. +4
    -0
      dnn/src/cuda/matrix_mul/opr_impl.h
  21. +14
    -0
      dnn/src/cuda/pooling/opr_impl.cpp
  22. +5
    -0
      dnn/src/cuda/pooling/opr_impl.h
  23. +7
    -0
      dnn/src/fallback/batched_matrix_mul/opr_impl.cpp
  24. +3
    -0
      dnn/src/fallback/batched_matrix_mul/opr_impl.h
  25. +8
    -1
      dnn/src/fallback/conv_bias/opr_impl.cpp
  26. +4
    -0
      dnn/src/fallback/conv_bias/opr_impl.h
  27. +18
    -2
      dnn/src/fallback/convolution/opr_impl.cpp
  28. +7
    -0
      dnn/src/fallback/convolution/opr_impl.h
  29. +7
    -0
      dnn/src/fallback/matrix_mul/opr_impl.cpp
  30. +4
    -0
      dnn/src/fallback/matrix_mul/opr_impl.h
  31. +10
    -0
      dnn/src/naive/batch_conv_bias/opr_impl.cpp
  32. +5
    -0
      dnn/src/naive/batch_conv_bias/opr_impl.h
  33. +7
    -1
      dnn/src/naive/batched_matrix_mul/opr_impl.cpp
  34. +3
    -0
      dnn/src/naive/batched_matrix_mul/opr_impl.h
  35. +9
    -0
      dnn/src/naive/conv_bias/opr_impl.cpp
  36. +5
    -0
      dnn/src/naive/conv_bias/opr_impl.h
  37. +21
    -0
      dnn/src/naive/convolution/convolution.cpp
  38. +9
    -0
      dnn/src/naive/convolution/opr_impl.h
  39. +21
    -1
      dnn/src/naive/convolution3d/convolution3d.cpp
  40. +9
    -0
      dnn/src/naive/convolution3d/opr_impl.h
  41. +23
    -0
      dnn/src/naive/deformable_conv/opr_impl.h
  42. +23
    -0
      dnn/src/naive/local_share/opr_impl.cpp
  43. +12
    -1
      dnn/src/naive/local_share/opr_impl.h
  44. +7
    -0
      dnn/src/naive/matrix_mul/opr_impl.cpp
  45. +4
    -0
      dnn/src/naive/matrix_mul/opr_impl.h
  46. +9
    -0
      dnn/src/naive/pooling/opr_impl.cpp
  47. +5
    -0
      dnn/src/naive/pooling/opr_impl.h
  48. +8
    -0
      dnn/src/rocm/batched_matrix_mul/opr_impl.cpp
  49. +3
    -0
      dnn/src/rocm/batched_matrix_mul/opr_impl.h
  50. +24
    -0
      dnn/src/rocm/convolution/opr_impl.cpp
  51. +9
    -0
      dnn/src/rocm/convolution/opr_impl.h
  52. +8
    -0
      dnn/src/rocm/matrix_mul/opr_impl.cpp
  53. +4
    -0
      dnn/src/rocm/matrix_mul/opr_impl.h
  54. +12
    -1
      dnn/src/rocm/pooling/opr_impl.cpp
  55. +5
    -0
      dnn/src/rocm/pooling/opr_impl.h
  56. +4
    -1
      dnn/src/x86/pooling/opr_impl.cpp
  57. +2
    -0
      dnn/src/x86/pooling/opr_impl.h
  58. +1
    -1
      dnn/test/common/accuracy_shake_checker.h
  59. +1
    -1
      dnn/test/common/benchmarker.h
  60. +2
    -2
      dnn/test/common/checker.h
  61. +3
    -3
      dnn/test/common/convolution.cpp
  62. +4
    -4
      dnn/test/common/opr_algo_proxy.h
  63. +3
    -3
      dnn/test/common/opr_proxy.h
  64. +2
    -2
      dnn/test/cuda/cutlass_matmul.cpp
  65. +2
    -2
      src/core/test/graph/misc.cpp
  66. +1
    -1
      src/opr/impl/search_policy/algo_chooser.cpp
  67. +10
    -0
      src/opr/test/dnn/convolution.cpp

+ 75
- 4
dnn/include/megdnn/oprs/base.h View File

@@ -315,7 +315,7 @@ public:
/*! /*!
* \brief get a string representation for current algorithm set; * \brief get a string representation for current algorithm set;
* *
* get_all_algorithms() may return different algorithms only if
* get_all_algorithms_safe() may return different algorithms only if
* algorithm set name differs. This is used for checking cache * algorithm set name differs. This is used for checking cache
* validity. * validity.
*/ */
@@ -354,6 +354,15 @@ public:
return ret; return ret;
} }


std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0,
const TensorLayout& p1) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms_safe(p0, p1)) {
ret.emplace_back(algo->info());
}
return ret;
}

/** /**
* \brief Returns the best algorithm information which indicate the * \brief Returns the best algorithm information which indicate the
* algorithm by heuristic. * algorithm by heuristic.
@@ -378,6 +387,8 @@ protected:
//! get all possible algorithms for the specified layouts //! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms( virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1) = 0; const TensorLayout& p0, const TensorLayout& p1) = 0;
virtual std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& p0, const TensorLayout& p1) = 0;


/** /**
* \brief Returns the best algorithm by heuristic. * \brief Returns the best algorithm by heuristic.
@@ -412,6 +423,16 @@ public:
return ret; return ret;
} }


std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2)) {
ret.emplace_back(algo->info());
}
return ret;
}

/** /**
* \brief Returns the best algorithm information which indicate the * \brief Returns the best algorithm information which indicate the
* algorithm by heuristic. * algorithm by heuristic.
@@ -438,6 +459,9 @@ protected:
virtual std::vector<Algorithm*> get_all_algorithms( virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2) = 0; const TensorLayout& p2) = 0;
virtual std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2) = 0;


/** /**
* \brief Returns the best algorithm by heuristic. * \brief Returns the best algorithm by heuristic.
@@ -463,7 +487,7 @@ public:
using AlgoAttribute = detail::Algorithm::Attribute; using AlgoAttribute = detail::Algorithm::Attribute;


//! get all possible algorithm decriptions for the specified layouts //! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
const TensorLayout& p1, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p2,
const TensorLayout& p3) { const TensorLayout& p3) {
@@ -474,6 +498,17 @@ public:
return ret; return ret;
} }


std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2,
const TensorLayout& p3) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3)) {
ret.emplace_back(algo->info());
}
return ret;
}

/** /**
* \brief Returns the best algorithm information which indicate the * \brief Returns the best algorithm information which indicate the
* algorithm by heuristic. * algorithm by heuristic.
@@ -500,6 +535,9 @@ protected:
virtual std::vector<Algorithm*> get_all_algorithms( virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3) = 0; const TensorLayout& p2, const TensorLayout& p3) = 0;
virtual std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3) = 0;


/** /**
* \brief Returns the best algorithm by heuristic. * \brief Returns the best algorithm by heuristic.
@@ -537,6 +575,18 @@ public:
return ret; return ret;
} }


std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2,
const TensorLayout& p3,
const TensorLayout& p4) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4)) {
ret.emplace_back(algo->info());
}
return ret;
}

/** /**
* \brief Returns the best algorithm information which indicate the * \brief Returns the best algorithm information which indicate the
* algorithm by heuristic. * algorithm by heuristic.
@@ -562,7 +612,11 @@ protected:
~MultiAlgoOpr() = default; ~MultiAlgoOpr() = default;


//! get all possible algorithms for the specified layouts //! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms(
virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4) = 0;
virtual std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3, const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4) = 0; const TensorLayout& p4) = 0;
@@ -604,6 +658,18 @@ public:
return ret; return ret;
} }


std::vector<AlgorithmInfo> get_all_algorithms_info_safe(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4, p5, p6, p7)) {
ret.emplace_back(algo->info());
}
return ret;
}

/** /**
* \brief Returns the best algorithm information which indicate the * \brief Returns the best algorithm information which indicate the
* algorithm by heuristic. * algorithm by heuristic.
@@ -629,7 +695,12 @@ protected:
~MultiAlgoOpr() = default; ~MultiAlgoOpr() = default;


//! get all possible algorithms for the specified layouts //! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms(
virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7) = 0;
virtual std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3, const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5, const TensorLayout& p4, const TensorLayout& p5,


+ 6
- 1
dnn/src/arm_common/pooling/opr_impl.cpp View File

@@ -172,9 +172,14 @@ std::vector<Algorithm*> PoolingImpl::get_all_algorithms(
ret.push_back(i); ret.push_back(i);
} }
} }
megdnn_assert(!ret.empty(), "no usable pooling fwd algorithm");
return ret; return ret;
} }
std::vector<Algorithm*> PoolingImpl::get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& dst) {
auto ret_safe = get_all_algorithms(src,dst);
megdnn_assert(!ret_safe.empty(), "no usable pooling fwd algorithm");
return ret_safe;
}


Algorithm* PoolingImpl::get_algorithm_heuristic( Algorithm* PoolingImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,


+ 2
- 0
dnn/src/arm_common/pooling/opr_impl.h View File

@@ -131,6 +131,8 @@ public:


std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& dst) override; const TensorLayout& src, const TensorLayout& dst) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& dst) override;


Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,


+ 8
- 2
dnn/src/common/algo_chooser.h View File

@@ -100,10 +100,16 @@ std::vector<typename Opr::Algorithm*> get_all_algorithms(
ret.push_back(i); ret.push_back(i);
} }
} }
megdnn_assert(!ret.empty(), "no algorithm for %s",
args.to_string().c_str());
return ret; return ret;
} }
template <class Opr>
std::vector<typename Opr::Algorithm*> get_all_algorithms_safe(
const typename Opr::AlgoBase::SizeArgs& args) {
auto ret_safe = get_all_algorithms<Opr>(args);
megdnn_assert(!ret_safe.empty(), "no algorithm for %s",
args.to_string().c_str());
return ret_safe;
}


/*! /*!
* \brief a helper function to get an algorithm match attribute. If require a * \brief a helper function to get an algorithm match attribute. If require a


+ 9
- 0
dnn/src/cuda/batch_conv_bias/opr_impl.cpp View File

@@ -51,6 +51,15 @@ BatchConvBiasForwardImpl::get_all_algorithms(const TensorLayout& src,
AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; AlgoBase::SizeArgs args{this, src, filter, bias, z, dst};
return megdnn::get_all_algorithms<BatchConvBiasForwardImpl>(args); return megdnn::get_all_algorithms<BatchConvBiasForwardImpl>(args);
} }
std::vector<BatchConvBiasForwardImpl::Algorithm*>
BatchConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& bias,
const TensorLayout& z,
const TensorLayout& dst) {
AlgoBase::SizeArgs args{this, src, filter, bias, z, dst};
return megdnn::get_all_algorithms_safe<BatchConvBiasForwardImpl>(args);
}


size_t BatchConvBiasForwardImpl::get_workspace_in_bytes( size_t BatchConvBiasForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,


+ 4
- 0
dnn/src/cuda/batch_conv_bias/opr_impl.h View File

@@ -42,6 +42,10 @@ protected:
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override; const TensorLayout& dst) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,


+ 6
- 0
dnn/src/cuda/batched_matrix_mul/opr_impl.cpp View File

@@ -51,6 +51,12 @@ std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms(
} }
return ret; return ret;
} }
std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms_safe(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) {
auto ret_safe = get_all_algorithms(A,B,C);
megdnn_assert(!ret_safe.empty(), "no usable batchedmatrixmulForward fwd algorithm");
return ret_safe;
}


Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,


+ 3
- 0
dnn/src/cuda/batched_matrix_mul/opr_impl.h View File

@@ -45,6 +45,9 @@ protected:
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A,
const TensorLayout& B, const TensorLayout& B,
const TensorLayout& C) override; const TensorLayout& C) override;
std::vector<Algorithm*> get_all_algorithms_safe(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,


+ 10
- 0
dnn/src/cuda/conv_bias/opr_impl.cpp View File

@@ -49,6 +49,16 @@ ConvBiasForwardImpl::get_all_algorithms(const TensorLayout& src,
{this, src, filter, bias, z, dst}); {this, src, filter, bias, z, dst});
} }


std::vector<ConvBiasForward::Algorithm*>
ConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& bias,
const TensorLayout& z,
const TensorLayout& dst) {
return megdnn::get_all_algorithms_safe<ConvBiasForwardImpl>(
{this, src, filter, bias, z, dst});
}

ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,


+ 4
- 0
dnn/src/cuda/conv_bias/opr_impl.h View File

@@ -84,6 +84,10 @@ public:
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override; const TensorLayout& dst) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,


+ 24
- 0
dnn/src/cuda/convolution/opr_impl.cpp View File

@@ -53,6 +53,14 @@ ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src,
return megdnn::get_all_algorithms<ConvolutionForwardImpl>(args); return megdnn::get_all_algorithms<ConvolutionForwardImpl>(args);
} }


std::vector<ConvolutionForwardImpl::Algorithm*>
ConvolutionForwardImpl::get_all_algorithms_safe(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst) {
AlgoBase::SizeArgs args{this, src, filter, dst};
return megdnn::get_all_algorithms_safe<ConvolutionForwardImpl>(args);
}

size_t ConvolutionForwardImpl::get_workspace_in_bytes( size_t ConvolutionForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, const TensorLayout& dst,
@@ -97,6 +105,14 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
{this, filter, diff, grad}); {this, filter, diff, grad});
} }


std::vector<ConvolutionBackwardDataImpl::Algorithm*>
ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad) {
return megdnn::get_all_algorithms_safe<ConvolutionBackwardDataImpl>(
{this, filter, diff, grad});
}

ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic( ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
@@ -222,6 +238,14 @@ ConvolutionBackwardFilterImpl::get_all_algorithms(const TensorLayout& src,
{this, src, diff, grad}); {this, src, diff, grad});
} }


std::vector<ConvolutionBackwardFilterImpl::Algorithm*>
ConvolutionBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad) {
return megdnn::get_all_algorithms_safe<ConvolutionBackwardFilterImpl>(
{this, src, diff, grad});
}

ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,


+ 12
- 0
dnn/src/cuda/convolution/opr_impl.h View File

@@ -59,6 +59,10 @@ protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override; const TensorLayout& dst) override;

std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes, const TensorLayout& dst, size_t workspace_limit_in_bytes,
@@ -111,6 +115,10 @@ protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override; const TensorLayout& grad) override;

std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
@@ -159,6 +167,10 @@ protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override; const TensorLayout& grad) override;

std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,


+ 24
- 0
dnn/src/cuda/convolution3d/opr_impl.cpp View File

@@ -108,6 +108,14 @@ Convolution3DForwardImpl::get_all_algorithms(const TensorLayout& src,
{this, src, filter, dst}); {this, src, filter, dst});
} }


std::vector<Convolution3DForwardImpl::Algorithm*>
Convolution3DForwardImpl::get_all_algorithms_safe(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst) {
return megdnn::get_all_algorithms_safe<Convolution3DForwardImpl>(
{this, src, filter, dst});
}

size_t Convolution3DForwardImpl::get_workspace_in_bytes( size_t Convolution3DForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) { const TensorLayout& dst) {
@@ -146,6 +154,14 @@ Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
{this, filter, diff, grad}); {this, filter, diff, grad});
} }


std::vector<Convolution3DBackwardDataImpl::Algorithm*>
Convolution3DBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad) {
return megdnn::get_all_algorithms_safe<Convolution3DBackwardDataImpl>(
{this, filter, diff, grad});
}

Convolution3DBackwardDataImpl::Algorithm* Convolution3DBackwardDataImpl::Algorithm*
Convolution3DBackwardDataImpl::get_algorithm_heuristic( Convolution3DBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
@@ -226,6 +242,14 @@ Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout& src,
{this, src, diff, grad}); {this, src, diff, grad});
} }


std::vector<Convolution3DBackwardFilterImpl::Algorithm*>
Convolution3DBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad) {
return megdnn::get_all_algorithms_safe<Convolution3DBackwardFilterImpl>(
{this, src, diff, grad});
}

Convolution3DBackwardFilterImpl::Algorithm* Convolution3DBackwardFilterImpl::Algorithm*
Convolution3DBackwardFilterImpl::get_algorithm_heuristic( Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,


+ 9
- 0
dnn/src/cuda/convolution3d/opr_impl.h View File

@@ -39,6 +39,9 @@ protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override; const TensorLayout& dst) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes, const TensorLayout& dst, size_t workspace_limit_in_bytes,
@@ -72,6 +75,9 @@ public:


protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override; const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
@@ -109,6 +115,9 @@ protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override; const TensorLayout& grad) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,


+ 25
- 0
dnn/src/cuda/deformable_conv/opr_impl.cpp View File

@@ -51,6 +51,15 @@ std::vector<AlgoFwd*> Fwd::get_all_algorithms(const TensorLayout& /* im */,


return algos; return algos;
} }
std::vector<AlgoFwd*> Fwd::get_all_algorithms_safe(const TensorLayout& im,
const TensorLayout& filter,
const TensorLayout& offset,
const TensorLayout& mask,
const TensorLayout& dst) {
auto ret_safe = Fwd::get_all_algorithms(im,filter,offset,mask,dst);
megdnn_assert(!ret_safe.empty(), "no usable deformable_conv fwd algorithm");
return ret_safe;
}


AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im,
const TensorLayout& filter, const TensorLayout& filter,
@@ -115,6 +124,14 @@ std::vector<AlgoBwdFlt*> BwdFlt::get_all_algorithms(const TensorLayout& /* im */
return algos; return algos;
} }


std::vector<AlgoBwdFlt*> BwdFlt::get_all_algorithms_safe(const TensorLayout& im,
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& filter_grad) {
auto ret_safe = BwdFlt::get_all_algorithms(im,offset,mask,out_grad,filter_grad);
megdnn_assert(!ret_safe.empty(), "no usable deformable_conv bwd filter algorithm");
return ret_safe;
}

AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( AlgoBwdFlt* BwdFlt::get_algorithm_heuristic(
const TensorLayout& im, const TensorLayout& offset, const TensorLayout& im, const TensorLayout& offset,
const TensorLayout& mask, const TensorLayout& out_grad, const TensorLayout& mask, const TensorLayout& out_grad,
@@ -181,6 +198,14 @@ std::vector<AlgoBwdData*> BwdData::get_all_algorithms(
algos.push_back(static_cast<AlgoBwdData*>(i)); algos.push_back(static_cast<AlgoBwdData*>(i));
return algos; return algos;
} }
std::vector<AlgoBwdData*> BwdData::get_all_algorithms_safe(
const TensorLayout& im, const TensorLayout& filter,
const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& out_grad,
const TensorLayout& im_grad, const TensorLayout& offset_grad, const TensorLayout& mask_grad ) {
auto ret_safe = BwdData::get_all_algorithms(im,filter,offset,mask,out_grad,im_grad,offset_grad,mask_grad);
megdnn_assert(!ret_safe.empty(), "no usable deformable_conv bwd data algorithm");
return ret_safe;
}


AlgoBwdData* BwdData::get_algorithm_heuristic( AlgoBwdData* BwdData::get_algorithm_heuristic(
const TensorLayout& im, const TensorLayout& filter, const TensorLayout& im, const TensorLayout& filter,


+ 15
- 0
dnn/src/cuda/deformable_conv/opr_impl.h View File

@@ -54,6 +54,10 @@ protected:
const TensorLayout& im, const TensorLayout& filter, const TensorLayout& im, const TensorLayout& filter,
const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& dst) override; const TensorLayout& dst) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& im, const TensorLayout& filter,
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& dst) override;


Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& im, const TensorLayout& filter, const TensorLayout& im, const TensorLayout& filter,
@@ -105,6 +109,10 @@ protected:
const TensorLayout& im, const TensorLayout& offset, const TensorLayout& im, const TensorLayout& offset,
const TensorLayout& mask, const TensorLayout& out_grad, const TensorLayout& mask, const TensorLayout& out_grad,
const TensorLayout& filter_grad) override; const TensorLayout& filter_grad) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& im, const TensorLayout& offset,
const TensorLayout& mask, const TensorLayout& out_grad,
const TensorLayout& filter_grad) override;


Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& im, const TensorLayout& offset, const TensorLayout& im, const TensorLayout& offset,
@@ -161,6 +169,13 @@ protected:
const TensorLayout& out_grad, const TensorLayout& im_grad, const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad, const TensorLayout& offset_grad,
const TensorLayout& mask_grad) override; const TensorLayout& mask_grad) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& im, const TensorLayout& filter,
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad,
const TensorLayout& mask_grad) override;


Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& im, const TensorLayout& filter, const TensorLayout& im, const TensorLayout& filter,


+ 24
- 1
dnn/src/cuda/local_share/opr_impl.cpp View File

@@ -47,7 +47,6 @@ LocalShareForwardImpl::get_algorithm_heuristic(
Algorithm::attribute_str(positive_attr).c_str(), Algorithm::attribute_str(positive_attr).c_str(),
args.to_string().c_str(), workspace_limit_in_bytes)); args.to_string().c_str(), workspace_limit_in_bytes));
} }

std::vector<LocalShareForwardImpl::Algorithm*> std::vector<LocalShareForwardImpl::Algorithm*>
LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src, LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src,
const TensorLayout& filter, const TensorLayout& filter,
@@ -56,6 +55,14 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src,
return megdnn::get_all_algorithms<LocalShareForwardImpl>(args); return megdnn::get_all_algorithms<LocalShareForwardImpl>(args);
} }


std::vector<LocalShareForwardImpl::Algorithm*>
LocalShareForwardImpl::get_all_algorithms_safe(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst) {
AlgoBase::SizeArgs args{this, src, filter, dst};
return megdnn::get_all_algorithms_safe<LocalShareForwardImpl>(args);
}

size_t LocalShareForwardImpl::get_workspace_in_bytes(const TensorLayout& src, size_t LocalShareForwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& filter, const TensorLayout& filter,
const TensorLayout& dst) { const TensorLayout& dst) {
@@ -109,6 +116,14 @@ LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
return megdnn::get_all_algorithms<LocalShareBackwardDataImpl>(args); return megdnn::get_all_algorithms<LocalShareBackwardDataImpl>(args);
} }


std::vector<LocalShareBackwardDataImpl::Algorithm*>
LocalShareBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad) {
AlgoBase::SizeArgs args{this, filter, diff, grad};
return megdnn::get_all_algorithms_safe<LocalShareBackwardDataImpl>(args);
}

size_t LocalShareBackwardDataImpl::get_workspace_in_bytes(const TensorLayout& filter, size_t LocalShareBackwardDataImpl::get_workspace_in_bytes(const TensorLayout& filter,
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad) { const TensorLayout& grad) {
@@ -162,6 +177,14 @@ LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout& src,
return megdnn::get_all_algorithms<LocalShareBackwardFilterImpl>(args); return megdnn::get_all_algorithms<LocalShareBackwardFilterImpl>(args);
} }


std::vector<LocalShareBackwardFilterImpl::Algorithm*>
LocalShareBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad) {
AlgoBase::SizeArgs args{this, src, diff, grad};
return megdnn::get_all_algorithms_safe<LocalShareBackwardFilterImpl>(args);
}

size_t LocalShareBackwardFilterImpl::get_workspace_in_bytes(const TensorLayout& src, size_t LocalShareBackwardFilterImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad) { const TensorLayout& grad) {


+ 9
- 0
dnn/src/cuda/local_share/opr_impl.h View File

@@ -37,6 +37,9 @@ public:


protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override; const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
@@ -72,6 +75,9 @@ protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override; const TensorLayout& grad) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
@@ -105,6 +111,9 @@ protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override; const TensorLayout& grad) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,


+ 8
- 0
dnn/src/cuda/matrix_mul/opr_impl.cpp View File

@@ -28,6 +28,14 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args); return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args);
} }


std::vector<MatrixMulForwardImpl::Algorithm*>
MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) {
AlgoBase::SizeArgs args{this, A, B, C};
return megdnn::get_all_algorithms_safe<MatrixMulForwardImpl>(args);
}

MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,


+ 4
- 0
dnn/src/cuda/matrix_mul/opr_impl.h View File

@@ -60,6 +60,10 @@ protected:
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A,
const TensorLayout& B, const TensorLayout& B,
const TensorLayout& C) override; const TensorLayout& C) override;

std::vector<Algorithm*> get_all_algorithms_safe(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,


+ 14
- 0
dnn/src/cuda/pooling/opr_impl.cpp View File

@@ -33,6 +33,11 @@ PoolingForwardImpl::get_all_algorithms(const TensorLayout& src,
const TensorLayout& dst) { const TensorLayout& dst) {
return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst}); return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst});
} }
std::vector<PoolingForwardImpl::Algorithm*>
PoolingForwardImpl::get_all_algorithms_safe(const TensorLayout& src,
const TensorLayout& dst) {
return megdnn::get_all_algorithms_safe<PoolingForwardImpl>({this, src, dst});
}


PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,
@@ -77,6 +82,15 @@ PoolingBackwardImpl::get_all_algorithms(const TensorLayout& src,
{this, src, dst, diff, grad}); {this, src, dst, diff, grad});
} }


std::vector<PoolingBackwardImpl::Algorithm*>
PoolingBackwardImpl::get_all_algorithms_safe(const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& diff,
const TensorLayout& grad) {
return megdnn::get_all_algorithms_safe<PoolingBackwardImpl>(
{this, src, dst, diff, grad});
}

PoolingBackwardImpl::Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( PoolingBackwardImpl::Algorithm* PoolingBackwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad, const TensorLayout& diff, const TensorLayout& grad,


+ 5
- 0
dnn/src/cuda/pooling/opr_impl.h View File

@@ -55,6 +55,8 @@ public:
protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& dst) override; const TensorLayout& src, const TensorLayout& dst) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
@@ -99,6 +101,9 @@ protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad) override; const TensorLayout& diff, const TensorLayout& grad) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad, const TensorLayout& diff, const TensorLayout& grad,


+ 7
- 0
dnn/src/fallback/batched_matrix_mul/opr_impl.cpp View File

@@ -26,6 +26,13 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
AlgoBase::SizeArgs args{this, A, B, C}; AlgoBase::SizeArgs args{this, A, B, C};
return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args);
} }
std::vector<BatchedMatrixMulForwardImpl::Algorithm*>
BatchedMatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) {
AlgoBase::SizeArgs args{this, A, B, C};
return megdnn::get_all_algorithms_safe<BatchedMatrixMulForwardImpl>(args);
}


BatchedMatrixMulForwardImpl::Algorithm* BatchedMatrixMulForwardImpl::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( BatchedMatrixMulForwardImpl::get_algorithm_heuristic(


+ 3
- 0
dnn/src/fallback/batched_matrix_mul/opr_impl.h View File

@@ -35,6 +35,9 @@ private:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override; const TensorLayout& /*C*/) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override;


Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,


+ 8
- 1
dnn/src/fallback/conv_bias/opr_impl.cpp View File

@@ -279,11 +279,18 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms(
auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
auto ret = get_all_algorithms_with_ncb(fparam); auto ret = get_all_algorithms_with_ncb(fparam);
if (ret.empty()) { if (ret.empty()) {
return naive::ConvBiasForwardImpl::get_all_algorithms(src, filter, bias,
return naive::ConvBiasForwardImpl::get_all_algorithms_safe(src, filter, bias,
z, dst); z, dst);
} }
return ret; return ret;
} }
std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) {
auto ret_safe = ConvBiasImpl::get_all_algorithms(src,filter,bias,z,dst);
return ret_safe;
}


ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic( ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,


+ 4
- 0
dnn/src/fallback/conv_bias/opr_impl.h View File

@@ -87,6 +87,10 @@ public:
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override; const TensorLayout& dst) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override;


//! implemented by get_algorithm_heuristic_with_ncb() //! implemented by get_algorithm_heuristic_with_ncb()
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(


+ 18
- 2
dnn/src/fallback/convolution/opr_impl.cpp View File

@@ -198,12 +198,19 @@ std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms(
auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
auto ret = get_all_algorithms_with_ncb(fparam); auto ret = get_all_algorithms_with_ncb(fparam);
if (ret.empty()) { if (ret.empty()) {
return naive::ConvolutionForwardImpl::get_all_algorithms(src, filter,
return naive::ConvolutionForwardImpl::get_all_algorithms_safe(src, filter,
dst); dst);
} }
return ret; return ret;
} }


std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) {
auto ret_safe = ConvolutionImpl::get_all_algorithms(src,filter,dst);
return ret_safe;
}

ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic( ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes, const TensorLayout& dst, size_t workspace_limit_in_bytes,
@@ -536,10 +543,19 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
} }
auto fparam = make_ncb_kern_size_param(filter, diff, grad); auto fparam = make_ncb_kern_size_param(filter, diff, grad);
auto ret = get_all_algorithms_with_ncb(fparam); auto ret = get_all_algorithms_with_ncb(fparam);
megdnn_assert(!ret.empty(), "no usable conv fwd algorithm");
return ret; return ret;
} }


std::vector<ConvolutionBackwardDataImpl::Algorithm*>
ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad) {
auto ret_safe = ConvolutionBackwardDataImpl::get_all_algorithms(filter,diff,grad);
megdnn_assert(!ret_safe.empty(), "no usable conv bwd algorithm");
return ret_safe;
}

ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic( ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,


+ 7
- 0
dnn/src/fallback/convolution/opr_impl.h View File

@@ -85,6 +85,10 @@ public:
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override; const TensorLayout& dst) override;


std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;

//! implemented by get_algorithm_heuristic_with_ncb() //! implemented by get_algorithm_heuristic_with_ncb()
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
@@ -326,6 +330,9 @@ public:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override; const TensorLayout& grad) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,


+ 7
- 0
dnn/src/fallback/matrix_mul/opr_impl.cpp View File

@@ -96,6 +96,13 @@ std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms(
return gemv_algos; return gemv_algos;
} }


std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms_safe(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) {
auto gemv_algos_safe = get_all_algorithms(A,B,C);
megdnn_assert(!gemv_algos_safe.empty(), "no usable MatrixMul fwd algorithm");
return gemv_algos_safe;
}

MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc( MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) { const AlgorithmDesc& desc) {
if (!desc.valid()) { if (!desc.valid()) {


+ 4
- 0
dnn/src/fallback/matrix_mul/opr_impl.h View File

@@ -270,6 +270,10 @@ protected:
const TensorLayout& B, const TensorLayout& B,
const TensorLayout& C) override; const TensorLayout& C) override;


std::vector<Algorithm*> get_all_algorithms_safe(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) override;

Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,


+ 10
- 0
dnn/src/naive/batch_conv_bias/opr_impl.cpp View File

@@ -128,6 +128,16 @@ BatchConvBiasForwardImpl::get_all_algorithms(const TensorLayout&,
->default_batch_conv_bias_fwd_algo()}; ->default_batch_conv_bias_fwd_algo()};
} }


std::vector<BatchConvBiasForward::Algorithm*>
BatchConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout&,
const TensorLayout&,
const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())
->default_batch_conv_bias_fwd_algo()};
}

BatchConvBiasForward::Algorithm* BatchConvBiasForward::Algorithm*
BatchConvBiasForwardImpl::get_algorithm_heuristic( BatchConvBiasForwardImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* filter */, const TensorLayout& /* src */, const TensorLayout& /* filter */,


+ 5
- 0
dnn/src/naive/batch_conv_bias/opr_impl.h View File

@@ -30,6 +30,11 @@ public:
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override; const TensorLayout& dst) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override;


Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,


+ 7
- 1
dnn/src/naive/batched_matrix_mul/opr_impl.cpp View File

@@ -63,7 +63,6 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A,
} }


} }

std::vector<BatchedMatrixMulForward::Algorithm*> std::vector<BatchedMatrixMulForward::Algorithm*>
BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/,
const TensorLayout& /*B*/, const TensorLayout& /*B*/,
@@ -71,6 +70,13 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/,
return {static_cast<HandleImpl*>(handle()) return {static_cast<HandleImpl*>(handle())
->default_batched_matmul_fwd_algo()}; ->default_batched_matmul_fwd_algo()};
} }
std::vector<BatchedMatrixMulForward::Algorithm*>
BatchedMatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/) {
return {static_cast<HandleImpl*>(handle())
->default_batched_matmul_fwd_algo()};
}


BatchedMatrixMulForward::Algorithm* BatchedMatrixMulForward::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( BatchedMatrixMulForwardImpl::get_algorithm_heuristic(


+ 3
- 0
dnn/src/naive/batched_matrix_mul/opr_impl.h View File

@@ -27,6 +27,9 @@ public:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override; const TensorLayout& /*C*/) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override;


Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,


+ 9
- 0
dnn/src/naive/conv_bias/opr_impl.cpp View File

@@ -321,6 +321,15 @@ ConvBiasForwardImpl::get_all_algorithms(const TensorLayout&,
return {static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo()}; return {static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo()};
} }


std::vector<ConvBiasForward::Algorithm*>
ConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout&,
const TensorLayout&,
const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo()};
}

ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* filter */, const TensorLayout& /* src */, const TensorLayout& /* filter */,
const TensorLayout& /* bias */, const TensorLayout& /* z */, const TensorLayout& /* bias */, const TensorLayout& /* z */,


+ 5
- 0
dnn/src/naive/conv_bias/opr_impl.h View File

@@ -31,6 +31,11 @@ public:
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override; const TensorLayout& dst) override;


std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override;

Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,


+ 21
- 0
dnn/src/naive/convolution/convolution.cpp View File

@@ -287,6 +287,13 @@ ConvolutionForwardImpl:: get_all_algorithms(const TensorLayout &,
return {static_cast<HandleImpl *>(handle())->default_conv_fwd_algo()}; return {static_cast<HandleImpl *>(handle())->default_conv_fwd_algo()};
} }


std::vector<ConvolutionForward::Algorithm *>
ConvolutionForwardImpl:: get_all_algorithms_safe(const TensorLayout &,
const TensorLayout &, const TensorLayout &)
{
return {static_cast<HandleImpl *>(handle())->default_conv_fwd_algo()};
}

ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* filter */, const TensorLayout& /* src */, const TensorLayout& /* filter */,
const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */,
@@ -313,6 +320,13 @@ ConvolutionBackwardDataImpl:: get_all_algorithms(const TensorLayout &,
return {static_cast<HandleImpl *>(handle())->default_conv_bwd_data_algo()}; return {static_cast<HandleImpl *>(handle())->default_conv_bwd_data_algo()};
} }


std::vector<ConvolutionBackwardData::Algorithm *>
ConvolutionBackwardDataImpl:: get_all_algorithms_safe(const TensorLayout &,
const TensorLayout &, const TensorLayout &)
{
return {static_cast<HandleImpl *>(handle())->default_conv_bwd_data_algo()};
}

ConvolutionBackwardData::Algorithm* ConvolutionBackwardData::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic( ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& /* filter */, const TensorLayout& /* diff */, const TensorLayout& /* filter */, const TensorLayout& /* diff */,
@@ -341,6 +355,13 @@ ConvolutionBackwardFilterImpl:: get_all_algorithms(const TensorLayout &,
return {static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo()}; return {static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo()};
} }


std::vector<ConvolutionBackwardFilter::Algorithm *>
ConvolutionBackwardFilterImpl:: get_all_algorithms_safe(const TensorLayout &,
const TensorLayout &, const TensorLayout &)
{
return {static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo()};
}

ConvolutionBackwardFilter::Algorithm* ConvolutionBackwardFilter::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */, const TensorLayout& /* src */, const TensorLayout& /* diff */,


+ 9
- 0
dnn/src/naive/convolution/opr_impl.h View File

@@ -25,6 +25,9 @@ class ConvolutionForwardImpl: public ConvolutionForward {
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src,
const TensorLayout &filter, const TensorLayout &filter,
const TensorLayout &dst) override; const TensorLayout &dst) override;
std::vector<Algorithm *> get_all_algorithms_safe(const TensorLayout &src,
const TensorLayout &filter,
const TensorLayout &dst) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes, const TensorLayout& dst, size_t workspace_limit_in_bytes,
@@ -67,6 +70,9 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter,
const TensorLayout &diff, const TensorLayout &diff,
const TensorLayout &grad) override; const TensorLayout &grad) override;
std::vector<Algorithm *> get_all_algorithms_safe(const TensorLayout &filter,
const TensorLayout &diff,
const TensorLayout &grad) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
@@ -90,6 +96,9 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src,
const TensorLayout &diff, const TensorLayout &diff,
const TensorLayout &grad) override; const TensorLayout &grad) override;
std::vector<Algorithm *> get_all_algorithms_safe(const TensorLayout &src,
const TensorLayout &diff,
const TensorLayout &grad) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,


+ 21
- 1
dnn/src/naive/convolution3d/convolution3d.cpp View File

@@ -108,13 +108,18 @@ void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src,


megdnn_assert_internal(0); megdnn_assert_internal(0);
} }

std::vector<Convolution3DForward::Algorithm*> std::vector<Convolution3DForward::Algorithm*>
Convolution3DForwardImpl::get_all_algorithms(const TensorLayout&, Convolution3DForwardImpl::get_all_algorithms(const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout&) { const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()}; return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()};
} }
std::vector<Convolution3DForward::Algorithm*>
Convolution3DForwardImpl::get_all_algorithms_safe(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()};
}


Convolution3DForward::Algorithm* Convolution3DForward::Algorithm*
Convolution3DForwardImpl::get_algorithm_heuristic( Convolution3DForwardImpl::get_algorithm_heuristic(
@@ -143,6 +148,13 @@ Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout&,
return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()}; return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()};
} }


std::vector<Convolution3DBackwardData::Algorithm*>
Convolution3DBackwardDataImpl::get_all_algorithms_safe(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()};
}

Convolution3DBackwardData::Algorithm* Convolution3DBackwardData::Algorithm*
Convolution3DBackwardDataImpl::get_algorithm_heuristic( Convolution3DBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& /* filter */, const TensorLayout& /* diff */, const TensorLayout& /* filter */, const TensorLayout& /* diff */,
@@ -172,6 +184,14 @@ Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout&,
->default_conv3d_bwd_filter_algo()}; ->default_conv3d_bwd_filter_algo()};
} }


std::vector<Convolution3DBackwardFilter::Algorithm*>
Convolution3DBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())
->default_conv3d_bwd_filter_algo()};
}

Convolution3DBackwardFilter::Algorithm* Convolution3DBackwardFilter::Algorithm*
Convolution3DBackwardFilterImpl::get_algorithm_heuristic( Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */, const TensorLayout& /* src */, const TensorLayout& /* diff */,


+ 9
- 0
dnn/src/naive/convolution3d/opr_impl.h View File

@@ -22,6 +22,9 @@ public:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override; const TensorLayout& dst) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes, const TensorLayout& dst, size_t workspace_limit_in_bytes,
@@ -44,6 +47,9 @@ public:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override; const TensorLayout& grad) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
@@ -66,6 +72,9 @@ public:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override; const TensorLayout& grad) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,


+ 23
- 0
dnn/src/naive/deformable_conv/opr_impl.h View File

@@ -25,6 +25,12 @@ public:
const TensorLayout& /* dst */) override { const TensorLayout& /* dst */) override {
return std::vector<Algorithm*>(); return std::vector<Algorithm*>();
}; };
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& /* im */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* dst */) override {
return std::vector<Algorithm*>();
};


Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* filter */, const TensorLayout& /* src */, const TensorLayout& /* filter */,
@@ -67,6 +73,13 @@ public:
const TensorLayout& /* filter_grad */) override { const TensorLayout& /* filter_grad */) override {
return std::vector<Algorithm*>(); return std::vector<Algorithm*>();
}; };
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& /* im */, const TensorLayout& /* offset */,
const TensorLayout& /* mask */, const TensorLayout& /* out_grad */,
const TensorLayout& /* filter_grad */) override {
return std::vector<Algorithm*>();
};


Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& /* im */, const TensorLayout& /* offset */, const TensorLayout& /* im */, const TensorLayout& /* offset */,
@@ -112,6 +125,16 @@ public:
return std::vector<Algorithm*>(); return std::vector<Algorithm*>();
}; };


std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& /* im */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */,
const TensorLayout& /* out_grad */,
const TensorLayout& /* im_grad */,
const TensorLayout& /* offset_grad */,
const TensorLayout& /* mask_grad */) override {
return std::vector<Algorithm*>();
};

Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& /* im */, const TensorLayout& /* filter */, const TensorLayout& /* im */, const TensorLayout& /* filter */,
const TensorLayout& /* offset */, const TensorLayout& /* mask */, const TensorLayout& /* offset */, const TensorLayout& /* mask */,


+ 23
- 0
dnn/src/naive/local_share/opr_impl.cpp View File

@@ -159,6 +159,13 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout&,
return {static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo()}; return {static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo()};
} }


std::vector<LocalShareForward::Algorithm*>
LocalShareForwardImpl::get_all_algorithms_safe(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo()};
}

LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */, const TensorLayout& /* src */, const TensorLayout& /* diff */,
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
@@ -187,6 +194,14 @@ LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&,
->default_local_share_bwd_data_algo()}; ->default_local_share_bwd_data_algo()};
} }


std::vector<LocalShareBackwardData::Algorithm*>
LocalShareBackwardDataImpl::get_all_algorithms_safe(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())
->default_local_share_bwd_data_algo()};
}

LocalShareBackwardData::Algorithm* LocalShareBackwardData::Algorithm*
LocalShareBackwardDataImpl::get_algorithm_heuristic( LocalShareBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& /* filter */, const TensorLayout& /* diff */, const TensorLayout& /* filter */, const TensorLayout& /* diff */,
@@ -216,6 +231,14 @@ LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&,
->default_local_share_bwd_filter_algo()}; ->default_local_share_bwd_filter_algo()};
} }


std::vector<LocalShareBackwardFilter::Algorithm*>
LocalShareBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())
->default_local_share_bwd_filter_algo()};
}

LocalShareBackwardFilter::Algorithm* LocalShareBackwardFilter::Algorithm*
LocalShareBackwardFilterImpl::get_algorithm_heuristic( LocalShareBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& /* src */, const TensorLayout& /* diff */, const TensorLayout& /* src */, const TensorLayout& /* diff */,


+ 12
- 1
dnn/src/naive/local_share/opr_impl.h View File

@@ -30,6 +30,10 @@ public:
const TensorLayout& /*src*/, const TensorLayout& /*filter*/, const TensorLayout& /*src*/, const TensorLayout& /*filter*/,
const TensorLayout& /*dst*/) override; const TensorLayout& /*dst*/) override;


std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& /*src*/, const TensorLayout& /*filter*/,
const TensorLayout& /*dst*/) override;

Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& /*src*/, const TensorLayout& /*filter*/, const TensorLayout& /*src*/, const TensorLayout& /*filter*/,
const TensorLayout& /*dst*/, size_t /*workspace_limit_in_bytes*/, const TensorLayout& /*dst*/, size_t /*workspace_limit_in_bytes*/,
@@ -55,6 +59,10 @@ public:
const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, const TensorLayout& /*filter*/, const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/) override; const TensorLayout& /*grad*/) override;


std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& /*filter*/, const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/) override;

Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, const TensorLayout& /*filter*/, const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/,
@@ -75,11 +83,14 @@ public:
const TensorLayout&) override { const TensorLayout&) override {
return 0; return 0;
} }

std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*src*/, const TensorLayout& /*diff*/, const TensorLayout& /*src*/, const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/) override; const TensorLayout& /*grad*/) override;


std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& /*src*/, const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/) override;

Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& /*src*/, const TensorLayout& /*diff*/, const TensorLayout& /*src*/, const TensorLayout& /*diff*/,
const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/,


+ 7
- 0
dnn/src/naive/matrix_mul/opr_impl.cpp View File

@@ -88,6 +88,13 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/,
return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()}; return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()};
} }


std::vector<MatrixMulForward::Algorithm*>
MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& /*A*/,
const TensorLayout& /*B*/,
const TensorLayout& /*C*/) {
return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()};
}

MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,


+ 4
- 0
dnn/src/naive/matrix_mul/opr_impl.h View File

@@ -29,6 +29,10 @@ public:
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override; const TensorLayout& /*C*/) override;


std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override;

Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,


+ 9
- 0
dnn/src/naive/pooling/opr_impl.cpp View File

@@ -603,6 +603,10 @@ std::vector<Algorithm*> PoolingForwardImpl::get_all_algorithms(
const TensorLayout&, const TensorLayout&) { const TensorLayout&, const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())->default_pooling_fwd_algo()}; return {static_cast<HandleImpl*>(handle())->default_pooling_fwd_algo()};
} }
std::vector<Algorithm*> PoolingForwardImpl::get_all_algorithms_safe(
const TensorLayout&, const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())->default_pooling_fwd_algo()};
}


Algorithm* PoolingForwardImpl::get_algorithm_heuristic( Algorithm* PoolingForwardImpl::get_algorithm_heuristic(
const TensorLayout& /*src*/, const TensorLayout& /*dst*/, const TensorLayout& /*src*/, const TensorLayout& /*dst*/,
@@ -626,6 +630,11 @@ std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms(
const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) { const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) {
return {static_cast<HandleImpl*>(handle())->default_pooling_bwd_algo()}; return {static_cast<HandleImpl*>(handle())->default_pooling_bwd_algo()};
} }
std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms_safe(
const TensorLayout& /*src*/, const TensorLayout& /*dst*/,
const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) {
return {static_cast<HandleImpl*>(handle())->default_pooling_bwd_algo()};
}


Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( Algorithm* PoolingBackwardImpl::get_algorithm_heuristic(
const TensorLayout& /*src*/, const TensorLayout& /*dst*/, const TensorLayout& /*src*/, const TensorLayout& /*dst*/,


+ 5
- 0
dnn/src/naive/pooling/opr_impl.h View File

@@ -35,6 +35,8 @@ class PoolingForwardImpl: public PoolingForward {


std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& dst) override; const TensorLayout& src, const TensorLayout& dst) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& dst) override;


Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,
@@ -60,6 +62,9 @@ public:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad) override; const TensorLayout& diff, const TensorLayout& grad) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad) override;


Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,


+ 8
- 0
dnn/src/rocm/batched_matrix_mul/opr_impl.cpp View File

@@ -29,6 +29,14 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args);
} }


std::vector<BatchedMatrixMulForwardImpl::Algorithm*>
BatchedMatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) {
AlgoBase::SizeArgs args{this, A, B, C};
return megdnn::get_all_algorithms_safe<BatchedMatrixMulForwardImpl>(args);
}

BatchedMatrixMulForwardImpl::Algorithm* BatchedMatrixMulForwardImpl::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,


+ 3
- 0
dnn/src/rocm/batched_matrix_mul/opr_impl.h View File

@@ -35,6 +35,9 @@ private:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override; const TensorLayout& /*C*/) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override;


Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,


+ 24
- 0
dnn/src/rocm/convolution/opr_impl.cpp View File

@@ -109,6 +109,14 @@ ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src,
{this, src, filter, dst}); {this, src, filter, dst});
} }


std::vector<ConvolutionForwardImpl::Algorithm*>
ConvolutionForwardImpl::get_all_algorithms_safe(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst) {
return megdnn::get_all_algorithms_safe<ConvolutionForwardImpl>(
{this, src, filter, dst});
}

size_t ConvolutionForwardImpl::get_workspace_in_bytes( size_t ConvolutionForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, const PreprocessedFilter*) { const TensorLayout& dst, const PreprocessedFilter*) {
@@ -162,6 +170,14 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
{this, filter, diff, grad}); {this, filter, diff, grad});
} }


std::vector<ConvolutionBackwardDataImpl::Algorithm*>
ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad) {
return megdnn::get_all_algorithms_safe<ConvolutionBackwardDataImpl>(
{this, filter, diff, grad});
}

ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_heuristic( ConvolutionBackwardDataImpl::get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
@@ -243,6 +259,14 @@ ConvolutionBackwardFilterImpl::get_all_algorithms(const TensorLayout& src,
{this, src, diff, grad}); {this, src, diff, grad});
} }


std::vector<ConvolutionBackwardFilterImpl::Algorithm*>
ConvolutionBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad) {
return megdnn::get_all_algorithms_safe<ConvolutionBackwardFilterImpl>(
{this, src, diff, grad});
}

ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,


+ 9
- 0
dnn/src/rocm/convolution/opr_impl.h View File

@@ -74,6 +74,9 @@ private:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override; const TensorLayout& dst) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, size_t workspace_limit_in_bytes, const TensorLayout& dst, size_t workspace_limit_in_bytes,
@@ -123,6 +126,9 @@ private:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override; const TensorLayout& grad) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,
@@ -172,6 +178,9 @@ private:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override; const TensorLayout& grad) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes, const TensorLayout& grad, size_t workspace_limit_in_bytes,


+ 8
- 0
dnn/src/rocm/matrix_mul/opr_impl.cpp View File

@@ -27,6 +27,14 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A,
return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args); return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args);
} }


std::vector<MatrixMulForwardImpl::Algorithm*>
MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) {
AlgoBase::SizeArgs args{this, A, B, C};
return megdnn::get_all_algorithms_safe<MatrixMulForwardImpl>(args);
}

MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,


+ 4
- 0
dnn/src/rocm/matrix_mul/opr_impl.h View File

@@ -36,6 +36,10 @@ private:
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override; const TensorLayout& /*C*/) override;


std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/) override;

Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/,


+ 12
- 1
dnn/src/rocm/pooling/opr_impl.cpp View File

@@ -25,12 +25,16 @@ size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const char* PoolingForwardImpl::get_algorithm_set_name() const { const char* PoolingForwardImpl::get_algorithm_set_name() const {
return "ROCM_POOLING_FORWARD"; return "ROCM_POOLING_FORWARD";
} }

std::vector<PoolingForwardImpl::Algorithm*> std::vector<PoolingForwardImpl::Algorithm*>
PoolingForwardImpl::get_all_algorithms(const TensorLayout& src, PoolingForwardImpl::get_all_algorithms(const TensorLayout& src,
const TensorLayout& dst) { const TensorLayout& dst) {
return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst}); return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst});
} }
std::vector<PoolingForwardImpl::Algorithm*>
PoolingForwardImpl::get_all_algorithms_safe(const TensorLayout& src,
const TensorLayout& dst) {
return megdnn::get_all_algorithms_safe<PoolingForwardImpl>({this, src, dst});
}


PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,
@@ -82,6 +86,13 @@ std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms(
{this, src, dst, diff, grad}); {this, src, dst, diff, grad});
} }


std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad) {
return megdnn::get_all_algorithms_safe<PoolingBackwardImpl>(
{this, src, dst, diff, grad});
}

Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( Algorithm* PoolingBackwardImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad, const TensorLayout& diff, const TensorLayout& grad,


+ 5
- 0
dnn/src/rocm/pooling/opr_impl.h View File

@@ -46,6 +46,8 @@ class PoolingForwardImpl final: public PoolingForward {
protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& dst) override; const TensorLayout& src, const TensorLayout& dst) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
@@ -93,6 +95,9 @@ class PoolingBackwardImpl final: public PoolingBackward {
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad) override; const TensorLayout& diff, const TensorLayout& grad) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& diff, const TensorLayout& grad, const TensorLayout& diff, const TensorLayout& grad,


+ 4
- 1
dnn/src/x86/pooling/opr_impl.cpp View File

@@ -74,11 +74,14 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src,
return fallback_worksapce; return fallback_worksapce;
} }
} }

std::vector<Algorithm*> PoolingImpl::get_all_algorithms( std::vector<Algorithm*> PoolingImpl::get_all_algorithms(
const TensorLayout& src, const TensorLayout& dst) { const TensorLayout& src, const TensorLayout& dst) {
return megdnn::get_all_algorithms<PoolingImpl>({this, src, dst}); return megdnn::get_all_algorithms<PoolingImpl>({this, src, dst});
} }
std::vector<Algorithm*> PoolingImpl::get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& dst) {
return megdnn::get_all_algorithms_safe<PoolingImpl>({this, src, dst});
}


Algorithm* PoolingImpl::get_algorithm_heuristic( Algorithm* PoolingImpl::get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,


+ 2
- 0
dnn/src/x86/pooling/opr_impl.h View File

@@ -63,6 +63,8 @@ public:
protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& dst) override; const TensorLayout& src, const TensorLayout& dst) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& src, const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& src, const TensorLayout& dst,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,


+ 1
- 1
dnn/test/common/accuracy_shake_checker.h View File

@@ -164,7 +164,7 @@ public:
} }
std::vector<Algorithm::Info::Desc> ret; std::vector<Algorithm::Info::Desc> ret;
megdnn_assert(layouts.size() == OprTrait<Opr>::arity); megdnn_assert(layouts.size() == OprTrait<Opr>::arity);
auto vec = AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info(
auto vec = AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info_safe(
opr, layouts); opr, layouts);
for (auto algo_info : vec) { for (auto algo_info : vec) {
if (!(algo_info.attribute & if (!(algo_info.attribute &


+ 1
- 1
dnn/test/common/benchmarker.h View File

@@ -377,7 +377,7 @@ float algo_benchmark(Benchmarker<Opr, T>& benchmark, TensorLayoutArray layouts,
auto opr = benchmark.opr(); auto opr = benchmark.opr();
opr->param() = benchmark.param(); opr->param() = benchmark.param();
proxy.deduce_layout(opr, layouts); proxy.deduce_layout(opr, layouts);
auto algos = OprAlgoProxy<Opr>::get_all_algorithms_info(opr, layouts);
auto algos = OprAlgoProxy<Opr>::get_all_algorithms_info_safe(opr, layouts);
float min_used = std::numeric_limits<float>::max(); float min_used = std::numeric_limits<float>::max();
bool execed = false; bool execed = false;
for (auto i : algos) { for (auto i : algos) {


+ 2
- 2
dnn/test/common/checker.h View File

@@ -514,7 +514,7 @@ struct ExecutionPolicyAlgoName {
* \brief a callable to check that given algorithm is used for heuristic * \brief a callable to check that given algorithm is used for heuristic
* \param require_algo if its value is true, then requires * \param require_algo if its value is true, then requires
* get_algorithm_heuristic() to return the expected algo; otherwise the * get_algorithm_heuristic() to return the expected algo; otherwise the
* expected algo must exist in get_all_algorithms() and it would be set to
* expected algo must exist in get_all_algorithms_safe() and it would be set to
* be used * be used
*/ */
template <class Opr, typename OprAlgoProxy = OprAlgoProxy<Opr>> template <class Opr, typename OprAlgoProxy = OprAlgoProxy<Opr>>
@@ -536,7 +536,7 @@ public:
opr->param() = opr->param() =
Algorithm::deserialize_read_pod<typename Opr::Param>(param); Algorithm::deserialize_read_pod<typename Opr::Param>(param);
for (auto algo_info : for (auto algo_info :
AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info(
AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info_safe(
opr.get(), layouts)) { opr.get(), layouts)) {
if (std::regex_match( if (std::regex_match(
algo_info.desc.name, algo_info.desc.name,


+ 3
- 3
dnn/test/common/convolution.cpp View File

@@ -695,7 +695,7 @@ Checker<Convolution> checker(handle);
float scale = 1.0f / sqrt(fshp[channel_start] * FH * FW); float scale = 1.0f / sqrt(fshp[channel_start] * FH * FW);
UniformFloatRNG rng(scale, 2 * scale); UniformFloatRNG rng(scale, 2 * scale);
checker.set_rng(0, &rng).set_rng(1, &rng); checker.set_rng(0, &rng).set_rng(1, &rng);
for (auto algo : opr->get_all_algorithms_info(ily, fly, oly)) {
for (auto algo : opr->get_all_algorithms_info_safe(ily, fly, oly)) {
used_algos.insert(algo.desc); used_algos.insert(algo.desc);
opr->execution_policy().algo = algo.desc; opr->execution_policy().algo = algo.desc;


@@ -720,7 +720,7 @@ Checker<Convolution> checker(handle);
opr->param() = param; opr->param() = param;
std::string param_str; std::string param_str;
Algorithm::serialize_write_pod(opr->param(), param_str); Algorithm::serialize_write_pod(opr->param(), param_str);
for (auto algo : opr->get_all_algorithms_info(fly, oly, ily)) {
for (auto algo : opr->get_all_algorithms_info_safe(fly, oly, ily)) {
used_algos_bwd_data.insert(algo.desc); used_algos_bwd_data.insert(algo.desc);
opr->execution_policy().algo = algo.desc; opr->execution_policy().algo = algo.desc;
construct_sub_execution_policy_heuristic< construct_sub_execution_policy_heuristic<
@@ -747,7 +747,7 @@ Checker<Convolution> checker(handle);
opr->param() = param; opr->param() = param;
std::string param_str; std::string param_str;
Algorithm::serialize_write_pod(opr->param(), param_str); Algorithm::serialize_write_pod(opr->param(), param_str);
for (auto algo : opr->get_all_algorithms_info(ily, oly, fly)) {
for (auto algo : opr->get_all_algorithms_info_safe(ily, oly, fly)) {
used_algos_bwd_flt.insert(algo.desc); used_algos_bwd_flt.insert(algo.desc);
opr->execution_policy().algo = algo.desc; opr->execution_policy().algo = algo.desc;
construct_sub_execution_policy_heuristic< construct_sub_execution_policy_heuristic<


+ 4
- 4
dnn/test/common/opr_algo_proxy.h View File

@@ -25,9 +25,9 @@ struct AlgoProxy;
template <typename Opr> \ template <typename Opr> \
struct AlgoProxy<Opr, arity> { \ struct AlgoProxy<Opr, arity> { \
static std::vector<typename Opr::AlgorithmInfo> \ static std::vector<typename Opr::AlgorithmInfo> \
get_all_algorithms_info(Opr* opr, const TensorLayoutArray& layouts) { \
get_all_algorithms_info_safe(Opr* opr, const TensorLayoutArray& layouts) { \
megdnn_assert(layouts.size() == arity); \ megdnn_assert(layouts.size() == arity); \
return opr->get_all_algorithms_info(LAYOUTS); \
return opr->get_all_algorithms_info_safe(LAYOUTS); \
} \ } \
static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \ static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \
Opr* opr, const TensorLayoutArray& layouts) { \ Opr* opr, const TensorLayoutArray& layouts) { \
@@ -80,9 +80,9 @@ DEF_ALGO_PROXY(8);
template <> \ template <> \
struct AlgoProxy<Opr, arity> { \ struct AlgoProxy<Opr, arity> { \
static std::vector<typename Opr::AlgorithmInfo> \ static std::vector<typename Opr::AlgorithmInfo> \
get_all_algorithms_info(Opr* opr, const TensorLayoutArray& layouts) { \
get_all_algorithms_info_safe(Opr* opr, const TensorLayoutArray& layouts) { \
megdnn_assert(layouts.size() == arity); \ megdnn_assert(layouts.size() == arity); \
return opr->get_all_algorithms_info(LAYOUTS); \
return opr->get_all_algorithms_info_safe(LAYOUTS); \
} \ } \
static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \ static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \
Opr* opr, const TensorLayoutArray& layouts) { \ Opr* opr, const TensorLayoutArray& layouts) { \


+ 3
- 3
dnn/test/common/opr_proxy.h View File

@@ -288,7 +288,7 @@ struct OprProxyProfilingBase
Algorithm::deserialize_read_pod<typename Opr::Param>(param); Algorithm::deserialize_read_pod<typename Opr::Param>(param);


std::vector<Algorithm::SearchItem> ret; std::vector<Algorithm::SearchItem> ret;
for (auto algo_info : AlgoProxy<Opr, arity>::get_all_algorithms_info(
for (auto algo_info : AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(
opr.get(), layouts)) { opr.get(), layouts)) {
Algorithm* algo = opr->get_algorithm_from_desc(algo_info.desc); Algorithm* algo = opr->get_algorithm_from_desc(algo_info.desc);
std::vector<Algorithm::SearchItem>&& sub_items = std::vector<Algorithm::SearchItem>&& sub_items =
@@ -367,7 +367,7 @@ struct OprProxyProfilingBase
megdnn_log("Find best algo %s in cache", algo->name()); megdnn_log("Find best algo %s in cache", algo->name());
return; return;
} }
for (auto algo : AlgoProxy<Opr, arity>::get_all_algorithms_info(
for (auto algo : AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(
opr.get(), layouts)) { opr.get(), layouts)) {
//! construct execution_policy //! construct execution_policy
opr->execution_policy().algo = algo.desc; opr->execution_policy().algo = algo.desc;
@@ -492,7 +492,7 @@ struct OprWeightPreprocessProxyImpl : public OprProxyProfilingBase<Opr> {
if (Base::m_profiling && !Base::target_execution_policy.algo.valid()) { if (Base::m_profiling && !Base::target_execution_policy.algo.valid()) {
size_t min_time = std::numeric_limits<size_t>::max(); size_t min_time = std::numeric_limits<size_t>::max();
for (auto algo : for (auto algo :
AlgoProxy<Opr, arity>::get_all_algorithms_info(opr, layouts)) {
AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr, layouts)) {
opr->execution_policy().algo = algo.desc; opr->execution_policy().algo = algo.desc;


auto preprocess_tensors = auto preprocess_tensors =


+ 2
- 2
dnn/test/cuda/cutlass_matmul.cpp View File

@@ -84,7 +84,7 @@ void test_multibatchsize(
auto opr_reference = handle_cuda->create_operator<MatrixMulForward>(); auto opr_reference = handle_cuda->create_operator<MatrixMulForward>();
{ {
opr_reference->execution_policy().algo.reset(); opr_reference->execution_policy().algo.reset();
for (auto i : opr_reference->get_all_algorithms_info(
for (auto i : opr_reference->get_all_algorithms_info_safe(
A_tensor.layout(), B_tensor.layout(), A_tensor.layout(), B_tensor.layout(),
C_tensor.layout())) { C_tensor.layout())) {
if (std::regex_match( if (std::regex_match(
@@ -113,7 +113,7 @@ void test_multibatchsize(
{{}, {}, C_tensor_prime.tensornd_host()}); {{}, {}, C_tensor_prime.tensornd_host()});
{ {
opr_reference->execution_policy().algo.reset(); opr_reference->execution_policy().algo.reset();
for (auto i : opr_reference->get_all_algorithms_info(
for (auto i : opr_reference->get_all_algorithms_info_safe(
A_tensor_prime.layout(), B_tensor.layout(), A_tensor_prime.layout(), B_tensor.layout(),
C_tensor_batch.layout())) { C_tensor_batch.layout())) {
if (std::regex_match( if (std::regex_match(


+ 2
- 2
src/core/test/graph/misc.cpp View File

@@ -1938,7 +1938,7 @@ typename megdnn::ExecutionPolicy try_find_any_weight_preprocess_algo(
return {}; return {};
} }
} }
for (auto&& algo : dnn_op->get_all_algorithms_info(
for (auto&& algo : dnn_op->get_all_algorithms_info_safe(
std::forward<Args>(args)...)) { std::forward<Args>(args)...)) {
dnn_op->execution_policy().algo = algo.desc; dnn_op->execution_policy().algo = algo.desc;
auto layouts = dnn_op->deduce_preprocessed_filter_layout( auto layouts = dnn_op->deduce_preprocessed_filter_layout(
@@ -1972,7 +1972,7 @@ typename megdnn::ExecutionPolicy try_find_any_bias_preprocess_algo(
return {}; return {};
} }
} }
for (auto&& algo : dnn_op->get_all_algorithms_info(
for (auto&& algo : dnn_op->get_all_algorithms_info_safe(
std::forward<Args>(args)...)) { std::forward<Args>(args)...)) {
dnn_op->execution_policy().algo = algo.desc; dnn_op->execution_policy().algo = algo.desc;
auto layouts = dnn_op->deduce_preprocessed_filter_layout( auto layouts = dnn_op->deduce_preprocessed_filter_layout(


+ 1
- 1
src/opr/impl/search_policy/algo_chooser.cpp View File

@@ -805,7 +805,7 @@ std::vector<typename AlgoChooser<Opr>::ImplAlgo>
AlgoChooser<Opr>::AlgoChooserHelper::get_all_candidates() const { AlgoChooser<Opr>::AlgoChooserHelper::get_all_candidates() const {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates"))) MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates")))
auto heu = choose_by_heuristic(m_execution_policy.strategy); auto heu = choose_by_heuristic(m_execution_policy.strategy);
auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info(args...),
auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info_safe(args...),
m_fastrun_layouts); m_fastrun_layouts);
bool found = false; bool found = false;
for (size_t i = 0; i < ret.size(); ++i) { for (size_t i = 0; i < ret.size(); ++i) {


+ 10
- 0
src/opr/test/dnn/convolution.cpp View File

@@ -2473,6 +2473,11 @@ public:
std::vector<AlgorithmInfo>(const TensorLayout& p0, std::vector<AlgorithmInfo>(const TensorLayout& p0,
const TensorLayout& p1, const TensorLayout& p1,
const TensorLayout& p2)); const TensorLayout& p2));

MOCK_METHOD3(get_all_algorithms_info_safe,
std::vector<AlgorithmInfo>(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2));
MOCK_METHOD6(get_algorithm_info_heuristic, MOCK_METHOD6(get_algorithm_info_heuristic,
AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1, AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p2,
@@ -2484,6 +2489,11 @@ public:
std::vector<Algorithm*>(const TensorLayout& p0, std::vector<Algorithm*>(const TensorLayout& p0,
const TensorLayout& p1, const TensorLayout& p1,
const TensorLayout& p2)); const TensorLayout& p2));

MOCK_METHOD3(get_all_algorithms_safe,
std::vector<Algorithm*>(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2));
MOCK_METHOD6(get_algorithm_heuristic, MOCK_METHOD6(get_algorithm_heuristic,
Algorithm*(const TensorLayout& p0, const TensorLayout& p1, Algorithm*(const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p2,


Loading…
Cancel
Save