You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

custom_opnode.sereg.h 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #include "megbrain/opr/custom_opnode.h"
  2. #include "megbrain/serialization/sereg.h"
  3. namespace mgb {
  4. namespace serialization {
  5. void custom_dumper(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
  6. auto&& custom_op = opr.cast_final_safe<opr::CustomOpNode>();
  7. std::string op_type = custom_op.op_type();
  8. ctx.dump_buf_with_len(op_type.c_str(), op_type.size());
  9. uint32_t tag = custom_op.param_tag();
  10. ctx.dump_buf_with_len(&tag, sizeof(tag));
  11. std::string bytes = custom_op.param().to_bytes();
  12. ctx.dump_buf_with_len(bytes.c_str(), bytes.size());
  13. }
  14. mgb::cg::OperatorNodeBase* custom_loader(
  15. OprLoadContext& ctx, const cg::VarNodeArray& inputs,
  16. const OperatorNodeConfig& config) {
  17. std::string op_type = ctx.load_buf_with_len();
  18. auto* op_manager = custom::CustomOpManager::inst();
  19. auto op = op_manager->find(op_type);
  20. std::string tag_str = ctx.load_buf_with_len();
  21. uint32_t tag = *reinterpret_cast<const uint32_t*>(tag_str.c_str());
  22. mgb_assert(
  23. tag == op->param_info().tag(),
  24. "Wrong Param TAG of Op %s, should be %u, but load %u\n", op_type.c_str(),
  25. op->param_info().tag(), tag);
  26. custom::Param param(op->param_info());
  27. std::string bytes = ctx.load_buf_with_len();
  28. param.from_bytes(bytes);
  29. return opr::CustomOpNode::make(op, inputs, param, config)[0]->owner_opr();
  30. }
  31. } // namespace serialization
  32. } // namespace mgb
  33. #define CUSTOM_OP_SEREG_REG(cls) \
  34. namespace { \
  35. struct _OprReg##cls { \
  36. static void entry() { \
  37. MGB_SEREG_OPR_INTL_CALL_ADD( \
  38. cls, ::mgb::serialization::custom_dumper, \
  39. ::mgb::serialization::custom_loader); \
  40. } \
  41. }; \
  42. } \
  43. MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls)
  44. #define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max) \
  45. namespace { \
  46. struct _OprRegV2##cls { \
  47. static void entry() { \
  48. MGB_SEREG_OPR_INTL_CALL_ADD_V2( \
  49. cls, ::mgb::serialization::custom_dumper, \
  50. ::mgb::serialization::custom_loader, nullptr, _version_min, \
  51. _version_max); \
  52. } \
  53. }; \
  54. } \
  55. MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(cls, _OprRegV2##cls)
  56. using namespace mgb;
  57. using CustomOpNode = opr::CustomOpNode;
  58. CUSTOM_OP_SEREG_REG(CustomOpNode);
  59. CUSTOM_OP_SEREG_REG_V2(CustomOpNode, 2, CURRENT_VERSION);