You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_pooling.cpp 1.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. #include "test/cuda/fixture.h"
  2. #include "megdnn/opr_param_defs.h"
  3. #include "megdnn/oprs.h"
  4. #include "test/common/checker.h"
  5. #include "test/common/conv_pooling.h"
  6. #include "test/common/rng.h"
  7. #include "test/common/tensor.h"
  8. #include "test/common/workspace_wrapper.h"
  9. namespace megdnn {
  10. namespace test {
  11. #if 0
  12. TEST_F(CUDA, CONV_POOLING_FORWARD)
  13. {
  14. using namespace conv_pooling;
  15. std::vector<TestArg> args = get_args();
  16. Checker<ConvPoolingForward> checker(handle_cuda());
  17. NormalRNG default_rng;
  18. ConstValue const_val;
  19. for (auto &&arg: args) {
  20. float scale = 1.0f / sqrt(arg.filter[1] * arg.filter[2] * arg.filter[3]);
  21. UniformFloatRNG rng(scale, 2 * scale);
  22. checker.
  23. set_dtype(0, dtype::Float32()).
  24. set_dtype(1, dtype::Float32()).
  25. set_dtype(2, dtype::Float32()).
  26. set_rng(0, &default_rng).
  27. set_rng(1, &default_rng).
  28. set_rng(2, &default_rng).
  29. set_epsilon(1e-3).
  30. set_param(arg.param).
  31. execs({arg.src, arg.filter, arg.bias, {}});
  32. /*checker.
  33. set_dtype(0, dtype::Float16()).
  34. set_dtype(1, dtype::Float16()).
  35. set_rng(0, &rng).
  36. set_rng(1, &rng).
  37. set_epsilon(1e-1).
  38. set_param(arg.param).
  39. execs({arg.src, arg.filter, {}});
  40. */
  41. }
  42. }
  43. #endif
  44. } // namespace test
  45. } // namespace megdnn
  46. // vim: syntax=cpp.doxygen