You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cond_take.h 683 B

123456789101112131415161718192021222324252627
  1. #pragma once
  2. #include "./checker.h"
  3. #include "megdnn/oprs.h"
  4. namespace megdnn {
  5. namespace test {
  6. class CondTakeTestcase {
  7. std::unique_ptr<uint8_t> m_mem;
  8. CondTake::Param m_param;
  9. TensorND m_data, m_mask;
  10. CondTakeTestcase(
  11. CondTake::Param param, const TensorLayout& data, const TensorLayout& mask)
  12. : m_param{param}, m_data{nullptr, data}, m_mask{nullptr, mask} {}
  13. public:
  14. //! pair of (data, idx)
  15. using Result = std::pair<std::shared_ptr<TensorND>, std::shared_ptr<TensorND>>;
  16. Result run(CondTake* opr);
  17. static std::vector<CondTakeTestcase> make();
  18. };
  19. } // namespace test
  20. } // namespace megdnn
  21. // vim: syntax=cpp.doxygen