You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group_local.cpp 1.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #include "test/arm_common/fixture.h"
  2. #include "test/common/benchmarker.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/group_local.h"
  5. #include "test/common/task_record_check.h"
  6. #include "test/common/timer.h"
  7. namespace megdnn {
  8. namespace test {
  9. using Param = param::Convolution;
  10. TEST_F(ARM_COMMON, GROUP_LOCAL_FORWARD) {
  11. auto args = group_local::get_args();
  12. Checker<GroupLocalForward> checker(handle());
  13. for (auto&& arg : args) {
  14. checker.set_param(arg.param).execs({arg.sshape(), arg.fshape(), arg.dshape()});
  15. }
  16. NormalRNG rng(10.f);
  17. checker.set_rng(0, &rng).set_rng(1, &rng);
  18. args = group_local::get_args_for_fp16();
  19. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  20. for (auto&& arg : args) {
  21. checker.set_dtype(0, dtype::Float16())
  22. .set_dtype(1, dtype::Float16())
  23. .set_dtype(2, dtype::Float16());
  24. checker.set_epsilon(1e-2);
  25. checker.set_param(arg.param).execs({arg.sshape(), arg.fshape(), arg.dshape()});
  26. }
  27. #endif
  28. }
  29. TEST_F(ARM_COMMON, GROUP_LOCAL_FORWARD_RECORD) {
  30. auto args = group_local::get_args();
  31. TaskRecordChecker<GroupLocalForward> checker(0);
  32. for (auto&& arg : args) {
  33. checker.set_param(arg.param).execs({arg.sshape(), arg.fshape(), arg.dshape()});
  34. }
  35. NormalRNG rng(10.f);
  36. checker.set_rng(0, &rng).set_rng(1, &rng);
  37. args = group_local::get_args_for_fp16();
  38. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  39. for (auto&& arg : args) {
  40. checker.set_dtype(0, dtype::Float16())
  41. .set_dtype(1, dtype::Float16())
  42. .set_dtype(2, dtype::Float16());
  43. checker.set_epsilon(1e-2);
  44. checker.set_param(arg.param).execs({arg.sshape(), arg.fshape(), arg.dshape()});
  45. }
  46. #endif
  47. }
  48. } // namespace test
  49. } // namespace megdnn
  50. // vim: syntax=cpp.doxygen