You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resize.cpp 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. #include "test/common/resize.h"
  2. #include "megdnn/oprs/cv.h"
  3. #include "test/common/checker.h"
  4. #include "test/naive/fixture.h"
  5. using namespace megdnn;
  6. using namespace test;
  7. TEST_F(NAIVE, RESIZE_NCHW4) {
  8. Checker<Resize> checker(handle());
  9. auto args = resize::get_nchw4_args();
  10. auto convert_true_format = [](const TensorLayout& layout) {
  11. return layout.reshape({layout[0], layout[1] / 4, layout[2], layout[3], 4})
  12. .dimshuffle({0, 1, 4, 2, 3});
  13. };
  14. for (auto&& arg : args) {
  15. auto extra_impl = [this, param = arg.param,
  16. convert_true_format](const TensorNDArray& tensors) {
  17. auto resize = handle()->create_operator<Resize>();
  18. resize->param().imode = param.imode;
  19. resize->param().format = Resize::Param::Format::NCHW;
  20. TensorNDArray nchw_tensors;
  21. for (size_t i = 0; i < tensors.size(); ++i) {
  22. auto layout = tensors[i].layout;
  23. layout = layout.reshape(
  24. {layout[0], layout[1] * 4, layout[2], layout[3]});
  25. layout.dtype = dtype::Int8();
  26. nchw_tensors.emplace_back(malloc(layout.span().dist_byte()), layout);
  27. }
  28. TensorNDArray nchw4_tensors;
  29. for (size_t i = 0; i < tensors.size(); ++i) {
  30. auto layout = convert_true_format(nchw_tensors[i].layout);
  31. nchw4_tensors.emplace_back(tensors[i].raw_ptr(), std::move(layout));
  32. }
  33. auto relayout = handle()->create_operator<RelayoutForward>();
  34. relayout->exec(nchw4_tensors[0], nchw_tensors[0]);
  35. auto workspace_size = resize->get_workspace_in_bytes(
  36. nchw_tensors[0].layout, nchw_tensors[1].layout);
  37. dt_byte* workspace_ptr = static_cast<dt_byte*>(malloc(workspace_size));
  38. Workspace workspace{workspace_ptr, workspace_size};
  39. resize->exec(nchw_tensors[0], nchw_tensors[1], workspace);
  40. relayout->exec(nchw_tensors[1], nchw4_tensors[1]);
  41. free(workspace_ptr);
  42. for (auto&& tensor : nchw_tensors) {
  43. free(tensor.raw_ptr());
  44. }
  45. };
  46. checker.set_extra_opr_impl(extra_impl);
  47. checker.set_param(arg.param)
  48. .set_dtype(0, dtype::QuantizedS8(0.1f))
  49. .set_dtype(1, dtype::QuantizedS8(0.1f))
  50. .set_epsilon(1 + 1e-3)
  51. .execs({arg.src, arg.dst});
  52. }
  53. }