Browse Source

fix(pylite): fix pylite callback test bug

GitOrigin-RevId: f4bd153950
release-1.7
Megvii Engine Team 3 years ago
parent
commit
b20cda6bb8
3 changed files with 149 additions and 84 deletions
  1. +23
    -7
      lite/pylite/megenginelite/network.py
  2. +75
    -77
      lite/pylite/test/test_network.py
  3. +51
    -0
      lite/pylite/test/test_network_cuda.py

+ 23
- 7
lite/pylite/megenginelite/network.py View File

@@ -235,15 +235,29 @@ class LiteNetworkIO(object):


LiteAsyncCallback = CFUNCTYPE(c_int)
LiteStartCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
LiteFinishCallback = CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)


def wrap_async_callback(func):
global wrapper

@CFUNCTYPE(c_int)
def wrapper():
return func()

return wrapper


def start_finish_callback(func):
global wrapper

@CFUNCTYPE(c_int, POINTER(LiteIO), POINTER(_Ctensor), c_size_t)
def wrapper(c_ios, c_tensors, size):
ios = {}
for i in range(size):
tensor = LiteTensor()
tensor._tensor = c_tensors[i]
tensor._tensor = c_void_p(c_tensors[i])
tensor.update()
io = c_ios[i]
ios[io] = tensor
@@ -288,8 +302,8 @@ class _NetworkAPI(_LiteCObjBase):
("LITE_enable_io_txt_dump", [_Cnetwork, c_char_p]),
("LITE_enable_io_bin_dump", [_Cnetwork, c_char_p]),
("LITE_set_async_callback", [_Cnetwork, LiteAsyncCallback]),
("LITE_set_start_callback", [_Cnetwork]),
("LITE_set_finish_callback", [_Cnetwork]),
("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]),
("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]),
("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]),
]

@@ -482,8 +496,8 @@ class LiteNetwork(object):
self._api.LITE_share_runtime_memroy(self._network, src_network._network)

def async_with_callback(self, async_callback):
async_callback = LiteAsyncCallback(async_callback)
self._api.LITE_set_async_callback(self._network, async_callback)
callback = wrap_async_callback(async_callback)
self._api.LITE_set_async_callback(self._network, callback)

def set_start_callback(self, start_callback):
"""
@@ -491,7 +505,8 @@ class LiteNetwork(object):
the start_callback with param mapping from LiteIO to the corresponding
LiteTensor
"""
self._api.LITE_set_start_callback(self._network, start_callback)
callback = start_finish_callback(start_callback)
self._api.LITE_set_start_callback(self._network, callback)

def set_finish_callback(self, finish_callback):
"""
@@ -499,7 +514,8 @@ class LiteNetwork(object):
the finish_callback with param mapping from LiteIO to the corresponding
LiteTensor
"""
self._api.LITE_set_finish_callback(self._network, finish_callback)
callback = start_finish_callback(finish_callback)
self._api.LITE_set_finish_callback(self._network, callback)

def enable_profile_performance(self, profile_file):
c_file = profile_file.encode("utf-8")


+ 75
- 77
lite/pylite/test/test_network.py View File

@@ -274,83 +274,81 @@ class TestNetwork(TestShuffleNet):
self.do_forward(src_network)
self.do_forward(new_network)

# def test_network_async(self):
# count = 0
# finished = False
#
# def async_callback():
# nonlocal finished
# finished = True
# return 0
#
# option = LiteOptions()
# option.var_sanity_check_first_run = 0
# config = LiteConfig(option=option)
#
# network = LiteNetwork(config=config)
# network.load(self.model_path)
#
# network.async_with_callback(async_callback)
#
# input_tensor = network.get_io_tensor(network.get_input_name(0))
# output_tensor = network.get_io_tensor(network.get_output_name(0))
#
# input_tensor.set_data_by_share(self.input_data)
# network.forward()
#
# while not finished:
# count += 1
#
# assert count > 0
# output_data = output_tensor.to_numpy()
# self.check_correct(output_data)
#
# def test_network_start_callback(self):
# network = LiteNetwork()
# network.load(self.model_path)
# start_checked = False
#
# @start_finish_callback
# def start_callback(ios):
# nonlocal start_checked
# start_checked = True
# assert len(ios) == 1
# for key in ios:
# io = key
# data = ios[key].to_numpy().flatten()
# input_data = self.input_data.flatten()
# assert data.size == input_data.size
# assert io.name.decode("utf-8") == "data"
# for i in range(data.size):
# assert data[i] == input_data[i]
# return 0
#
# network.set_start_callback(start_callback)
# self.do_forward(network, 1)
# assert start_checked == True
#
# def test_network_finish_callback(self):
# network = LiteNetwork()
# network.load(self.model_path)
# finish_checked = False
#
# @start_finish_callback
# def finish_callback(ios):
# nonlocal finish_checked
# finish_checked = True
# assert len(ios) == 1
# for key in ios:
# io = key
# data = ios[key].to_numpy().flatten()
# output_data = self.correct_data.flatten()
# assert data.size == output_data.size
# for i in range(data.size):
# assert data[i] == output_data[i]
# return 0
#
# network.set_finish_callback(finish_callback)
# self.do_forward(network, 1)
# assert finish_checked == True
def test_network_async(self):
count = 0
finished = False

def async_callback():
nonlocal finished
finished = True
return 0

option = LiteOptions()
option.var_sanity_check_first_run = 0
config = LiteConfig(option=option)

network = LiteNetwork(config=config)
network.load(self.model_path)

network.async_with_callback(async_callback)

input_tensor = network.get_io_tensor(network.get_input_name(0))
output_tensor = network.get_io_tensor(network.get_output_name(0))

input_tensor.set_data_by_share(self.input_data)
network.forward()

while not finished:
count += 1

assert count > 0
output_data = output_tensor.to_numpy()
self.check_correct(output_data)

def test_network_start_callback(self):
network = LiteNetwork()
network.load(self.model_path)
start_checked = False

def start_callback(ios):
nonlocal start_checked
start_checked = True
assert len(ios) == 1
for key in ios:
io = key
data = ios[key].to_numpy().flatten()
input_data = self.input_data.flatten()
assert data.size == input_data.size
assert io.name.decode("utf-8") == "data"
for i in range(data.size):
assert data[i] == input_data[i]
return 0

network.set_start_callback(start_callback)
self.do_forward(network, 1)
assert start_checked == True

def test_network_finish_callback(self):
network = LiteNetwork()
network.load(self.model_path)
finish_checked = False

def finish_callback(ios):
nonlocal finish_checked
finish_checked = True
assert len(ios) == 1
for key in ios:
io = key
data = ios[key].to_numpy().flatten()
output_data = self.correct_data.flatten()
assert data.size == output_data.size
for i in range(data.size):
assert data[i] == output_data[i]
return 0

network.set_finish_callback(finish_callback)
self.do_forward(network, 1)
assert finish_checked == True

def test_enable_profile(self):
network = LiteNetwork()


+ 51
- 0
lite/pylite/test/test_network_cuda.py View File

@@ -186,6 +186,57 @@ class TestNetwork(TestShuffleNetCuda):
self.do_forward(src_network)
self.do_forward(new_network)

@require_cuda
def test_network_start_callback(self):
config = LiteConfig()
config.device = LiteDeviceType.LITE_CUDA
network = LiteNetwork(config)
network.load(self.model_path)
start_checked = False

def start_callback(ios):
nonlocal start_checked
start_checked = True
assert len(ios) == 1
for key in ios:
io = key
data = ios[key].to_numpy().flatten()
input_data = self.input_data.flatten()
assert data.size == input_data.size
assert io.name.decode("utf-8") == "data"
for i in range(data.size):
assert data[i] == input_data[i]
return 0

network.set_start_callback(start_callback)
self.do_forward(network, 1)
assert start_checked == True

@require_cuda
def test_network_finish_callback(self):
config = LiteConfig()
config.device = LiteDeviceType.LITE_CUDA
network = LiteNetwork(config)
network.load(self.model_path)
finish_checked = False

def finish_callback(ios):
nonlocal finish_checked
finish_checked = True
assert len(ios) == 1
for key in ios:
io = key
data = ios[key].to_numpy().flatten()
output_data = self.correct_data.flatten()
assert data.size == output_data.size
for i in range(data.size):
assert data[i] == output_data[i]
return 0

network.set_finish_callback(finish_callback)
self.do_forward(network, 1)
assert finish_checked == True

@require_cuda()
def test_enable_profile(self):
config = LiteConfig()


Loading…
Cancel
Save