Also, add a CONSTANT value inference tag to outputs of
MultipleDeviceTensorHolder.
GitOrigin-RevId: 82a805ed5f
tags/v1.0.0-rc1
@@ -570,10 +570,10 @@ void ParamFusePass::apply(OptState &state) const { | |||
*var->owner_graph(), hv, var_namer.name(var)); | |||
} else { | |||
if (is_default_format) { | |||
new_var = opr::SharedDeviceTensor::make( | |||
new_var = opr::SharedDeviceTensor::make_const( | |||
*var->owner_graph(), inferred_val, var_namer.name(var)); | |||
} else { | |||
new_var = opr::SharedDeviceTensorWithFormat::make( | |||
new_var = opr::SharedDeviceTensorWithFormat::make_const( | |||
*var->owner_graph(), inferred_val, var_namer.name(var)); | |||
} | |||
} | |||
@@ -281,11 +281,11 @@ void Host2DeviceCopy::record_execute_deps(ExecDependencyArray& deps) { | |||
/* ===================== SharedDeviceTensor related ===================== */ | |||
intl::SharedDeviceTensorBase::SharedDeviceTensorBase( | |||
ComputingGraph &graph, const std::shared_ptr<DeviceTensorND> &dev_data, | |||
const OperatorNodeConfig &config): | |||
Super{&graph, config, "shared", {}}, | |||
m_dev_data{dev_data} | |||
{ | |||
ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, | |||
bool const_value, const OperatorNodeConfig& config) | |||
: Super{&graph, config, "shared", {}}, | |||
m_dev_data{dev_data}, | |||
m_const_value(const_value) { | |||
if (config.has_comp_node_set()) { | |||
mgb_assert(config.get_single_comp_node() == dev_data->comp_node()); | |||
} | |||
@@ -307,26 +307,42 @@ void intl::SharedDeviceTensorBase::init_output_comp_node() { | |||
comp_node(m_dev_data->comp_node()); | |||
} | |||
bool intl::SharedDeviceTensorBase::fill_in_static_infer(DeviceTensorND* dest) { | |||
if (m_const_value) { | |||
if (dest) { | |||
if (m_static_infer.empty()) { | |||
m_static_infer.comp_node(CompNode::default_cpu()) | |||
.copy_from(*m_dev_data); | |||
} | |||
*dest = m_static_infer; | |||
} | |||
return true; | |||
} | |||
return false; | |||
} | |||
cg::static_infer::SourceType SharedDeviceTensor::static_infer_src_type() const { | |||
return cg::static_infer::SourceType::CONSTANT; | |||
} | |||
SymbolVar SharedDeviceTensor::make(ComputingGraph &graph, | |||
const std::shared_ptr<DeviceTensorND> &dev_data, | |||
bool const_value, | |||
const OperatorNodeConfig &config) { | |||
return graph.insert_opr(std::make_unique<SharedDeviceTensor>( | |||
graph, dev_data, config))->output(0); | |||
graph, dev_data, const_value, config))->output(0); | |||
} | |||
SymbolVar SharedDeviceTensor::make(ComputingGraph &graph, | |||
const HostTensorND &value, | |||
bool const_value, | |||
const OperatorNodeConfig &config) { | |||
auto cn = value.comp_node(); | |||
if (config.has_comp_node_set()) | |||
cn = config.get_single_comp_node(); | |||
auto dev_v = std::make_shared<DeviceTensorND>(); | |||
dev_v->comp_node(cn).copy_from(value).sync(); | |||
return make(graph, dev_v, config); | |||
return make(graph, dev_v, const_value, config); | |||
} | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(SharedDeviceTensor); | |||
@@ -342,7 +358,7 @@ SymbolVar VolatileSharedDeviceTensor::make(ComputingGraph &graph, | |||
const std::shared_ptr<DeviceTensorND> &dev_data, | |||
const OperatorNodeConfig &config) { | |||
return graph.insert_opr(std::make_unique<VolatileSharedDeviceTensor>( | |||
graph, dev_data, config))->output(0); | |||
graph, dev_data, false, config))->output(0); | |||
} | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(VolatileSharedDeviceTensor); | |||
@@ -354,10 +370,10 @@ void SharedDeviceTensorWithFormat::init_output_format() { | |||
SymbolVar SharedDeviceTensorWithFormat::make( | |||
ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, | |||
const OperatorNodeConfig& config) { | |||
bool const_value, const OperatorNodeConfig& config) { | |||
auto&& opr = | |||
graph.insert_opr(std::make_unique<SharedDeviceTensorWithFormat>( | |||
graph, dev_data, config)) | |||
graph, dev_data, const_value, config)) | |||
->cast_final_safe<SharedDeviceTensorWithFormat>(); | |||
return opr.output(0); | |||
} | |||
@@ -870,6 +886,24 @@ void intl::MultipleDeviceTensorHolderBase::init_output_static_infer_desc() { | |||
}; | |||
mgr.register_shape_infer(output(i), | |||
{SourceType::CONSTANT, {}, infer_shp}); | |||
auto infer_val = [this, i](DeviceTensorND& dest, const InpVal&) { | |||
if (m_host_values.empty()) { | |||
m_host_values.resize(m_values.size()); | |||
} | |||
if (m_host_values[i].empty()) { | |||
m_host_values[i] | |||
.comp_node(CompNode::default_cpu()) | |||
.copy_from(*m_values[i]); | |||
} | |||
if (!m_host_values[i].empty()) { | |||
dest = m_host_values[i]; | |||
return true; | |||
} | |||
return false; | |||
}; | |||
mgr.register_value_infer(output(i), | |||
{SourceType::CONSTANT, {}, infer_val}); | |||
} | |||
} | |||
@@ -79,6 +79,10 @@ namespace serialization { | |||
HostTensorND val; | |||
val.copy_from(opr.get_dev_tensor()).sync(); | |||
ctx.dump_tensor(opr.name(), val, Meth::VALUE_SHARED); | |||
// Note that we don't persist opr.m_const_value, because it does not | |||
// affect correctness, and SharedDeviceTensor will be bundled | |||
// together as MultipleDeviceTensorHolder in optimize_for_inference | |||
// before being dumped. | |||
} | |||
static cg::OperatorNodeBase* load( | |||
@@ -280,9 +284,10 @@ namespace opr { | |||
const OperatorNodeConfig &config) { | |||
mgb_assert(inputs.empty()); | |||
auto &&opr = opr_.cast_final_safe<Opr>(); | |||
return Opr::make( | |||
*ctx.owner_graph(opr, inputs), opr.dev_data(), config). | |||
node()->owner_opr(); | |||
return Opr::make(*ctx.owner_graph(opr, inputs), opr.dev_data(), | |||
opr.const_value(), config) | |||
.node() | |||
->owner_opr(); | |||
} | |||
cg::OperatorNodeBase* opr_shallow_copy_immutable_tensor( | |||
@@ -75,19 +75,22 @@ class DeviceTensorHolder: public HostIONodeBase { | |||
*/ | |||
MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { | |||
std::shared_ptr<DeviceTensorND> m_dev_data; | |||
DeviceTensorND m_static_infer; | |||
bool m_const_value; | |||
const TensorShape& get_output_shape() override; | |||
bool fill_in_static_infer(DeviceTensorND* dest) override { | |||
MGB_MARK_USED_VAR(dest); | |||
return false; | |||
} | |||
bool fill_in_static_infer(DeviceTensorND* dest) override; | |||
void init_output_comp_node() override; | |||
public: | |||
//! const_value marks whether the device value of this operator should | |||
//! be treated as constant during graph execution. Should be false in | |||
//! most cases. | |||
SharedDeviceTensorBase(ComputingGraph &graph, | |||
const std::shared_ptr<DeviceTensorND> &dev_data, | |||
bool const_value, | |||
const OperatorNodeConfig &config); | |||
const DeviceTensorND& get_dev_tensor() const override { | |||
@@ -97,6 +100,8 @@ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { | |||
const std::shared_ptr<DeviceTensorND>& dev_data() const { | |||
return m_dev_data; | |||
} | |||
bool const_value() const { return m_const_value; } | |||
}; | |||
/*! | |||
@@ -104,6 +109,7 @@ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { | |||
* device tensors | |||
* | |||
* This opr is used to speed up inference by packing params together. | |||
* This operator assumes the device tensors are constant. | |||
*/ | |||
MGB_DEFINE_CLS_WITH_SUPER(MultipleDeviceTensorHolderBase, | |||
cg::OperatorNodeBase) // { | |||
@@ -125,6 +131,8 @@ private: | |||
void init_output_comp_node() override; | |||
void init_output_static_infer_desc() override; | |||
NodeProp* do_make_node_prop() const override; | |||
SmallVector<DeviceTensorND> m_host_values; | |||
}; | |||
} // namespace intl | |||
@@ -249,16 +257,43 @@ MGB_DEFINE_OPR_CLASS(SharedDeviceTensor, intl::SharedDeviceTensorBase) // { | |||
public: | |||
using Super::Super; | |||
static SymbolVar make(ComputingGraph &graph, | |||
const std::shared_ptr<DeviceTensorND> &dev_data, | |||
const OperatorNodeConfig &config = {}); | |||
static SymbolVar make(ComputingGraph& graph, | |||
const std::shared_ptr<DeviceTensorND>& dev_data, | |||
bool const_value, | |||
const OperatorNodeConfig& config); | |||
static SymbolVar make(ComputingGraph& graph, | |||
const std::shared_ptr<DeviceTensorND>& dev_data, | |||
const OperatorNodeConfig& config = {}) { | |||
return make(graph, dev_data, false, config); | |||
} | |||
static SymbolVar make_const( | |||
ComputingGraph& graph, | |||
const std::shared_ptr<DeviceTensorND>& dev_data, | |||
const OperatorNodeConfig& config = {}) { | |||
return make(graph, dev_data, true, config); | |||
} | |||
/*! | |||
* \brief make a SharedDeviceTensor by first coping from host to device | |||
* | |||
* See SharedDeviceTensorBase::SharedDeviceTensorBase for const_value. | |||
*/ | |||
static SymbolVar make(ComputingGraph &graph, | |||
const HostTensorND &value, | |||
const OperatorNodeConfig &config = {}); | |||
static SymbolVar make(ComputingGraph& graph, const HostTensorND& value, | |||
bool const_value, | |||
const OperatorNodeConfig& config); | |||
static SymbolVar make(ComputingGraph& graph, const HostTensorND& value, | |||
const OperatorNodeConfig& config = {}) { | |||
return make(graph, value, false, config); | |||
} | |||
static SymbolVar make_const(ComputingGraph& graph, | |||
const HostTensorND& value, | |||
const OperatorNodeConfig& config = {}) { | |||
return make(graph, value, false, config); | |||
} | |||
}; | |||
/*! | |||
@@ -276,7 +311,19 @@ public: | |||
static SymbolVar make(ComputingGraph& graph, | |||
const std::shared_ptr<DeviceTensorND>& dev_data, | |||
const OperatorNodeConfig& config = {}); | |||
bool const_value, const OperatorNodeConfig& config); | |||
static SymbolVar make(ComputingGraph& graph, | |||
const std::shared_ptr<DeviceTensorND>& dev_data, | |||
const OperatorNodeConfig& config = {}) { | |||
return make(graph, dev_data, false, config); | |||
} | |||
static SymbolVar make_const(ComputingGraph& graph, | |||
const std::shared_ptr<DeviceTensorND>& dev_data, | |||
const OperatorNodeConfig& config = {}) { | |||
return make(graph, dev_data, true, config); | |||
} | |||
}; | |||
/*! | |||
@@ -297,6 +344,15 @@ MGB_DEFINE_OPR_CLASS( | |||
static SymbolVar make(ComputingGraph &graph, | |||
const std::shared_ptr<DeviceTensorND> &dev_data, | |||
const OperatorNodeConfig &config = {}); | |||
//! adapter for io.sereg.h: opr_shallow_copy_shared_device_tensor | |||
static SymbolVar make(ComputingGraph& graph, | |||
const std::shared_ptr<DeviceTensorND>& dev_data, | |||
bool const_value, | |||
const OperatorNodeConfig& config) { | |||
mgb_assert(!const_value); | |||
return make(graph, dev_data, false, config); | |||
} | |||
}; | |||
/*! | |||