You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fake_quant.h 1.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "megdnn/opr_param_defs.h"
  4. namespace megdnn {
  5. namespace test {
  6. namespace fake_quant {
  7. struct TestArg {
  8. param::FakeQuant param;
  9. TensorShape ishape;
  10. TensorShape scale_shape;
  11. TensorShape zeropoint_shape;
  12. TestArg(param::FakeQuant param, TensorShape ishape, TensorShape scale_shape,
  13. TensorShape zeropoint_shape)
  14. : param(param),
  15. ishape(ishape),
  16. scale_shape(scale_shape),
  17. zeropoint_shape(zeropoint_shape) {}
  18. };
  19. inline std::vector<TestArg> get_args() {
  20. std::vector<TestArg> args;
  21. param::FakeQuant cur_param;
  22. cur_param.qmin = -128;
  23. cur_param.qmax = 128;
  24. for (size_t i = 10; i < 40; i += 2) {
  25. args.emplace_back(
  26. cur_param, TensorShape{10, 64, i, i}, TensorShape{1}, TensorShape{1});
  27. }
  28. for (size_t m : {1, 10})
  29. for (size_t n : {1, 10})
  30. for (size_t j : {1, 10})
  31. for (size_t k : {1, 10}) {
  32. args.emplace_back(
  33. cur_param, TensorShape{10, 64, 10, 10},
  34. TensorShape{10, 64, m, n}, TensorShape{10, 64, j, k});
  35. }
  36. return args;
  37. }
  38. } // namespace fake_quant
  39. } // namespace test
  40. } // namespace megdnn