|
|
@@ -10,6 +10,7 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "megbrain/imperative/opr_utility.h" |
|
|
|
#include "./mgb_core_impl/graph/cg_impl.h" |
|
|
|
|
|
|
|
// FIXME; setup_config_cn is copied from src/opr/impl/utility.cpp |
|
|
|
namespace { |
|
|
@@ -64,14 +65,18 @@ SymbolVarArray InputCallback::make(cg::ComputingGraph& graph, |
|
|
|
|
|
|
|
void InputCallback::init_output_static_infer_desc() { |
|
|
|
if (m_output_shape.ndim) { |
|
|
|
// Write this shape to static infer manager. The effect is |
|
|
|
// that infer_shape_fallible() will return a non-empty shape |
|
|
|
// while get_infer_type() remains NO_DESC. Most places check |
|
|
|
// infer type before relying on inferred shape so things |
|
|
|
// won't break. Memory optimizer however, deliberately omits |
|
|
|
// infer type check so it will be able to use this shape for hint. |
|
|
|
using namespace cg::static_infer; |
|
|
|
auto &&mgr = owner_graph()->static_infer_manager(); |
|
|
|
auto infer_shape = [this](TensorShape &dest, const InpVal &) { |
|
|
|
dest = m_output_shape; |
|
|
|
return true; |
|
|
|
}; |
|
|
|
mgr.register_shape_infer(output(0), |
|
|
|
{SourceType::CONSTANT, {}, infer_shape}); |
|
|
|
auto* var = output(0); |
|
|
|
var->shape(m_output_shape); |
|
|
|
auto&& mgr = cg::ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl(); |
|
|
|
auto* handle = mgr.get_tag_handler_for_shape(var); |
|
|
|
handle->sync_from_var(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -86,9 +91,6 @@ cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const { |
|
|
|
|
|
|
|
void InputCallback::scn_do_execute() { |
|
|
|
auto dev_tensor = m_callback(); |
|
|
|
if (m_output_shape.ndim) { |
|
|
|
mgb_assert(dev_tensor.shape().eq_shape(m_output_shape)); |
|
|
|
} |
|
|
|
output(0)->reset_dev_tensor_from_tensor(dev_tensor); |
|
|
|
} |
|
|
|
|
|
|
|