You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

oprs.h 1.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. #pragma once
  2. #include "megdnn/oprs/cv.h"
  3. #include "megdnn/oprs/general.h"
  4. #include "megdnn/oprs/imgproc.h"
  5. #include "megdnn/oprs/linalg.h"
  6. #include "megdnn/oprs/nn.h"
  7. #include "megdnn/oprs/nn_int.h"
  8. #include "megdnn/oprs/utils.h"
  9. template <typename Opr>
  10. struct OprArityTrait;
  11. template <typename Opr, int _arity_in, int _arity_out>
  12. struct OprArityTraitTmpl {
  13. static constexpr int arity_in = _arity_in;
  14. static constexpr int arity_out = _arity_out;
  15. static constexpr int arity = arity_in + arity_out;
  16. };
  17. #define INST_ARITY(_Opr, _in, _out) \
  18. template <> \
  19. struct OprArityTrait<_Opr> : public OprArityTraitTmpl<_Opr, _in, _out> {};
  20. INST_ARITY(megdnn::ConvolutionBackwardData, 2, 1);
  21. INST_ARITY(megdnn::ConvolutionBackwardFilter, 2, 1);
  22. INST_ARITY(megdnn::Convolution3DForward, 2, 1);
  23. INST_ARITY(megdnn::Convolution3DBackwardData, 2, 1);
  24. INST_ARITY(megdnn::Convolution3DBackwardFilter, 2, 1);
  25. INST_ARITY(megdnn::LocalShareForward, 2, 1);
  26. INST_ARITY(megdnn::LocalShareBackwardData, 2, 1);
  27. INST_ARITY(megdnn::LocalShareBackwardFilter, 2, 1);
  28. INST_ARITY(megdnn::Convolution, 2, 1);
  29. INST_ARITY(megdnn::DeformableConvForward, 4, 1);
  30. INST_ARITY(megdnn::DeformableConvBackwardFilter, 4, 1);
  31. INST_ARITY(megdnn::BatchConvBiasForward, 4, 1);
  32. INST_ARITY(megdnn::ConvBias, 4, 1);
  33. INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3);
  34. INST_ARITY(megdnn::MatrixMul, 2, 1);
  35. INST_ARITY(megdnn::BatchedMatrixMul, 2, 1);
  36. INST_ARITY(megdnn::PoolingForward, 1, 1);
  37. INST_ARITY(megdnn::PoolingBackward, 3, 1);
  38. #undef INST_ARITY
  39. // vim: syntax=cpp.doxygen