Browse Source

fix(imperative/tensor): fix ConstTensorCache

GitOrigin-RevId: 0767bcfa28
tags/v1.3.1
Megvii Engine Team 4 years ago
parent
commit
58ebb26156
1 changed files with 20 additions and 5 deletions
  1. +20
    -5
      imperative/src/impl/physical_tensor.cpp

+ 20
- 5
imperative/src/impl/physical_tensor.cpp View File

@@ -125,6 +125,7 @@ public:
size_t size;
BlobPtr blob;

Entry() = default;
Entry(const dt_byte* ptr, size_t size_, BlobPtr blob_)
: data(new dt_byte[size_]), size(size_), blob(blob_) {
memcpy(data.get(), ptr, size);
@@ -136,6 +137,8 @@ public:
}
};

using KV = std::pair<uint64_t, Entry>;

bool check(const HostTensorND& hv) {
auto&& layout = hv.layout();
auto&& span = layout.span();
@@ -190,7 +193,7 @@ public:
}

std::mutex mtx;
size_t hwm = 1024, lwm = 512, max_bytes = TensorShape::MAX_NDIM * 8, window = 65536;
const size_t hwm = 1024, lwm = 512, max_bytes = TensorShape::MAX_NDIM * 8, window = 65536;

private:
void maybe_collect_g0() {
@@ -200,25 +203,37 @@ private:
}
}
void maybe_collect_g1() {
if (g1.size() <= hwm) return;
if (g1.size() < hwm) return;

using KV = std::pair<uint64_t, Entry>;
std::vector<KV> tmp;
tmp.reserve(g1.size());
tmp.clear();
for (auto&& kv : g1) {
tmp.emplace_back(kv.first, std::move(kv.second));
}
std::nth_element(tmp.begin(), tmp.begin() + lwm, tmp.end(), [](const KV& lhs, const KV& rhs) {
return lhs.second.hitcnt > rhs.second.hitcnt;
});
tmp.resize(lwm);
g1.clear();
for (auto&& kv : tmp) {
kv.second.hitcnt = 0;
g1.emplace(std::move(kv));
}
}

// g0: records blobs which have been seen at least once (within a window)
// g0b: backup of g0
// g1: records the most frequently used blobs which have been seen at least
// twice. When `g1.size() == hwm`, it will be refreshed and only the top
// `lhw` frequently used blobs will be kept.
std::unordered_set<uint64_t> g0, g0b;
std::unordered_map<uint64_t, Entry> g1;
std::vector<KV> tmp;

public:
ConstTensorCache() {
g0.reserve(window), g0b.reserve(window);
g1.reserve(hwm), tmp.reserve(hwm);
}
};

struct MultiCNConstTensorCache : CompNodeDepedentObject {


Loading…
Cancel
Save