GitOrigin-RevId: 78e00dab5d
release-1.7
@@ -31,6 +31,7 @@ class LiteDataType(IntEnum): | |||||
LITE_INT16 = 3 | LITE_INT16 = 3 | ||||
LITE_INT8 = 4 | LITE_INT8 = 4 | ||||
LITE_UINT8 = 5 | LITE_UINT8 = 5 | ||||
LITE_UINT16 = 6 | |||||
class LiteTensorPhase(IntEnum): | class LiteTensorPhase(IntEnum): | ||||
@@ -22,6 +22,7 @@ _lite_type_to_nptypes = { | |||||
LiteDataType.LITE_UINT8: np.uint8, | LiteDataType.LITE_UINT8: np.uint8, | ||||
LiteDataType.LITE_INT8: np.int8, | LiteDataType.LITE_INT8: np.int8, | ||||
LiteDataType.LITE_INT16: np.int16, | LiteDataType.LITE_INT16: np.int16, | ||||
LiteDataType.LITE_UINT16: np.uint16, | |||||
LiteDataType.LITE_HALF: np.float16, | LiteDataType.LITE_HALF: np.float16, | ||||
} | } | ||||
@@ -33,6 +34,7 @@ _str_nptypes_to_lite_nptypes = { | |||||
np.dtype("uint8"): LiteDataType.LITE_UINT8, | np.dtype("uint8"): LiteDataType.LITE_UINT8, | ||||
np.dtype("int8"): LiteDataType.LITE_INT8, | np.dtype("int8"): LiteDataType.LITE_INT8, | ||||
np.dtype("int16"): LiteDataType.LITE_INT16, | np.dtype("int16"): LiteDataType.LITE_INT16, | ||||
np.dtype("uint16"): LiteDataType.LITE_UINT16, | |||||
np.dtype("float16"): LiteDataType.LITE_HALF, | np.dtype("float16"): LiteDataType.LITE_HALF, | ||||
} | } | ||||
@@ -43,7 +45,7 @@ ctype_to_lite_dtypes = { | |||||
c_ubyte: LiteDataType.LITE_UINT8, | c_ubyte: LiteDataType.LITE_UINT8, | ||||
c_byte: LiteDataType.LITE_INT8, | c_byte: LiteDataType.LITE_INT8, | ||||
c_short: LiteDataType.LITE_INT16, | c_short: LiteDataType.LITE_INT16, | ||||
c_ushort: LiteDataType.LITE_INT16, | |||||
c_ushort: LiteDataType.LITE_UINT16, | |||||
} | } | ||||
@@ -83,7 +85,7 @@ class LiteLayout(Structure): | |||||
def __repr__(self): | def __repr__(self): | ||||
data = { | data = { | ||||
"shapes": list(self.shapes), | |||||
"shapes": list(self.shapes)[0 : self.ndim], | |||||
"ndim": self.ndim, | "ndim": self.ndim, | ||||
"data_type": _lite_type_to_nptypes[LiteDataType(self.data_type)], | "data_type": _lite_type_to_nptypes[LiteDataType(self.data_type)], | ||||
} | } | ||||
@@ -100,6 +100,9 @@ LTensorLayout lite::to_impl_layout(const Layout& layout) { | |||||
case LiteDataType::LITE_INT16: | case LiteDataType::LITE_INT16: | ||||
mge_layout.dtype = mgb::dtype::Int16(); | mge_layout.dtype = mgb::dtype::Int16(); | ||||
break; | break; | ||||
case LiteDataType::LITE_UINT16: | |||||
mge_layout.dtype = mgb::dtype::Uint16(); | |||||
break; | |||||
default: | default: | ||||
LITE_THROW(mgb::ssprintf( | LITE_THROW(mgb::ssprintf( | ||||
"unsupport dtype in lite enum id is %d.", | "unsupport dtype in lite enum id is %d.", | ||||
@@ -133,6 +136,9 @@ Layout lite::to_lite_layout(const LTensorLayout& mge_layout) { | |||||
case mgb::DTypeEnum::Int16: | case mgb::DTypeEnum::Int16: | ||||
layout.data_type = LiteDataType::LITE_INT16; | layout.data_type = LiteDataType::LITE_INT16; | ||||
break; | break; | ||||
case mgb::DTypeEnum::Uint16: | |||||
layout.data_type = LiteDataType::LITE_UINT16; | |||||
break; | |||||
case mgb::DTypeEnum::Int8: | case mgb::DTypeEnum::Int8: | ||||
layout.data_type = LiteDataType::LITE_INT8; | layout.data_type = LiteDataType::LITE_INT8; | ||||
break; | break; | ||||
@@ -442,6 +442,24 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) { | |||||
} | } | ||||
} | } | ||||
void NetworkImplDft::try_infer_tensor_layout( | |||||
std::shared_ptr<Tensor> tensor, mgb::cg::SymbolVar var) { | |||||
auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager(); | |||||
auto infer_trait = var.node()->get_static_infer_trait(); | |||||
if (std::get<0>(infer_trait)) { | |||||
auto shape = static_infer_mgr.infer_shape_fallible(var.node()); | |||||
if (!shape) { | |||||
LITE_WARN( | |||||
"Lite infer output shape failed, maybe the model is " | |||||
"dynamic " | |||||
"shape.\n"); | |||||
return; | |||||
} | |||||
Layout layout = to_lite_layout(mgb::TensorLayout{*shape, var.dtype()}); | |||||
tensor->set_layout(layout); | |||||
} | |||||
} | |||||
void NetworkImplDft::update_io() { | void NetworkImplDft::update_io() { | ||||
update_input(); | update_input(); | ||||
update_output(); | update_output(); | ||||
@@ -564,6 +582,14 @@ void NetworkImplDft::update_output() { | |||||
out_it->lite_tensor = | out_it->lite_tensor = | ||||
std::make_shared<Tensor>(device_id, stream_id, device_type); | std::make_shared<Tensor>(device_id, stream_id, device_type); | ||||
} | } | ||||
mgb::SymbolVar var; | |||||
for (auto&& out_var : m_load_result.output_var_list) { | |||||
if (out_var.node()->name() == out_it->name) { | |||||
var = out_var; | |||||
break; | |||||
} | |||||
} | |||||
try_infer_tensor_layout(out_it->lite_tensor, var); | |||||
} | } | ||||
//! user not set, use default output | //! user not set, use default output | ||||
} else { | } else { | ||||
@@ -579,12 +605,14 @@ void NetworkImplDft::update_output() { | |||||
it->lite_tensor = | it->lite_tensor = | ||||
std::make_shared<Tensor>(device_id, stream_id, device_type); | std::make_shared<Tensor>(device_id, stream_id, device_type); | ||||
} | } | ||||
try_infer_tensor_layout(it->lite_tensor, out); | |||||
} else { | } else { | ||||
IOInner output; | IOInner output; | ||||
output.name = out.node()->name(); | output.name = out.node()->name(); | ||||
output.lite_tensor = std::make_shared<Tensor>( | output.lite_tensor = std::make_shared<Tensor>( | ||||
device_id, stream_id, device_type, true); | device_id, stream_id, device_type, true); | ||||
m_network_io->outputs.push_back({output}); | m_network_io->outputs.push_back({output}); | ||||
try_infer_tensor_layout(output.lite_tensor, out); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -201,6 +201,10 @@ private: | |||||
//! compile the graph to get the execute function | //! compile the graph to get the execute function | ||||
void compile_graph(); | void compile_graph(); | ||||
//! try to infer output tensor layout | |||||
void try_infer_tensor_layout( | |||||
std::shared_ptr<Tensor> tensor, mgb::cg::SymbolVar var); | |||||
private: | private: | ||||
bool m_async = false; | bool m_async = false; | ||||
bool m_is_cpu_inplace_mode = false; | bool m_is_cpu_inplace_mode = false; | ||||
@@ -102,6 +102,8 @@ LiteLogLevel lite::get_log_level() { | |||||
} | } | ||||
std::string lite::ssprintf(const char* format, ...) { | std::string lite::ssprintf(const char* format, ...) { | ||||
if (!format) | |||||
return ""; | |||||
va_list ap; | va_list ap; | ||||
va_start(ap, format); | va_start(ap, format); | ||||
auto ret = svsprintf(format, ap); | auto ret = svsprintf(format, ap); | ||||
@@ -110,6 +112,8 @@ std::string lite::ssprintf(const char* format, ...) { | |||||
} | } | ||||
void lite::print_log(LiteLogLevel level, const char* format, ...) { | void lite::print_log(LiteLogLevel level, const char* format, ...) { | ||||
if (!format) | |||||
return; | |||||
if (static_cast<uint32_t>(level) < static_cast<uint32_t>(get_log_level())) { | if (static_cast<uint32_t>(level) < static_cast<uint32_t>(get_log_level())) { | ||||
return; | return; | ||||
} | } | ||||
@@ -90,6 +90,11 @@ TEST(TestNetWork, GetAllName) { | |||||
auto input_names = network->get_all_input_name(); | auto input_names = network->get_all_input_name(); | ||||
auto output_names = network->get_all_output_name(); | auto output_names = network->get_all_output_name(); | ||||
auto output_tensor = network->get_output_tensor(0); | |||||
auto out_layout = output_tensor->get_layout(); | |||||
ASSERT_EQ(out_layout.ndim, 2); | |||||
ASSERT_EQ(out_layout.shapes[0], 1); | |||||
ASSERT_EQ(out_layout.shapes[1], 1000); | |||||
ASSERT_EQ(input_names.size(), 1); | ASSERT_EQ(input_names.size(), 1); | ||||
ASSERT_EQ(output_names.size(), 1); | ASSERT_EQ(output_names.size(), 1); | ||||
ASSERT_TRUE(input_names[0] == "data"); | ASSERT_TRUE(input_names[0] == "data"); | ||||
@@ -488,6 +488,13 @@ public: | |||||
*/ | */ | ||||
MemAllocPlan& init_mem_plan(const DeviceTensorND* fixed_alloc = nullptr); | MemAllocPlan& init_mem_plan(const DeviceTensorND* fixed_alloc = nullptr); | ||||
/*! | |||||
* \brief get the shape and value infer trait | |||||
*/ | |||||
const std::tuple<void*, void*>& get_static_infer_trait() { | |||||
return m_static_infer_trait; | |||||
} | |||||
private: | private: | ||||
//! whether its memory should be allocated by mgb system during graph | //! whether its memory should be allocated by mgb system during graph | ||||
//! execution; initialized in VarNodeMemManager::reset_opr_seq() | //! execution; initialized in VarNodeMemManager::reset_opr_seq() | ||||