Also, add a CONSTANT value inference tag to outputs of
MultipleDeviceTensorHolder.
GitOrigin-RevId: 82a805ed5f
tags/v1.0.0-rc1
@@ -570,10 +570,10 @@ void ParamFusePass::apply(OptState &state) const { | |||||
*var->owner_graph(), hv, var_namer.name(var)); | *var->owner_graph(), hv, var_namer.name(var)); | ||||
} else { | } else { | ||||
if (is_default_format) { | if (is_default_format) { | ||||
new_var = opr::SharedDeviceTensor::make( | |||||
new_var = opr::SharedDeviceTensor::make_const( | |||||
*var->owner_graph(), inferred_val, var_namer.name(var)); | *var->owner_graph(), inferred_val, var_namer.name(var)); | ||||
} else { | } else { | ||||
new_var = opr::SharedDeviceTensorWithFormat::make( | |||||
new_var = opr::SharedDeviceTensorWithFormat::make_const( | |||||
*var->owner_graph(), inferred_val, var_namer.name(var)); | *var->owner_graph(), inferred_val, var_namer.name(var)); | ||||
} | } | ||||
} | } | ||||
@@ -281,11 +281,11 @@ void Host2DeviceCopy::record_execute_deps(ExecDependencyArray& deps) { | |||||
/* ===================== SharedDeviceTensor related ===================== */ | /* ===================== SharedDeviceTensor related ===================== */ | ||||
intl::SharedDeviceTensorBase::SharedDeviceTensorBase( | intl::SharedDeviceTensorBase::SharedDeviceTensorBase( | ||||
ComputingGraph &graph, const std::shared_ptr<DeviceTensorND> &dev_data, | |||||
const OperatorNodeConfig &config): | |||||
Super{&graph, config, "shared", {}}, | |||||
m_dev_data{dev_data} | |||||
{ | |||||
ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, | |||||
bool const_value, const OperatorNodeConfig& config) | |||||
: Super{&graph, config, "shared", {}}, | |||||
m_dev_data{dev_data}, | |||||
m_const_value(const_value) { | |||||
if (config.has_comp_node_set()) { | if (config.has_comp_node_set()) { | ||||
mgb_assert(config.get_single_comp_node() == dev_data->comp_node()); | mgb_assert(config.get_single_comp_node() == dev_data->comp_node()); | ||||
} | } | ||||
@@ -307,26 +307,42 @@ void intl::SharedDeviceTensorBase::init_output_comp_node() { | |||||
comp_node(m_dev_data->comp_node()); | comp_node(m_dev_data->comp_node()); | ||||
} | } | ||||
bool intl::SharedDeviceTensorBase::fill_in_static_infer(DeviceTensorND* dest) { | |||||
if (m_const_value) { | |||||
if (dest) { | |||||
if (m_static_infer.empty()) { | |||||
m_static_infer.comp_node(CompNode::default_cpu()) | |||||
.copy_from(*m_dev_data); | |||||
} | |||||
*dest = m_static_infer; | |||||
} | |||||
return true; | |||||
} | |||||
return false; | |||||
} | |||||
cg::static_infer::SourceType SharedDeviceTensor::static_infer_src_type() const { | cg::static_infer::SourceType SharedDeviceTensor::static_infer_src_type() const { | ||||
return cg::static_infer::SourceType::CONSTANT; | return cg::static_infer::SourceType::CONSTANT; | ||||
} | } | ||||
SymbolVar SharedDeviceTensor::make(ComputingGraph &graph, | SymbolVar SharedDeviceTensor::make(ComputingGraph &graph, | ||||
const std::shared_ptr<DeviceTensorND> &dev_data, | const std::shared_ptr<DeviceTensorND> &dev_data, | ||||
bool const_value, | |||||
const OperatorNodeConfig &config) { | const OperatorNodeConfig &config) { | ||||
return graph.insert_opr(std::make_unique<SharedDeviceTensor>( | return graph.insert_opr(std::make_unique<SharedDeviceTensor>( | ||||
graph, dev_data, config))->output(0); | |||||
graph, dev_data, const_value, config))->output(0); | |||||
} | } | ||||
SymbolVar SharedDeviceTensor::make(ComputingGraph &graph, | SymbolVar SharedDeviceTensor::make(ComputingGraph &graph, | ||||
const HostTensorND &value, | const HostTensorND &value, | ||||
bool const_value, | |||||
const OperatorNodeConfig &config) { | const OperatorNodeConfig &config) { | ||||
auto cn = value.comp_node(); | auto cn = value.comp_node(); | ||||
if (config.has_comp_node_set()) | if (config.has_comp_node_set()) | ||||
cn = config.get_single_comp_node(); | cn = config.get_single_comp_node(); | ||||
auto dev_v = std::make_shared<DeviceTensorND>(); | auto dev_v = std::make_shared<DeviceTensorND>(); | ||||
dev_v->comp_node(cn).copy_from(value).sync(); | dev_v->comp_node(cn).copy_from(value).sync(); | ||||
return make(graph, dev_v, config); | |||||
return make(graph, dev_v, const_value, config); | |||||
} | } | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(SharedDeviceTensor); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(SharedDeviceTensor); | ||||
@@ -342,7 +358,7 @@ SymbolVar VolatileSharedDeviceTensor::make(ComputingGraph &graph, | |||||
const std::shared_ptr<DeviceTensorND> &dev_data, | const std::shared_ptr<DeviceTensorND> &dev_data, | ||||
const OperatorNodeConfig &config) { | const OperatorNodeConfig &config) { | ||||
return graph.insert_opr(std::make_unique<VolatileSharedDeviceTensor>( | return graph.insert_opr(std::make_unique<VolatileSharedDeviceTensor>( | ||||
graph, dev_data, config))->output(0); | |||||
graph, dev_data, false, config))->output(0); | |||||
} | } | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(VolatileSharedDeviceTensor); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(VolatileSharedDeviceTensor); | ||||
@@ -354,10 +370,10 @@ void SharedDeviceTensorWithFormat::init_output_format() { | |||||
SymbolVar SharedDeviceTensorWithFormat::make( | SymbolVar SharedDeviceTensorWithFormat::make( | ||||
ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, | ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, | ||||
const OperatorNodeConfig& config) { | |||||
bool const_value, const OperatorNodeConfig& config) { | |||||
auto&& opr = | auto&& opr = | ||||
graph.insert_opr(std::make_unique<SharedDeviceTensorWithFormat>( | graph.insert_opr(std::make_unique<SharedDeviceTensorWithFormat>( | ||||
graph, dev_data, config)) | |||||
graph, dev_data, const_value, config)) | |||||
->cast_final_safe<SharedDeviceTensorWithFormat>(); | ->cast_final_safe<SharedDeviceTensorWithFormat>(); | ||||
return opr.output(0); | return opr.output(0); | ||||
} | } | ||||
@@ -870,6 +886,24 @@ void intl::MultipleDeviceTensorHolderBase::init_output_static_infer_desc() { | |||||
}; | }; | ||||
mgr.register_shape_infer(output(i), | mgr.register_shape_infer(output(i), | ||||
{SourceType::CONSTANT, {}, infer_shp}); | {SourceType::CONSTANT, {}, infer_shp}); | ||||
auto infer_val = [this, i](DeviceTensorND& dest, const InpVal&) { | |||||
if (m_host_values.empty()) { | |||||
m_host_values.resize(m_values.size()); | |||||
} | |||||
if (m_host_values[i].empty()) { | |||||
m_host_values[i] | |||||
.comp_node(CompNode::default_cpu()) | |||||
.copy_from(*m_values[i]); | |||||
} | |||||
if (!m_host_values[i].empty()) { | |||||
dest = m_host_values[i]; | |||||
return true; | |||||
} | |||||
return false; | |||||
}; | |||||
mgr.register_value_infer(output(i), | |||||
{SourceType::CONSTANT, {}, infer_val}); | |||||
} | } | ||||
} | } | ||||
@@ -79,6 +79,10 @@ namespace serialization { | |||||
HostTensorND val; | HostTensorND val; | ||||
val.copy_from(opr.get_dev_tensor()).sync(); | val.copy_from(opr.get_dev_tensor()).sync(); | ||||
ctx.dump_tensor(opr.name(), val, Meth::VALUE_SHARED); | ctx.dump_tensor(opr.name(), val, Meth::VALUE_SHARED); | ||||
// Note that we don't persist opr.m_const_value, because it does not | |||||
// affect correctness, and SharedDeviceTensor will be bundled | |||||
// together as MultipleDeviceTensorHolder in optimize_for_inference | |||||
// before being dumped. | |||||
} | } | ||||
static cg::OperatorNodeBase* load( | static cg::OperatorNodeBase* load( | ||||
@@ -280,9 +284,10 @@ namespace opr { | |||||
const OperatorNodeConfig &config) { | const OperatorNodeConfig &config) { | ||||
mgb_assert(inputs.empty()); | mgb_assert(inputs.empty()); | ||||
auto &&opr = opr_.cast_final_safe<Opr>(); | auto &&opr = opr_.cast_final_safe<Opr>(); | ||||
return Opr::make( | |||||
*ctx.owner_graph(opr, inputs), opr.dev_data(), config). | |||||
node()->owner_opr(); | |||||
return Opr::make(*ctx.owner_graph(opr, inputs), opr.dev_data(), | |||||
opr.const_value(), config) | |||||
.node() | |||||
->owner_opr(); | |||||
} | } | ||||
cg::OperatorNodeBase* opr_shallow_copy_immutable_tensor( | cg::OperatorNodeBase* opr_shallow_copy_immutable_tensor( | ||||
@@ -75,19 +75,22 @@ class DeviceTensorHolder: public HostIONodeBase { | |||||
*/ | */ | ||||
MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { | MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { | ||||
std::shared_ptr<DeviceTensorND> m_dev_data; | std::shared_ptr<DeviceTensorND> m_dev_data; | ||||
DeviceTensorND m_static_infer; | |||||
bool m_const_value; | |||||
const TensorShape& get_output_shape() override; | const TensorShape& get_output_shape() override; | ||||
bool fill_in_static_infer(DeviceTensorND* dest) override { | |||||
MGB_MARK_USED_VAR(dest); | |||||
return false; | |||||
} | |||||
bool fill_in_static_infer(DeviceTensorND* dest) override; | |||||
void init_output_comp_node() override; | void init_output_comp_node() override; | ||||
public: | public: | ||||
//! const_value marks whether the device value of this operator should | |||||
//! be treated as constant during graph execution. Should be false in | |||||
//! most cases. | |||||
SharedDeviceTensorBase(ComputingGraph &graph, | SharedDeviceTensorBase(ComputingGraph &graph, | ||||
const std::shared_ptr<DeviceTensorND> &dev_data, | const std::shared_ptr<DeviceTensorND> &dev_data, | ||||
bool const_value, | |||||
const OperatorNodeConfig &config); | const OperatorNodeConfig &config); | ||||
const DeviceTensorND& get_dev_tensor() const override { | const DeviceTensorND& get_dev_tensor() const override { | ||||
@@ -97,6 +100,8 @@ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { | |||||
const std::shared_ptr<DeviceTensorND>& dev_data() const { | const std::shared_ptr<DeviceTensorND>& dev_data() const { | ||||
return m_dev_data; | return m_dev_data; | ||||
} | } | ||||
bool const_value() const { return m_const_value; } | |||||
}; | }; | ||||
/*! | /*! | ||||
@@ -104,6 +109,7 @@ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { | |||||
* device tensors | * device tensors | ||||
* | * | ||||
* This opr is used to speed up inference by packing params together. | * This opr is used to speed up inference by packing params together. | ||||
* This operator assumes the device tensors are constant. | |||||
*/ | */ | ||||
MGB_DEFINE_CLS_WITH_SUPER(MultipleDeviceTensorHolderBase, | MGB_DEFINE_CLS_WITH_SUPER(MultipleDeviceTensorHolderBase, | ||||
cg::OperatorNodeBase) // { | cg::OperatorNodeBase) // { | ||||
@@ -125,6 +131,8 @@ private: | |||||
void init_output_comp_node() override; | void init_output_comp_node() override; | ||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
NodeProp* do_make_node_prop() const override; | NodeProp* do_make_node_prop() const override; | ||||
SmallVector<DeviceTensorND> m_host_values; | |||||
}; | }; | ||||
} // namespace intl | } // namespace intl | ||||
@@ -249,16 +257,43 @@ MGB_DEFINE_OPR_CLASS(SharedDeviceTensor, intl::SharedDeviceTensorBase) // { | |||||
public: | public: | ||||
using Super::Super; | using Super::Super; | ||||
static SymbolVar make(ComputingGraph &graph, | |||||
const std::shared_ptr<DeviceTensorND> &dev_data, | |||||
const OperatorNodeConfig &config = {}); | |||||
static SymbolVar make(ComputingGraph& graph, | |||||
const std::shared_ptr<DeviceTensorND>& dev_data, | |||||
bool const_value, | |||||
const OperatorNodeConfig& config); | |||||
static SymbolVar make(ComputingGraph& graph, | |||||
const std::shared_ptr<DeviceTensorND>& dev_data, | |||||
const OperatorNodeConfig& config = {}) { | |||||
return make(graph, dev_data, false, config); | |||||
} | |||||
static SymbolVar make_const( | |||||
ComputingGraph& graph, | |||||
const std::shared_ptr<DeviceTensorND>& dev_data, | |||||
const OperatorNodeConfig& config = {}) { | |||||
return make(graph, dev_data, true, config); | |||||
} | |||||
/*! | /*! | ||||
* \brief make a SharedDeviceTensor by first coping from host to device | * \brief make a SharedDeviceTensor by first coping from host to device | ||||
* | |||||
* See SharedDeviceTensorBase::SharedDeviceTensorBase for const_value. | |||||
*/ | */ | ||||
static SymbolVar make(ComputingGraph &graph, | |||||
const HostTensorND &value, | |||||
const OperatorNodeConfig &config = {}); | |||||
static SymbolVar make(ComputingGraph& graph, const HostTensorND& value, | |||||
bool const_value, | |||||
const OperatorNodeConfig& config); | |||||
static SymbolVar make(ComputingGraph& graph, const HostTensorND& value, | |||||
const OperatorNodeConfig& config = {}) { | |||||
return make(graph, value, false, config); | |||||
} | |||||
static SymbolVar make_const(ComputingGraph& graph, | |||||
const HostTensorND& value, | |||||
const OperatorNodeConfig& config = {}) { | |||||
return make(graph, value, false, config); | |||||
} | |||||
}; | }; | ||||
/*! | /*! | ||||
@@ -276,7 +311,19 @@ public: | |||||
static SymbolVar make(ComputingGraph& graph, | static SymbolVar make(ComputingGraph& graph, | ||||
const std::shared_ptr<DeviceTensorND>& dev_data, | const std::shared_ptr<DeviceTensorND>& dev_data, | ||||
const OperatorNodeConfig& config = {}); | |||||
bool const_value, const OperatorNodeConfig& config); | |||||
static SymbolVar make(ComputingGraph& graph, | |||||
const std::shared_ptr<DeviceTensorND>& dev_data, | |||||
const OperatorNodeConfig& config = {}) { | |||||
return make(graph, dev_data, false, config); | |||||
} | |||||
static SymbolVar make_const(ComputingGraph& graph, | |||||
const std::shared_ptr<DeviceTensorND>& dev_data, | |||||
const OperatorNodeConfig& config = {}) { | |||||
return make(graph, dev_data, true, config); | |||||
} | |||||
}; | }; | ||||
/*! | /*! | ||||
@@ -297,6 +344,15 @@ MGB_DEFINE_OPR_CLASS( | |||||
static SymbolVar make(ComputingGraph &graph, | static SymbolVar make(ComputingGraph &graph, | ||||
const std::shared_ptr<DeviceTensorND> &dev_data, | const std::shared_ptr<DeviceTensorND> &dev_data, | ||||
const OperatorNodeConfig &config = {}); | const OperatorNodeConfig &config = {}); | ||||
//! adapter for io.sereg.h: opr_shallow_copy_shared_device_tensor | |||||
static SymbolVar make(ComputingGraph& graph, | |||||
const std::shared_ptr<DeviceTensorND>& dev_data, | |||||
bool const_value, | |||||
const OperatorNodeConfig& config) { | |||||
mgb_assert(!const_value); | |||||
return make(graph, dev_data, false, config); | |||||
} | |||||
}; | }; | ||||
/*! | /*! | ||||