Browse Source

fix(mgb/core): ensure all VarNodes would be handled in init_ready_event

GitOrigin-RevId: 0b6cb24021
release-1.1
Megvii Engine Team 4 years ago
parent
commit
5cfe73657b
2 changed files with 51 additions and 3 deletions
  1. +2
    -3
      src/core/impl/graph/seq_comp_node_opt_impl.cpp
  2. +49
    -0
      src/core/test/graph/misc.cpp

+ 2
- 3
src/core/impl/graph/seq_comp_node_opt_impl.cpp View File

@@ -303,9 +303,8 @@ void SeqCompNodeOptimizerImpl::init_ready_event(
}

opr->input_waiting_spec(std::move(waiting_spec));
auto&& usable_output = opr->usable_output();
for (size_t i = 0; i < usable_output.size(); ++ i) {
var2step[usable_output[i]] = {cur_step, i};
for (size_t i = 0; i < opr->output().size(); ++ i) {
var2step[opr->output(i)] = {cur_step, i};
}
cur_step ++;
}


+ 49
- 0
src/core/test/graph/misc.cpp View File

@@ -1086,6 +1086,55 @@ TEST(TestGraph, DynShapeDepCrossCN) {
}

namespace {
MGB_DEFINE_OPR_CLASS(CustomCopy, cg::SingleCNOperatorNodeBase) // {
std::shared_ptr<DeviceTensorND> m_data;

void scn_do_execute() override {
using namespace std::literals;
std::this_thread::sleep_for(100ms);
m_data->copy_from(input(0)->dev_tensor());
}

void init_output_static_infer_desc() override {
using namespace cg::static_infer;
owner_graph()->static_infer_manager().register_shape_infer(
output(0), ShapeInferDesc::make_const({}));
}

public:
CustomCopy(VarNode* x, std::shared_ptr<DeviceTensorND> dv)
: Super{x->owner_graph(), {dv->comp_node()}, "d2h", {x}},
m_data(dv) {
add_input({x});
using F = VarNode::Flag;
add_output(None)
->add_flag(F::ALLOW_EMPTY_SHAPE)
.add_flag(F::VOLATILE_CONTENT);
}
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CustomCopy);
}

TEST(TestGraph, DependentOnVolatileContent) {
HostTensorGenerator<> gen;
auto cn0 = CompNode::load("xpu0"),
cn1 = cn0.change_stream(1);
auto host_x = gen({233}, cn0);
auto dev_y = std::make_shared<DeviceTensorND>(cn1);

auto graph = ComputingGraph::make();
auto x = opr::SharedDeviceTensor::make(*graph, *host_x),
y = x.insert_single_output_opr<CustomCopy>(x.node(), dev_y),
x_new = opr::AddUpdate::make(x, x.make_scalar(1));

auto func = graph->compile({{y, {}}, {x_new, {}}});
func->execute().wait();
HostTensorND host_y;
host_y.copy_from(*dev_y).sync();
MGB_ASSERT_TENSOR_EQ(*host_x, host_y);
}

namespace {
void check_wait(SymbolVar dest, SymbolVar dep) {
if (!dep.node()) {
ASSERT_EQ(0u,


Loading…
Cancel
Save