You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transformation.h 2.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. #pragma once
  2. #include <optional>
  3. #include <string>
  4. #include "pybind11/pybind11.h"
  5. #include "megbrain/imperative/dispatch.h"
  6. #include "megbrain/imperative/transformation.h"
  7. #include "megbrain/imperative/utils/helper.h"
  8. #include "megbrain/imperative/value.h"
  9. #include "megbrain/utils/small_vector.h"
  10. namespace mgb::imperative::python {
  11. struct TransformationManager {
  12. public:
  13. enum Segment {
  14. ModuleTrace,
  15. DTypePromote,
  16. DimExpansion,
  17. Grad,
  18. Scalar,
  19. Symbol,
  20. Trace,
  21. Eval,
  22. };
  23. std::array<std::vector<std::shared_ptr<Transformation>>, 8> segments;
  24. private:
  25. template <Segment segment>
  26. void unregister(std::shared_ptr<Transformation> transformation) noexcept {
  27. mgb_assert(segment < segments.size());
  28. auto iter = std::find(
  29. segments[segment].begin(), segments[segment].end(), transformation);
  30. mgb_assert(iter != segments[segment].end());
  31. transformation->unregister();
  32. segments[segment].erase(iter);
  33. }
  34. public:
  35. template <Segment segment>
  36. [[nodiscard]] std::unique_ptr<CleanupGuard<>> register_at(
  37. std::shared_ptr<Transformation> transformation) {
  38. mgb_assert(segment < segments.size());
  39. std::shared_ptr<Transformation> next;
  40. for (size_t i = segment; i < segments.size(); ++i) {
  41. if (!segments[i].empty()) {
  42. next = segments[i].back();
  43. break;
  44. }
  45. }
  46. if (!next) {
  47. transformation->register_at(Transformation::bottom());
  48. } else {
  49. transformation->register_at(next->pos());
  50. }
  51. segments[segment].push_back(transformation);
  52. return std::make_unique<CleanupGuard<>>(
  53. [this, transformation]() { unregister<segment>(transformation); });
  54. }
  55. static TransformationManager& get_instance() {
  56. static TransformationManager sl_instance;
  57. return sl_instance;
  58. }
  59. };
  60. class PyValue final : public PrimitiveValue<PyValue, pybind11::object> {
  61. public:
  62. using PrimitiveValue::PrimitiveValue;
  63. std::string to_string() const {
  64. return pybind11::str((const pybind11::object&)*this).cast<std::string>();
  65. }
  66. };
  67. } // namespace mgb::imperative::python