You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algorithm_cache.h 2.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "megdnn/oprs/base.h"
  4. #include <mutex>
  5. #include <string>
  6. #include <unordered_map>
  7. namespace megdnn {
  8. class AlgorithmCache {
  9. private:
  10. AlgorithmCache() = default;
  11. public:
  12. MGE_WIN_DECLSPEC_FUC static AlgorithmCache& instance();
  13. struct KeyStorage {
  14. size_t k1, k2;
  15. bool operator==(const KeyStorage& k) const { return k1 == k.k1 && k2 == k.k2; }
  16. };
  17. struct Hash {
  18. size_t operator()(const KeyStorage& k) const {
  19. size_t h1 = k.k1;
  20. size_t h2 = k.k2;
  21. h1 ^= h2 + 0x9e3779b9 + (h1 << 6) + (h1 >> 2);
  22. return h1;
  23. }
  24. };
  25. class Key {
  26. Handle* m_handle;
  27. uint32_t m_opr_type;
  28. const TensorLayout* m_inp_layouts_ptr;
  29. size_t m_inp_layouts_size;
  30. const void* m_param_ptr;
  31. size_t m_param_size;
  32. mutable SmallVector<size_t> m_buf;
  33. public:
  34. Key(Handle* opr_handle, Algorithm::OprType opr_type,
  35. const TensorLayout* inp_layouts_ptr, size_t inp_layouts_size,
  36. const void* param_ptr = nullptr, size_t param_size = 0)
  37. : m_handle{opr_handle},
  38. m_opr_type{static_cast<uint32_t>(opr_type)},
  39. m_inp_layouts_ptr{inp_layouts_ptr},
  40. m_inp_layouts_size{inp_layouts_size},
  41. m_param_ptr{param_ptr},
  42. m_param_size{param_size} {}
  43. KeyStorage build_key_storage() const;
  44. };
  45. struct Result {
  46. ExecutionPolicy policy;
  47. size_t workspace;
  48. // for cache collision
  49. SmallVector<size_t> m_buf;
  50. SmallVector<char> m_param_buf;
  51. };
  52. MGE_WIN_DECLSPEC_FUC void put(const Key& key, Result& result);
  53. MGE_WIN_DECLSPEC_FUC Result get(const Key& key);
  54. MGE_WIN_DECLSPEC_FUC void clear();
  55. private:
  56. std::unordered_map<KeyStorage, Result, Hash> m_heuristic_cache;
  57. #if __DEPLOY_ON_XP_SP2__
  58. size_t m_mtx;
  59. #else
  60. std::mutex m_mtx;
  61. #endif
  62. };
  63. } // namespace megdnn