|
|
@@ -25,6 +25,7 @@ |
|
|
|
#include "megbrain/utils/timer.h" |
|
|
|
#include "megbrain/comp_node_env.h" |
|
|
|
#include "megbrain/gopt/inference.h" |
|
|
|
#include "megbrain/plugin/profiler.h" |
|
|
|
|
|
|
|
#include "megbrain/test/helper.h" |
|
|
|
#include "megdnn/oprs/base.h" |
|
|
@@ -1993,6 +1994,12 @@ typename megdnn::ExecutionPolicy try_find_any_bias_preprocess_algo( |
|
|
|
void test_free_memory_in_weight_preprocess(int record_level, CompNode cn) { |
|
|
|
HostTensorGenerator<> gen; |
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
#if MGB_ENABLE_JSON |
|
|
|
std::unique_ptr<GraphProfiler> profiler; |
|
|
|
if(!record_level){ |
|
|
|
profiler = std::make_unique<GraphProfiler>(graph.get()); |
|
|
|
} |
|
|
|
#endif |
|
|
|
graph->options().graph_opt.weight_preprocess = true; |
|
|
|
graph->options().comp_node_seq_record_level = record_level; |
|
|
|
auto mkvar = [&](const char* name, const TensorShape& shp) { |
|
|
@@ -2055,6 +2062,13 @@ void test_free_memory_in_weight_preprocess(int record_level, CompNode cn) { |
|
|
|
if (wp2.val()) { |
|
|
|
check(w2); |
|
|
|
} |
|
|
|
#if MGB_ENABLE_JSON |
|
|
|
if (profiler) { |
|
|
|
func->wait(); |
|
|
|
profiler->to_json_full(func.get()) |
|
|
|
->writeto_fpath(output_file("weight_preprocess.json")); |
|
|
|
} |
|
|
|
#endif |
|
|
|
} |
|
|
|
} // anonymous namespace |
|
|
|
|
|
|
|