GitOrigin-RevId: 7ac74a596a
tags/v1.9.0
@@ -275,14 +275,16 @@ class TestNetwork(TestShuffleNetCuda): | |||||
@require_cuda() | @require_cuda() | ||||
def test_enable_global_layout_transform(self): | def test_enable_global_layout_transform(self): | ||||
network = LiteNetwork() | |||||
config_ = LiteConfig(device_type=LiteDeviceType.LITE_CUDA) | |||||
network = LiteNetwork(config=config_) | |||||
network.enable_global_layout_transform() | network.enable_global_layout_transform() | ||||
network.load(self.model_path) | network.load(self.model_path) | ||||
self.do_forward(network) | self.do_forward(network) | ||||
@require_cuda() | @require_cuda() | ||||
def test_dump_layout_transform_model(self): | def test_dump_layout_transform_model(self): | ||||
network = LiteNetwork() | |||||
config_ = LiteConfig(device_type=LiteDeviceType.LITE_CUDA) | |||||
network = LiteNetwork(config=config_) | |||||
network.enable_global_layout_transform() | network.enable_global_layout_transform() | ||||
network.load(self.model_path) | network.load(self.model_path) | ||||
network.dump_layout_transform_model("./model_afer_layoutTrans.mgb") | network.dump_layout_transform_model("./model_afer_layoutTrans.mgb") | ||||
@@ -365,8 +365,22 @@ void NetworkImplDft::adapt_option_valid() { | |||||
void NetworkImplDft::global_layout_transform() { | void NetworkImplDft::global_layout_transform() { | ||||
if (m_set_layout_transform) { | if (m_set_layout_transform) { | ||||
m_load_result.output_var_list = mgb::gopt::layout_transform( | |||||
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map; | |||||
auto output_var_array = mgb::gopt::layout_transform( | |||||
m_load_result.output_var_list, m_layout_transform_target); | m_load_result.output_var_list, m_layout_transform_target); | ||||
// replace symvar in output_var_list | |||||
for (size_t idx = 0; idx < output_var_array.size(); ++idx) { | |||||
out_var_map[m_load_result.output_var_list[idx]] = output_var_array[idx]; | |||||
m_load_result.output_var_list[idx] = output_var_array[idx]; | |||||
} | |||||
// replace symvar in output_var_map_id | |||||
for (auto&& item : m_load_result.output_var_map_id) { | |||||
item.second = out_var_map[item.second]; | |||||
} | |||||
// replace symvar in output_var_map | |||||
for (auto&& item : m_load_result.output_var_map) { | |||||
item.second = out_var_map[item.second]; | |||||
} | |||||
} | } | ||||
} | } | ||||