GitOrigin-RevId: f159f49208
tags/v1.8.0
@@ -97,7 +97,7 @@ struct LITE_API Options { | |||||
bool no_profiling_on_shape_change = false; | bool no_profiling_on_shape_change = false; | ||||
uint8_t jit_level = 0; | uint8_t jit_level = 0; | ||||
uint8_t comp_node_seq_record_level = 0; | uint8_t comp_node_seq_record_level = 0; | ||||
uint8_t graph_opt_level = 0; | |||||
uint8_t graph_opt_level = 2; | |||||
uint16_t async_exec_level = 1; | uint16_t async_exec_level = 1; | ||||
//! layout transform options | //! layout transform options | ||||
@@ -368,7 +368,6 @@ public: | |||||
const std::shared_ptr<Network> src_network); | const std::shared_ptr<Network> src_network); | ||||
//! set global layout transform optimization for network | //! set global layout transform optimization for network | ||||
static void enable_global_layout_transform(std::shared_ptr<Network> network); | static void enable_global_layout_transform(std::shared_ptr<Network> network); | ||||
//! dump network after global layout transform optimization | //! dump network after global layout transform optimization | ||||
@@ -362,6 +362,8 @@ class _NetworkAPI(_LiteCObjBase): | |||||
("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]), | ("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]), | ||||
("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]), | ("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]), | ||||
("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), | ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), | ||||
("LITE_enable_global_layout_transform", [_Cnetwork]), | |||||
("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]), | |||||
] | ] | ||||
@@ -610,3 +612,10 @@ class LiteNetwork(object): | |||||
def get_static_memory_alloc_info(self, log_dir="logs/test"): | def get_static_memory_alloc_info(self, log_dir="logs/test"): | ||||
c_log_dir = log_dir.encode("utf-8") | c_log_dir = log_dir.encode("utf-8") | ||||
self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir) | self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir) | ||||
def enable_global_layout_transform(self): | |||||
self._api.LITE_enable_global_layout_transform(self._network) | |||||
def dump_layout_transform_model(self, model_file): | |||||
c_file = model_file.encode("utf-8") | |||||
self._api.LITE_dump_layout_transform_model(self._network, c_file) |
@@ -451,3 +451,20 @@ class TestNetwork(TestShuffleNet): | |||||
network.wait() | network.wait() | ||||
self.check_correct(out_array) | self.check_correct(out_array) | ||||
def test_enable_global_layout_transform(self): | |||||
network = LiteNetwork() | |||||
network.enable_global_layout_transform() | |||||
network.load(self.model_path) | |||||
self.do_forward(network) | |||||
def test_dump_layout_transform_model(self): | |||||
network = LiteNetwork() | |||||
network.enable_global_layout_transform() | |||||
network.load(self.model_path) | |||||
network.dump_layout_transform_model("./model_afer_layoutTrans.mgb") | |||||
self.do_forward(network) | |||||
fi = open("./model_afer_layoutTrans.mgb", "r") | |||||
fi.close() | |||||
os.remove("./model_afer_layoutTrans.mgb") |
@@ -272,3 +272,22 @@ class TestNetwork(TestShuffleNetCuda): | |||||
| LiteAlgoSelectStrategy.LITE_ALGO_REPRODUCIBLE | | LiteAlgoSelectStrategy.LITE_ALGO_REPRODUCIBLE | ||||
) | ) | ||||
self.do_forward(network) | self.do_forward(network) | ||||
@require_cuda() | |||||
def test_enable_global_layout_transform(self): | |||||
network = LiteNetwork() | |||||
network.enable_global_layout_transform() | |||||
network.load(self.model_path) | |||||
self.do_forward(network) | |||||
@require_cuda() | |||||
def test_dump_layout_transform_model(self): | |||||
network = LiteNetwork() | |||||
network.enable_global_layout_transform() | |||||
network.load(self.model_path) | |||||
network.dump_layout_transform_model("./model_afer_layoutTrans.mgb") | |||||
self.do_forward(network) | |||||
fi = open("./model_afer_layoutTrans.mgb", "r") | |||||
fi.close() | |||||
os.remove("./model_afer_layoutTrans.mgb") |
@@ -406,7 +406,7 @@ void NetworkImplDft::load_model( | |||||
use_tensorrt(); | use_tensorrt(); | ||||
} | } | ||||
m_load_result = m_loader->load(m_load_config, false); | |||||
m_load_result = m_loader->load(m_load_config, true); | |||||
global_layout_transform(); | global_layout_transform(); | ||||
@@ -910,7 +910,6 @@ TEST(TestNetWork, LoadPackedModel) { | |||||
} | } | ||||
TEST(TestNetWork, GlabalLayoutTransform) { | TEST(TestNetWork, GlabalLayoutTransform) { | ||||
// set_log_level(LiteLogLevel::DEBUG); | |||||
auto tensor = get_input_data("./input_data.npy"); | auto tensor = get_input_data("./input_data.npy"); | ||||
std::string model_path = "./shufflenet.mge"; | std::string model_path = "./shufflenet.mge"; | ||||
std::string input_name = "data"; | std::string input_name = "data"; | ||||
@@ -931,6 +930,7 @@ TEST(TestNetWork, GlabalLayoutTransform) { | |||||
network->forward(); | network->forward(); | ||||
network->wait(); | network->wait(); | ||||
ASSERT_TRUE(fopen(dump_model_name.c_str(), "r")); | ASSERT_TRUE(fopen(dump_model_name.c_str(), "r")); | ||||
remove(dump_model_name.c_str()); | |||||
} | } | ||||
TEST(TestNetWork, GetDeviceType) { | TEST(TestNetWork, GetDeviceType) { | ||||