You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lsq.h 1.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. /**
  2. * \file dnn/test/common/lsq.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/basic_types.h"
  14. #include "megdnn/opr_param_defs.h"
  15. namespace megdnn {
  16. namespace test {
  17. namespace lsq {
  18. struct TestArg {
  19. param::LSQ param;
  20. TensorShape ishape;
  21. TensorShape scale_shape;
  22. TensorShape zeropoint_shape;
  23. TensorShape gradscale_shape;
  24. TestArg(param::LSQ param, TensorShape ishape, TensorShape scale_shape,
  25. TensorShape zeropoint_shape, TensorShape gradscale_shape)
  26. : param(param),
  27. ishape(ishape),
  28. scale_shape(scale_shape),
  29. zeropoint_shape(zeropoint_shape),
  30. gradscale_shape(gradscale_shape) {}
  31. };
  32. inline std::vector<TestArg> get_args() {
  33. std::vector<TestArg> args;
  34. param::LSQ cur_param;
  35. cur_param.qmin = -127;
  36. cur_param.qmax = 127;
  37. for (size_t i = 10; i < 30; i += 2) {
  38. args.emplace_back(
  39. cur_param, TensorShape{10, 64, i, i}, TensorShape{1}, TensorShape{1},
  40. TensorShape{1});
  41. }
  42. return args;
  43. }
  44. } // namespace lsq
  45. } // namespace test
  46. } // namespace megdnn

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台