@@ -235,15 +235,29 @@ class LiteNetworkIO(object): | |||||
LiteAsyncCallback = CFUNCTYPE(c_int) | LiteAsyncCallback = CFUNCTYPE(c_int) | ||||
LiteStartCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t) | |||||
LiteFinishCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t) | |||||
def wrap_async_callback(func): | |||||
global wrapper | |||||
@CFUNCTYPE(c_int) | |||||
def wrapper(): | |||||
return func() | |||||
return wrapper | |||||
def start_finish_callback(func): | def start_finish_callback(func): | ||||
global wrapper | |||||
@CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t) | @CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t) | ||||
def wrapper(c_ios, c_tensors, size): | def wrapper(c_ios, c_tensors, size): | ||||
ios = {} | ios = {} | ||||
for i in range(size): | for i in range(size): | ||||
tensor = LiteTensor() | tensor = LiteTensor() | ||||
tensor._tensor = c_tensors[i] | |||||
tensor._tensor = c_void_p(c_tensors[i]) | |||||
tensor.update() | tensor.update() | ||||
io = c_ios[i] | io = c_ios[i] | ||||
ios[io] = tensor | ios[io] = tensor | ||||
@@ -288,8 +302,8 @@ class _NetworkAPI(_LiteCObjBase): | |||||
("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]), | ("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]), | ||||
("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]), | ("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]), | ||||
("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]), | ("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]), | ||||
("LITE_set_start_callback", [_Cnetwork]), | |||||
("LITE_set_finish_callback", [_Cnetwork]), | |||||
("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]), | |||||
("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]), | |||||
("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), | ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), | ||||
] | ] | ||||
@@ -482,8 +496,8 @@ class LiteNetwork(object): | |||||
self._api.LITE_share_runtime_memroy(self._network, src_network._network) | self._api.LITE_share_runtime_memroy(self._network, src_network._network) | ||||
def async_with_callback(self, async_callback): | def async_with_callback(self, async_callback): | ||||
async_callback = LiteAsyncCallback(async_callback) | |||||
self._api.LITE_set_async_callback(self._network, async_callback) | |||||
callback = wrap_async_callback(async_callback) | |||||
self._api.LITE_set_async_callback(self._network, callback) | |||||
def set_start_callback(self, start_callback): | def set_start_callback(self, start_callback): | ||||
""" | """ | ||||
@@ -491,7 +505,8 @@ class LiteNetwork(object): | |||||
the start_callback with param mapping from LiteIO to the corresponding | the start_callback with param mapping from LiteIO to the corresponding | ||||
LiteTensor | LiteTensor | ||||
""" | """ | ||||
self._api.LITE_set_start_callback(self._network, start_callback) | |||||
callback = start_finish_callback(start_callback) | |||||
self._api.LITE_set_start_callback(self._network, callback) | |||||
def set_finish_callback(self, finish_callback): | def set_finish_callback(self, finish_callback): | ||||
""" | """ | ||||
@@ -499,7 +514,8 @@ class LiteNetwork(object): | |||||
the finish_callback with param mapping from LiteIO to the corresponding | the finish_callback with param mapping from LiteIO to the corresponding | ||||
LiteTensor | LiteTensor | ||||
""" | """ | ||||
self._api.LITE_set_finish_callback(self._network, finish_callback) | |||||
callback = start_finish_callback(finish_callback) | |||||
self._api.LITE_set_finish_callback(self._network, callback) | |||||
def enable_profile_performance(self, profile_file): | def enable_profile_performance(self, profile_file): | ||||
c_file = profile_file.encode("utf-8") | c_file = profile_file.encode("utf-8") | ||||
@@ -274,83 +274,81 @@ class TestNetwork(TestShuffleNet): | |||||
self.do_forward(src_network) | self.do_forward(src_network) | ||||
self.do_forward(new_network) | self.do_forward(new_network) | ||||
# def test_network_async(self): | |||||
# count = 0 | |||||
# finished = False | |||||
# | |||||
# def async_callback(): | |||||
# nonlocal finished | |||||
# finished = True | |||||
# return 0 | |||||
# | |||||
# option = LiteOptions() | |||||
# option.var_sanity_check_first_run = 0 | |||||
# config = LiteConfig(option=option) | |||||
# | |||||
# network = LiteNetwork(config=config) | |||||
# network.load(self.model_path) | |||||
# | |||||
# network.async_with_callback(async_callback) | |||||
# | |||||
# input_tensor = network.get_io_tensor(network.get_input_name(0)) | |||||
# output_tensor = network.get_io_tensor(network.get_output_name(0)) | |||||
# | |||||
# input_tensor.set_data_by_share(self.input_data) | |||||
# network.forward() | |||||
# | |||||
# while not finished: | |||||
# count += 1 | |||||
# | |||||
# assert count > 0 | |||||
# output_data = output_tensor.to_numpy() | |||||
# self.check_correct(output_data) | |||||
# | |||||
# def test_network_start_callback(self): | |||||
# network = LiteNetwork() | |||||
# network.load(self.model_path) | |||||
# start_checked = False | |||||
# | |||||
# @start_finish_callback | |||||
# def start_callback(ios): | |||||
# nonlocal start_checked | |||||
# start_checked = True | |||||
# assert len(ios) == 1 | |||||
# for key in ios: | |||||
# io = key | |||||
# data = ios[key].to_numpy().flatten() | |||||
# input_data = self.input_data.flatten() | |||||
# assert data.size == input_data.size | |||||
# assert io.name.decode("utf-8") == "data" | |||||
# for i in range(data.size): | |||||
# assert data[i] == input_data[i] | |||||
# return 0 | |||||
# | |||||
# network.set_start_callback(start_callback) | |||||
# self.do_forward(network, 1) | |||||
# assert start_checked == True | |||||
# | |||||
# def test_network_finish_callback(self): | |||||
# network = LiteNetwork() | |||||
# network.load(self.model_path) | |||||
# finish_checked = False | |||||
# | |||||
# @start_finish_callback | |||||
# def finish_callback(ios): | |||||
# nonlocal finish_checked | |||||
# finish_checked = True | |||||
# assert len(ios) == 1 | |||||
# for key in ios: | |||||
# io = key | |||||
# data = ios[key].to_numpy().flatten() | |||||
# output_data = self.correct_data.flatten() | |||||
# assert data.size == output_data.size | |||||
# for i in range(data.size): | |||||
# assert data[i] == output_data[i] | |||||
# return 0 | |||||
# | |||||
# network.set_finish_callback(finish_callback) | |||||
# self.do_forward(network, 1) | |||||
# assert finish_checked == True | |||||
def test_network_async(self): | |||||
count = 0 | |||||
finished = False | |||||
def async_callback(): | |||||
nonlocal finished | |||||
finished = True | |||||
return 0 | |||||
option = LiteOptions() | |||||
option.var_sanity_check_first_run = 0 | |||||
config = LiteConfig(option=option) | |||||
network = LiteNetwork(config=config) | |||||
network.load(self.model_path) | |||||
network.async_with_callback(async_callback) | |||||
input_tensor = network.get_io_tensor(network.get_input_name(0)) | |||||
output_tensor = network.get_io_tensor(network.get_output_name(0)) | |||||
input_tensor.set_data_by_share(self.input_data) | |||||
network.forward() | |||||
while not finished: | |||||
count += 1 | |||||
assert count > 0 | |||||
output_data = output_tensor.to_numpy() | |||||
self.check_correct(output_data) | |||||
def test_network_start_callback(self): | |||||
network = LiteNetwork() | |||||
network.load(self.model_path) | |||||
start_checked = False | |||||
def start_callback(ios): | |||||
nonlocal start_checked | |||||
start_checked = True | |||||
assert len(ios) == 1 | |||||
for key in ios: | |||||
io = key | |||||
data = ios[key].to_numpy().flatten() | |||||
input_data = self.input_data.flatten() | |||||
assert data.size == input_data.size | |||||
assert io.name.decode("utf-8") == "data" | |||||
for i in range(data.size): | |||||
assert data[i] == input_data[i] | |||||
return 0 | |||||
network.set_start_callback(start_callback) | |||||
self.do_forward(network, 1) | |||||
assert start_checked == True | |||||
def test_network_finish_callback(self): | |||||
network = LiteNetwork() | |||||
network.load(self.model_path) | |||||
finish_checked = False | |||||
def finish_callback(ios): | |||||
nonlocal finish_checked | |||||
finish_checked = True | |||||
assert len(ios) == 1 | |||||
for key in ios: | |||||
io = key | |||||
data = ios[key].to_numpy().flatten() | |||||
output_data = self.correct_data.flatten() | |||||
assert data.size == output_data.size | |||||
for i in range(data.size): | |||||
assert data[i] == output_data[i] | |||||
return 0 | |||||
network.set_finish_callback(finish_callback) | |||||
self.do_forward(network, 1) | |||||
assert finish_checked == True | |||||
def test_enable_profile(self): | def test_enable_profile(self): | ||||
network = LiteNetwork() | network = LiteNetwork() | ||||
@@ -186,6 +186,57 @@ class TestNetwork(TestShuffleNetCuda): | |||||
self.do_forward(src_network) | self.do_forward(src_network) | ||||
self.do_forward(new_network) | self.do_forward(new_network) | ||||
@require_cuda | |||||
def test_network_start_callback(self): | |||||
config = LiteConfig() | |||||
config.device = LiteDeviceType.LITE_CUDA | |||||
network = LiteNetwork(config) | |||||
network.load(self.model_path) | |||||
start_checked = False | |||||
def start_callback(ios): | |||||
nonlocal start_checked | |||||
start_checked = True | |||||
assert len(ios) == 1 | |||||
for key in ios: | |||||
io = key | |||||
data = ios[key].to_numpy().flatten() | |||||
input_data = self.input_data.flatten() | |||||
assert data.size == input_data.size | |||||
assert io.name.decode("utf-8") == "data" | |||||
for i in range(data.size): | |||||
assert data[i] == input_data[i] | |||||
return 0 | |||||
network.set_start_callback(start_callback) | |||||
self.do_forward(network, 1) | |||||
assert start_checked == True | |||||
@require_cuda | |||||
def test_network_finish_callback(self): | |||||
config = LiteConfig() | |||||
config.device = LiteDeviceType.LITE_CUDA | |||||
network = LiteNetwork(config) | |||||
network.load(self.model_path) | |||||
finish_checked = False | |||||
def finish_callback(ios): | |||||
nonlocal finish_checked | |||||
finish_checked = True | |||||
assert len(ios) == 1 | |||||
for key in ios: | |||||
io = key | |||||
data = ios[key].to_numpy().flatten() | |||||
output_data = self.correct_data.flatten() | |||||
assert data.size == output_data.size | |||||
for i in range(data.size): | |||||
assert data[i] == output_data[i] | |||||
return 0 | |||||
network.set_finish_callback(finish_callback) | |||||
self.do_forward(network, 1) | |||||
assert finish_checked == True | |||||
@require_cuda() | @require_cuda() | ||||
def test_enable_profile(self): | def test_enable_profile(self): | ||||
config = LiteConfig() | config = LiteConfig() | ||||