You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algorithm_cache.h 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "megdnn/oprs/base.h"
  4. #include <mutex>
  5. #include <string>
  6. #include <unordered_map>
  7. namespace megdnn {
  8. class AlgorithmCache {
  9. private:
  10. AlgorithmCache() = default;
  11. public:
  12. MGE_WIN_DECLSPEC_FUC static AlgorithmCache& instance();
  13. struct KeyStorage {
  14. size_t k1, k2;
  15. bool operator==(const KeyStorage& k) const { return k1 == k.k1 && k2 == k.k2; }
  16. };
  17. struct Key {
  18. Handle* m_handle;
  19. uint32_t m_opr_type;
  20. const TensorLayout* m_inp_layouts_ptr;
  21. size_t m_inp_layouts_size;
  22. const void* m_param_ptr;
  23. size_t m_param_size;
  24. mutable SmallVector<size_t> m_buf;
  25. public:
  26. Key(Handle* opr_handle, Algorithm::OprType opr_type,
  27. const TensorLayout* inp_layouts_ptr, size_t inp_layouts_size,
  28. const void* param_ptr = nullptr, size_t param_size = 0)
  29. : m_handle{opr_handle},
  30. m_opr_type{static_cast<uint32_t>(opr_type)},
  31. m_inp_layouts_ptr{inp_layouts_ptr},
  32. m_inp_layouts_size{inp_layouts_size},
  33. m_param_ptr{param_ptr},
  34. m_param_size{param_size} {}
  35. KeyStorage build_key_storage() const;
  36. };
  37. struct Result {
  38. ExecutionPolicy policy;
  39. size_t workspace;
  40. // for cache collision
  41. SmallVector<size_t> m_buf;
  42. SmallVector<char> m_param_buf;
  43. };
  44. MGE_WIN_DECLSPEC_FUC void put(const Key& key, Result& result);
  45. MGE_WIN_DECLSPEC_FUC Result get(const Key& key);
  46. MGE_WIN_DECLSPEC_FUC void clear();
  47. private:
  48. struct Hash {
  49. size_t operator()(const KeyStorage& k) const {
  50. size_t h1 = k.k1;
  51. size_t h2 = k.k2;
  52. h1 ^= h2 + 0x9e3779b9 + (h1 << 6) + (h1 >> 2);
  53. return h1;
  54. }
  55. };
  56. std::unordered_map<KeyStorage, Result, Hash> m_heuristic_cache;
  57. #if __DEPLOY_ON_XP_SP2__
  58. size_t m_mtx;
  59. #else
  60. std::mutex m_mtx;
  61. #endif
  62. };
  63. } // namespace megdnn