You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

concat.cpp 1.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. #include "test/fallback/fixture.h"
  2. #include "test/common/checker.h"
  3. #include "test/common/task_record_check.h"
  4. namespace megdnn {
  5. namespace test {
  6. TEST_F(FALLBACK, CONCAT) {
  7. Checker<Concat> checker(handle());
  8. using Param = Concat::Param;
  9. for (auto dtype : std::vector<DType>{
  10. dtype::Float32(), dtype::Int32(), dtype::Int16(), dtype::Float16(),
  11. dtype::Int8(), dtype::Uint8()}) {
  12. for (size_t axis = 0; axis < 4; ++axis) {
  13. Param param;
  14. param.axis = axis;
  15. TensorShapeArray shapes(4, TensorShape({12, 13, 14, 15}));
  16. for (size_t i = 0; i < 4; ++i) {
  17. shapes[i].shape[axis] = i + 1;
  18. }
  19. shapes.emplace_back();
  20. for (size_t i = 0; i < shapes.size(); ++i)
  21. checker.set_dtype(i, dtype);
  22. checker.set_param(param).exec(shapes);
  23. }
  24. }
  25. }
  26. TEST_F(FALLBACK, CONCAT_RECORD) {
  27. TaskRecordChecker<Concat> checker(1);
  28. using Param = Concat::Param;
  29. Param param;
  30. param.axis = 0;
  31. TensorShapeArray shapes(4, TensorShape({12, 13, 14, 15}));
  32. for (size_t i = 0; i < 4; ++i) {
  33. shapes[i].shape[0] = i + 1;
  34. }
  35. shapes.emplace_back();
  36. for (size_t i = 0; i < shapes.size(); ++i)
  37. checker.set_dtype(i, dtype::Float32());
  38. checker.set_param(param).exec(shapes);
  39. }
  40. } // namespace test
  41. } // namespace megdnn
  42. // vim: syntax=cpp.doxygen