@@ -242,7 +242,13 @@ int LITE_destroy_network(LiteNetwork network) { | |||||
LITE_CAPI_BEGIN(); | LITE_CAPI_BEGIN(); | ||||
LITE_ASSERT(network, "The network pass to LITE api is null"); | LITE_ASSERT(network, "The network pass to LITE api is null"); | ||||
LITE_LOCK_GUARD(mtx_network); | LITE_LOCK_GUARD(mtx_network); | ||||
get_gloabl_network_holder().erase(network); | |||||
auto& global_holder = get_gloabl_network_holder(); | |||||
if (global_holder.find(network) != global_holder.end()) { | |||||
global_holder.erase(network); | |||||
} else { | |||||
//! means the network has been destoryed | |||||
return -1; | |||||
} | |||||
LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
} | } | ||||
@@ -60,8 +60,10 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) { | |||||
auto lite_tensor = std::make_shared<lite::Tensor>( | auto lite_tensor = std::make_shared<lite::Tensor>( | ||||
tensor_describe.device_id, tensor_describe.device_type, layout, | tensor_describe.device_id, tensor_describe.device_type, layout, | ||||
tensor_describe.is_pinned_host); | tensor_describe.is_pinned_host); | ||||
LITE_LOCK_GUARD(mtx_tensor); | |||||
get_global_tensor_holder()[lite_tensor.get()] = lite_tensor; | |||||
{ | |||||
LITE_LOCK_GUARD(mtx_tensor); | |||||
get_global_tensor_holder()[lite_tensor.get()] = lite_tensor; | |||||
} | |||||
*tensor = lite_tensor.get(); | *tensor = lite_tensor.get(); | ||||
LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
} | } | ||||
@@ -70,7 +72,13 @@ int LITE_destroy_tensor(LiteTensor tensor) { | |||||
LITE_CAPI_BEGIN(); | LITE_CAPI_BEGIN(); | ||||
LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null"); | LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null"); | ||||
LITE_LOCK_GUARD(mtx_tensor); | LITE_LOCK_GUARD(mtx_tensor); | ||||
get_global_tensor_holder().erase(tensor); | |||||
auto& global_holder = get_global_tensor_holder(); | |||||
if (global_holder.find(tensor) != global_holder.end()) { | |||||
global_holder.erase(tensor); | |||||
} else { | |||||
//! return -1, means the tensor has been destroyed. | |||||
return -1; | |||||
} | |||||
LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
} | } | ||||
@@ -126,8 +134,10 @@ int LITE_tensor_slice( | |||||
} | } | ||||
} | } | ||||
auto ret_tensor = static_cast<lite::Tensor*>(tensor)->slice(starts, ends, steps); | auto ret_tensor = static_cast<lite::Tensor*>(tensor)->slice(starts, ends, steps); | ||||
LITE_LOCK_GUARD(mtx_tensor); | |||||
get_global_tensor_holder()[ret_tensor.get()] = ret_tensor; | |||||
{ | |||||
LITE_LOCK_GUARD(mtx_tensor); | |||||
get_global_tensor_holder()[ret_tensor.get()] = ret_tensor; | |||||
} | |||||
*slice_tensor = ret_tensor.get(); | *slice_tensor = ret_tensor.get(); | ||||
LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
} | } | ||||
@@ -226,12 +236,16 @@ int LITE_tensor_concat( | |||||
LiteTensor* tensors, int nr_tensor, int dim, LiteDeviceType dst_device, | LiteTensor* tensors, int nr_tensor, int dim, LiteDeviceType dst_device, | ||||
int device_id, LiteTensor* result_tensor) { | int device_id, LiteTensor* result_tensor) { | ||||
LITE_CAPI_BEGIN(); | LITE_CAPI_BEGIN(); | ||||
LITE_ASSERT(result_tensor, "The tensor pass to LITE c_api is null"); | |||||
std::vector<lite::Tensor> v_tensors; | std::vector<lite::Tensor> v_tensors; | ||||
for (int i = 0; i < nr_tensor; i++) { | for (int i = 0; i < nr_tensor; i++) { | ||||
v_tensors.push_back(*static_cast<lite::Tensor*>(tensors[i])); | v_tensors.push_back(*static_cast<lite::Tensor*>(tensors[i])); | ||||
} | } | ||||
auto tensor = lite::TensorUtils::concat(v_tensors, dim, dst_device, device_id); | auto tensor = lite::TensorUtils::concat(v_tensors, dim, dst_device, device_id); | ||||
get_global_tensor_holder()[tensor.get()] = tensor; | |||||
{ | |||||
LITE_LOCK_GUARD(mtx_tensor); | |||||
get_global_tensor_holder()[tensor.get()] = tensor; | |||||
} | |||||
*result_tensor = tensor.get(); | *result_tensor = tensor.get(); | ||||
LITE_CAPI_END() | LITE_CAPI_END() | ||||
} | } | ||||
@@ -476,7 +476,7 @@ def start_finish_callback(func): | |||||
def wrapper(c_ios, c_tensors, size): | def wrapper(c_ios, c_tensors, size): | ||||
ios = {} | ios = {} | ||||
for i in range(size): | for i in range(size): | ||||
tensor = LiteTensor() | |||||
tensor = LiteTensor(physic_construct=False) | |||||
tensor._tensor = c_void_p(c_tensors[i]) | tensor._tensor = c_void_p(c_tensors[i]) | ||||
tensor.update() | tensor.update() | ||||
io = c_ios[i] | io = c_ios[i] | ||||
@@ -729,7 +729,7 @@ class LiteNetwork(object): | |||||
c_name = c_char_p(name.encode("utf-8")) | c_name = c_char_p(name.encode("utf-8")) | ||||
else: | else: | ||||
c_name = c_char_p(name) | c_name = c_char_p(name) | ||||
tensor = LiteTensor() | |||||
tensor = LiteTensor(physic_construct=False) | |||||
self._api.LITE_get_io_tensor( | self._api.LITE_get_io_tensor( | ||||
self._network, c_name, phase, byref(tensor._tensor) | self._network, c_name, phase, byref(tensor._tensor) | ||||
) | ) | ||||
@@ -233,6 +233,7 @@ class LiteTensor(object): | |||||
is_pinned_host=False, | is_pinned_host=False, | ||||
shapes=None, | shapes=None, | ||||
dtype=None, | dtype=None, | ||||
physic_construct=True, | |||||
): | ): | ||||
self._tensor = _Ctensor() | self._tensor = _Ctensor() | ||||
self._layout = LiteLayout() | self._layout = LiteLayout() | ||||
@@ -250,8 +251,10 @@ class LiteTensor(object): | |||||
tensor_desc.device_type = device_type | tensor_desc.device_type = device_type | ||||
tensor_desc.device_id = device_id | tensor_desc.device_id = device_id | ||||
tensor_desc.is_pinned_host = is_pinned_host | tensor_desc.is_pinned_host = is_pinned_host | ||||
self._api.LITE_make_tensor(tensor_desc, byref(self._tensor)) | |||||
self.update() | |||||
if physic_construct: | |||||
self._api.LITE_make_tensor(tensor_desc, byref(self._tensor)) | |||||
self.update() | |||||
def __del__(self): | def __del__(self): | ||||
self._api.LITE_destroy_tensor(self._tensor) | self._api.LITE_destroy_tensor(self._tensor) | ||||
@@ -399,7 +402,7 @@ class LiteTensor(object): | |||||
c_start = (c_size_t * length)(*start) | c_start = (c_size_t * length)(*start) | ||||
c_end = (c_size_t * length)(*end) | c_end = (c_size_t * length)(*end) | ||||
c_step = (c_size_t * length)(*step) | c_step = (c_size_t * length)(*step) | ||||
slice_tensor = LiteTensor() | |||||
slice_tensor = LiteTensor(physic_construct=False) | |||||
self._api.LITE_tensor_slice( | self._api.LITE_tensor_slice( | ||||
self._tensor, c_start, c_end, c_step, length, byref(slice_tensor._tensor), | self._tensor, c_start, c_end, c_step, length, byref(slice_tensor._tensor), | ||||
) | ) | ||||
@@ -560,7 +563,7 @@ def LiteTensorConcat( | |||||
length = len(tensors) | length = len(tensors) | ||||
c_tensors = [t._tensor for t in tensors] | c_tensors = [t._tensor for t in tensors] | ||||
c_tensors = (_Ctensor * length)(*c_tensors) | c_tensors = (_Ctensor * length)(*c_tensors) | ||||
result_tensor = LiteTensor() | |||||
result_tensor = LiteTensor(physic_construct=False) | |||||
api.LITE_tensor_concat( | api.LITE_tensor_concat( | ||||
cast(byref(c_tensors), POINTER(c_void_p)), | cast(byref(c_tensors), POINTER(c_void_p)), | ||||
length, | length, | ||||
@@ -1022,6 +1022,20 @@ TEST(TestCapiNetWork, TestShareWeights) { | |||||
LITE_CAPI_CHECK(LITE_destroy_network(c_network2)); | LITE_CAPI_CHECK(LITE_destroy_network(c_network2)); | ||||
} | } | ||||
TEST(TestCapiNetWork, GlobalHolder) { | |||||
std::string model_path = "./shufflenet.mge"; | |||||
LiteNetwork c_network; | |||||
LITE_CAPI_CHECK( | |||||
LITE_make_network(&c_network, *default_config(), *default_network_io())); | |||||
auto destroy_network = c_network; | |||||
LITE_CAPI_CHECK( | |||||
LITE_make_network(&c_network, *default_config(), *default_network_io())); | |||||
//! make sure destroy_network is destroyed by LITE_make_network | |||||
LITE_destroy_network(destroy_network); | |||||
ASSERT_EQ(LITE_destroy_network(destroy_network), -1); | |||||
LITE_CAPI_CHECK(LITE_destroy_network(c_network)); | |||||
} | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -251,6 +251,7 @@ TEST(TestCapiTensor, Slice) { | |||||
} | } | ||||
} | } | ||||
LITE_destroy_tensor(tensor); | LITE_destroy_tensor(tensor); | ||||
LITE_destroy_tensor(slice_tensor); | |||||
}; | }; | ||||
check(1, 8, 1, true); | check(1, 8, 1, true); | ||||
check(1, 8, 1, false); | check(1, 8, 1, false); | ||||
@@ -316,6 +317,21 @@ TEST(TestCapiTensor, ThreadLocalError) { | |||||
thread2.join(); | thread2.join(); | ||||
} | } | ||||
TEST(TestCapiTensor, GlobalHolder) { | |||||
LiteTensor c_tensor0; | |||||
LiteTensorDesc description = default_desc; | |||||
description.layout = LiteLayout{{20, 20}, 2, LiteDataType::LITE_FLOAT}; | |||||
LITE_make_tensor(description, &c_tensor0); | |||||
auto destroy_tensor = c_tensor0; | |||||
LITE_make_tensor(description, &c_tensor0); | |||||
//! make sure destroy_tensor is destroyed by LITE_make_tensor | |||||
LITE_destroy_tensor(destroy_tensor); | |||||
ASSERT_EQ(LITE_destroy_tensor(destroy_tensor), -1); | |||||
LITE_destroy_tensor(c_tensor0); | |||||
} | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |