|
|
@@ -513,13 +513,18 @@ void GraphDumperOSSV2::dump_tensor( |
|
|
|
check_tensor_value_valid(name, tensor); |
|
|
|
auto&& dumper = m_config.tensor_value_dumper; |
|
|
|
if (dumper) { |
|
|
|
mgb_log_warn( |
|
|
|
"serialization v2 format is pure flatbuffer format, not support " |
|
|
|
"user tensor value dumper callback."); |
|
|
|
std::vector<uint8_t> out_vec; |
|
|
|
auto temp_out_file = OutputFile::make_vector_proxy(&out_vec); |
|
|
|
dumper(*temp_out_file, *m_cur_opr, tensor); |
|
|
|
data = m_builder.CreateVector( |
|
|
|
reinterpret_cast<uint8_t*>(out_vec.data()), out_vec.size()); |
|
|
|
m_cur_rst.tensor_value_bytes += out_vec.size(); |
|
|
|
} else { |
|
|
|
data = m_builder.CreateVector( |
|
|
|
reinterpret_cast<uint8_t*>(tensor.raw_ptr()), |
|
|
|
layout.span().high_byte); |
|
|
|
m_cur_rst.tensor_value_bytes += layout.span().high_byte; |
|
|
|
} |
|
|
|
data = m_builder.CreateVector( |
|
|
|
reinterpret_cast<uint8_t*>(tensor.raw_ptr()), layout.span().high_byte); |
|
|
|
m_cur_rst.tensor_value_bytes += layout.span().high_byte; |
|
|
|
} |
|
|
|
|
|
|
|
auto fbname = should_keep_name ? m_builder.CreateSharedString(name) : 0; |
|
|
@@ -688,14 +693,9 @@ std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor( |
|
|
|
|
|
|
|
auto&& loader = m_loader->m_cur_load_config->tensor_value_loader; |
|
|
|
if (tensor->data() && tensor->data()->size() > 0) { |
|
|
|
if (loader) { |
|
|
|
mgb_log_warn( |
|
|
|
"serialization v2 format is pure flatbuffer format, not support " |
|
|
|
"user tensor value loader callback."); |
|
|
|
} |
|
|
|
fill_tensor_memory( |
|
|
|
*ret, tensor->data()->data(), tensor->data()->size(), |
|
|
|
m_loader->m_file->is_shared_memory()); |
|
|
|
m_loader->m_file->is_shared_memory(), loader); |
|
|
|
} |
|
|
|
if (tensor->name()) { |
|
|
|
m_tensor_map[tensor->name()->str()] = ret; |
|
|
@@ -737,6 +737,7 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl:: |
|
|
|
shared_pair.first = tensor->name()->str(); |
|
|
|
} |
|
|
|
|
|
|
|
auto loader = m_loader->m_cur_load_config->tensor_value_loader; |
|
|
|
if (comp_node.mem_node() == CompNode::default_cpu().mem_node() || copy_immediatly) { |
|
|
|
// directly forward CPU memory |
|
|
|
shared_tensor_ref = std::make_shared<DeviceTensorND>(); |
|
|
@@ -745,7 +746,7 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl:: |
|
|
|
hv.dtype(layout.dtype).resize(layout); |
|
|
|
fill_tensor_memory( |
|
|
|
hv, tensor->data()->data(), tensor->data()->size(), |
|
|
|
m_loader->m_file->is_shared_memory()); |
|
|
|
m_loader->m_file->is_shared_memory(), loader); |
|
|
|
} |
|
|
|
if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) { |
|
|
|
*shared_tensor_ref = DeviceTensorND::make_proxy(hv); |
|
|
@@ -761,7 +762,7 @@ std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl:: |
|
|
|
hv.dtype(layout.dtype).resize(layout); |
|
|
|
fill_tensor_memory( |
|
|
|
hv, tensor->data()->data(), tensor->data()->size(), |
|
|
|
m_loader->m_file->is_shared_memory()); |
|
|
|
m_loader->m_file->is_shared_memory(), loader); |
|
|
|
} |
|
|
|
shared_tensor_ref = m_device_value_loader.make(comp_node, std::move(hv)); |
|
|
|
} |
|
|
@@ -947,7 +948,7 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re |
|
|
|
if (m_shared_tensor_map.empty()) { |
|
|
|
m_shared_tensor_map.resize(m_model->nr_shared_tensor()); |
|
|
|
} else { |
|
|
|
mgb_assert(m_shared_tensor_map.size() == m_model->nr_shared_tensor()); |
|
|
|
mgb_assert(m_shared_tensor_map.size() >= m_model->nr_shared_tensor()); |
|
|
|
} |
|
|
|
SharedTensorAlignMent tensor_alignment( |
|
|
|
m_model_buf, m_file.get(), |
|
|
|