@@ -554,9 +554,7 @@ void WarpPerspectiveForwardImpl::exec( | |||||
cuda_check(cudaMemcpyAsync( | cuda_check(cudaMemcpyAsync( | ||||
bundle.get(i), workspace_cpu.get(0), workspace_cpu.get_size(0), | bundle.get(i), workspace_cpu.get(0), workspace_cpu.get_size(0), | ||||
cudaMemcpyHostToDevice, stream)); | cudaMemcpyHostToDevice, stream)); | ||||
cuda_check(cudaStreamAddCallback( | |||||
stream, callback_free, static_cast<void*>(workspace_cpu_raw), | |||||
0)); | |||||
free(workspace_cpu_raw); | |||||
warp_perspective::forward_proxy_multi_src( | warp_perspective::forward_proxy_multi_src( | ||||
is_nhwc, srcs_gpu, mat.ptr<dt_float32>(), | is_nhwc, srcs_gpu, mat.ptr<dt_float32>(), | ||||
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr, | mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr, | ||||
@@ -579,9 +577,7 @@ void WarpPerspectiveForwardImpl::exec( | |||||
cuda_check(cudaMemcpyAsync( | cuda_check(cudaMemcpyAsync( | ||||
bundle.get(0), workspace_cpu.get(0), workspace_cpu.get_size(0), | bundle.get(0), workspace_cpu.get(0), workspace_cpu.get_size(0), | ||||
cudaMemcpyHostToDevice, stream)); | cudaMemcpyHostToDevice, stream)); | ||||
cuda_check(cudaStreamAddCallback( | |||||
stream, callback_free, static_cast<void*>(workspace_cpu_raw), | |||||
0)); | |||||
free(workspace_cpu_raw); | |||||
warp_perspective::forward_proxy_multi_src( | warp_perspective::forward_proxy_multi_src( | ||||
is_nhwc, srcs_gpu, mat.ptr<dt_float32>(), | is_nhwc, srcs_gpu, mat.ptr<dt_float32>(), | ||||
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr, | mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr, | ||||
@@ -299,7 +299,7 @@ public: | |||||
* @param io_name the name of the tensor | * @param io_name the name of the tensor | ||||
* @param phase indicate the tensor is input tensor | * @param phase indicate the tensor is input tensor | ||||
*/ | */ | ||||
std::vector<std::shared_ptr<Tensor>> get_io_tensors( | |||||
std::vector<std::shared_ptr<Tensor>> get_discrete_tensors( | |||||
std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT); | std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT); | ||||
//! get the network input tensor by index | //! get the network input tensor by index | ||||
@@ -311,7 +311,7 @@ LITE_API int LITE_get_io_tensor( | |||||
* \param[in] phase The tensor phase | * \param[in] phase The tensor phase | ||||
* \param[out] tensor The IO tensor get from the network | * \param[out] tensor The IO tensor get from the network | ||||
*/ | */ | ||||
LITE_API int LITE_get_io_tensors( | |||||
LITE_API int LITE_get_discrete_tensor( | |||||
LiteNetwork network, const char* io_name, size_t n_idx, LiteTensorPhase phase, | LiteNetwork network, const char* io_name, size_t n_idx, LiteTensorPhase phase, | ||||
LiteTensor* tensor); | LiteTensor* tensor); | ||||
@@ -278,13 +278,13 @@ int LITE_get_io_tensor( | |||||
LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
} | } | ||||
int LITE_get_io_tensors( | |||||
int LITE_get_discrete_tensor( | |||||
LiteNetwork network, const char* io_name, size_t n_idx, LiteTensorPhase phase, | LiteNetwork network, const char* io_name, size_t n_idx, LiteTensorPhase phase, | ||||
LiteTensor* tensor) { | LiteTensor* tensor) { | ||||
LITE_CAPI_BEGIN(); | LITE_CAPI_BEGIN(); | ||||
LITE_ASSERT(network, "The network pass to LITE api is null"); | LITE_ASSERT(network, "The network pass to LITE api is null"); | ||||
auto io_tensors = | auto io_tensors = | ||||
static_cast<lite::Network*>(network)->get_io_tensors(io_name, phase); | |||||
static_cast<lite::Network*>(network)->get_discrete_tensors(io_name, phase); | |||||
LITE_ASSERT( | LITE_ASSERT( | ||||
n_idx < io_tensors.size(), "n_idx should be less than %zu", | n_idx < io_tensors.size(), "n_idx should be less than %zu", | ||||
io_tensors.size()); | io_tensors.size()); | ||||
@@ -542,7 +542,7 @@ class _NetworkAPI(_LiteCObjBase): | |||||
), | ), | ||||
("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]), | ("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]), | ||||
( | ( | ||||
"LITE_get_io_tensors", | |||||
"LITE_get_discrete_tensor", | |||||
[_Cnetwork, c_char_p, c_size_t, c_int, POINTER(_Ctensor)], | [_Cnetwork, c_char_p, c_size_t, c_int, POINTER(_Ctensor)], | ||||
), | ), | ||||
] | ] | ||||
@@ -745,7 +745,7 @@ class LiteNetwork(object): | |||||
tensor.update() | tensor.update() | ||||
return tensor | return tensor | ||||
def get_io_tensors(self, name, n_idx, phase=LiteTensorPhase.LITE_INPUT): | |||||
def get_discrete_tensor(self, name, n_idx, phase=LiteTensorPhase.LITE_INPUT): | |||||
""" | """ | ||||
get the n_idx'th tensor in the network input tensors whose | get the n_idx'th tensor in the network input tensors whose | ||||
input consists of discrete multiple tensors and tensor name is name | input consists of discrete multiple tensors and tensor name is name | ||||
@@ -763,7 +763,7 @@ class LiteNetwork(object): | |||||
else: | else: | ||||
c_name = c_char_p(name) | c_name = c_char_p(name) | ||||
tensor = LiteTensor(physic_construct=False) | tensor = LiteTensor(physic_construct=False) | ||||
self._api.LITE_get_io_tensors( | |||||
self._api.LITE_get_discrete_tensor( | |||||
self._network, c_name, n_idx, phase, byref(tensor._tensor) | self._network, c_name, n_idx, phase, byref(tensor._tensor) | ||||
) | ) | ||||
tensor.update() | tensor.update() | ||||
@@ -504,28 +504,59 @@ class TestNetwork(TestShuffleNet): | |||||
class TestDiscreteInputNet(unittest.TestCase): | class TestDiscreteInputNet(unittest.TestCase): | ||||
source_dir = os.getenv("LITE_TEST_RESOURCE") | source_dir = os.getenv("LITE_TEST_RESOURCE") | ||||
data_path = os.path.join(source_dir, "data_b3.npy") | |||||
data0_path = os.path.join(source_dir, "data0.npy") | data0_path = os.path.join(source_dir, "data0.npy") | ||||
data1_path = os.path.join(source_dir, "data1.npy") | data1_path = os.path.join(source_dir, "data1.npy") | ||||
data2_path = os.path.join(source_dir, "data2.npy") | data2_path = os.path.join(source_dir, "data2.npy") | ||||
roi_path = os.path.join(source_dir, "roi.npy") | |||||
model_path = os.path.join(source_dir, "test_discrete_input.mge") | model_path = os.path.join(source_dir, "test_discrete_input.mge") | ||||
data = np.load(data_path) | |||||
data0 = np.load(data0_path) | data0 = np.load(data0_path) | ||||
data1 = np.load(data1_path) | data1 = np.load(data1_path) | ||||
data2 = np.load(data2_path) | data2 = np.load(data2_path) | ||||
roi = np.load(roi_path) | |||||
def do_forward(self, network, times=3): | |||||
def check_correct(self, out_data, error=1e-4): | |||||
out_data = out_data.flatten() | |||||
config = LiteConfig() | |||||
net = LiteNetwork(config) | |||||
net.load(self.model_path) | |||||
input_tensor = net.get_io_tensor("data") | |||||
input_tensor.set_data_by_share(self.data) | |||||
roi_tensor = net.get_io_tensor("roi") | |||||
roi_tensor.set_data_by_share(self.roi) | |||||
output_name = net.get_output_name(0) | |||||
output_tensor = net.get_io_tensor(output_name) | |||||
net.forward() | |||||
net.wait() | |||||
correct_data = output_tensor.to_numpy().flatten() | |||||
assert correct_data.size == out_data.size | |||||
for i in range(out_data.size): | |||||
assert abs(out_data[i] - correct_data[i]) < error | |||||
def do_forward(self, network, times=1): | |||||
data_name = network.get_input_name(1) | data_name = network.get_input_name(1) | ||||
datas = [] | datas = [] | ||||
datas.append(network.get_io_tensors(data_name, 0)) | |||||
datas.append(network.get_io_tensors(data_name, 1)) | |||||
datas.append(network.get_io_tensors(data_name, 2)) | |||||
datas[0].set_data_by_copy(self.data0) | |||||
datas[1].set_data_by_copy(self.data1) | |||||
datas[2].set_data_by_copy(self.data2) | |||||
datas.append(network.get_discrete_tensor(data_name, 0)) | |||||
datas.append(network.get_discrete_tensor(data_name, 1)) | |||||
datas.append(network.get_discrete_tensor(data_name, 2)) | |||||
datas[0].set_data_by_share(self.data0) | |||||
datas[1].set_data_by_share(self.data1) | |||||
datas[2].set_data_by_share(self.data2) | |||||
roi_tensor = network.get_io_tensor("roi") | |||||
roi_tensor.set_data_by_share(self.roi) | |||||
out_name = network.get_output_name(0) | |||||
out_tensor = network.get_io_tensor(out_name) | |||||
for i in range(times): | for i in range(times): | ||||
network.forward() | network.forward() | ||||
network.wait() | network.wait() | ||||
out_data = out_tensor.to_numpy() | |||||
self.check_correct(out_data) | |||||
class TestDiscreteInput(TestDiscreteInputNet): | class TestDiscreteInput(TestDiscreteInputNet): | ||||
def test_discrete_input(self): | def test_discrete_input(self): | ||||
@@ -268,57 +268,69 @@ void NetworkImplDft::replace_src_discrete_input_opr_pass() { | |||||
gopt::SubGraph graph{dest_with_extra_deps}; | gopt::SubGraph graph{dest_with_extra_deps}; | ||||
auto rewriter = graph.make_rewriter(); | auto rewriter = graph.make_rewriter(); | ||||
auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) { | |||||
if (opr->same_type<mgb::opr::WarpPerspective>()) { | |||||
bool is_h2d = true; | |||||
if (opr->input(0)->owner_opr()->same_type<mgb::opr::Host2DeviceCopy>()) | |||||
is_h2d = true; | |||||
else if (opr->input(0) | |||||
->owner_opr() | |||||
->same_type<mgb::opr::VolatileSharedDeviceTensor>()) | |||||
is_h2d = false; | |||||
else | |||||
return; | |||||
SymbolVarArray srcs; | |||||
if (is_h2d) { | |||||
auto h2d = opr->input(0)->owner_opr(); | |||||
for (auto&& inp : get_io_tensors(m_user_config->discrete_input_name)) { | |||||
auto val = TensorHelper::implement(inp) | |||||
->cast_final_safe<TensorImplDft>() | |||||
.m_host_tensor; | |||||
LITE_ASSERT(val); | |||||
srcs.push_back(mgb::opr::Host2DeviceCopy::make( | |||||
*m_load_result.graph, val, h2d->config())); | |||||
auto on_opr = [&](cg::OperatorNodeBase* opr) { | |||||
bool replace_output = false; | |||||
for (auto inp : opr->input()) { | |||||
if ((inp->owner_opr()->same_type<mgb::opr::Host2DeviceCopy>() || | |||||
inp->owner_opr()->same_type<mgb::opr::VolatileSharedDeviceTensor>()) && | |||||
inp->name() == m_user_config->discrete_input_name) { | |||||
bool is_h2d = true; | |||||
if (inp->owner_opr()->same_type<mgb::opr::Host2DeviceCopy>()) { | |||||
is_h2d = true; | |||||
} else { | |||||
is_h2d = false; | |||||
} | } | ||||
} else { | |||||
auto volatiled = opr->input(0)->owner_opr(); | |||||
for (auto&& inp : get_io_tensors(m_user_config->discrete_input_name)) { | |||||
auto val = TensorHelper::implement(inp) | |||||
->cast_final_safe<TensorImplDft>() | |||||
.m_dev_tensor; | |||||
LITE_ASSERT(val); | |||||
srcs.push_back(mgb::opr::VolatileSharedDeviceTensor::make( | |||||
*m_load_result.graph, val, volatiled->config())); | |||||
SymbolVarArray srcs; | |||||
if (is_h2d) { | |||||
auto h2d = inp->owner_opr(); | |||||
for (auto&& i : | |||||
get_discrete_tensors(m_user_config->discrete_input_name)) { | |||||
auto val = TensorHelper::implement(i) | |||||
->cast_final_safe<TensorImplDft>() | |||||
.m_host_tensor; | |||||
LITE_ASSERT(val); | |||||
srcs.push_back(mgb::opr::Host2DeviceCopy::make( | |||||
*m_load_result.graph, val, h2d->config())); | |||||
} | |||||
} else { | |||||
auto volatiled = inp->owner_opr(); | |||||
for (auto&& i : | |||||
get_discrete_tensors(m_user_config->discrete_input_name)) { | |||||
auto val = TensorHelper::implement(i) | |||||
->cast_final_safe<TensorImplDft>() | |||||
.m_dev_tensor; | |||||
LITE_ASSERT(val); | |||||
srcs.push_back(mgb::opr::VolatileSharedDeviceTensor::make( | |||||
*m_load_result.graph, val, volatiled->config())); | |||||
} | |||||
} | } | ||||
} | |||||
auto& warp = opr->cast_final<mgb::opr::WarpPerspective>(); | |||||
SymbolVar new_out; | |||||
if (opr->input().size() == 3) { | |||||
new_out = mgb::opr::WarpPerspective::make( | |||||
srcs, warp.input(1), warp.input(2), warp.param(), | |||||
warp.config()); | |||||
} else { | |||||
LITE_ASSERT(opr->input().size() == 4); | |||||
new_out = mgb::opr::WarpPerspective::make( | |||||
srcs, warp.input(1), warp.input(2), warp.input(3), warp.param(), | |||||
warp.config()); | |||||
if (opr->same_type<mgb::opr::WarpPerspective>()) { | |||||
auto& warp = opr->cast_final<mgb::opr::WarpPerspective>(); | |||||
SymbolVar new_out; | |||||
if (opr->input().size() == 3) { | |||||
new_out = mgb::opr::WarpPerspective::make( | |||||
srcs, warp.input(1), warp.input(2), warp.param(), | |||||
warp.config()); | |||||
} else { | |||||
LITE_ASSERT(opr->input().size() == 4); | |||||
new_out = mgb::opr::WarpPerspective::make( | |||||
srcs, warp.input(1), warp.input(2), warp.input(3), | |||||
warp.param(), warp.config()); | |||||
} | |||||
rewriter.replace_var( | |||||
warp.output(0), new_out.node(), | |||||
"replace WarpPerspective to WarpPerspective multi src " | |||||
"version."); | |||||
replace_output = true; | |||||
} else { | |||||
auto concat = mgb::opr::Concat::make(srcs, 0); | |||||
rewriter.replace_var(inp, concat.node(), "add a concat opr."); | |||||
} | |||||
} | } | ||||
rewriter.replace_var( | |||||
warp.output(0), new_out.node(), | |||||
"replace WarpPerspective to WarpPerspective multi src version."); | |||||
} else { | |||||
} | |||||
if (!replace_output) { | |||||
rewriter.auto_replace_outputs(opr); | rewriter.auto_replace_outputs(opr); | ||||
} | } | ||||
}; | }; | ||||
@@ -385,6 +397,10 @@ void NetworkImplDft::replace_dev_input_pass() { | |||||
inp_var_map[host_val2var.at(host_val.get())] = dev_var; | inp_var_map[host_val2var.at(host_val.get())] = dev_var; | ||||
name2dev_tensor[config_in.name] = dev_val; | name2dev_tensor[config_in.name] = dev_val; | ||||
} | } | ||||
//! reset lite_tensor in discrete mode | |||||
if (config_in.name == m_user_config->discrete_input_name) { | |||||
config_in.lite_tensor.reset(); | |||||
} | |||||
} | } | ||||
auto new_ovar = mgb::cg::replace_vars(m_load_result.output_var_list, inp_var_map); | auto new_ovar = mgb::cg::replace_vars(m_load_result.output_var_list, inp_var_map); | ||||
for (size_t i = 0; i < new_ovar.size(); ++i) { | for (size_t i = 0; i < new_ovar.size(); ++i) { | ||||
@@ -611,8 +627,9 @@ void NetworkImplDft::configure_after_loaded() { | |||||
void NetworkImplDft::compile_graph() { | void NetworkImplDft::compile_graph() { | ||||
replace_dev_input_pass(); | replace_dev_input_pass(); | ||||
if (!m_user_config->discrete_input_name.empty()) | |||||
if (!m_user_config->discrete_input_name.empty()) { | |||||
replace_src_discrete_input_opr_pass(); | replace_src_discrete_input_opr_pass(); | ||||
} | |||||
make_output_spec(); | make_output_spec(); | ||||
m_execute_func = m_load_result.graph_compile(m_output_spec); | m_execute_func = m_load_result.graph_compile(m_output_spec); | ||||
} | } | ||||
@@ -792,6 +809,7 @@ void NetworkImplDft::update_input() { | |||||
} | } | ||||
} | } | ||||
//! initialization lite_tensors when input is composed of discrete multiple tensors | |||||
void NetworkImplDft::update_input_lite_tensors() { | void NetworkImplDft::update_input_lite_tensors() { | ||||
auto device_type = m_user_config->device_type; | auto device_type = m_user_config->device_type; | ||||
auto device_id = m_compnode_locator.device; | auto device_id = m_compnode_locator.device; | ||||
@@ -801,24 +819,22 @@ void NetworkImplDft::update_input_lite_tensors() { | |||||
if (in_tensor_iter.first != m_user_config->discrete_input_name) { | if (in_tensor_iter.first != m_user_config->discrete_input_name) { | ||||
continue; | continue; | ||||
} | } | ||||
bool found = false; | |||||
for (auto&& config_in : m_network_io->inputs) { | for (auto&& config_in : m_network_io->inputs) { | ||||
if (in_tensor_iter.first == config_in.name) { | if (in_tensor_iter.first == config_in.name) { | ||||
found = true; | |||||
size_t bs = in_tensor_iter.second->shape(0); | size_t bs = in_tensor_iter.second->shape(0); | ||||
auto shape = in_tensor_iter.second->shape(); | auto shape = in_tensor_iter.second->shape(); | ||||
shape.shape[0] = 1; | |||||
if (config_in.config_layout.ndim) { | if (config_in.config_layout.ndim) { | ||||
bs = config_in.config_layout.shapes[0]; | bs = config_in.config_layout.shapes[0]; | ||||
shape.shape[1] = config_in.config_layout.shapes[1]; | |||||
shape.shape[2] = config_in.config_layout.shapes[2]; | |||||
shape.shape[3] = config_in.config_layout.shapes[3]; | |||||
for (size_t i = 0; i < config_in.config_layout.ndim; ++i) { | |||||
shape.shape[i] = config_in.config_layout.shapes[i]; | |||||
} | |||||
} | } | ||||
HostTensorND tensor( | |||||
in_tensor_iter.second->comp_node(), shape, | |||||
in_tensor_iter.second->dtype(), | |||||
in_tensor_iter.second->format()); | |||||
shape.shape[0] = 1; | |||||
for (size_t i = 0; i < bs; ++i) { | for (size_t i = 0; i < bs; ++i) { | ||||
HostTensorND tensor( | |||||
in_tensor_iter.second->comp_node(), shape, | |||||
in_tensor_iter.second->dtype(), | |||||
in_tensor_iter.second->format()); | |||||
if (config_in.is_host) { | if (config_in.is_host) { | ||||
config_in.lite_tensors.push_back(std::make_shared<Tensor>( | config_in.lite_tensors.push_back(std::make_shared<Tensor>( | ||||
device_id, stream_id, device_type, true)); | device_id, stream_id, device_type, true)); | ||||
@@ -839,29 +855,6 @@ void NetworkImplDft::update_input_lite_tensors() { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
if (!found) { | |||||
size_t bs = in_tensor_iter.second->shape(0); | |||||
auto shape = in_tensor_iter.second->shape(); | |||||
shape.shape[0] = 1; | |||||
HostTensorND tensor( | |||||
in_tensor_iter.second->comp_node(), shape, | |||||
in_tensor_iter.second->dtype(), in_tensor_iter.second->format()); | |||||
IOInner io_in; | |||||
io_in.name = in_tensor_iter.first; | |||||
for (size_t i = 0; i < bs; ++i) { | |||||
io_in.lite_tensors.push_back(std::make_shared<Tensor>( | |||||
device_id, stream_id, device_type, true)); | |||||
TensorHelper::implement(io_in.lite_tensors[i]) | |||||
->cast_final_safe<TensorImplDft>() | |||||
.m_host_tensor = std::make_shared<HostTensorND>(tensor); | |||||
TensorHelper::implement(io_in.lite_tensors[i]) | |||||
->cast_final_safe<TensorImplDft>() | |||||
.m_record_reset = | |||||
m_user_config->options.comp_node_seq_record_level > 0; | |||||
io_in.lite_tensors[i]->update_from_implement(); | |||||
} | |||||
m_network_io->inputs.push_back(io_in); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -997,7 +990,15 @@ std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor( | |||||
if (phase == LiteTensorPhase::LITE_INPUT || phase == LiteTensorPhase::LITE_IO) { | if (phase == LiteTensorPhase::LITE_INPUT || phase == LiteTensorPhase::LITE_IO) { | ||||
for (auto&& config_in : m_network_io->inputs) { | for (auto&& config_in : m_network_io->inputs) { | ||||
if (io_name == config_in.name) { | if (io_name == config_in.name) { | ||||
return config_in.lite_tensor; | |||||
if (config_in.lite_tensor) { | |||||
return config_in.lite_tensor; | |||||
} else { | |||||
LITE_THROW(mgb::ssprintf( | |||||
"%s input tensor is in discrete mode, you can use " | |||||
"get_discrete_tensors to get this input.", | |||||
io_name.c_str())); | |||||
return nullptr; | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -1018,7 +1019,7 @@ std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor( | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
std::vector<std::shared_ptr<Tensor>> NetworkImplDft::get_io_tensors( | |||||
std::vector<std::shared_ptr<Tensor>> NetworkImplDft::get_discrete_tensors( | |||||
std::string io_name, LiteTensorPhase phase) { | std::string io_name, LiteTensorPhase phase) { | ||||
if (phase == LiteTensorPhase::LITE_INPUT) { | if (phase == LiteTensorPhase::LITE_INPUT) { | ||||
for (auto&& config_in : m_network_io->inputs) { | for (auto&& config_in : m_network_io->inputs) { | ||||
@@ -1038,7 +1039,7 @@ std::shared_ptr<Tensor> NetworkImplDft::get_input_tensor(size_t index) { | |||||
} | } | ||||
std::vector<std::shared_ptr<Tensor>> NetworkImplDft::get_input_tensors(size_t index) { | std::vector<std::shared_ptr<Tensor>> NetworkImplDft::get_input_tensors(size_t index) { | ||||
return get_io_tensors(get_input_name(index)); | |||||
return get_discrete_tensors(get_input_name(index)); | |||||
} | } | ||||
std::shared_ptr<Tensor> NetworkImplDft::get_output_tensor(size_t index) { | std::shared_ptr<Tensor> NetworkImplDft::get_output_tensor(size_t index) { | ||||
@@ -59,7 +59,7 @@ public: | |||||
//! get the network input tensors which input consists of discrete multiple tensors, | //! get the network input tensors which input consists of discrete multiple tensors, | ||||
//! layout (1, c, h, w) | //! layout (1, c, h, w) | ||||
std::vector<std::shared_ptr<Tensor>> get_io_tensors( | |||||
std::vector<std::shared_ptr<Tensor>> get_discrete_tensors( | |||||
std::string io_name, | std::string io_name, | ||||
LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT) override; | LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT) override; | ||||
@@ -127,12 +127,12 @@ std::shared_ptr<Tensor> Network::get_io_tensor( | |||||
LITE_ERROR_HANDLER_END | LITE_ERROR_HANDLER_END | ||||
} | } | ||||
std::vector<std::shared_ptr<Tensor>> Network::get_io_tensors( | |||||
std::vector<std::shared_ptr<Tensor>> Network::get_discrete_tensors( | |||||
std::string name, LiteTensorPhase phase) { | std::string name, LiteTensorPhase phase) { | ||||
LITE_ERROR_HANDLER_BEGIN | LITE_ERROR_HANDLER_BEGIN | ||||
LITE_ASSERT(m_loaded, "get_io_tensor should be used after model loaded."); | LITE_ASSERT(m_loaded, "get_io_tensor should be used after model loaded."); | ||||
LITE_CHECK_NON_NULL_POINTER(m_impl); | LITE_CHECK_NON_NULL_POINTER(m_impl); | ||||
return m_impl->get_io_tensors(name, phase); | |||||
return m_impl->get_discrete_tensors(name, phase); | |||||
LITE_ERROR_HANDLER_END | LITE_ERROR_HANDLER_END | ||||
} | } | ||||
@@ -91,8 +91,10 @@ public: | |||||
//! get the network input tensors which input consists of discrete multiple tensors, | //! get the network input tensors which input consists of discrete multiple tensors, | ||||
//! layout (1, c, h, w) | //! layout (1, c, h, w) | ||||
virtual std::vector<std::shared_ptr<Tensor>> get_io_tensors( | |||||
virtual std::vector<std::shared_ptr<Tensor>> get_discrete_tensors( | |||||
std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT) { | std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT) { | ||||
LITE_MARK_USED_VAR(io_name); | |||||
LITE_MARK_USED_VAR(phase); | |||||
return {}; | return {}; | ||||
} | } | ||||
@@ -102,6 +104,7 @@ public: | |||||
//! get the network input tensors which input consists of discrete multiple tensors | //! get the network input tensors which input consists of discrete multiple tensors | ||||
//! by index | //! by index | ||||
virtual std::vector<std::shared_ptr<Tensor>> get_input_tensors(size_t index) { | virtual std::vector<std::shared_ptr<Tensor>> get_input_tensors(size_t index) { | ||||
LITE_MARK_USED_VAR(index); | |||||
return {}; | return {}; | ||||
} | } | ||||
@@ -1393,6 +1393,7 @@ TEST(TestNetWork, Discrete_Input) { | |||||
auto data_0 = get_input_data("./data0.npy"); | auto data_0 = get_input_data("./data0.npy"); | ||||
auto data_1 = get_input_data("./data1.npy"); | auto data_1 = get_input_data("./data1.npy"); | ||||
auto data_2 = get_input_data("./data2.npy"); | auto data_2 = get_input_data("./data2.npy"); | ||||
auto roi = get_input_data("./roi.npy"); | |||||
std::string model_path = "./test_discrete_input.mge"; | std::string model_path = "./test_discrete_input.mge"; | ||||
Config config; | Config config; | ||||
@@ -1403,6 +1404,8 @@ TEST(TestNetWork, Discrete_Input) { | |||||
std::shared_ptr<Tensor> data_tensor = network0->get_io_tensor("data"); | std::shared_ptr<Tensor> data_tensor = network0->get_io_tensor("data"); | ||||
data_tensor->share_memory_with(*data); | data_tensor->share_memory_with(*data); | ||||
std::shared_ptr<Tensor> roi_tensor = network0->get_io_tensor("roi"); | |||||
roi_tensor->share_memory_with(*roi); | |||||
network0->forward(); | network0->forward(); | ||||
network0->wait(); | network0->wait(); | ||||
@@ -1417,8 +1420,11 @@ TEST(TestNetWork, Discrete_Input) { | |||||
std::shared_ptr<Network> network1 = std::make_shared<Network>(config, ios); | std::shared_ptr<Network> network1 = std::make_shared<Network>(config, ios); | ||||
network1->load_model(model_path); | network1->load_model(model_path); | ||||
std::shared_ptr<Tensor> roi_tensor1 = network1->get_io_tensor("roi"); | |||||
roi_tensor1->copy_from(*roi); | |||||
std::vector<std::shared_ptr<Tensor>> data_tensors = | std::vector<std::shared_ptr<Tensor>> data_tensors = | ||||
network1->get_io_tensors("data"); | |||||
network1->get_discrete_tensors("data"); | |||||
data_tensors[0]->share_memory_with(*data_0); | data_tensors[0]->share_memory_with(*data_0); | ||||
data_tensors[1]->share_memory_with(*data_1); | data_tensors[1]->share_memory_with(*data_1); | ||||
data_tensors[2]->share_memory_with(*data_2); | data_tensors[2]->share_memory_with(*data_2); | ||||
@@ -1435,6 +1441,7 @@ TEST(TestNetWork, Discrete_Input_Device) { | |||||
auto data_0 = get_input_data("./data0.npy"); | auto data_0 = get_input_data("./data0.npy"); | ||||
auto data_1 = get_input_data("./data1.npy"); | auto data_1 = get_input_data("./data1.npy"); | ||||
auto data_2 = get_input_data("./data2.npy"); | auto data_2 = get_input_data("./data2.npy"); | ||||
auto roi = get_input_data("./roi.npy"); | |||||
std::string model_path = "./test_discrete_input.mge"; | std::string model_path = "./test_discrete_input.mge"; | ||||
Config config; | Config config; | ||||
@@ -1444,7 +1451,9 @@ TEST(TestNetWork, Discrete_Input_Device) { | |||||
network0->load_model(model_path); | network0->load_model(model_path); | ||||
std::shared_ptr<Tensor> data_tensor = network0->get_io_tensor("data"); | std::shared_ptr<Tensor> data_tensor = network0->get_io_tensor("data"); | ||||
data_tensor->share_memory_with(*data); | |||||
data_tensor->copy_from(*data); | |||||
std::shared_ptr<Tensor> roi_tensor = network0->get_io_tensor("roi"); | |||||
roi_tensor->copy_from(*roi); | |||||
network0->forward(); | network0->forward(); | ||||
network0->wait(); | network0->wait(); | ||||
@@ -1459,8 +1468,10 @@ TEST(TestNetWork, Discrete_Input_Device) { | |||||
std::shared_ptr<Network> network1 = std::make_shared<Network>(config, ios); | std::shared_ptr<Network> network1 = std::make_shared<Network>(config, ios); | ||||
network1->load_model(model_path); | network1->load_model(model_path); | ||||
std::shared_ptr<Tensor> roi_tensor1 = network1->get_io_tensor("roi"); | |||||
roi_tensor1->copy_from(*roi); | |||||
std::vector<std::shared_ptr<Tensor>> data_tensors = | std::vector<std::shared_ptr<Tensor>> data_tensors = | ||||
network1->get_io_tensors("data"); | |||||
network1->get_discrete_tensors("data"); | |||||
auto d0_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly); | auto d0_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly); | ||||
auto d1_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly); | auto d1_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly); | ||||
auto d2_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly); | auto d2_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly); | ||||
@@ -1477,6 +1488,48 @@ TEST(TestNetWork, Discrete_Input_Device) { | |||||
compare_lite_tensor<float>(output_tensor0, output_tensor1); | compare_lite_tensor<float>(output_tensor0, output_tensor1); | ||||
} | } | ||||
TEST(TestNetWork, Discrete_Input_Concat) { | |||||
auto data = get_input_data("./data_b3.npy"); | |||||
auto data_0 = get_input_data("./data0.npy"); | |||||
auto data_1 = get_input_data("./data1.npy"); | |||||
auto data_2 = get_input_data("./data2.npy"); | |||||
std::string model_path = "./test_discrete_input_concat.mge"; | |||||
Config config; | |||||
config.device_type = LiteDeviceType::LITE_CUDA; | |||||
std::shared_ptr<Network> network0 = std::make_shared<Network>(config); | |||||
network0->load_model(model_path); | |||||
std::shared_ptr<Tensor> data_tensor = network0->get_io_tensor("data"); | |||||
data_tensor->copy_from(*data); | |||||
network0->forward(); | |||||
network0->wait(); | |||||
std::shared_ptr<Tensor> output_tensor0 = network0->get_output_tensor(0); | |||||
config.discrete_input_name = "data"; | |||||
NetworkIO ios; | |||||
bool is_host = true; | |||||
Layout d_ly{{3, 3, 224, 224}, 4, LiteDataType::LITE_FLOAT}; | |||||
ios.inputs.push_back({"data", is_host, LiteIOType::LITE_IO_VALUE, d_ly}); | |||||
std::shared_ptr<Network> network1 = std::make_shared<Network>(config, ios); | |||||
network1->load_model(model_path); | |||||
std::vector<std::shared_ptr<Tensor>> data_tensors = | |||||
network1->get_discrete_tensors("data"); | |||||
data_tensors[0]->copy_from(*data_0); | |||||
data_tensors[1]->copy_from(*data_1); | |||||
data_tensors[2]->copy_from(*data_2); | |||||
network1->forward(); | |||||
network1->wait(); | |||||
std::shared_ptr<Tensor> output_tensor1 = network1->get_output_tensor(0); | |||||
compare_lite_tensor<float>(output_tensor0, output_tensor1); | |||||
} | |||||
#endif | #endif | ||||
#if MGB_ATLAS || MGB_CAMBRICON | #if MGB_ATLAS || MGB_CAMBRICON | ||||
@@ -322,7 +322,7 @@ TEST(TestCapiNetWork, Discrete_Input) { | |||||
std::vector<LiteTensor> c_data_tensors(3, nullptr); | std::vector<LiteTensor> c_data_tensors(3, nullptr); | ||||
for (size_t i = 0; i < 3; i++) { | for (size_t i = 0; i < 3; i++) { | ||||
LITE_CAPI_CHECK(LITE_get_io_tensors( | |||||
LITE_CAPI_CHECK(LITE_get_discrete_tensor( | |||||
c_network, "data", i, LITE_INPUT, &c_data_tensors[i])); | c_network, "data", i, LITE_INPUT, &c_data_tensors[i])); | ||||
LITE_CAPI_CHECK(LITE_reset_tensor_memory( | LITE_CAPI_CHECK(LITE_reset_tensor_memory( | ||||
c_data_tensors[i], datas[i]->get_memory_ptr(), data_length_in_byte)); | c_data_tensors[i], datas[i]->get_memory_ptr(), data_length_in_byte)); | ||||