You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.h 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. /**
  2. * \file dnn/src/naive/batch_conv_bias/opr_impl.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/oprs.h"
  13. #include "src/common/utils.h"
  14. namespace megdnn {
  15. namespace naive {
  16. class BatchConvBiasForwardImpl : public BatchConvBiasForward {
  17. public:
  18. using BatchConvBiasForward::BatchConvBiasForward;
  19. void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  20. _megdnn_tensor_in bias, _megdnn_tensor_in z,
  21. _megdnn_tensor_out dst, _megdnn_workspace workspace) override;
  22. size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
  23. const TensorLayout&, const TensorLayout&,
  24. const TensorLayout&) override;
  25. std::vector<Algorithm*> get_all_algorithms(
  26. const TensorLayout& src, const TensorLayout& filter,
  27. const TensorLayout& bias, const TensorLayout& z,
  28. const TensorLayout& dst) override;
  29. Algorithm* get_algorithm_heuristic(const TensorLayout& src,
  30. const TensorLayout& filter,
  31. const TensorLayout& bias,
  32. const TensorLayout& z,
  33. const TensorLayout& dst,
  34. size_t workspace_limit_in_bytes,
  35. const AlgoAttribute& attr) override;
  36. Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
  37. const char* get_algorithm_set_name() const override { return "DEFAULT"; }
  38. private:
  39. WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
  40. const TensorLayout& src,
  41. const TensorLayout& filter,
  42. const TensorLayout& bias,
  43. const TensorLayout& z,
  44. const TensorLayout& dst);
  45. };
  46. } // namespace naive
  47. } // namespace megdnn
  48. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台