You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

host_static_calc.cpp 1.2 kB

12345678910111213141516171819202122232425262728293031323334353637
  1. /**
  2. * \file test/src/host_static_calc.cpp
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \brief static calculating on host to check opr correctness
  7. *
  8. * \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  9. *
  10. */
  11. #include "megbrain/test/host_static_calc.h"
  12. void mgb::elemwise_static_calc(
  13. opr::Elemwise::Mode mode, HostTensorND& dest,
  14. const std::vector<HostTensorND>& inputs) {
  15. #if defined(ANDROID) || defined(IOS) || defined(__arm__)
  16. static opr::intl::UniqPtrWithCN<megdnn::Elemwise> opr_impl;
  17. static std::mutex mtx;
  18. MGB_LOCK_GUARD(mtx);
  19. #else
  20. static thread_local opr::intl::UniqPtrWithCN<megdnn::Elemwise> opr_impl;
  21. #endif
  22. auto cn = CompNode::default_cpu();
  23. if (!opr_impl) {
  24. opr_impl = opr::intl::create_megdnn_opr<megdnn::Elemwise>(cn);
  25. }
  26. DeviceTensorND dev_dest{cn};
  27. SmallVector<DeviceTensorND> dev_inp(inputs.size());
  28. for (size_t i = 0; i < inputs.size(); ++i) {
  29. dev_inp[i].comp_node(cn).copy_from(inputs[i]);
  30. }
  31. opr::Elemwise::perform(mode, dev_dest, dev_inp, opr_impl);
  32. dest.copy_from(dev_dest);
  33. }
  34. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}