@@ -242,7 +242,13 @@ int LITE_destroy_network(LiteNetwork network) { | |||
LITE_CAPI_BEGIN(); | |||
LITE_ASSERT(network, "The network pass to LITE api is null"); | |||
LITE_LOCK_GUARD(mtx_network); | |||
get_gloabl_network_holder().erase(network); | |||
auto& global_holder = get_gloabl_network_holder(); | |||
if (global_holder.find(network) != global_holder.end()) { | |||
global_holder.erase(network); | |||
} else { | |||
//! means the network has been destoryed | |||
return -1; | |||
} | |||
LITE_CAPI_END(); | |||
} | |||
@@ -60,8 +60,10 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) { | |||
auto lite_tensor = std::make_shared<lite::Tensor>( | |||
tensor_describe.device_id, tensor_describe.device_type, layout, | |||
tensor_describe.is_pinned_host); | |||
LITE_LOCK_GUARD(mtx_tensor); | |||
get_global_tensor_holder()[lite_tensor.get()] = lite_tensor; | |||
{ | |||
LITE_LOCK_GUARD(mtx_tensor); | |||
get_global_tensor_holder()[lite_tensor.get()] = lite_tensor; | |||
} | |||
*tensor = lite_tensor.get(); | |||
LITE_CAPI_END(); | |||
} | |||
@@ -70,7 +72,13 @@ int LITE_destroy_tensor(LiteTensor tensor) { | |||
LITE_CAPI_BEGIN(); | |||
LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null"); | |||
LITE_LOCK_GUARD(mtx_tensor); | |||
get_global_tensor_holder().erase(tensor); | |||
auto& global_holder = get_global_tensor_holder(); | |||
if (global_holder.find(tensor) != global_holder.end()) { | |||
global_holder.erase(tensor); | |||
} else { | |||
//! return -1, means the tensor has been destroyed. | |||
return -1; | |||
} | |||
LITE_CAPI_END(); | |||
} | |||
@@ -126,8 +134,10 @@ int LITE_tensor_slice( | |||
} | |||
} | |||
auto ret_tensor = static_cast<lite::Tensor*>(tensor)->slice(starts, ends, steps); | |||
LITE_LOCK_GUARD(mtx_tensor); | |||
get_global_tensor_holder()[ret_tensor.get()] = ret_tensor; | |||
{ | |||
LITE_LOCK_GUARD(mtx_tensor); | |||
get_global_tensor_holder()[ret_tensor.get()] = ret_tensor; | |||
} | |||
*slice_tensor = ret_tensor.get(); | |||
LITE_CAPI_END(); | |||
} | |||
@@ -226,12 +236,16 @@ int LITE_tensor_concat( | |||
LiteTensor* tensors, int nr_tensor, int dim, LiteDeviceType dst_device, | |||
int device_id, LiteTensor* result_tensor) { | |||
LITE_CAPI_BEGIN(); | |||
LITE_ASSERT(result_tensor, "The tensor pass to LITE c_api is null"); | |||
std::vector<lite::Tensor> v_tensors; | |||
for (int i = 0; i < nr_tensor; i++) { | |||
v_tensors.push_back(*static_cast<lite::Tensor*>(tensors[i])); | |||
} | |||
auto tensor = lite::TensorUtils::concat(v_tensors, dim, dst_device, device_id); | |||
get_global_tensor_holder()[tensor.get()] = tensor; | |||
{ | |||
LITE_LOCK_GUARD(mtx_tensor); | |||
get_global_tensor_holder()[tensor.get()] = tensor; | |||
} | |||
*result_tensor = tensor.get(); | |||
LITE_CAPI_END() | |||
} | |||
@@ -476,7 +476,7 @@ def start_finish_callback(func): | |||
def wrapper(c_ios, c_tensors, size): | |||
ios = {} | |||
for i in range(size): | |||
tensor = LiteTensor() | |||
tensor = LiteTensor(physic_construct=False) | |||
tensor._tensor = c_void_p(c_tensors[i]) | |||
tensor.update() | |||
io = c_ios[i] | |||
@@ -729,7 +729,7 @@ class LiteNetwork(object): | |||
c_name = c_char_p(name.encode("utf-8")) | |||
else: | |||
c_name = c_char_p(name) | |||
tensor = LiteTensor() | |||
tensor = LiteTensor(physic_construct=False) | |||
self._api.LITE_get_io_tensor( | |||
self._network, c_name, phase, byref(tensor._tensor) | |||
) | |||
@@ -233,6 +233,7 @@ class LiteTensor(object): | |||
is_pinned_host=False, | |||
shapes=None, | |||
dtype=None, | |||
physic_construct=True, | |||
): | |||
self._tensor = _Ctensor() | |||
self._layout = LiteLayout() | |||
@@ -250,8 +251,10 @@ class LiteTensor(object): | |||
tensor_desc.device_type = device_type | |||
tensor_desc.device_id = device_id | |||
tensor_desc.is_pinned_host = is_pinned_host | |||
self._api.LITE_make_tensor(tensor_desc, byref(self._tensor)) | |||
self.update() | |||
if physic_construct: | |||
self._api.LITE_make_tensor(tensor_desc, byref(self._tensor)) | |||
self.update() | |||
def __del__(self): | |||
self._api.LITE_destroy_tensor(self._tensor) | |||
@@ -399,7 +402,7 @@ class LiteTensor(object): | |||
c_start = (c_size_t * length)(*start) | |||
c_end = (c_size_t * length)(*end) | |||
c_step = (c_size_t * length)(*step) | |||
slice_tensor = LiteTensor() | |||
slice_tensor = LiteTensor(physic_construct=False) | |||
self._api.LITE_tensor_slice( | |||
self._tensor, c_start, c_end, c_step, length, byref(slice_tensor._tensor), | |||
) | |||
@@ -560,7 +563,7 @@ def LiteTensorConcat( | |||
length = len(tensors) | |||
c_tensors = [t._tensor for t in tensors] | |||
c_tensors = (_Ctensor * length)(*c_tensors) | |||
result_tensor = LiteTensor() | |||
result_tensor = LiteTensor(physic_construct=False) | |||
api.LITE_tensor_concat( | |||
cast(byref(c_tensors), POINTER(c_void_p)), | |||
length, | |||
@@ -1022,6 +1022,20 @@ TEST(TestCapiNetWork, TestShareWeights) { | |||
LITE_CAPI_CHECK(LITE_destroy_network(c_network2)); | |||
} | |||
TEST(TestCapiNetWork, GlobalHolder) { | |||
std::string model_path = "./shufflenet.mge"; | |||
LiteNetwork c_network; | |||
LITE_CAPI_CHECK( | |||
LITE_make_network(&c_network, *default_config(), *default_network_io())); | |||
auto destroy_network = c_network; | |||
LITE_CAPI_CHECK( | |||
LITE_make_network(&c_network, *default_config(), *default_network_io())); | |||
//! make sure destroy_network is destroyed by LITE_make_network | |||
LITE_destroy_network(destroy_network); | |||
ASSERT_EQ(LITE_destroy_network(destroy_network), -1); | |||
LITE_CAPI_CHECK(LITE_destroy_network(c_network)); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -251,6 +251,7 @@ TEST(TestCapiTensor, Slice) { | |||
} | |||
} | |||
LITE_destroy_tensor(tensor); | |||
LITE_destroy_tensor(slice_tensor); | |||
}; | |||
check(1, 8, 1, true); | |||
check(1, 8, 1, false); | |||
@@ -316,6 +317,21 @@ TEST(TestCapiTensor, ThreadLocalError) { | |||
thread2.join(); | |||
} | |||
TEST(TestCapiTensor, GlobalHolder) { | |||
LiteTensor c_tensor0; | |||
LiteTensorDesc description = default_desc; | |||
description.layout = LiteLayout{{20, 20}, 2, LiteDataType::LITE_FLOAT}; | |||
LITE_make_tensor(description, &c_tensor0); | |||
auto destroy_tensor = c_tensor0; | |||
LITE_make_tensor(description, &c_tensor0); | |||
//! make sure destroy_tensor is destroyed by LITE_make_tensor | |||
LITE_destroy_tensor(destroy_tensor); | |||
ASSERT_EQ(LITE_destroy_tensor(destroy_tensor), -1); | |||
LITE_destroy_tensor(c_tensor0); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |