@@ -17,8 +17,10 @@ | |||||
#include "lite-c/tensor_c.h" | #include "lite-c/tensor_c.h" | ||||
#include "lite/network.h" | #include "lite/network.h" | ||||
#if LITE_ENABLE_EXCEPTION | |||||
#include <exception> | #include <exception> | ||||
#include <stdexcept> | #include <stdexcept> | ||||
#endif | |||||
//! convert c Layout to lite::Layout | //! convert c Layout to lite::Layout | ||||
lite::Layout convert_to_layout(const LiteLayout& layout); | lite::Layout convert_to_layout(const LiteLayout& layout); | ||||
@@ -13,11 +13,7 @@ | |||||
#include "common.h" | #include "common.h" | ||||
#include "lite-c/global_c.h" | #include "lite-c/global_c.h" | ||||
#include <exception> | |||||
#include <mutex> | |||||
namespace { | namespace { | ||||
class ErrorMsg { | class ErrorMsg { | ||||
public: | public: | ||||
std::string& get_error_msg() { return error_msg; } | std::string& get_error_msg() { return error_msg; } | ||||
@@ -26,18 +22,22 @@ public: | |||||
private: | private: | ||||
std::string error_msg; | std::string error_msg; | ||||
}; | }; | ||||
static LITE_MUTEX mtx_error; | |||||
ErrorMsg& get_global_error() { | ErrorMsg& get_global_error() { | ||||
static thread_local ErrorMsg error_msg; | |||||
static ErrorMsg error_msg; | |||||
return error_msg; | return error_msg; | ||||
} | } | ||||
} // namespace | } // namespace | ||||
int LiteHandleException(const std::exception& e) { | int LiteHandleException(const std::exception& e) { | ||||
LITE_LOCK_GUARD(mtx_error); | |||||
get_global_error().set_error_msg(e.what()); | get_global_error().set_error_msg(e.what()); | ||||
return -1; | return -1; | ||||
} | } | ||||
const char* LITE_get_last_error() { | const char* LITE_get_last_error() { | ||||
LITE_LOCK_GUARD(mtx_error); | |||||
return get_global_error().get_error_msg().c_str(); | return get_global_error().get_error_msg().c_str(); | ||||
} | } | ||||
@@ -72,9 +72,9 @@ LiteNetworkIO* default_network_io() { | |||||
} | } | ||||
namespace { | namespace { | ||||
static LITE_MUTEX mtx_network; | |||||
std::unordered_map<void*, std::shared_ptr<lite::Network>>& get_gloabl_network_holder() { | std::unordered_map<void*, std::shared_ptr<lite::Network>>& get_gloabl_network_holder() { | ||||
static thread_local std::unordered_map<void*, std::shared_ptr<lite::Network>> | |||||
network_holder; | |||||
static std::unordered_map<void*, std::shared_ptr<lite::Network>> network_holder; | |||||
return network_holder; | return network_holder; | ||||
} | } | ||||
@@ -168,6 +168,7 @@ int LITE_make_default_network(LiteNetwork* network) { | |||||
LITE_CAPI_BEGIN(); | LITE_CAPI_BEGIN(); | ||||
LITE_ASSERT(network, "The network pass to LITE api is null"); | LITE_ASSERT(network, "The network pass to LITE api is null"); | ||||
auto lite_network = std::make_shared<lite::Network>(); | auto lite_network = std::make_shared<lite::Network>(); | ||||
LITE_LOCK_GUARD(mtx_network); | |||||
get_gloabl_network_holder()[lite_network.get()] = lite_network; | get_gloabl_network_holder()[lite_network.get()] = lite_network; | ||||
*network = lite_network.get(); | *network = lite_network.get(); | ||||
LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
@@ -179,6 +180,7 @@ int LITE_make_network( | |||||
LITE_ASSERT(network, "The network pass to LITE api is null"); | LITE_ASSERT(network, "The network pass to LITE api is null"); | ||||
auto lite_network = std::make_shared<lite::Network>( | auto lite_network = std::make_shared<lite::Network>( | ||||
convert_to_lite_config(config), convert_to_lite_io(network_io)); | convert_to_lite_config(config), convert_to_lite_io(network_io)); | ||||
LITE_LOCK_GUARD(mtx_network); | |||||
get_gloabl_network_holder()[lite_network.get()] = lite_network; | get_gloabl_network_holder()[lite_network.get()] = lite_network; | ||||
*network = lite_network.get(); | *network = lite_network.get(); | ||||
LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
@@ -188,6 +190,7 @@ int LITE_make_network_config(LiteNetwork* network, const LiteConfig config) { | |||||
LITE_CAPI_BEGIN(); | LITE_CAPI_BEGIN(); | ||||
LITE_ASSERT(network, "The network pass to LITE api is null"); | LITE_ASSERT(network, "The network pass to LITE api is null"); | ||||
auto lite_network = std::make_shared<lite::Network>(convert_to_lite_config(config)); | auto lite_network = std::make_shared<lite::Network>(convert_to_lite_config(config)); | ||||
LITE_LOCK_GUARD(mtx_network); | |||||
get_gloabl_network_holder()[lite_network.get()] = lite_network; | get_gloabl_network_holder()[lite_network.get()] = lite_network; | ||||
*network = lite_network.get(); | *network = lite_network.get(); | ||||
LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
@@ -212,6 +215,7 @@ int LITE_load_model_from_path(LiteNetwork network, const char* model_path) { | |||||
int LITE_destroy_network(LiteNetwork network) { | int LITE_destroy_network(LiteNetwork network) { | ||||
LITE_CAPI_BEGIN(); | LITE_CAPI_BEGIN(); | ||||
LITE_ASSERT(network, "The network pass to LITE api is null"); | LITE_ASSERT(network, "The network pass to LITE api is null"); | ||||
LITE_LOCK_GUARD(mtx_network); | |||||
get_gloabl_network_holder().erase(network); | get_gloabl_network_holder().erase(network); | ||||
LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
} | } | ||||
@@ -26,13 +26,16 @@ const LiteTensorDesc default_desc = { | |||||
.device_type = LiteDeviceType::LITE_CPU, | .device_type = LiteDeviceType::LITE_CPU, | ||||
.device_id = 0}; | .device_id = 0}; | ||||
namespace { | namespace { | ||||
static LITE_MUTEX mtx_tensor; | |||||
std::unordered_map<void*, std::shared_ptr<lite::Tensor>>& get_global_tensor_holder() { | std::unordered_map<void*, std::shared_ptr<lite::Tensor>>& get_global_tensor_holder() { | ||||
static thread_local std::unordered_map<void*, std::shared_ptr<lite::Tensor>> | |||||
global_holder; | |||||
static std::unordered_map<void*, std::shared_ptr<lite::Tensor>> global_holder; | |||||
return global_holder; | return global_holder; | ||||
} | } | ||||
static LITE_MUTEX mtx_attr; | |||||
std::unordered_map<std::string, lite::LiteAny>& get_global_tensor_attr_holder() { | std::unordered_map<std::string, lite::LiteAny>& get_global_tensor_attr_holder() { | ||||
static thread_local std::unordered_map<std::string, lite::LiteAny> global_holder; | |||||
static std::unordered_map<std::string, lite::LiteAny> global_holder; | |||||
return global_holder; | return global_holder; | ||||
} | } | ||||
} // namespace | } // namespace | ||||
@@ -68,6 +71,7 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) { | |||||
auto lite_tensor = std::make_shared<lite::Tensor>( | auto lite_tensor = std::make_shared<lite::Tensor>( | ||||
tensor_describe.device_id, tensor_describe.device_type, layout, | tensor_describe.device_id, tensor_describe.device_type, layout, | ||||
tensor_describe.is_pinned_host); | tensor_describe.is_pinned_host); | ||||
LITE_LOCK_GUARD(mtx_tensor); | |||||
get_global_tensor_holder()[lite_tensor.get()] = lite_tensor; | get_global_tensor_holder()[lite_tensor.get()] = lite_tensor; | ||||
*tensor = lite_tensor.get(); | *tensor = lite_tensor.get(); | ||||
LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
@@ -76,6 +80,7 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) { | |||||
int LITE_destroy_tensor(LiteTensor tensor) { | int LITE_destroy_tensor(LiteTensor tensor) { | ||||
LITE_CAPI_BEGIN(); | LITE_CAPI_BEGIN(); | ||||
LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null"); | LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null"); | ||||
LITE_LOCK_GUARD(mtx_tensor); | |||||
get_global_tensor_holder().erase(tensor); | get_global_tensor_holder().erase(tensor); | ||||
LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
} | } | ||||
@@ -132,6 +137,7 @@ int LITE_tensor_slice( | |||||
} | } | ||||
} | } | ||||
auto ret_tensor = static_cast<lite::Tensor*>(tensor)->slice(starts, ends, steps); | auto ret_tensor = static_cast<lite::Tensor*>(tensor)->slice(starts, ends, steps); | ||||
LITE_LOCK_GUARD(mtx_tensor); | |||||
get_global_tensor_holder()[ret_tensor.get()] = ret_tensor; | get_global_tensor_holder()[ret_tensor.get()] = ret_tensor; | ||||
*slice_tensor = ret_tensor.get(); | *slice_tensor = ret_tensor.get(); | ||||
LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
@@ -18,6 +18,7 @@ | |||||
#include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
#include <memory> | #include <memory> | ||||
#include <thread> | |||||
TEST(TestCapiTensor, Basic) { | TEST(TestCapiTensor, Basic) { | ||||
LiteTensor c_tensor0, c_tensor1; | LiteTensor c_tensor0, c_tensor1; | ||||
@@ -305,6 +306,23 @@ TEST(TestCapiTensor, GetMemoryByIndex) { | |||||
LITE_destroy_tensor(c_tensor0); | LITE_destroy_tensor(c_tensor0); | ||||
} | } | ||||
TEST(TestCapiTensor, ThreadLocalError) { | |||||
LiteTensor c_tensor0; | |||||
LiteTensorDesc description = default_desc; | |||||
description.layout = LiteLayout{{20, 20}, 2, LiteDataType::LITE_FLOAT}; | |||||
void *ptr0, *ptr1; | |||||
std::thread thread1([&]() { | |||||
LITE_make_tensor(description, &c_tensor0); | |||||
LITE_get_tensor_memory(c_tensor0, &ptr0); | |||||
}); | |||||
thread1.join(); | |||||
std::thread thread2([&]() { | |||||
LITE_get_tensor_memory(c_tensor0, &ptr1); | |||||
LITE_destroy_tensor(c_tensor0); | |||||
}); | |||||
thread2.join(); | |||||
} | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |