Browse Source

feat(lite): add global layout transform python interface for lite

GitOrigin-RevId: f159f49208
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
f7e034b506
6 changed files with 48 additions and 4 deletions
  1. +1
    -2
      lite/include/lite/network.h
  2. +9
    -0
      lite/pylite/megenginelite/network.py
  3. +17
    -0
      lite/pylite/test/test_network.py
  4. +19
    -0
      lite/pylite/test/test_network_cuda.py
  5. +1
    -1
      lite/src/mge/network_impl.cpp
  6. +1
    -1
      lite/test/test_network.cpp

+ 1
- 2
lite/include/lite/network.h View File

@@ -97,7 +97,7 @@ struct LITE_API Options {
bool no_profiling_on_shape_change = false; bool no_profiling_on_shape_change = false;
uint8_t jit_level = 0; uint8_t jit_level = 0;
uint8_t comp_node_seq_record_level = 0; uint8_t comp_node_seq_record_level = 0;
uint8_t graph_opt_level = 0;
uint8_t graph_opt_level = 2;
uint16_t async_exec_level = 1; uint16_t async_exec_level = 1;


//! layout transform options //! layout transform options
@@ -368,7 +368,6 @@ public:
const std::shared_ptr<Network> src_network); const std::shared_ptr<Network> src_network);


//! set global layout transform optimization for network //! set global layout transform optimization for network

static void enable_global_layout_transform(std::shared_ptr<Network> network); static void enable_global_layout_transform(std::shared_ptr<Network> network);


//! dump network after global layout transform optimization //! dump network after global layout transform optimization


+ 9
- 0
lite/pylite/megenginelite/network.py View File

@@ -362,6 +362,8 @@ class _NetworkAPI(_LiteCObjBase):
("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]), ("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]),
("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]), ("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]),
("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]),
("LITE_enable_global_layout_transform", [_Cnetwork]),
("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]),
] ]




@@ -610,3 +612,10 @@ class LiteNetwork(object):
def get_static_memory_alloc_info(self, log_dir="logs/test"): def get_static_memory_alloc_info(self, log_dir="logs/test"):
c_log_dir = log_dir.encode("utf-8") c_log_dir = log_dir.encode("utf-8")
self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir) self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir)

def enable_global_layout_transform(self):
self._api.LITE_enable_global_layout_transform(self._network)

def dump_layout_transform_model(self, model_file):
c_file = model_file.encode("utf-8")
self._api.LITE_dump_layout_transform_model(self._network, c_file)

+ 17
- 0
lite/pylite/test/test_network.py View File

@@ -451,3 +451,20 @@ class TestNetwork(TestShuffleNet):
network.wait() network.wait()


self.check_correct(out_array) self.check_correct(out_array)

def test_enable_global_layout_transform(self):
network = LiteNetwork()
network.enable_global_layout_transform()
network.load(self.model_path)
self.do_forward(network)

def test_dump_layout_transform_model(self):
network = LiteNetwork()
network.enable_global_layout_transform()
network.load(self.model_path)
network.dump_layout_transform_model("./model_afer_layoutTrans.mgb")
self.do_forward(network)

fi = open("./model_afer_layoutTrans.mgb", "r")
fi.close()
os.remove("./model_afer_layoutTrans.mgb")

+ 19
- 0
lite/pylite/test/test_network_cuda.py View File

@@ -272,3 +272,22 @@ class TestNetwork(TestShuffleNetCuda):
| LiteAlgoSelectStrategy.LITE_ALGO_REPRODUCIBLE | LiteAlgoSelectStrategy.LITE_ALGO_REPRODUCIBLE
) )
self.do_forward(network) self.do_forward(network)

@require_cuda()
def test_enable_global_layout_transform(self):
network = LiteNetwork()
network.enable_global_layout_transform()
network.load(self.model_path)
self.do_forward(network)

@require_cuda()
def test_dump_layout_transform_model(self):
network = LiteNetwork()
network.enable_global_layout_transform()
network.load(self.model_path)
network.dump_layout_transform_model("./model_afer_layoutTrans.mgb")
self.do_forward(network)

fi = open("./model_afer_layoutTrans.mgb", "r")
fi.close()
os.remove("./model_afer_layoutTrans.mgb")

+ 1
- 1
lite/src/mge/network_impl.cpp View File

@@ -406,7 +406,7 @@ void NetworkImplDft::load_model(
use_tensorrt(); use_tensorrt();
} }


m_load_result = m_loader->load(m_load_config, false);
m_load_result = m_loader->load(m_load_config, true);


global_layout_transform(); global_layout_transform();




+ 1
- 1
lite/test/test_network.cpp View File

@@ -910,7 +910,6 @@ TEST(TestNetWork, LoadPackedModel) {
} }


TEST(TestNetWork, GlabalLayoutTransform) { TEST(TestNetWork, GlabalLayoutTransform) {
// set_log_level(LiteLogLevel::DEBUG);
auto tensor = get_input_data("./input_data.npy"); auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge"; std::string model_path = "./shufflenet.mge";
std::string input_name = "data"; std::string input_name = "data";
@@ -931,6 +930,7 @@ TEST(TestNetWork, GlabalLayoutTransform) {
network->forward(); network->forward();
network->wait(); network->wait();
ASSERT_TRUE(fopen(dump_model_name.c_str(), "r")); ASSERT_TRUE(fopen(dump_model_name.c_str(), "r"));
remove(dump_model_name.c_str());
} }


TEST(TestNetWork, GetDeviceType) { TEST(TestNetWork, GetDeviceType) {


Loading…
Cancel
Save