You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

misc.sereg.h 2.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #include "megbrain/opr/misc.h"
  2. #include "megbrain/serialization/sereg.h"
  3. namespace mgb {
  4. namespace serialization {
  5. template <>
  6. struct OprMaker<opr::Argsort, 1> {
  7. using Opr = opr::Argsort;
  8. using Param = Opr::Param;
  9. static cg::OperatorNodeBase* make(
  10. const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph,
  11. const OperatorNodeConfig& config) {
  12. MGB_MARK_USED_VAR(graph);
  13. auto out = Opr::make(inputs[0], param, config);
  14. return out[0].node()->owner_opr();
  15. }
  16. };
  17. template <>
  18. struct OprMaker<opr::CondTake, 2> {
  19. using Opr = opr::CondTake;
  20. using Param = Opr::Param;
  21. static cg::OperatorNodeBase* make(
  22. const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph,
  23. const OperatorNodeConfig& config) {
  24. MGB_MARK_USED_VAR(graph);
  25. auto out = Opr::make(inputs[0], inputs[1], param, config);
  26. return out[0].node()->owner_opr();
  27. }
  28. };
  29. template <>
  30. struct OprMaker<opr::TopK, 2> {
  31. using Opr = opr::TopK;
  32. using Param = Opr::Param;
  33. static cg::OperatorNodeBase* make(
  34. const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph,
  35. const OperatorNodeConfig& config) {
  36. MGB_MARK_USED_VAR(graph);
  37. auto out = Opr::make(inputs[0], inputs[1], param, config);
  38. return out[0].node()->owner_opr();
  39. }
  40. };
  41. template <>
  42. struct OprMaker<opr::CheckNonFinite, 0> {
  43. using Opr = opr::CheckNonFinite;
  44. using Param = Opr::Param;
  45. static cg::OperatorNodeBase* make(
  46. const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph,
  47. const OperatorNodeConfig& config) {
  48. MGB_MARK_USED_VAR(graph);
  49. auto out = Opr::make(inputs, param, config);
  50. return out[0].node()->owner_opr();
  51. }
  52. };
  53. } // namespace serialization
  54. namespace opr {
  55. MGB_SEREG_OPR(Argmax, 1);
  56. MGB_SEREG_OPR(Argmin, 1);
  57. MGB_SEREG_OPR(Argsort, 1);
  58. MGB_SEREG_OPR(ArgsortBackward, 3);
  59. MGB_SEREG_OPR(CondTake, 2);
  60. MGB_SEREG_OPR(TopK, 2);
  61. //! current cumsum version
  62. using CumsumV1 = opr::Cumsum;
  63. MGB_SEREG_OPR(CumsumV1, 1);
  64. MGB_SEREG_OPR(Cumprod, 1);
  65. #if MGB_CUDA
  66. MGB_SEREG_OPR(NvOf, 1);
  67. #endif
  68. MGB_SEREG_OPR(CheckNonFinite, 0);
  69. } // namespace opr
  70. } // namespace mgb
  71. // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}