Browse Source

fix(mgb/core): add warning information about const_var_shape when record mode

GitOrigin-RevId: a99f9c4e5d
tags/v1.3.1
Megvii Engine Team 4 years ago
parent
commit
e1c83d8d51
3 changed files with 31 additions and 14 deletions
  1. +13
    -13
      src/core/impl/comp_node/cpu/comp_node.cpp
  2. +1
    -1
      src/core/impl/comp_node/cpu/comp_node.h
  3. +17
    -0
      src/core/impl/graph/cg_impl_seq.cpp

+ 13
- 13
src/core/impl/comp_node/cpu/comp_node.cpp View File

@@ -243,7 +243,7 @@ public:
};

using CompNodeBaseImpl = CpuCompNode::CompNodeBaseImpl;
using CompNodeNoRecorderImpl = CpuCompNode::CompNodeNoRecorderImpl;
using CompNodeDefaultImpl = CpuCompNode::CompNodeDefaultImpl;
using CompNodeRecorderImpl = CpuCompNode::CompNodeRecorderImpl;

//! ==================== CompNodeBaseImpl ======================
@@ -466,29 +466,29 @@ public:
}
};

//! ==================== CompNodeNoRecorderImpl ======================
//! ==================== CompNodeDefaultImpl ======================
/**
* \note: CompNodeNoRecorderImpl will use most implements in base including:
* \note: CompNodeDefaultImpl will use most implements in base including:
* alloc_device, alloc_host, copy_to_host, copy_to_device, peer_copy_to,
* add_callback ...
*/
class CpuCompNode::CompNodeNoRecorderImpl final : public CompNodeBaseImpl {
class CpuCompNode::CompNodeDefaultImpl final : public CompNodeBaseImpl {
MGB_DYN_TYPE_OBJ_FINAL_DECL;

public:
//! ptr to default cpu, only used by check_global_finalized
static CompNodeNoRecorderImpl* sm_default_cpu_comp_node_ptr;
static CompNodeDefaultImpl* sm_default_cpu_comp_node_ptr;

static void static_free_device(ImplBase* self, void* ptr) {
static_cast<CompNodeNoRecorderImpl*>(self)->free_device(ptr);
static_cast<CompNodeDefaultImpl*>(self)->free_device(ptr);
}

static void static_free_host(ImplBase* self, void* ptr) {
static_cast<CompNodeNoRecorderImpl*>(self)->free_host(ptr);
static_cast<CompNodeDefaultImpl*>(self)->free_host(ptr);
}
using CpuEventImpl = CpuDispatchableBase::EventImpl;

CompNodeNoRecorderImpl(const Locator& locator,
CompNodeDefaultImpl(const Locator& locator,
const Locator& locator_logical)
: CompNodeBaseImpl(locator, locator_logical, static_free_device,
static_free_host) {
@@ -501,7 +501,7 @@ public:
sm_default_cpu_comp_node_ptr = this;
}

~CompNodeNoRecorderImpl() {
~CompNodeDefaultImpl() {
m_env.fini();
sm_default_cpu_comp_node_ptr = nullptr;
}
@@ -551,8 +551,8 @@ public:

SeqRecorderImpl* cur_recorder() const override { return nullptr; }
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeNoRecorderImpl);
CompNodeNoRecorderImpl* CompNodeNoRecorderImpl::sm_default_cpu_comp_node_ptr =
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeDefaultImpl);
CompNodeDefaultImpl* CompNodeDefaultImpl::sm_default_cpu_comp_node_ptr =
nullptr;

//! ==================== CompNodeRecorderImpl ======================
@@ -746,7 +746,7 @@ public:
void peer_copy_to(Impl* dest_impl, void* dest, const void* src,
size_t size) override {
//! copy to default_cpu
if (dest_impl->same_type<CpuCompNode::CompNodeNoRecorderImpl>()) {
if (dest_impl->same_type<CpuCompNode::CompNodeDefaultImpl>()) {
CompNodeBaseImpl::peer_copy_to(dest_impl, dest, src, size);
return;
}
@@ -986,7 +986,7 @@ void CpuCompNode::sync_all() {
// CpuCompNode::Pool
CompNode CompNode::default_cpu() {
static Locator locator{DeviceType::CPU, Locator::DEVICE_CPU_DEFAULT, {-1}};
static CompNodeNoRecorderImpl impl{locator, locator};
static CompNodeDefaultImpl impl{locator, locator};
return &impl;
}



+ 1
- 1
src/core/impl/comp_node/cpu/comp_node.h View File

@@ -55,7 +55,7 @@ namespace mgb {
};

class CompNodeBaseImpl;
class CompNodeNoRecorderImpl;
class CompNodeDefaultImpl;
class CompNodeRecorderImpl;

static void foreach(thin_function<void(CompNode)> callback);


+ 17
- 0
src/core/impl/graph/cg_impl_seq.cpp View File

@@ -11,6 +11,7 @@

#include "./cg_impl_seq.h"
#include "megbrain/graph/exc_extra_info.h"
#include "megbrain/opr/tensor_manip.h"

using namespace mgb;
using namespace cg;
@@ -255,6 +256,22 @@ ComputingGraphImpl::ComputingSequence::check_enable_comp_node_seq_recorder() {
}
}
}
auto check_const_shape = [&]() {
for (auto i : *m_opr_seq) {
for (auto j : i->output()) {
if (j->shape().ndim && !is_const_var_shape(j)) {
mgb_log_warn(
"Non-const var shape detected. Make sure all "
"shapes are constant. Check whether "
"'const_var_shape' is set "
"in GraphLoadConfig under record mode");
return;
}
}
}
};
check_const_shape();

auto cn = *m_used_comp_node.begin();
auto rec = cn.create_seq_recorder(m_owner_graph);
if (!rec) {


Loading…
Cancel
Save