From ec234135a67b1cb1a6e2613331751050cb3b3979 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 26 Aug 2022 19:43:50 +0800 Subject: [PATCH] feat(lite): support discrete inputs GitOrigin-RevId: 25ce8da275d17d50d771986d36744e7866a0094f --- dnn/src/cuda/warp_perspective/forward.cpp | 8 +- lite/include/lite/network.h | 2 +- lite/lite-c/include/lite-c/network_c.h | 2 +- lite/lite-c/src/network.cpp | 4 +- lite/pylite/megenginelite/network.py | 6 +- lite/pylite/test/test_network.py | 47 +++++++-- lite/src/mge/network_impl.cpp | 169 +++++++++++++++--------------- lite/src/mge/network_impl.h | 2 +- lite/src/network.cpp | 4 +- lite/src/network_impl_base.h | 5 +- lite/test/test_network.cpp | 59 ++++++++++- lite/test/test_network_c.cpp | 2 +- 12 files changed, 197 insertions(+), 113 deletions(-) diff --git a/dnn/src/cuda/warp_perspective/forward.cpp b/dnn/src/cuda/warp_perspective/forward.cpp index ed345d52..03af1745 100644 --- a/dnn/src/cuda/warp_perspective/forward.cpp +++ b/dnn/src/cuda/warp_perspective/forward.cpp @@ -554,9 +554,7 @@ void WarpPerspectiveForwardImpl::exec( cuda_check(cudaMemcpyAsync( bundle.get(i), workspace_cpu.get(0), workspace_cpu.get_size(0), cudaMemcpyHostToDevice, stream)); - cuda_check(cudaStreamAddCallback( - stream, callback_free, static_cast(workspace_cpu_raw), - 0)); + free(workspace_cpu_raw); warp_perspective::forward_proxy_multi_src( is_nhwc, srcs_gpu, mat.ptr(), mat_idx.raw_ptr() ? mat_idx.ptr() : nullptr, @@ -579,9 +577,7 @@ void WarpPerspectiveForwardImpl::exec( cuda_check(cudaMemcpyAsync( bundle.get(0), workspace_cpu.get(0), workspace_cpu.get_size(0), cudaMemcpyHostToDevice, stream)); - cuda_check(cudaStreamAddCallback( - stream, callback_free, static_cast(workspace_cpu_raw), - 0)); + free(workspace_cpu_raw); warp_perspective::forward_proxy_multi_src( is_nhwc, srcs_gpu, mat.ptr(), mat_idx.raw_ptr() ? mat_idx.ptr() : nullptr, diff --git a/lite/include/lite/network.h b/lite/include/lite/network.h index 2b7e5abf..90d4433c 100644 --- a/lite/include/lite/network.h +++ b/lite/include/lite/network.h @@ -299,7 +299,7 @@ public: * @param io_name the name of the tensor * @param phase indicate the tensor is input tensor */ - std::vector> get_io_tensors( + std::vector> get_discrete_tensors( std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT); //! get the network input tensor by index diff --git a/lite/lite-c/include/lite-c/network_c.h b/lite/lite-c/include/lite-c/network_c.h index 8b7316a7..c4e6c20b 100644 --- a/lite/lite-c/include/lite-c/network_c.h +++ b/lite/lite-c/include/lite-c/network_c.h @@ -311,7 +311,7 @@ LITE_API int LITE_get_io_tensor( * \param[in] phase The tensor phase * \param[out] tensor The IO tensor get from the network */ -LITE_API int LITE_get_io_tensors( +LITE_API int LITE_get_discrete_tensor( LiteNetwork network, const char* io_name, size_t n_idx, LiteTensorPhase phase, LiteTensor* tensor); diff --git a/lite/lite-c/src/network.cpp b/lite/lite-c/src/network.cpp index 0c8ec8a4..e073a3f3 100644 --- a/lite/lite-c/src/network.cpp +++ b/lite/lite-c/src/network.cpp @@ -278,13 +278,13 @@ int LITE_get_io_tensor( LITE_CAPI_END(); } -int LITE_get_io_tensors( +int LITE_get_discrete_tensor( LiteNetwork network, const char* io_name, size_t n_idx, LiteTensorPhase phase, LiteTensor* tensor) { LITE_CAPI_BEGIN(); LITE_ASSERT(network, "The network pass to LITE api is null"); auto io_tensors = - static_cast(network)->get_io_tensors(io_name, phase); + static_cast(network)->get_discrete_tensors(io_name, phase); LITE_ASSERT( n_idx < io_tensors.size(), "n_idx should be less than %zu", io_tensors.size()); diff --git a/lite/pylite/megenginelite/network.py b/lite/pylite/megenginelite/network.py index 390e7627..dcd87965 100644 --- a/lite/pylite/megenginelite/network.py +++ b/lite/pylite/megenginelite/network.py @@ -542,7 +542,7 @@ class _NetworkAPI(_LiteCObjBase): ), ("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]), ( - "LITE_get_io_tensors", + "LITE_get_discrete_tensor", [_Cnetwork, c_char_p, c_size_t, c_int, POINTER(_Ctensor)], ), ] @@ -745,7 +745,7 @@ class LiteNetwork(object): tensor.update() return tensor - def get_io_tensors(self, name, n_idx, phase=LiteTensorPhase.LITE_INPUT): + def get_discrete_tensor(self, name, n_idx, phase=LiteTensorPhase.LITE_INPUT): """ get the n_idx'th tensor in the network input tensors whose input consists of discrete multiple tensors and tensor name is name @@ -763,7 +763,7 @@ class LiteNetwork(object): else: c_name = c_char_p(name) tensor = LiteTensor(physic_construct=False) - self._api.LITE_get_io_tensors( + self._api.LITE_get_discrete_tensor( self._network, c_name, n_idx, phase, byref(tensor._tensor) ) tensor.update() diff --git a/lite/pylite/test/test_network.py b/lite/pylite/test/test_network.py index 8e046992..1847ff87 100644 --- a/lite/pylite/test/test_network.py +++ b/lite/pylite/test/test_network.py @@ -504,28 +504,59 @@ class TestNetwork(TestShuffleNet): class TestDiscreteInputNet(unittest.TestCase): source_dir = os.getenv("LITE_TEST_RESOURCE") + data_path = os.path.join(source_dir, "data_b3.npy") data0_path = os.path.join(source_dir, "data0.npy") data1_path = os.path.join(source_dir, "data1.npy") data2_path = os.path.join(source_dir, "data2.npy") + roi_path = os.path.join(source_dir, "roi.npy") model_path = os.path.join(source_dir, "test_discrete_input.mge") + data = np.load(data_path) data0 = np.load(data0_path) data1 = np.load(data1_path) data2 = np.load(data2_path) + roi = np.load(roi_path) - def do_forward(self, network, times=3): + def check_correct(self, out_data, error=1e-4): + out_data = out_data.flatten() + + config = LiteConfig() + net = LiteNetwork(config) + net.load(self.model_path) + input_tensor = net.get_io_tensor("data") + input_tensor.set_data_by_share(self.data) + roi_tensor = net.get_io_tensor("roi") + roi_tensor.set_data_by_share(self.roi) + output_name = net.get_output_name(0) + output_tensor = net.get_io_tensor(output_name) + net.forward() + net.wait() + + correct_data = output_tensor.to_numpy().flatten() + assert correct_data.size == out_data.size + for i in range(out_data.size): + assert abs(out_data[i] - correct_data[i]) < error + + def do_forward(self, network, times=1): data_name = network.get_input_name(1) datas = [] - datas.append(network.get_io_tensors(data_name, 0)) - datas.append(network.get_io_tensors(data_name, 1)) - datas.append(network.get_io_tensors(data_name, 2)) - - datas[0].set_data_by_copy(self.data0) - datas[1].set_data_by_copy(self.data1) - datas[2].set_data_by_copy(self.data2) + datas.append(network.get_discrete_tensor(data_name, 0)) + datas.append(network.get_discrete_tensor(data_name, 1)) + datas.append(network.get_discrete_tensor(data_name, 2)) + + datas[0].set_data_by_share(self.data0) + datas[1].set_data_by_share(self.data1) + datas[2].set_data_by_share(self.data2) + roi_tensor = network.get_io_tensor("roi") + roi_tensor.set_data_by_share(self.roi) + out_name = network.get_output_name(0) + out_tensor = network.get_io_tensor(out_name) for i in range(times): network.forward() network.wait() + out_data = out_tensor.to_numpy() + self.check_correct(out_data) + class TestDiscreteInput(TestDiscreteInputNet): def test_discrete_input(self): diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index a37c9d45..2d6840f9 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -268,57 +268,69 @@ void NetworkImplDft::replace_src_discrete_input_opr_pass() { gopt::SubGraph graph{dest_with_extra_deps}; auto rewriter = graph.make_rewriter(); - auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) { - if (opr->same_type()) { - bool is_h2d = true; - if (opr->input(0)->owner_opr()->same_type()) - is_h2d = true; - else if (opr->input(0) - ->owner_opr() - ->same_type()) - is_h2d = false; - else - return; - - SymbolVarArray srcs; - if (is_h2d) { - auto h2d = opr->input(0)->owner_opr(); - for (auto&& inp : get_io_tensors(m_user_config->discrete_input_name)) { - auto val = TensorHelper::implement(inp) - ->cast_final_safe() - .m_host_tensor; - LITE_ASSERT(val); - srcs.push_back(mgb::opr::Host2DeviceCopy::make( - *m_load_result.graph, val, h2d->config())); + auto on_opr = [&](cg::OperatorNodeBase* opr) { + bool replace_output = false; + for (auto inp : opr->input()) { + if ((inp->owner_opr()->same_type() || + inp->owner_opr()->same_type()) && + inp->name() == m_user_config->discrete_input_name) { + bool is_h2d = true; + if (inp->owner_opr()->same_type()) { + is_h2d = true; + } else { + is_h2d = false; } - } else { - auto volatiled = opr->input(0)->owner_opr(); - for (auto&& inp : get_io_tensors(m_user_config->discrete_input_name)) { - auto val = TensorHelper::implement(inp) - ->cast_final_safe() - .m_dev_tensor; - LITE_ASSERT(val); - srcs.push_back(mgb::opr::VolatileSharedDeviceTensor::make( - *m_load_result.graph, val, volatiled->config())); + + SymbolVarArray srcs; + if (is_h2d) { + auto h2d = inp->owner_opr(); + for (auto&& i : + get_discrete_tensors(m_user_config->discrete_input_name)) { + auto val = TensorHelper::implement(i) + ->cast_final_safe() + .m_host_tensor; + LITE_ASSERT(val); + srcs.push_back(mgb::opr::Host2DeviceCopy::make( + *m_load_result.graph, val, h2d->config())); + } + } else { + auto volatiled = inp->owner_opr(); + for (auto&& i : + get_discrete_tensors(m_user_config->discrete_input_name)) { + auto val = TensorHelper::implement(i) + ->cast_final_safe() + .m_dev_tensor; + LITE_ASSERT(val); + srcs.push_back(mgb::opr::VolatileSharedDeviceTensor::make( + *m_load_result.graph, val, volatiled->config())); + } } - } - auto& warp = opr->cast_final(); - SymbolVar new_out; - if (opr->input().size() == 3) { - new_out = mgb::opr::WarpPerspective::make( - srcs, warp.input(1), warp.input(2), warp.param(), - warp.config()); - } else { - LITE_ASSERT(opr->input().size() == 4); - new_out = mgb::opr::WarpPerspective::make( - srcs, warp.input(1), warp.input(2), warp.input(3), warp.param(), - warp.config()); + if (opr->same_type()) { + auto& warp = opr->cast_final(); + SymbolVar new_out; + if (opr->input().size() == 3) { + new_out = mgb::opr::WarpPerspective::make( + srcs, warp.input(1), warp.input(2), warp.param(), + warp.config()); + } else { + LITE_ASSERT(opr->input().size() == 4); + new_out = mgb::opr::WarpPerspective::make( + srcs, warp.input(1), warp.input(2), warp.input(3), + warp.param(), warp.config()); + } + rewriter.replace_var( + warp.output(0), new_out.node(), + "replace WarpPerspective to WarpPerspective multi src " + "version."); + replace_output = true; + } else { + auto concat = mgb::opr::Concat::make(srcs, 0); + rewriter.replace_var(inp, concat.node(), "add a concat opr."); + } } - rewriter.replace_var( - warp.output(0), new_out.node(), - "replace WarpPerspective to WarpPerspective multi src version."); - } else { + } + if (!replace_output) { rewriter.auto_replace_outputs(opr); } }; @@ -385,6 +397,10 @@ void NetworkImplDft::replace_dev_input_pass() { inp_var_map[host_val2var.at(host_val.get())] = dev_var; name2dev_tensor[config_in.name] = dev_val; } + //! reset lite_tensor in discrete mode + if (config_in.name == m_user_config->discrete_input_name) { + config_in.lite_tensor.reset(); + } } auto new_ovar = mgb::cg::replace_vars(m_load_result.output_var_list, inp_var_map); for (size_t i = 0; i < new_ovar.size(); ++i) { @@ -611,8 +627,9 @@ void NetworkImplDft::configure_after_loaded() { void NetworkImplDft::compile_graph() { replace_dev_input_pass(); - if (!m_user_config->discrete_input_name.empty()) + if (!m_user_config->discrete_input_name.empty()) { replace_src_discrete_input_opr_pass(); + } make_output_spec(); m_execute_func = m_load_result.graph_compile(m_output_spec); } @@ -792,6 +809,7 @@ void NetworkImplDft::update_input() { } } +//! initialization lite_tensors when input is composed of discrete multiple tensors void NetworkImplDft::update_input_lite_tensors() { auto device_type = m_user_config->device_type; auto device_id = m_compnode_locator.device; @@ -801,24 +819,22 @@ void NetworkImplDft::update_input_lite_tensors() { if (in_tensor_iter.first != m_user_config->discrete_input_name) { continue; } - bool found = false; for (auto&& config_in : m_network_io->inputs) { if (in_tensor_iter.first == config_in.name) { - found = true; size_t bs = in_tensor_iter.second->shape(0); auto shape = in_tensor_iter.second->shape(); - shape.shape[0] = 1; if (config_in.config_layout.ndim) { bs = config_in.config_layout.shapes[0]; - shape.shape[1] = config_in.config_layout.shapes[1]; - shape.shape[2] = config_in.config_layout.shapes[2]; - shape.shape[3] = config_in.config_layout.shapes[3]; + for (size_t i = 0; i < config_in.config_layout.ndim; ++i) { + shape.shape[i] = config_in.config_layout.shapes[i]; + } } - HostTensorND tensor( - in_tensor_iter.second->comp_node(), shape, - in_tensor_iter.second->dtype(), - in_tensor_iter.second->format()); + shape.shape[0] = 1; for (size_t i = 0; i < bs; ++i) { + HostTensorND tensor( + in_tensor_iter.second->comp_node(), shape, + in_tensor_iter.second->dtype(), + in_tensor_iter.second->format()); if (config_in.is_host) { config_in.lite_tensors.push_back(std::make_shared( device_id, stream_id, device_type, true)); @@ -839,29 +855,6 @@ void NetworkImplDft::update_input_lite_tensors() { } } } - if (!found) { - size_t bs = in_tensor_iter.second->shape(0); - auto shape = in_tensor_iter.second->shape(); - shape.shape[0] = 1; - HostTensorND tensor( - in_tensor_iter.second->comp_node(), shape, - in_tensor_iter.second->dtype(), in_tensor_iter.second->format()); - IOInner io_in; - io_in.name = in_tensor_iter.first; - for (size_t i = 0; i < bs; ++i) { - io_in.lite_tensors.push_back(std::make_shared( - device_id, stream_id, device_type, true)); - TensorHelper::implement(io_in.lite_tensors[i]) - ->cast_final_safe() - .m_host_tensor = std::make_shared(tensor); - TensorHelper::implement(io_in.lite_tensors[i]) - ->cast_final_safe() - .m_record_reset = - m_user_config->options.comp_node_seq_record_level > 0; - io_in.lite_tensors[i]->update_from_implement(); - } - m_network_io->inputs.push_back(io_in); - } } } @@ -997,7 +990,15 @@ std::shared_ptr NetworkImplDft::get_io_tensor( if (phase == LiteTensorPhase::LITE_INPUT || phase == LiteTensorPhase::LITE_IO) { for (auto&& config_in : m_network_io->inputs) { if (io_name == config_in.name) { - return config_in.lite_tensor; + if (config_in.lite_tensor) { + return config_in.lite_tensor; + } else { + LITE_THROW(mgb::ssprintf( + "%s input tensor is in discrete mode, you can use " + "get_discrete_tensors to get this input.", + io_name.c_str())); + return nullptr; + } } } } @@ -1018,7 +1019,7 @@ std::shared_ptr NetworkImplDft::get_io_tensor( return nullptr; } -std::vector> NetworkImplDft::get_io_tensors( +std::vector> NetworkImplDft::get_discrete_tensors( std::string io_name, LiteTensorPhase phase) { if (phase == LiteTensorPhase::LITE_INPUT) { for (auto&& config_in : m_network_io->inputs) { @@ -1038,7 +1039,7 @@ std::shared_ptr NetworkImplDft::get_input_tensor(size_t index) { } std::vector> NetworkImplDft::get_input_tensors(size_t index) { - return get_io_tensors(get_input_name(index)); + return get_discrete_tensors(get_input_name(index)); } std::shared_ptr NetworkImplDft::get_output_tensor(size_t index) { diff --git a/lite/src/mge/network_impl.h b/lite/src/mge/network_impl.h index bef304bc..e2add83b 100644 --- a/lite/src/mge/network_impl.h +++ b/lite/src/mge/network_impl.h @@ -59,7 +59,7 @@ public: //! get the network input tensors which input consists of discrete multiple tensors, //! layout (1, c, h, w) - std::vector> get_io_tensors( + std::vector> get_discrete_tensors( std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT) override; diff --git a/lite/src/network.cpp b/lite/src/network.cpp index fad1343a..675e0a0a 100644 --- a/lite/src/network.cpp +++ b/lite/src/network.cpp @@ -127,12 +127,12 @@ std::shared_ptr Network::get_io_tensor( LITE_ERROR_HANDLER_END } -std::vector> Network::get_io_tensors( +std::vector> Network::get_discrete_tensors( std::string name, LiteTensorPhase phase) { LITE_ERROR_HANDLER_BEGIN LITE_ASSERT(m_loaded, "get_io_tensor should be used after model loaded."); LITE_CHECK_NON_NULL_POINTER(m_impl); - return m_impl->get_io_tensors(name, phase); + return m_impl->get_discrete_tensors(name, phase); LITE_ERROR_HANDLER_END } diff --git a/lite/src/network_impl_base.h b/lite/src/network_impl_base.h index dd1d3c75..49117df4 100644 --- a/lite/src/network_impl_base.h +++ b/lite/src/network_impl_base.h @@ -91,8 +91,10 @@ public: //! get the network input tensors which input consists of discrete multiple tensors, //! layout (1, c, h, w) - virtual std::vector> get_io_tensors( + virtual std::vector> get_discrete_tensors( std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT) { + LITE_MARK_USED_VAR(io_name); + LITE_MARK_USED_VAR(phase); return {}; } @@ -102,6 +104,7 @@ public: //! get the network input tensors which input consists of discrete multiple tensors //! by index virtual std::vector> get_input_tensors(size_t index) { + LITE_MARK_USED_VAR(index); return {}; } diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index d252b3c4..0134424b 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -1393,6 +1393,7 @@ TEST(TestNetWork, Discrete_Input) { auto data_0 = get_input_data("./data0.npy"); auto data_1 = get_input_data("./data1.npy"); auto data_2 = get_input_data("./data2.npy"); + auto roi = get_input_data("./roi.npy"); std::string model_path = "./test_discrete_input.mge"; Config config; @@ -1403,6 +1404,8 @@ TEST(TestNetWork, Discrete_Input) { std::shared_ptr data_tensor = network0->get_io_tensor("data"); data_tensor->share_memory_with(*data); + std::shared_ptr roi_tensor = network0->get_io_tensor("roi"); + roi_tensor->share_memory_with(*roi); network0->forward(); network0->wait(); @@ -1417,8 +1420,11 @@ TEST(TestNetWork, Discrete_Input) { std::shared_ptr network1 = std::make_shared(config, ios); network1->load_model(model_path); + std::shared_ptr roi_tensor1 = network1->get_io_tensor("roi"); + roi_tensor1->copy_from(*roi); + std::vector> data_tensors = - network1->get_io_tensors("data"); + network1->get_discrete_tensors("data"); data_tensors[0]->share_memory_with(*data_0); data_tensors[1]->share_memory_with(*data_1); data_tensors[2]->share_memory_with(*data_2); @@ -1435,6 +1441,7 @@ TEST(TestNetWork, Discrete_Input_Device) { auto data_0 = get_input_data("./data0.npy"); auto data_1 = get_input_data("./data1.npy"); auto data_2 = get_input_data("./data2.npy"); + auto roi = get_input_data("./roi.npy"); std::string model_path = "./test_discrete_input.mge"; Config config; @@ -1444,7 +1451,9 @@ TEST(TestNetWork, Discrete_Input_Device) { network0->load_model(model_path); std::shared_ptr data_tensor = network0->get_io_tensor("data"); - data_tensor->share_memory_with(*data); + data_tensor->copy_from(*data); + std::shared_ptr roi_tensor = network0->get_io_tensor("roi"); + roi_tensor->copy_from(*roi); network0->forward(); network0->wait(); @@ -1459,8 +1468,10 @@ TEST(TestNetWork, Discrete_Input_Device) { std::shared_ptr network1 = std::make_shared(config, ios); network1->load_model(model_path); + std::shared_ptr roi_tensor1 = network1->get_io_tensor("roi"); + roi_tensor1->copy_from(*roi); std::vector> data_tensors = - network1->get_io_tensors("data"); + network1->get_discrete_tensors("data"); auto d0_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly); auto d1_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly); auto d2_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly); @@ -1477,6 +1488,48 @@ TEST(TestNetWork, Discrete_Input_Device) { compare_lite_tensor(output_tensor0, output_tensor1); } + +TEST(TestNetWork, Discrete_Input_Concat) { + auto data = get_input_data("./data_b3.npy"); + auto data_0 = get_input_data("./data0.npy"); + auto data_1 = get_input_data("./data1.npy"); + auto data_2 = get_input_data("./data2.npy"); + std::string model_path = "./test_discrete_input_concat.mge"; + + Config config; + config.device_type = LiteDeviceType::LITE_CUDA; + + std::shared_ptr network0 = std::make_shared(config); + network0->load_model(model_path); + + std::shared_ptr data_tensor = network0->get_io_tensor("data"); + data_tensor->copy_from(*data); + + network0->forward(); + network0->wait(); + std::shared_ptr output_tensor0 = network0->get_output_tensor(0); + + config.discrete_input_name = "data"; + NetworkIO ios; + bool is_host = true; + Layout d_ly{{3, 3, 224, 224}, 4, LiteDataType::LITE_FLOAT}; + ios.inputs.push_back({"data", is_host, LiteIOType::LITE_IO_VALUE, d_ly}); + + std::shared_ptr network1 = std::make_shared(config, ios); + network1->load_model(model_path); + + std::vector> data_tensors = + network1->get_discrete_tensors("data"); + data_tensors[0]->copy_from(*data_0); + data_tensors[1]->copy_from(*data_1); + data_tensors[2]->copy_from(*data_2); + + network1->forward(); + network1->wait(); + std::shared_ptr output_tensor1 = network1->get_output_tensor(0); + + compare_lite_tensor(output_tensor0, output_tensor1); +} #endif #if MGB_ATLAS || MGB_CAMBRICON diff --git a/lite/test/test_network_c.cpp b/lite/test/test_network_c.cpp index 53fe3ee3..7503b2d8 100644 --- a/lite/test/test_network_c.cpp +++ b/lite/test/test_network_c.cpp @@ -322,7 +322,7 @@ TEST(TestCapiNetWork, Discrete_Input) { std::vector c_data_tensors(3, nullptr); for (size_t i = 0; i < 3; i++) { - LITE_CAPI_CHECK(LITE_get_io_tensors( + LITE_CAPI_CHECK(LITE_get_discrete_tensor( c_network, "data", i, LITE_INPUT, &c_data_tensors[i])); LITE_CAPI_CHECK(LITE_reset_tensor_memory( c_data_tensors[i], datas[i]->get_memory_ptr(), data_length_in_byte));