You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lsq.cpp 952 B

12345678910111213141516171819202122232425262728293031323334
  1. #include "test/naive/fixture.h"
  2. #include "megdnn/oprs/nn.h"
  3. #include "test/common/checker.h"
  4. using namespace megdnn;
  5. using namespace test;
  6. TEST_F(NAIVE, LSQ_FORWARD) {
  7. Checker<LSQ> checker(handle(), /* check_dispatch */ false);
  8. param::LSQ param;
  9. param.qmin = -127;
  10. param.qmax = 127;
  11. TensorND input = TensorValue(
  12. {2, 2, 2, 2}, dtype::Float32(),
  13. {0, 1, 3, 4, 1, 2, 4, 5, 3, 4, 6, 7, 4, 5, 7, 8});
  14. TensorND scale_shape = TensorValue({1}, dtype::Float32(), {2});
  15. TensorND zero_point = TensorValue({1}, dtype::Float32(), {1});
  16. TensorND grad_scale = TensorValue({1}, dtype::Float32(), {0.5});
  17. TensorND output = TensorValue(
  18. {2, 2, 2, 2}, dtype::Float32(),
  19. {0, 2, 4, 4, 2, 2, 4, 6, 4, 4, 6, 8, 4, 6, 8, 8});
  20. checker.set_param(param).exect(
  21. Testcase{input, scale_shape, zero_point, grad_scale, {}},
  22. Testcase{{}, {}, {}, {}, output});
  23. }