You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cond_take.cpp 757 B

1234567891011121314151617181920212223
  1. #include "test/common/cond_take.h"
  2. #include "megdnn/oprs.h"
  3. #include "test/common/checker.h"
  4. #include "test/cuda/fixture.h"
  5. using namespace megdnn;
  6. using namespace test;
  7. TEST_F(CUDA, COND_TAKE) {
  8. auto opr_naive = handle_naive()->create_operator<CondTake>();
  9. auto opr_cuda = handle_cuda()->create_operator<CondTake>();
  10. size_t tot_size = 0;
  11. for (auto&& i : CondTakeTestcase::make()) {
  12. auto ret_naive = i.run(opr_naive.get()), ret_cuda = i.run(opr_cuda.get());
  13. MEGDNN_ASSERT_TENSOR_EQ(*ret_naive.first, *ret_cuda.first);
  14. MEGDNN_ASSERT_TENSOR_EQ(*ret_naive.second, *ret_cuda.second);
  15. tot_size += ret_naive.first->layout.total_nr_elems();
  16. }
  17. ASSERT_GT(tot_size, (size_t)0);
  18. }
  19. // vim: syntax=cpp.doxygen