You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

chanwise_8x8x32.cpp 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. /**
  2. * \file dnn/src/cuda/conv_bias/chanwise_8x8x32.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./algo.h"
  12. #include "src/cuda/utils.h"
  13. #include "src/cuda/conv_bias/chanwise/kern.cuh"
  14. #include "src/common/conv_bias.h"
  15. #include "src/common/elemwise/kern_defs.cuh"
  16. using namespace megdnn;
  17. using namespace cuda;
  18. using namespace conv_bias;
  19. bool ConvBiasForwardImpl::AlgoChanwise8x8x32::is_available(
  20. const SizeArgs& args) const {
  21. if (!args.src_layout->is_contiguous() ||
  22. !args.dst_layout->is_contiguous()) {
  23. return false;
  24. }
  25. if (args.z_layout->ndim > 0)
  26. return false;
  27. using NonlineMode = param::ConvBias::NonlineMode;
  28. auto&& fm = args.filter_meta;
  29. return (args.nonlinear_mode == NonlineMode::IDENTITY ||
  30. args.nonlinear_mode == NonlineMode::RELU) &&
  31. args.filter_meta.format == Param::Format::NHWC &&
  32. args.src_layout->dtype == dtype::Int8() &&
  33. fm.dtype.enumv() == DTypeEnum::Int8 && fm.spatial_ndim == 2 &&
  34. fm.icpg == 1 && fm.ocpg == 1 && fm.group % 4 == 0;
  35. }
  36. size_t ConvBiasForwardImpl::AlgoChanwise8x8x32::get_workspace_in_bytes(
  37. const SizeArgs& args) const {
  38. auto dst_layout = *args.dst_layout;
  39. if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
  40. dst_layout.dtype = DType();
  41. args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype,
  42. args.filter_layout->dtype,
  43. dst_layout.dtype);
  44. return dst_layout.span().dist_byte();
  45. }
  46. return 0;
  47. }
  48. void ConvBiasForwardImpl::AlgoChanwise8x8x32::exec(const ExecArgs& args) const {
  49. WorkspaceBundle bundle{args.workspace.raw_ptr,
  50. {get_workspace_in_bytes(args)}};
  51. auto conv_dst_tensor = *args.dst_tensor;
  52. if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
  53. conv_dst_tensor.raw_ptr = bundle.get(0);
  54. conv_dst_tensor.layout.dtype = DType();
  55. args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype,
  56. args.filter_layout->dtype,
  57. conv_dst_tensor.layout.dtype);
  58. }
  59. {
  60. auto kparam = chanwise::Param::from_fwd_args(args);
  61. auto stream = cuda_stream(args.handle);
  62. chanwise::run_fwd_8x8x32(conv_dst_tensor.ptr<dt_int32>(),
  63. args.src_tensor->ptr<dt_int8>(),
  64. args.filter_tensor->ptr<dt_int8>(), kparam,
  65. stream);
  66. }
  67. handle_bias_and_nonlinear(args.handle, args.nonlinear_mode,
  68. &conv_dst_tensor, args.dst_tensor,
  69. args.bias_tensor);
  70. }
  71. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台