GitOrigin-RevId: 8fa47b35ad
tags/v1.0.0-rc1
@@ -963,11 +963,21 @@ void mixin::WeightPreprocessExecutor::record_preprocessed_weight( | |||||
bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess( | bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess( | ||||
const cg::OperatorNodeBase& opr) const { | const cg::OperatorNodeBase& opr) const { | ||||
bool param_merged = opr.input(1) | |||||
->owner_opr() | |||||
->same_type<opr::MultipleDeviceTensorHolder>(); | |||||
return opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE) && | |||||
(cg::is_const_var_value(opr.input(1)) || param_merged); | |||||
if (!opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE)) | |||||
return false; | |||||
if (cg::is_const_var_value(opr.input(1))) | |||||
return true; | |||||
auto* input_opr = opr.input(1)->owner_opr(); | |||||
if (input_opr->same_type<opr::MultipleDeviceTensorHolder>() || | |||||
input_opr->same_type<opr::MultipleDeviceTensorWithFormatHolder>()) | |||||
return true; | |||||
auto* sdt = input_opr->try_cast_final<opr::SharedDeviceTensor>(); | |||||
if (sdt && sdt->const_value()) | |||||
return true; | |||||
auto* sdtf = input_opr->try_cast_final<opr::SharedDeviceTensorWithFormat>(); | |||||
if (sdtf && sdtf->const_value()) | |||||
return true; | |||||
return false; | |||||
} | } | ||||
/* ==================== ConvolutionForward ==================== */ | /* ==================== ConvolutionForward ==================== */ | ||||
@@ -307,20 +307,6 @@ void intl::SharedDeviceTensorBase::init_output_comp_node() { | |||||
comp_node(m_dev_data->comp_node()); | comp_node(m_dev_data->comp_node()); | ||||
} | } | ||||
bool intl::SharedDeviceTensorBase::fill_in_static_infer(DeviceTensorND* dest) { | |||||
if (m_const_value) { | |||||
if (dest) { | |||||
if (m_static_infer.empty()) { | |||||
m_static_infer.comp_node(CompNode::default_cpu()) | |||||
.copy_from(*m_dev_data); | |||||
} | |||||
*dest = m_static_infer; | |||||
} | |||||
return true; | |||||
} | |||||
return false; | |||||
} | |||||
cg::static_infer::SourceType SharedDeviceTensor::static_infer_src_type() const { | cg::static_infer::SourceType SharedDeviceTensor::static_infer_src_type() const { | ||||
return cg::static_infer::SourceType::CONSTANT; | return cg::static_infer::SourceType::CONSTANT; | ||||
} | } | ||||
@@ -886,24 +872,6 @@ void intl::MultipleDeviceTensorHolderBase::init_output_static_infer_desc() { | |||||
}; | }; | ||||
mgr.register_shape_infer(output(i), | mgr.register_shape_infer(output(i), | ||||
{SourceType::CONSTANT, {}, infer_shp}); | {SourceType::CONSTANT, {}, infer_shp}); | ||||
auto infer_val = [this, i](DeviceTensorND& dest, const InpVal&) { | |||||
if (m_host_values.empty()) { | |||||
m_host_values.resize(m_values.size()); | |||||
} | |||||
if (m_host_values[i].empty()) { | |||||
m_host_values[i] | |||||
.comp_node(CompNode::default_cpu()) | |||||
.copy_from(*m_values[i]); | |||||
} | |||||
if (!m_host_values[i].empty()) { | |||||
dest = m_host_values[i]; | |||||
return true; | |||||
} | |||||
return false; | |||||
}; | |||||
mgr.register_value_infer(output(i), | |||||
{SourceType::CONSTANT, {}, infer_val}); | |||||
} | } | ||||
} | } | ||||
@@ -75,12 +75,14 @@ class DeviceTensorHolder: public HostIONodeBase { | |||||
*/ | */ | ||||
MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { | MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { | ||||
std::shared_ptr<DeviceTensorND> m_dev_data; | std::shared_ptr<DeviceTensorND> m_dev_data; | ||||
DeviceTensorND m_static_infer; | |||||
bool m_const_value; | bool m_const_value; | ||||
const TensorShape& get_output_shape() override; | const TensorShape& get_output_shape() override; | ||||
bool fill_in_static_infer(DeviceTensorND* dest) override; | |||||
bool fill_in_static_infer(DeviceTensorND* dest) override { | |||||
MGB_MARK_USED_VAR(dest); | |||||
return false; | |||||
} | |||||
void init_output_comp_node() override; | void init_output_comp_node() override; | ||||
@@ -131,8 +133,6 @@ private: | |||||
void init_output_comp_node() override; | void init_output_comp_node() override; | ||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
NodeProp* do_make_node_prop() const override; | NodeProp* do_make_node_prop() const override; | ||||
SmallVector<DeviceTensorND> m_host_values; | |||||
}; | }; | ||||
} // namespace intl | } // namespace intl | ||||