@@ -235,15 +235,29 @@ class LiteNetworkIO(object): | |||
LiteAsyncCallback = CFUNCTYPE(c_int) | |||
LiteStartCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t) | |||
LiteFinishCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t) | |||
def wrap_async_callback(func): | |||
global wrapper | |||
@CFUNCTYPE(c_int) | |||
def wrapper(): | |||
return func() | |||
return wrapper | |||
def start_finish_callback(func): | |||
global wrapper | |||
@CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t) | |||
def wrapper(c_ios, c_tensors, size): | |||
ios = {} | |||
for i in range(size): | |||
tensor = LiteTensor() | |||
tensor._tensor = c_tensors[i] | |||
tensor._tensor = c_void_p(c_tensors[i]) | |||
tensor.update() | |||
io = c_ios[i] | |||
ios[io] = tensor | |||
@@ -288,8 +302,8 @@ class _NetworkAPI(_LiteCObjBase): | |||
("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]), | |||
("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]), | |||
("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]), | |||
("LITE_set_start_callback", [_Cnetwork]), | |||
("LITE_set_finish_callback", [_Cnetwork]), | |||
("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]), | |||
("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]), | |||
("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), | |||
] | |||
@@ -482,8 +496,8 @@ class LiteNetwork(object): | |||
self._api.LITE_share_runtime_memroy(self._network, src_network._network) | |||
def async_with_callback(self, async_callback): | |||
async_callback = LiteAsyncCallback(async_callback) | |||
self._api.LITE_set_async_callback(self._network, async_callback) | |||
callback = wrap_async_callback(async_callback) | |||
self._api.LITE_set_async_callback(self._network, callback) | |||
def set_start_callback(self, start_callback): | |||
""" | |||
@@ -491,7 +505,8 @@ class LiteNetwork(object): | |||
the start_callback with param mapping from LiteIO to the corresponding | |||
LiteTensor | |||
""" | |||
self._api.LITE_set_start_callback(self._network, start_callback) | |||
callback = start_finish_callback(start_callback) | |||
self._api.LITE_set_start_callback(self._network, callback) | |||
def set_finish_callback(self, finish_callback): | |||
""" | |||
@@ -499,7 +514,8 @@ class LiteNetwork(object): | |||
the finish_callback with param mapping from LiteIO to the corresponding | |||
LiteTensor | |||
""" | |||
self._api.LITE_set_finish_callback(self._network, finish_callback) | |||
callback = start_finish_callback(finish_callback) | |||
self._api.LITE_set_finish_callback(self._network, callback) | |||
def enable_profile_performance(self, profile_file): | |||
c_file = profile_file.encode("utf-8") | |||
@@ -274,83 +274,81 @@ class TestNetwork(TestShuffleNet): | |||
self.do_forward(src_network) | |||
self.do_forward(new_network) | |||
# def test_network_async(self): | |||
# count = 0 | |||
# finished = False | |||
# | |||
# def async_callback(): | |||
# nonlocal finished | |||
# finished = True | |||
# return 0 | |||
# | |||
# option = LiteOptions() | |||
# option.var_sanity_check_first_run = 0 | |||
# config = LiteConfig(option=option) | |||
# | |||
# network = LiteNetwork(config=config) | |||
# network.load(self.model_path) | |||
# | |||
# network.async_with_callback(async_callback) | |||
# | |||
# input_tensor = network.get_io_tensor(network.get_input_name(0)) | |||
# output_tensor = network.get_io_tensor(network.get_output_name(0)) | |||
# | |||
# input_tensor.set_data_by_share(self.input_data) | |||
# network.forward() | |||
# | |||
# while not finished: | |||
# count += 1 | |||
# | |||
# assert count > 0 | |||
# output_data = output_tensor.to_numpy() | |||
# self.check_correct(output_data) | |||
# | |||
# def test_network_start_callback(self): | |||
# network = LiteNetwork() | |||
# network.load(self.model_path) | |||
# start_checked = False | |||
# | |||
# @start_finish_callback | |||
# def start_callback(ios): | |||
# nonlocal start_checked | |||
# start_checked = True | |||
# assert len(ios) == 1 | |||
# for key in ios: | |||
# io = key | |||
# data = ios[key].to_numpy().flatten() | |||
# input_data = self.input_data.flatten() | |||
# assert data.size == input_data.size | |||
# assert io.name.decode("utf-8") == "data" | |||
# for i in range(data.size): | |||
# assert data[i] == input_data[i] | |||
# return 0 | |||
# | |||
# network.set_start_callback(start_callback) | |||
# self.do_forward(network, 1) | |||
# assert start_checked == True | |||
# | |||
# def test_network_finish_callback(self): | |||
# network = LiteNetwork() | |||
# network.load(self.model_path) | |||
# finish_checked = False | |||
# | |||
# @start_finish_callback | |||
# def finish_callback(ios): | |||
# nonlocal finish_checked | |||
# finish_checked = True | |||
# assert len(ios) == 1 | |||
# for key in ios: | |||
# io = key | |||
# data = ios[key].to_numpy().flatten() | |||
# output_data = self.correct_data.flatten() | |||
# assert data.size == output_data.size | |||
# for i in range(data.size): | |||
# assert data[i] == output_data[i] | |||
# return 0 | |||
# | |||
# network.set_finish_callback(finish_callback) | |||
# self.do_forward(network, 1) | |||
# assert finish_checked == True | |||
def test_network_async(self): | |||
count = 0 | |||
finished = False | |||
def async_callback(): | |||
nonlocal finished | |||
finished = True | |||
return 0 | |||
option = LiteOptions() | |||
option.var_sanity_check_first_run = 0 | |||
config = LiteConfig(option=option) | |||
network = LiteNetwork(config=config) | |||
network.load(self.model_path) | |||
network.async_with_callback(async_callback) | |||
input_tensor = network.get_io_tensor(network.get_input_name(0)) | |||
output_tensor = network.get_io_tensor(network.get_output_name(0)) | |||
input_tensor.set_data_by_share(self.input_data) | |||
network.forward() | |||
while not finished: | |||
count += 1 | |||
assert count > 0 | |||
output_data = output_tensor.to_numpy() | |||
self.check_correct(output_data) | |||
def test_network_start_callback(self): | |||
network = LiteNetwork() | |||
network.load(self.model_path) | |||
start_checked = False | |||
def start_callback(ios): | |||
nonlocal start_checked | |||
start_checked = True | |||
assert len(ios) == 1 | |||
for key in ios: | |||
io = key | |||
data = ios[key].to_numpy().flatten() | |||
input_data = self.input_data.flatten() | |||
assert data.size == input_data.size | |||
assert io.name.decode("utf-8") == "data" | |||
for i in range(data.size): | |||
assert data[i] == input_data[i] | |||
return 0 | |||
network.set_start_callback(start_callback) | |||
self.do_forward(network, 1) | |||
assert start_checked == True | |||
def test_network_finish_callback(self): | |||
network = LiteNetwork() | |||
network.load(self.model_path) | |||
finish_checked = False | |||
def finish_callback(ios): | |||
nonlocal finish_checked | |||
finish_checked = True | |||
assert len(ios) == 1 | |||
for key in ios: | |||
io = key | |||
data = ios[key].to_numpy().flatten() | |||
output_data = self.correct_data.flatten() | |||
assert data.size == output_data.size | |||
for i in range(data.size): | |||
assert data[i] == output_data[i] | |||
return 0 | |||
network.set_finish_callback(finish_callback) | |||
self.do_forward(network, 1) | |||
assert finish_checked == True | |||
def test_enable_profile(self): | |||
network = LiteNetwork() | |||
@@ -186,6 +186,57 @@ class TestNetwork(TestShuffleNetCuda): | |||
self.do_forward(src_network) | |||
self.do_forward(new_network) | |||
@require_cuda | |||
def test_network_start_callback(self): | |||
config = LiteConfig() | |||
config.device = LiteDeviceType.LITE_CUDA | |||
network = LiteNetwork(config) | |||
network.load(self.model_path) | |||
start_checked = False | |||
def start_callback(ios): | |||
nonlocal start_checked | |||
start_checked = True | |||
assert len(ios) == 1 | |||
for key in ios: | |||
io = key | |||
data = ios[key].to_numpy().flatten() | |||
input_data = self.input_data.flatten() | |||
assert data.size == input_data.size | |||
assert io.name.decode("utf-8") == "data" | |||
for i in range(data.size): | |||
assert data[i] == input_data[i] | |||
return 0 | |||
network.set_start_callback(start_callback) | |||
self.do_forward(network, 1) | |||
assert start_checked == True | |||
@require_cuda | |||
def test_network_finish_callback(self): | |||
config = LiteConfig() | |||
config.device = LiteDeviceType.LITE_CUDA | |||
network = LiteNetwork(config) | |||
network.load(self.model_path) | |||
finish_checked = False | |||
def finish_callback(ios): | |||
nonlocal finish_checked | |||
finish_checked = True | |||
assert len(ios) == 1 | |||
for key in ios: | |||
io = key | |||
data = ios[key].to_numpy().flatten() | |||
output_data = self.correct_data.flatten() | |||
assert data.size == output_data.size | |||
for i in range(data.size): | |||
assert data[i] == output_data[i] | |||
return 0 | |||
network.set_finish_callback(finish_callback) | |||
self.do_forward(network, 1) | |||
assert finish_checked == True | |||
@require_cuda() | |||
def test_enable_profile(self): | |||
config = LiteConfig() | |||