GitOrigin-RevId: bb9150eb83
tags/v1.0.0-rc1
@@ -537,14 +537,28 @@ std::shared_ptr<json::Value> ComputingGraphImpl::ComputingSequence::to_json() | |||||
comp_seq->add(json::String::make(i->id_str())); | comp_seq->add(json::String::make(i->id_str())); | ||||
} | } | ||||
// expand opr and var nodes that do not appear in comp seq | |||||
// expand opr and var nodes that do not appear in comp seq, | |||||
// also expand var nodes which are only used in static infer | |||||
{ | { | ||||
VarNodeArray new_var_node; | VarNodeArray new_var_node; | ||||
auto&& mgr = m_owner_graph->static_infer_manager_impl(); | |||||
auto check_opr_input = [&](OperatorNodeBase* opr) { | auto check_opr_input = [&](OperatorNodeBase* opr) { | ||||
auto update = [&](VarNode* var) { | |||||
if (!(all_var_node.count(var))) { | |||||
all_var_node.insert(var); | |||||
new_var_node.push_back(var); | |||||
} | |||||
}; | |||||
for (auto i : opr->input()) { | for (auto i : opr->input()) { | ||||
if (!(all_var_node.count(i))) { | |||||
all_var_node.insert(i); | |||||
new_var_node.push_back(i); | |||||
update(i); | |||||
} | |||||
for (auto &&out : opr->output()) { | |||||
using DepType = static_infer::DepType; | |||||
for (auto&& i : mgr.get_deps({out, DepType::SHAPE})) { | |||||
update(i.dest); | |||||
} | |||||
for (auto&& i : mgr.get_deps({out, DepType::VALUE})) { | |||||
update(i.dest); | |||||
} | } | ||||
} | } | ||||
}; | }; | ||||
@@ -245,6 +245,9 @@ MGB_DEFINE_CLS_WITH_SUPER(StaticInferManagerImpl::TagTraitMutableBase, | |||||
return m_infer_withoutexc_ret; | return m_infer_withoutexc_ret; | ||||
} | } | ||||
//! original deps given in the InferDesc by the caller | |||||
virtual const DepVal& raw_deps() = 0; | |||||
protected: | protected: | ||||
//! current infer result, to be used by dependents | //! current infer result, to be used by dependents | ||||
InpElement m_inp_element; | InpElement m_inp_element; | ||||
@@ -300,9 +303,6 @@ MGB_DEFINE_CLS_WITH_SUPER(StaticInferManagerImpl::TagTraitMutableBase, | |||||
//! all missing inputs | //! all missing inputs | ||||
SharedSet<TagHandler*, TagHandlerSet> m_missing_input; | SharedSet<TagHandler*, TagHandlerSet> m_missing_input; | ||||
//! original deps given in the InferDesc by the caller | |||||
virtual const DepVal& raw_deps() = 0; | |||||
//! recursively set m_inp_element_synced of this and all receivers to | //! recursively set m_inp_element_synced of this and all receivers to | ||||
//! false | //! false | ||||
void reset_inp_element_synced(); | void reset_inp_element_synced(); | ||||
@@ -1027,6 +1027,14 @@ void StaticInferManagerImpl::update_mutable_src_shape(Tag dest) { | |||||
MGB_CATCH(MegBrainError & exc, { update_rethrow_exc(dest, exc); }) | MGB_CATCH(MegBrainError & exc, { update_rethrow_exc(dest, exc); }) | ||||
} | } | ||||
DepVal StaticInferManagerImpl::get_deps(const DepElement &elem) { | |||||
auto trait_base = get_tag_trait_container(elem.dest).select(elem.type); | |||||
if (!trait_base || trait_base->is_const()) | |||||
return {}; | |||||
return trait_base->as_mutable_safe()->raw_deps(); | |||||
} | |||||
/* ===================== CompSeqManager ===================== */ | /* ===================== CompSeqManager ===================== */ | ||||
class CompSeqManager::VersionedTagTrait { | class CompSeqManager::VersionedTagTrait { | ||||
@@ -99,6 +99,17 @@ class StaticInferManagerImpl final: public StaticInferManager { | |||||
*/ | */ | ||||
void update_mutable_src_shape(Tag tag); | void update_mutable_src_shape(Tag tag); | ||||
/*! | |||||
* \brief get original deps given in the InferDesc which is registered | |||||
* by register_shape_infer or register_value_infer | |||||
* | |||||
* Note: the \p elem with DepType::SHAPE and InferType::CONST shows no | |||||
* deps since the StaticInferManagerImpl folds the infererence chain of | |||||
* the const var shape | |||||
*/ | |||||
DepVal get_deps(const DepElement &elem); | |||||
private: | private: | ||||
friend class CompSeqManager; | friend class CompSeqManager; | ||||
@@ -396,6 +396,108 @@ VarNode& VarNode::comp_node(const CompNode &cn) { | |||||
} | } | ||||
#if MGB_ENABLE_JSON | #if MGB_ENABLE_JSON | ||||
std::shared_ptr<json::Value> | |||||
VarNode::dump_static_infer_info_to_json() const { | |||||
using namespace cg::static_infer; | |||||
auto&& mgr = static_cast<cg::ComputingGraphImpl*>( | |||||
owner_graph())->static_infer_manager_impl(); | |||||
auto get_dep_type = [](const DepType& type) -> std::string { | |||||
switch (type) { | |||||
#define cb(name) \ | |||||
case DepType::name: \ | |||||
return #name; | |||||
cb(SHAPE) | |||||
cb(VALUE) | |||||
#undef cb | |||||
default: | |||||
mgb_throw(MegBrainError, "unknown dep type"); | |||||
} | |||||
}; | |||||
auto get_infer_type = [](const InferType::Flag& type) { | |||||
switch (type) { | |||||
#define cb(name) \ | |||||
case InferType::Flag::name: \ | |||||
return json::String::make(#name); | |||||
cb(NO_DESC) | |||||
cb(CONST) | |||||
cb(RT_STATIC) | |||||
cb(MISSING_INP) | |||||
#undef cb | |||||
default: | |||||
mgb_throw(MegBrainError, "unknown infer type"); | |||||
} | |||||
}; | |||||
auto make_tag = [&](const DepType& type) { | |||||
VarNode* self = const_cast<VarNode*>(this); | |||||
auto c_deps = mgr.get_deps({self, type}); | |||||
auto deps = json::Array::make(); | |||||
for (auto&& i : c_deps) { | |||||
mgb_assert(i.dest); | |||||
deps->add(json::Object::make({ | |||||
{"var", json::String::make(i.dest->id_str())}, | |||||
{"dep_type", json::String::make(get_dep_type(i.type))} | |||||
})); | |||||
} | |||||
auto infer_type_handle = mgr.get_infer_type(self); | |||||
auto inferred_result = json::Null::make(); | |||||
auto infer_type = type == DepType::SHAPE ? infer_type_handle.shape | |||||
: infer_type_handle.value; | |||||
if (infer_type != InferType::Flag::NO_DESC) { | |||||
if (type == DepType::SHAPE) { | |||||
if (auto shape = mgr.infer_shape_fallible(self)) { | |||||
auto inferred_shape = json::Array::make(); | |||||
for (size_t i = 0; i < shape->ndim; ++ i) { | |||||
inferred_shape->add(json::Number::make((*shape)[i])); | |||||
} | |||||
inferred_result = inferred_shape; | |||||
} | |||||
} else { | |||||
if (auto p = mgr.infer_value_fallible(self)) { | |||||
auto&& dev = *p; | |||||
if (dev.shape().ndim == 1 && | |||||
dev.shape(0) < TensorShape::MAX_NDIM && | |||||
mgb_likely(dev.comp_node() == CompNode::default_cpu())) { | |||||
MGB_TRY { | |||||
size_t nr_elems = dev.shape(0); | |||||
auto&& dtype = dev.dtype(); | |||||
void* vptr = dev.raw_ptr(); | |||||
double data[nr_elems]; | |||||
HostTensorND contig; | |||||
if (!dev.layout().is_contiguous()) { | |||||
// both src and dst are placed on default cpu, | |||||
// no need for sync | |||||
contig.copy_from(dev); | |||||
mgb_assert(contig.layout().is_contiguous()); | |||||
vptr = contig.raw_ptr(); | |||||
} | |||||
static_cast_dtype(data, dtype, vptr, nr_elems); | |||||
auto inferred_value = json::Array::make(); | |||||
for (size_t i = 0; i < nr_elems; ++ i) { | |||||
inferred_value->add(json::Number::make(data[i])); | |||||
} | |||||
inferred_result = inferred_value; | |||||
} | |||||
MGB_CATCH(ConversionError&, {}); | |||||
} else { | |||||
inferred_result = json::String::make("Large Array"); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
return json::Object::make({ | |||||
{"node_type", json::String::make("static_infer_tag")}, | |||||
{"infer_type", get_infer_type(infer_type)}, | |||||
{"inferred_result", inferred_result}, | |||||
{"deps", deps} | |||||
}); | |||||
}; | |||||
return json::Object::make({ | |||||
#define TAG(type) {get_dep_type(type), make_tag(type)} | |||||
TAG(DepType::SHAPE), TAG(DepType::VALUE) | |||||
#undef TAG | |||||
}); | |||||
} | |||||
std::shared_ptr<json::Value> VarNode::to_json() const { | std::shared_ptr<json::Value> VarNode::to_json() const { | ||||
auto get_var = [](VarNode *p) -> std::shared_ptr<json::Value> { | auto get_var = [](VarNode *p) -> std::shared_ptr<json::Value> { | ||||
if(p) | if(p) | ||||
@@ -443,8 +545,10 @@ std::shared_ptr<json::Value> VarNode::to_json() const { | |||||
{"dev_ptr", json::Null::make()}, | {"dev_ptr", json::Null::make()}, | ||||
{"prev_dev_ptr", json::NumberInt::make(reinterpret_cast<size_t>( | {"prev_dev_ptr", json::NumberInt::make(reinterpret_cast<size_t>( | ||||
m_prev_dev_ptr))}, | m_prev_dev_ptr))}, | ||||
{"flag", flag} | |||||
{"flag", flag}, | |||||
{"static_infer_tags", dump_static_infer_info_to_json()} | |||||
}); | }); | ||||
if (m_prev_dev_ptr) { | if (m_prev_dev_ptr) { | ||||
(*rst)["prev_dev_ptr_end"] = json::NumberInt::make( | (*rst)["prev_dev_ptr_end"] = json::NumberInt::make( | ||||
reinterpret_cast<size_t>(m_prev_dev_ptr) + | reinterpret_cast<size_t>(m_prev_dev_ptr) + | ||||
@@ -575,6 +575,10 @@ class VarNode final: public GraphNodeBase { | |||||
void assign_dev_tensor_from_tensor(const DeviceTensorND &value); | void assign_dev_tensor_from_tensor(const DeviceTensorND &value); | ||||
#if MGB_ENABLE_JSON | |||||
std::shared_ptr<json::Value> dump_static_infer_info_to_json() const; | |||||
#endif | |||||
friend class static_infer::StaticInferManagerImpl; | friend class static_infer::StaticInferManagerImpl; | ||||
friend class VarNodeMemManager; | friend class VarNodeMemManager; | ||||
friend class VarDevMemDefragmenter; | friend class VarDevMemDefragmenter; | ||||