You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cudnn_conv_v8.cpp 3.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. /**
  2. * \file dnn/src/cuda/conv_bias/cudnn_conv_v8.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/common/conv_bias.h"
  12. #include "src/cuda/conv_bias/algo.h"
  13. #include "src/cuda/cudnn_wrapper_v8.h"
  14. #include "src/cuda/utils.h"
  15. #if CUDNN_VERSION >= 8004
  16. using namespace megdnn;
  17. using namespace cuda;
  18. using namespace conv_bias;
  19. bool ConvBiasForwardImpl::AlgoCUDNNConvV8::is_available(const SizeArgs& args) const {
  20. if (args.filter_meta.format != Param::Format::NCHW &&
  21. args.filter_meta.format != Param::Format::NHWC) {
  22. if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) {
  23. return false;
  24. }
  25. }
  26. if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 ||
  27. args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) {
  28. return false;
  29. }
  30. // FIXME: cudnn cannot handle the case when the initial value of dst tensor
  31. // contains nan and beta is zero, because the result of 0.f * nan is still
  32. // nan
  33. if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 &&
  34. args.dst_layout->dtype.enumv() == DTypeEnum::Float32 &&
  35. args.opr->param().format == param::ConvBias::Format::NCHW) {
  36. return false;
  37. }
  38. auto dst_layout = *args.dst_layout;
  39. if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
  40. dst_layout.dtype = DType();
  41. args.opr->check_or_deduce_dtype_fwd(
  42. args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype);
  43. }
  44. SizeArgs conv_args = args;
  45. conv_args.dst_layout = &dst_layout;
  46. if (!is_cudnn_supported(conv_args))
  47. return false;
  48. auto conv_opr = args.handle->create_operator<ConvolutionForward>();
  49. conv_opr->param() = get_param_convolution(args);
  50. ConvolutionForward::CanonizedFilterMeta fm;
  51. fm.copy_from(args.filter_meta);
  52. auto plan = get_heuristic_plan_from_opr(
  53. conv_opr.get(), *conv_args.src_layout, *conv_args.dst_layout,
  54. *conv_args.filter_layout, {}, {}, fm);
  55. return plan != nullptr;
  56. }
  57. size_t ConvBiasForwardImpl::AlgoCUDNNConvV8::cudnn_get_workspace_in_bytes(
  58. const SizeArgs& args) const {
  59. auto conv_opr = args.handle->create_operator<ConvolutionForward>();
  60. conv_opr->param() = get_param_convolution(args);
  61. ConvolutionForward::CanonizedFilterMeta fm;
  62. fm.copy_from(args.filter_meta);
  63. auto plan = get_heuristic_plan_from_opr(
  64. conv_opr.get(), *args.src_layout, *args.dst_layout, *args.filter_layout, {},
  65. {}, fm);
  66. megdnn_assert(
  67. plan != nullptr, "algo(%s) cannot find execution from heuristics", name());
  68. return plan->getWorkspaceSize();
  69. }
  70. void ConvBiasForwardImpl::AlgoCUDNNConvV8::cudnn_execute(
  71. const ExecArgs& args, const Workspace& workspace) const {
  72. auto conv_opr = args.handle->create_operator<ConvolutionForward>();
  73. conv_opr->param() = get_param_convolution(args);
  74. ConvolutionForward::CanonizedFilterMeta fm;
  75. fm.copy_from(args.filter_meta);
  76. auto plan = get_heuristic_plan_from_opr(
  77. conv_opr.get(), args.src_tensor->layout, args.dst_tensor->layout,
  78. args.filter_tensor->layout, {}, {}, fm);
  79. megdnn_assert(
  80. plan != nullptr, "algo(%s) cannot find execution from heuristics", name());
  81. auto&& handle = cudnn_handle(args.handle);
  82. run_single_conv_with_plan(
  83. handle, *plan, *args.src_tensor, *args.dst_tensor, *args.filter_tensor,
  84. workspace);
  85. }
  86. #endif
  87. // vim: syntax=cpp.doxygen