Browse Source

Add single_op model_id.

tags/v1.2.0
unknown 4 years ago
parent
commit
9accbccd2f
4 changed files with 34 additions and 20 deletions
  1. +4
    -4
      ge/single_op/single_op_manager.cc
  2. +12
    -10
      ge/single_op/stream_resource.cc
  3. +6
    -6
      ge/single_op/stream_resource.h
  4. +12
    -0
      tests/ut/ge/single_op/stream_resource_unittest.cc

+ 4
- 4
ge/single_op/single_op_manager.cc View File

@@ -46,14 +46,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::GetOpFr
return ACL_ERROR_GE_MEMORY_ALLOCATION; return ACL_ERROR_GE_MEMORY_ALLOCATION;
} }


SingleOp *op = res->GetOperator(model_data.model_data);
SingleOp *op = res->GetOperator(model_id);
if (op != nullptr) { if (op != nullptr) {
GELOGD("Got operator from stream cache"); GELOGD("Got operator from stream cache");
*single_op = op; *single_op = op;
return SUCCESS; return SUCCESS;
} }


return res->BuildOperator(model_name, model_data, single_op);
return res->BuildOperator(model_data, single_op, model_id);
} }


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::ReleaseResource(void *stream) { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::ReleaseResource(void *stream) {
@@ -116,14 +116,14 @@ Status SingleOpManager::GetDynamicOpFromModel(const string &model_name,
return ACL_ERROR_GE_MEMORY_ALLOCATION; return ACL_ERROR_GE_MEMORY_ALLOCATION;
} }


DynamicSingleOp *op = res->GetDynamicOperator(model_data.model_data);
DynamicSingleOp *op = res->GetDynamicOperator(model_id);
if (op != nullptr) { if (op != nullptr) {
GELOGD("Got operator from stream cache"); GELOGD("Got operator from stream cache");
*single_op = op; *single_op = op;
return SUCCESS; return SUCCESS;
} }


return res->BuildDynamicOperator(model_name, model_data, single_op);
return res->BuildDynamicOperator(model_data, single_op, model_id);
} }


void SingleOpManager::RegisterTilingFunc() { void SingleOpManager::RegisterTilingFunc() {


+ 12
- 10
ge/single_op/stream_resource.cc View File

@@ -41,7 +41,7 @@ StreamResource::~StreamResource() {
} }
} }


SingleOp *StreamResource::GetOperator(const void *key) {
SingleOp *StreamResource::GetOperator(const uint64_t key) {
std::lock_guard<std::mutex> lk(mu_); std::lock_guard<std::mutex> lk(mu_);
auto it = op_map_.find(key); auto it = op_map_.find(key);
if (it == op_map_.end()) { if (it == op_map_.end()) {
@@ -51,7 +51,7 @@ SingleOp *StreamResource::GetOperator(const void *key) {
return it->second.get(); return it->second.get();
} }


DynamicSingleOp *StreamResource::GetDynamicOperator(const void *key) {
DynamicSingleOp *StreamResource::GetDynamicOperator(const uint64_t key) {
std::lock_guard<std::mutex> lk(mu_); std::lock_guard<std::mutex> lk(mu_);
auto it = dynamic_op_map_.find(key); auto it = dynamic_op_map_.find(key);
if (it == dynamic_op_map_.end()) { if (it == dynamic_op_map_.end()) {
@@ -138,11 +138,12 @@ uint8_t *StreamResource::MallocWeight(const std::string &purpose, size_t size) {
return buffer; return buffer;
} }


Status StreamResource::BuildDynamicOperator(const string &model_name,
const ModelData &model_data,
DynamicSingleOp **single_op) {
Status StreamResource::BuildDynamicOperator(const ModelData &model_data,
DynamicSingleOp **single_op,
const uint64_t model_id) {
const string &model_name = std::to_string(model_id);
std::lock_guard<std::mutex> lk(mu_); std::lock_guard<std::mutex> lk(mu_);
auto it = dynamic_op_map_.find(model_data.model_data);
auto it = dynamic_op_map_.find(model_id);
if (it != dynamic_op_map_.end()) { if (it != dynamic_op_map_.end()) {
*single_op = it->second.get(); *single_op = it->second.get();
return SUCCESS; return SUCCESS;
@@ -162,13 +163,14 @@ Status StreamResource::BuildDynamicOperator(const string &model_name,
GE_CHK_STATUS_RET(model.BuildDynamicOp(*this, *new_op), GE_CHK_STATUS_RET(model.BuildDynamicOp(*this, *new_op),
"Build op failed. op = %s, ret = %u", model_name.c_str(), ret); "Build op failed. op = %s, ret = %u", model_name.c_str(), ret);
*single_op = new_op.get(); *single_op = new_op.get();
dynamic_op_map_[model_data.model_data] = std::move(new_op);
dynamic_op_map_[model_id] = std::move(new_op);
return SUCCESS; return SUCCESS;
} }


Status StreamResource::BuildOperator(const string &model_name, const ModelData &model_data, SingleOp **single_op) {
Status StreamResource::BuildOperator(const ModelData &model_data, SingleOp **single_op, const uint64_t model_id) {
const string &model_name = std::to_string(model_id);
std::lock_guard<std::mutex> lk(mu_); std::lock_guard<std::mutex> lk(mu_);
auto it = op_map_.find(model_data.model_data);
auto it = op_map_.find(model_id);
if (it != op_map_.end()) { if (it != op_map_.end()) {
*single_op = it->second.get(); *single_op = it->second.get();
return SUCCESS; return SUCCESS;
@@ -191,7 +193,7 @@ Status StreamResource::BuildOperator(const string &model_name, const ModelData &
GE_CHK_STATUS_RET(model.BuildOp(*this, *new_op), "Build op failed. op = %s, ret = %u", model_name.c_str(), ret); GE_CHK_STATUS_RET(model.BuildOp(*this, *new_op), "Build op failed. op = %s, ret = %u", model_name.c_str(), ret);


*single_op = new_op.get(); *single_op = new_op.get();
op_map_[model_data.model_data] = std::move(new_op);
op_map_[model_id] = std::move(new_op);
return SUCCESS; return SUCCESS;
} }




+ 6
- 6
ge/single_op/stream_resource.h View File

@@ -40,11 +40,11 @@ class StreamResource {
rtStream_t GetStream() const; rtStream_t GetStream() const;
void SetStream(rtStream_t stream); void SetStream(rtStream_t stream);


SingleOp *GetOperator(const void *key);
DynamicSingleOp *GetDynamicOperator(const void *key);
SingleOp *GetOperator(const uint64_t key);
DynamicSingleOp *GetDynamicOperator(const uint64_t key);


Status BuildOperator(const std::string &model_name, const ModelData &model_data, SingleOp **single_op);
Status BuildDynamicOperator(const std::string &model_name, const ModelData &model_data, DynamicSingleOp **single_op);
Status BuildOperator(const ModelData &model_data, SingleOp **single_op, const uint64_t model_id);
Status BuildDynamicOperator(const ModelData &model_data, DynamicSingleOp **single_op, const uint64_t model_id);


uint8_t *MallocMemory(const std::string &purpose, size_t size, bool holding_lock = true); uint8_t *MallocMemory(const std::string &purpose, size_t size, bool holding_lock = true);
uint8_t *MallocWeight(const std::string &purpose, size_t size); uint8_t *MallocWeight(const std::string &purpose, size_t size);
@@ -60,8 +60,8 @@ class StreamResource {
size_t max_memory_size_ = 0; size_t max_memory_size_ = 0;
std::vector<uint8_t *> memory_list_; std::vector<uint8_t *> memory_list_;
std::vector<uint8_t *> weight_list_; std::vector<uint8_t *> weight_list_;
std::unordered_map<const void *, std::unique_ptr<SingleOp>> op_map_;
std::unordered_map<const void *, std::unique_ptr<DynamicSingleOp>> dynamic_op_map_;
std::unordered_map<uint64_t, std::unique_ptr<SingleOp>> op_map_;
std::unordered_map<uint64_t, std::unique_ptr<DynamicSingleOp>> dynamic_op_map_;
rtStream_t stream_ = nullptr; rtStream_t stream_ = nullptr;
std::mutex mu_; std::mutex mu_;
std::mutex stream_mu_; std::mutex stream_mu_;


+ 12
- 0
tests/ut/ge/single_op/stream_resource_unittest.cc View File

@@ -58,6 +58,18 @@ TEST_F(UtestStreamResource, test_malloc_memory) {
ASSERT_NE(res.MallocMemory(purpose, 100), nullptr); ASSERT_NE(res.MallocMemory(purpose, 100), nullptr);
} }


TEST_F(UtestStreamResource, test_malloc_memory) {
StreamResource res((uintptr_t)1);
ModelData model_data;
SingleOp *single_op = nullptr;
DynamicSingleOp *dynamic_single_op = nullptr;
res.op_map_[0] = &single_op;
res.dynamic_op_map_[1] = &dynamic_single_op;

ASSERT_EQ(res.BuildOperator(model_data, &single_op, 0), SUCCESS);
ASSERT_EQ(res.BuildDynamicOperator(model_data, &dynamic_single_op, 1), SUCCESS);
}

/* /*
TEST_F(UtestStreamResource, test_do_malloc_memory) { TEST_F(UtestStreamResource, test_do_malloc_memory) {
size_t max_allocated = 0; size_t max_allocated = 0;


Loading…
Cancel
Save