Browse Source

feat(opr): add mutable tensor opr

GitOrigin-RevId: 7f8a3d7b66
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
b458178847
2 changed files with 86 additions and 0 deletions
  1. +63
    -0
      imperative/src/impl/opr_utility.cpp
  2. +23
    -0
      imperative/src/include/megbrain/imperative/opr_utility.h

+ 63
- 0
imperative/src/impl/opr_utility.cpp View File

@@ -271,6 +271,69 @@ void NopCallback::do_execute(ExecEnv& env) {
env.dispatch_on_comp_node(cn, runner);
}

MGB_DYN_TYPE_OBJ_FINAL_IMPL(MutableTensor);
MutableTensor::MutableTensor(
cg::ComputingGraph& graph, std::shared_ptr<DeviceTensorND> dev_tensor,
std::shared_ptr<HostTensorND> host_tensor, const OperatorNodeConfig& config)
: Super(&graph, config, {}, {}) {
m_dev_tensor = dev_tensor;
m_host_tensor = host_tensor;

add_output(None)
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
.dtype(m_dev_tensor->dtype());
add_equivalence_component<ScalarHash<const void*>>(this);
}

SymbolVar MutableTensor::make(
cg::ComputingGraph& graph, std::shared_ptr<DeviceTensorND> dev_tensor,
std::shared_ptr<HostTensorND> host_tensor, const OperatorNodeConfig& config) {
return graph
.insert_opr(std::make_unique<MutableTensor>(
graph, dev_tensor, host_tensor, config))
->output(0);
}

void MutableTensor::init_output_comp_node() {
if (config().has_comp_node_set()) {
mgb_assert(
config().get_single_comp_node() == m_dev_tensor->comp_node(),
"comp_node mismatch");
}
comp_node(m_dev_tensor->comp_node());
}

cg::OperatorNodeBase::NodeProp* MutableTensor::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_flag(NodeProp::Flag::IMPURE_OUTPUT_MEM_PLAN);
return ret;
}

void MutableTensor::scn_do_execute() {
output(0)->reset_dev_tensor_from_tensor(*m_dev_tensor);
}

void MutableTensor::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto& mgr = owner_graph()->static_infer_manager();
auto infer_shape = [this](TensorShape& dest, const InpVal&) {
dest = m_dev_tensor->shape();
return true;
};
mgr.register_shape_infer(output(0), {SourceType::MUTABLE, {}, infer_shape});
if (m_host_tensor) {
auto infer_value = [this](DeviceTensorND& dest, const InpVal&) {
if (!m_host_tensor->layout().ndim) {
return false;
}
dest = m_host_tensor->proxy_to_default_cpu();
return true;
};
mgr.register_value_infer(output(0), {SourceType::MUTABLE, {}, infer_value});
}
}

} // namespace opr
} // namespace mgb



+ 23
- 0
imperative/src/include/megbrain/imperative/opr_utility.h View File

@@ -16,6 +16,7 @@
#include "megbrain/opr/internal/identical_fwd.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/internal/param_tag_defs.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/param_defs.h"
#include "megbrain/serialization/sereg.h"

@@ -106,6 +107,28 @@ protected:
private:
callback_t m_callback;
};

MGB_DEFINE_OPR_CLASS(MutableTensor, cg::SingleCNOperatorNodeBase) // {
public:
MutableTensor(
cg::ComputingGraph& graph, std::shared_ptr<DeviceTensorND> dev_tensor,
std::shared_ptr<HostTensorND> host_tensor,
const OperatorNodeConfig& config);
static SymbolVar make(
cg::ComputingGraph& graph, std::shared_ptr<DeviceTensorND> dev_tensor,
std::shared_ptr<HostTensorND> host_tensor = {},
const OperatorNodeConfig& config = {});

protected:
void init_output_comp_node() override;
void init_output_static_infer_desc() override;
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override;
void scn_do_execute() override;

private:
std::shared_ptr<DeviceTensorND> m_dev_tensor;
std::shared_ptr<HostTensorND> m_host_tensor;
};
} // namespace opr
} // namespace mgb



Loading…
Cancel
Save