|
|
@@ -9,7 +9,11 @@ |
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include <sstream> |
|
|
|
#include <range/v3/all.hpp> |
|
|
|
|
|
|
|
#include "megbrain/imperative/ops/backward_graph.h" |
|
|
|
#include "megbrain/imperative/ops/opr_attr.h" |
|
|
|
#include "../op_trait.h" |
|
|
|
|
|
|
|
namespace mgb { |
|
|
@@ -66,6 +70,67 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::i |
|
|
|
return {ret, validated}; |
|
|
|
} |
|
|
|
|
|
|
|
std::string BackwardGraph::InternalGraph::repr() { |
|
|
|
std::ostringstream buf; |
|
|
|
buf << "("; |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
if (i > 0) buf << ", "; |
|
|
|
buf << "%" << inputs[i]; |
|
|
|
} |
|
|
|
buf << ") => {\n"; |
|
|
|
auto fmt_const = [](size_t i, TensorPtr& t) { |
|
|
|
if (t->shape().ndim == 1 && t->shape()[0] == 1) { |
|
|
|
auto&& v = t->get_value(); |
|
|
|
if (v.dtype() == dtype::Float32{}) { |
|
|
|
return std::to_string(*v.ptr<dt_float32>()); |
|
|
|
} else if (v.dtype() == dtype::Int32{}) { |
|
|
|
return std::to_string(*v.ptr<int32_t>()); |
|
|
|
} |
|
|
|
} |
|
|
|
return std::string("%c") + std::to_string(i); |
|
|
|
}; |
|
|
|
std::unordered_map<size_t, std::string> const_reps; |
|
|
|
for (auto&& [i, t] : constants) { |
|
|
|
const_reps.emplace(i, fmt_const(i, t)); |
|
|
|
} |
|
|
|
for (auto& [op, ins, outs] : exprs) { |
|
|
|
buf << " "; |
|
|
|
if (outs.size()) { |
|
|
|
for (size_t i = 0; i < outs.size(); ++i) { |
|
|
|
if (i > 0) buf << ", "; |
|
|
|
buf << "%" << outs[i]; |
|
|
|
} |
|
|
|
buf << " = "; |
|
|
|
} |
|
|
|
if (auto* p = op->try_cast_final<OprAttr>()) { |
|
|
|
buf << p->type; |
|
|
|
} else { |
|
|
|
buf << op->dyn_typeinfo()->name; |
|
|
|
} |
|
|
|
for (size_t i : ins) { |
|
|
|
buf << " "; |
|
|
|
auto&& it = const_reps.find(i); |
|
|
|
if (it != const_reps.end()) { |
|
|
|
buf << it->second; |
|
|
|
} else { |
|
|
|
buf << "%" << i; |
|
|
|
} |
|
|
|
} |
|
|
|
buf << "\n"; |
|
|
|
} |
|
|
|
buf << " "; |
|
|
|
if (outputs.size()) { |
|
|
|
for (size_t i = 0; i < outputs.size(); ++i) { |
|
|
|
if (i > 0) buf << ", "; |
|
|
|
buf << "%" << outputs[i]; |
|
|
|
} |
|
|
|
} else { |
|
|
|
buf << "()"; |
|
|
|
} |
|
|
|
buf << "\n}\n"; |
|
|
|
return buf.str(); |
|
|
|
} |
|
|
|
|
|
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph); |
|
|
|
|
|
|
|
namespace { |
|
|
|