|
|
@@ -271,6 +271,69 @@ void NopCallback::do_execute(ExecEnv& env) { |
|
|
|
env.dispatch_on_comp_node(cn, runner); |
|
|
|
} |
|
|
|
|
|
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MutableTensor); |
|
|
|
MutableTensor::MutableTensor( |
|
|
|
cg::ComputingGraph& graph, std::shared_ptr<DeviceTensorND> dev_tensor, |
|
|
|
std::shared_ptr<HostTensorND> host_tensor, const OperatorNodeConfig& config) |
|
|
|
: Super(&graph, config, {}, {}) { |
|
|
|
m_dev_tensor = dev_tensor; |
|
|
|
m_host_tensor = host_tensor; |
|
|
|
|
|
|
|
add_output(None) |
|
|
|
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) |
|
|
|
.add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC) |
|
|
|
.dtype(m_dev_tensor->dtype()); |
|
|
|
add_equivalence_component<ScalarHash<const void*>>(this); |
|
|
|
} |
|
|
|
|
|
|
|
SymbolVar MutableTensor::make( |
|
|
|
cg::ComputingGraph& graph, std::shared_ptr<DeviceTensorND> dev_tensor, |
|
|
|
std::shared_ptr<HostTensorND> host_tensor, const OperatorNodeConfig& config) { |
|
|
|
return graph |
|
|
|
.insert_opr(std::make_unique<MutableTensor>( |
|
|
|
graph, dev_tensor, host_tensor, config)) |
|
|
|
->output(0); |
|
|
|
} |
|
|
|
|
|
|
|
void MutableTensor::init_output_comp_node() { |
|
|
|
if (config().has_comp_node_set()) { |
|
|
|
mgb_assert( |
|
|
|
config().get_single_comp_node() == m_dev_tensor->comp_node(), |
|
|
|
"comp_node mismatch"); |
|
|
|
} |
|
|
|
comp_node(m_dev_tensor->comp_node()); |
|
|
|
} |
|
|
|
|
|
|
|
cg::OperatorNodeBase::NodeProp* MutableTensor::do_make_node_prop() const { |
|
|
|
auto ret = Super::do_make_node_prop(); |
|
|
|
ret->add_flag(NodeProp::Flag::IMPURE_OUTPUT_MEM_PLAN); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
void MutableTensor::scn_do_execute() { |
|
|
|
output(0)->reset_dev_tensor_from_tensor(*m_dev_tensor); |
|
|
|
} |
|
|
|
|
|
|
|
void MutableTensor::init_output_static_infer_desc() { |
|
|
|
using namespace cg::static_infer; |
|
|
|
auto& mgr = owner_graph()->static_infer_manager(); |
|
|
|
auto infer_shape = [this](TensorShape& dest, const InpVal&) { |
|
|
|
dest = m_dev_tensor->shape(); |
|
|
|
return true; |
|
|
|
}; |
|
|
|
mgr.register_shape_infer(output(0), {SourceType::MUTABLE, {}, infer_shape}); |
|
|
|
if (m_host_tensor) { |
|
|
|
auto infer_value = [this](DeviceTensorND& dest, const InpVal&) { |
|
|
|
if (!m_host_tensor->layout().ndim) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
dest = m_host_tensor->proxy_to_default_cpu(); |
|
|
|
return true; |
|
|
|
}; |
|
|
|
mgr.register_value_infer(output(0), {SourceType::MUTABLE, {}, infer_value}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace opr |
|
|
|
} // namespace mgb |
|
|
|
|
|
|
|