Browse Source

feat(lite): auto deduce output tensor shape before model forward

GitOrigin-RevId: 78e00dab5d
release-1.7
Megvii Engine Team 3 years ago
parent
commit
565466c25f
8 changed files with 59 additions and 2 deletions
  1. +1
    -0
      lite/pylite/megenginelite/struct.py
  2. +4
    -2
      lite/pylite/megenginelite/tensor.py
  3. +6
    -0
      lite/src/mge/common.cpp
  4. +28
    -0
      lite/src/mge/network_impl.cpp
  5. +4
    -0
      lite/src/mge/network_impl.h
  6. +4
    -0
      lite/src/misc.cpp
  7. +5
    -0
      lite/test/test_network.cpp
  8. +7
    -0
      src/core/include/megbrain/graph/var_node.h

+ 1
- 0
lite/pylite/megenginelite/struct.py View File

@@ -31,6 +31,7 @@ class LiteDataType(IntEnum):
LITE_INT16 = 3 LITE_INT16 = 3
LITE_INT8 = 4 LITE_INT8 = 4
LITE_UINT8 = 5 LITE_UINT8 = 5
LITE_UINT16 = 6




class LiteTensorPhase(IntEnum): class LiteTensorPhase(IntEnum):


+ 4
- 2
lite/pylite/megenginelite/tensor.py View File

@@ -22,6 +22,7 @@ _lite_type_to_nptypes = {
LiteDataType.LITE_UINT8: np.uint8, LiteDataType.LITE_UINT8: np.uint8,
LiteDataType.LITE_INT8: np.int8, LiteDataType.LITE_INT8: np.int8,
LiteDataType.LITE_INT16: np.int16, LiteDataType.LITE_INT16: np.int16,
LiteDataType.LITE_UINT16: np.uint16,
LiteDataType.LITE_HALF: np.float16, LiteDataType.LITE_HALF: np.float16,
} }


@@ -33,6 +34,7 @@ _str_nptypes_to_lite_nptypes = {
np.dtype("uint8"): LiteDataType.LITE_UINT8, np.dtype("uint8"): LiteDataType.LITE_UINT8,
np.dtype("int8"): LiteDataType.LITE_INT8, np.dtype("int8"): LiteDataType.LITE_INT8,
np.dtype("int16"): LiteDataType.LITE_INT16, np.dtype("int16"): LiteDataType.LITE_INT16,
np.dtype("uint16"): LiteDataType.LITE_UINT16,
np.dtype("float16"): LiteDataType.LITE_HALF, np.dtype("float16"): LiteDataType.LITE_HALF,
} }


@@ -43,7 +45,7 @@ ctype_to_lite_dtypes = {
c_ubyte: LiteDataType.LITE_UINT8, c_ubyte: LiteDataType.LITE_UINT8,
c_byte: LiteDataType.LITE_INT8, c_byte: LiteDataType.LITE_INT8,
c_short: LiteDataType.LITE_INT16, c_short: LiteDataType.LITE_INT16,
c_ushort: LiteDataType.LITE_INT16,
c_ushort: LiteDataType.LITE_UINT16,
} }




@@ -83,7 +85,7 @@ class LiteLayout(Structure):


def __repr__(self): def __repr__(self):
data = { data = {
"shapes": list(self.shapes),
"shapes": list(self.shapes)[0 : self.ndim],
"ndim": self.ndim, "ndim": self.ndim,
"data_type": _lite_type_to_nptypes[LiteDataType(self.data_type)], "data_type": _lite_type_to_nptypes[LiteDataType(self.data_type)],
} }


+ 6
- 0
lite/src/mge/common.cpp View File

@@ -100,6 +100,9 @@ LTensorLayout lite::to_impl_layout(const Layout& layout) {
case LiteDataType::LITE_INT16: case LiteDataType::LITE_INT16:
mge_layout.dtype = mgb::dtype::Int16(); mge_layout.dtype = mgb::dtype::Int16();
break; break;
case LiteDataType::LITE_UINT16:
mge_layout.dtype = mgb::dtype::Uint16();
break;
default: default:
LITE_THROW(mgb::ssprintf( LITE_THROW(mgb::ssprintf(
"unsupport dtype in lite enum id is %d.", "unsupport dtype in lite enum id is %d.",
@@ -133,6 +136,9 @@ Layout lite::to_lite_layout(const LTensorLayout& mge_layout) {
case mgb::DTypeEnum::Int16: case mgb::DTypeEnum::Int16:
layout.data_type = LiteDataType::LITE_INT16; layout.data_type = LiteDataType::LITE_INT16;
break; break;
case mgb::DTypeEnum::Uint16:
layout.data_type = LiteDataType::LITE_UINT16;
break;
case mgb::DTypeEnum::Int8: case mgb::DTypeEnum::Int8:
layout.data_type = LiteDataType::LITE_INT8; layout.data_type = LiteDataType::LITE_INT8;
break; break;


+ 28
- 0
lite/src/mge/network_impl.cpp View File

@@ -442,6 +442,24 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) {
} }
} }


void NetworkImplDft::try_infer_tensor_layout(
std::shared_ptr<Tensor> tensor, mgb::cg::SymbolVar var) {
auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager();
auto infer_trait = var.node()->get_static_infer_trait();
if (std::get<0>(infer_trait)) {
auto shape = static_infer_mgr.infer_shape_fallible(var.node());
if (!shape) {
LITE_WARN(
"Lite infer output shape failed, maybe the model is "
"dynamic "
"shape.\n");
return;
}
Layout layout = to_lite_layout(mgb::TensorLayout{*shape, var.dtype()});
tensor->set_layout(layout);
}
}

void NetworkImplDft::update_io() { void NetworkImplDft::update_io() {
update_input(); update_input();
update_output(); update_output();
@@ -564,6 +582,14 @@ void NetworkImplDft::update_output() {
out_it->lite_tensor = out_it->lite_tensor =
std::make_shared<Tensor>(device_id, stream_id, device_type); std::make_shared<Tensor>(device_id, stream_id, device_type);
} }
mgb::SymbolVar var;
for (auto&& out_var : m_load_result.output_var_list) {
if (out_var.node()->name() == out_it->name) {
var = out_var;
break;
}
}
try_infer_tensor_layout(out_it->lite_tensor, var);
} }
//! user not set, use default output //! user not set, use default output
} else { } else {
@@ -579,12 +605,14 @@ void NetworkImplDft::update_output() {
it->lite_tensor = it->lite_tensor =
std::make_shared<Tensor>(device_id, stream_id, device_type); std::make_shared<Tensor>(device_id, stream_id, device_type);
} }
try_infer_tensor_layout(it->lite_tensor, out);
} else { } else {
IOInner output; IOInner output;
output.name = out.node()->name(); output.name = out.node()->name();
output.lite_tensor = std::make_shared<Tensor>( output.lite_tensor = std::make_shared<Tensor>(
device_id, stream_id, device_type, true); device_id, stream_id, device_type, true);
m_network_io->outputs.push_back({output}); m_network_io->outputs.push_back({output});
try_infer_tensor_layout(output.lite_tensor, out);
} }
} }
} }


+ 4
- 0
lite/src/mge/network_impl.h View File

@@ -201,6 +201,10 @@ private:
//! compile the graph to get the execute function //! compile the graph to get the execute function
void compile_graph(); void compile_graph();


//! try to infer output tensor layout
void try_infer_tensor_layout(
std::shared_ptr<Tensor> tensor, mgb::cg::SymbolVar var);

private: private:
bool m_async = false; bool m_async = false;
bool m_is_cpu_inplace_mode = false; bool m_is_cpu_inplace_mode = false;


+ 4
- 0
lite/src/misc.cpp View File

@@ -102,6 +102,8 @@ LiteLogLevel lite::get_log_level() {
} }


std::string lite::ssprintf(const char* format, ...) { std::string lite::ssprintf(const char* format, ...) {
if (!format)
return "";
va_list ap; va_list ap;
va_start(ap, format); va_start(ap, format);
auto ret = svsprintf(format, ap); auto ret = svsprintf(format, ap);
@@ -110,6 +112,8 @@ std::string lite::ssprintf(const char* format, ...) {
} }


void lite::print_log(LiteLogLevel level, const char* format, ...) { void lite::print_log(LiteLogLevel level, const char* format, ...) {
if (!format)
return;
if (static_cast<uint32_t>(level) < static_cast<uint32_t>(get_log_level())) { if (static_cast<uint32_t>(level) < static_cast<uint32_t>(get_log_level())) {
return; return;
} }


+ 5
- 0
lite/test/test_network.cpp View File

@@ -90,6 +90,11 @@ TEST(TestNetWork, GetAllName) {
auto input_names = network->get_all_input_name(); auto input_names = network->get_all_input_name();
auto output_names = network->get_all_output_name(); auto output_names = network->get_all_output_name();


auto output_tensor = network->get_output_tensor(0);
auto out_layout = output_tensor->get_layout();
ASSERT_EQ(out_layout.ndim, 2);
ASSERT_EQ(out_layout.shapes[0], 1);
ASSERT_EQ(out_layout.shapes[1], 1000);
ASSERT_EQ(input_names.size(), 1); ASSERT_EQ(input_names.size(), 1);
ASSERT_EQ(output_names.size(), 1); ASSERT_EQ(output_names.size(), 1);
ASSERT_TRUE(input_names[0] == "data"); ASSERT_TRUE(input_names[0] == "data");


+ 7
- 0
src/core/include/megbrain/graph/var_node.h View File

@@ -488,6 +488,13 @@ public:
*/ */
MemAllocPlan& init_mem_plan(const DeviceTensorND* fixed_alloc = nullptr); MemAllocPlan& init_mem_plan(const DeviceTensorND* fixed_alloc = nullptr);


/*!
* \brief get the shape and value infer trait
*/
const std::tuple<void*, void*>& get_static_infer_trait() {
return m_static_infer_trait;
}

private: private:
//! whether its memory should be allocated by mgb system during graph //! whether its memory should be allocated by mgb system during graph
//! execution; initialized in VarNodeMemManager::reset_opr_seq() //! execution; initialized in VarNodeMemManager::reset_opr_seq()


Loading…
Cancel
Save