Browse Source

fix(opr): remove constant value infer for const SharedDeviceTensor

GitOrigin-RevId: 8fa47b35ad
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
354b478c1f
3 changed files with 19 additions and 41 deletions
  1. +15
    -5
      src/opr/impl/dnn/convolution.cpp
  2. +0
    -32
      src/opr/impl/io.cpp
  3. +4
    -4
      src/opr/include/megbrain/opr/io.h

+ 15
- 5
src/opr/impl/dnn/convolution.cpp View File

@@ -963,11 +963,21 @@ void mixin::WeightPreprocessExecutor::record_preprocessed_weight(

bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess(
const cg::OperatorNodeBase& opr) const {
bool param_merged = opr.input(1)
->owner_opr()
->same_type<opr::MultipleDeviceTensorHolder>();
return opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE) &&
(cg::is_const_var_value(opr.input(1)) || param_merged);
if (!opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE))
return false;
if (cg::is_const_var_value(opr.input(1)))
return true;
auto* input_opr = opr.input(1)->owner_opr();
if (input_opr->same_type<opr::MultipleDeviceTensorHolder>() ||
input_opr->same_type<opr::MultipleDeviceTensorWithFormatHolder>())
return true;
auto* sdt = input_opr->try_cast_final<opr::SharedDeviceTensor>();
if (sdt && sdt->const_value())
return true;
auto* sdtf = input_opr->try_cast_final<opr::SharedDeviceTensorWithFormat>();
if (sdtf && sdtf->const_value())
return true;
return false;
}

/* ==================== ConvolutionForward ==================== */


+ 0
- 32
src/opr/impl/io.cpp View File

@@ -307,20 +307,6 @@ void intl::SharedDeviceTensorBase::init_output_comp_node() {
comp_node(m_dev_data->comp_node());
}

bool intl::SharedDeviceTensorBase::fill_in_static_infer(DeviceTensorND* dest) {
if (m_const_value) {
if (dest) {
if (m_static_infer.empty()) {
m_static_infer.comp_node(CompNode::default_cpu())
.copy_from(*m_dev_data);
}
*dest = m_static_infer;
}
return true;
}
return false;
}

cg::static_infer::SourceType SharedDeviceTensor::static_infer_src_type() const {
return cg::static_infer::SourceType::CONSTANT;
}
@@ -886,24 +872,6 @@ void intl::MultipleDeviceTensorHolderBase::init_output_static_infer_desc() {
};
mgr.register_shape_infer(output(i),
{SourceType::CONSTANT, {}, infer_shp});

auto infer_val = [this, i](DeviceTensorND& dest, const InpVal&) {
if (m_host_values.empty()) {
m_host_values.resize(m_values.size());
}
if (m_host_values[i].empty()) {
m_host_values[i]
.comp_node(CompNode::default_cpu())
.copy_from(*m_values[i]);
}
if (!m_host_values[i].empty()) {
dest = m_host_values[i];
return true;
}
return false;
};
mgr.register_value_infer(output(i),
{SourceType::CONSTANT, {}, infer_val});
}
}



+ 4
- 4
src/opr/include/megbrain/opr/io.h View File

@@ -75,12 +75,14 @@ class DeviceTensorHolder: public HostIONodeBase {
*/
MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // {
std::shared_ptr<DeviceTensorND> m_dev_data;
DeviceTensorND m_static_infer;
bool m_const_value;

const TensorShape& get_output_shape() override;

bool fill_in_static_infer(DeviceTensorND* dest) override;
bool fill_in_static_infer(DeviceTensorND* dest) override {
MGB_MARK_USED_VAR(dest);
return false;
}

void init_output_comp_node() override;

@@ -131,8 +133,6 @@ private:
void init_output_comp_node() override;
void init_output_static_infer_desc() override;
NodeProp* do_make_node_prop() const override;

SmallVector<DeviceTensorND> m_host_values;
};

} // namespace intl


Loading…
Cancel
Save