Browse Source

refactor(mgb/core): refactor cpu compnode so that default cpu has no ability to record

GitOrigin-RevId: 7de4771476
tags/v1.3.1
Megvii Engine Team 4 years ago
parent
commit
fc8b501bb8
3 changed files with 507 additions and 382 deletions
  1. +484
    -381
      src/core/impl/comp_node/cpu/comp_node.cpp
  2. +3
    -1
      src/core/impl/comp_node/cpu/comp_node.h
  3. +20
    -0
      src/core/test/comp_node_helper.cpp

+ 484
- 381
src/core/impl/comp_node/cpu/comp_node.cpp
File diff suppressed because it is too large
View File


+ 3
- 1
src/core/impl/comp_node/cpu/comp_node.h View File

@@ -54,7 +54,9 @@ namespace mgb {
void add_callback(Task&& task) override;
};

class CompNodeImpl;
class CompNodeBaseImpl;
class CompNodeNoRecorderImpl;
class CompNodeRecorderImpl;

static void foreach(thin_function<void(CompNode)> callback);
static void finalize();


+ 20
- 0
src/core/test/comp_node_helper.cpp View File

@@ -100,6 +100,26 @@ void run_comp_seq_rec_basic_level2(CompNode cn) {
MGB_ASSERT_TENSOR_NEAR(expect, host_z, 1e-3) << "iter " << iter;
}
ASSERT_EQ(executed.size(), 2u);

//! test default_cpu with record2
{
HostTensorND hz;
graph = ComputingGraph::make();
x = opr::Host2DeviceCopy::make(*graph, host_x);
y = opr::Host2DeviceCopy::make(*graph, host_y);
z = opr::ConvBias::make(x, y, param);
z = opr::GetVarShape::make(z);
graph->options().comp_node_seq_record_level = 2;
graph->options().var_sanity_check_first_run = false;
auto func = graph->compile({make_callback_copy(z, hz, true)});
ComputingGraph::assert_destroy(graph);
func->execute();
ASSERT_TRUE(hz.comp_node() == cn);
ASSERT_EQ(hz.ptr<int>()[0], 3);
ASSERT_EQ(hz.ptr<int>()[1], 6);
ASSERT_EQ(hz.ptr<int>()[2], 8);
ASSERT_EQ(hz.ptr<int>()[3], 6);
}
}

void run_comp_seq_rec_dyn_elemwise(CompNode cn, bool fake_first) {


Loading…
Cancel
Save