@@ -52,7 +52,7 @@ std::string get_default_device() { | |||||
} | } | ||||
void init_common(py::module m) { | void init_common(py::module m) { | ||||
auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") | |||||
auto PyCompNode = py::class_<CompNode>(m, "CompNode") | |||||
.def(py::init()) | .def(py::init()) | ||||
.def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) | .def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) | ||||
.def_property_readonly("logical_name", [](const CompNode& cn) { | .def_property_readonly("logical_name", [](const CompNode& cn) { | ||||
@@ -34,53 +34,36 @@ struct GradSlotWeakPtr { | |||||
size_t idx; | size_t idx; | ||||
}; | }; | ||||
struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<OptimizedBackwardGraphResult>>, CompNodeDepedentObject { | |||||
std::shared_ptr<void> on_comp_node_finalize() override { | |||||
clear(); | |||||
return {}; | |||||
} | |||||
} backward_graph_cache; | |||||
std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | ||||
ApplyContext& ctx, const apply_result_t& outputs) { | ApplyContext& ctx, const apply_result_t& outputs) { | ||||
// hash | // hash | ||||
static_assert(alignof(size_t) % alignof(bool) == 0); | |||||
size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool); | |||||
alignas(alignof(size_t)) std::byte buf[buf_size]; | |||||
size_t* size_t_ptr = reinterpret_cast<size_t*>(buf); | |||||
bool* bool_ptr = reinterpret_cast<bool*>(size_t_ptr + (1 + ctx.nargs * 2)); | |||||
bool* bool_ptr0 = bool_ptr; | |||||
*(size_t_ptr++) = ctx.op->hash(); | |||||
using OptimizedBackwardGraphCache = OpMethResultCache<std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>; | |||||
thread_local OptimizedBackwardGraphCache cache; | |||||
decltype(cache)::key_t cache_key{ctx.op}; | |||||
SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs; | |||||
SmallVector<bool>& input_requires_grad = std::get<0>(cache_key.extras); | |||||
input_descs.resize(ctx.nargs); | |||||
input_requires_grad.resize(ctx.nargs); | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | for (size_t i = 0; i < ctx.nargs; ++i) { | ||||
*(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); | |||||
*(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); | |||||
*(bool_ptr++) = !ctx.args[i]->m_grad_info_dict.empty(); | |||||
input_descs[i].layout.dtype = ctx.args[i]->dtype(); | |||||
input_descs[i].comp_node = ctx.args[i]->comp_node(); | |||||
input_requires_grad[i] = python::input_requires_grad(ctx, i); | |||||
} | } | ||||
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) && | |||||
bool_ptr == reinterpret_cast<bool*>(buf + buf_size)); | |||||
uint64_t key = XXHash{}.update(buf, buf_size).digest(); | |||||
auto&& iter = backward_graph_cache.find(key); | |||||
if (iter != backward_graph_cache.end()) { | |||||
auto iter = cache.find(cache_key); | |||||
if (iter != cache.end()) { | |||||
return iter->second; | return iter->second; | ||||
} | } | ||||
// slow path | // slow path | ||||
SmallVector<LogicalTensorDesc> inputs(ctx.nargs); | |||||
SmallVector<bool> input_requires_grad(ctx.nargs, false); | |||||
SmallVector<bool> output_has_grad(outputs.size(), true); | SmallVector<bool> output_has_grad(outputs.size(), true); | ||||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
inputs[i].comp_node = ctx.args[i]->comp_node(); | |||||
inputs[i].layout.dtype = ctx.args[i]->dtype(); | |||||
input_requires_grad[i] = python::input_requires_grad(ctx, i); | |||||
} | |||||
std::shared_ptr<OptimizedBackwardGraphResult> ret; | std::shared_ptr<OptimizedBackwardGraphResult> ret; | ||||
auto bg = OpDef::make_backward_graph( | auto bg = OpDef::make_backward_graph( | ||||
*ctx.op, inputs, input_requires_grad, output_has_grad); | |||||
*ctx.op, input_descs, input_requires_grad, output_has_grad); | |||||
if (!bg.graph.empty()) { | if (!bg.graph.empty()) { | ||||
ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | ||||
} | } | ||||
backward_graph_cache.emplace(key, ret); | |||||
cache.emplace(cache_key, ret); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -85,7 +85,14 @@ EncodedSubraph OpDef::make_backward_graph( | |||||
const SmallVector<LogicalTensorDesc>& inputs, | const SmallVector<LogicalTensorDesc>& inputs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
const SmallVector<bool>& output_has_grad) { | const SmallVector<bool>& output_has_grad) { | ||||
return def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | |||||
using BackwardGraphCache = OpMethResultCache<EncodedSubraph, SmallVector<bool>, SmallVector<bool>>; | |||||
thread_local BackwardGraphCache cache; | |||||
decltype(cache)::key_t cache_key{const_cast<OpDef&>(def).shared_from_this(), inputs, {input_requires_grad, output_has_grad}}; | |||||
auto iter = cache.find(cache_key); | |||||
if (iter == cache.end()) { | |||||
iter = cache.insert({cache_key, def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad)}).first; | |||||
} | |||||
return iter->second; | |||||
} | } | ||||
std::vector<std::pair<const char*, std::string>> OpDef::props( | std::vector<std::pair<const char*, std::string>> OpDef::props( | ||||
@@ -94,7 +101,7 @@ std::vector<std::pair<const char*, std::string>> OpDef::props( | |||||
} | } | ||||
std::string OpDef::to_string() const { | std::string OpDef::to_string() const { | ||||
std::string builder = "{"; | |||||
std::string builder = trait()->make_name(*this) + "{"; | |||||
for (auto&& [name, value]: props(*this)) { | for (auto&& [name, value]: props(*this)) { | ||||
builder += name; | builder += name; | ||||
builder += ": "; | builder += ": "; | ||||
@@ -170,7 +177,7 @@ std::string Subgraph::repr() const { | |||||
if (auto* p = op->try_cast_final<OprAttr>()) { | if (auto* p = op->try_cast_final<OprAttr>()) { | ||||
buf << p->type; | buf << p->type; | ||||
} else { | } else { | ||||
buf << op->dyn_typeinfo()->name; | |||||
buf << op->make_name(); | |||||
} | } | ||||
for (size_t i : ins) { | for (size_t i : ins) { | ||||
buf << " "; | buf << " "; | ||||
@@ -196,6 +203,26 @@ std::string Subgraph::repr() const { | |||||
return buf.str(); | return buf.str(); | ||||
} | } | ||||
bool Subgraph::is_single() const { | |||||
if (exprs.size() != 1) { | |||||
return false; | |||||
} | |||||
auto& expr = exprs.at(0); | |||||
return expr.inputs == inputs && expr.outputs == outputs; | |||||
} | |||||
std::shared_ptr<OpDef> Subgraph::as_single() const { | |||||
if (is_single()) { | |||||
return exprs.at(0).op; | |||||
} else { | |||||
return nullptr; | |||||
} | |||||
} | |||||
bool Subgraph::operator==(const Subgraph& rhs) const { | |||||
mgb_assert(false, "Not Implemented"); | |||||
} | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -12,6 +12,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain/imperative/op_def.h" | #include "megbrain/imperative/op_def.h" | ||||
#include "megbrain/imperative/graph_cache.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
@@ -113,49 +113,12 @@ void execute(const OpDef& def, | |||||
// return graph->infer_output_attrs_fallible(def, inputs); | // return graph->infer_output_attrs_fallible(def, inputs); | ||||
// } | // } | ||||
namespace { | |||||
size_t get_backward_graph_hash_key(const OpDef& def, | |||||
const SmallVector<LogicalTensorDesc>& inputs, | |||||
const SmallVector<bool>& input_requires_grad, | |||||
const SmallVector<bool>& output_has_grad) { | |||||
XXHash state; | |||||
size_t length = 0, data[3 + 2 * inputs.size()]; | |||||
data[length ++] = def.hash(); | |||||
for (auto &&i : inputs) { | |||||
data[length ++] = mgb::hash(i.layout.dtype.handle()); | |||||
data[length ++] = mgb::hash(i.comp_node); | |||||
} | |||||
data[length ++] = mgb::hash(input_requires_grad); | |||||
data[length ++] = mgb::hash(output_has_grad); | |||||
mgb_assert(length == 3 + 2 * inputs.size()); | |||||
state.update(data, length * sizeof(size_t)); | |||||
return state.digest(); | |||||
} | |||||
struct BackwardGraphCache : std::unordered_map<size_t, EncodedSubraph>, CompNodeDepedentObject { | |||||
std::shared_ptr<void> on_comp_node_finalize() override { | |||||
clear(); | |||||
return {}; | |||||
} | |||||
} backward_graph_cache; | |||||
} // anonymous namespace | |||||
EncodedSubraph | EncodedSubraph | ||||
make_backward_graph(const OpDef& def, | make_backward_graph(const OpDef& def, | ||||
const SmallVector<LogicalTensorDesc>& inputs, | const SmallVector<LogicalTensorDesc>& inputs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
const SmallVector<bool>& output_has_grad) { | const SmallVector<bool>& output_has_grad) { | ||||
auto hash_key = get_backward_graph_hash_key(def, inputs, input_requires_grad, output_has_grad); | |||||
auto&& iter = backward_graph_cache.find(hash_key); | |||||
if (iter != backward_graph_cache.end()) { | |||||
return iter->second; | |||||
} | |||||
auto&& graph = ProxyGraph::get_default_graph(); | |||||
auto res = graph->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | |||||
backward_graph_cache.emplace(hash_key, res); | |||||
return res; | |||||
return ProxyGraph::get_default_graph()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | |||||
} | } | ||||
} // namespace proxy_graph_detail | } // namespace proxy_graph_detail | ||||
@@ -0,0 +1,90 @@ | |||||
/** | |||||
* \file imperative/src/include/megbrain/imperative/graph_builder.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/imperative/subgraph.h" | |||||
#include "megbrain/imperative/op_def.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
template <typename... TExtraArgs> | |||||
struct OpMethArgs { | |||||
std::shared_ptr<OpDef> op; | |||||
SmallVector<LogicalTensorDesc> inputs; | |||||
std::tuple<TExtraArgs...> extras; | |||||
size_t hash() const; | |||||
bool operator==(const OpMethArgs& rhs) const { | |||||
if (bool(op) ^ bool(rhs.op)) { | |||||
return false; | |||||
} | |||||
if (op && rhs.op && !op->is_same(*rhs.op)) { | |||||
return false; | |||||
} | |||||
if (inputs.size() != rhs.inputs.size()) { | |||||
return false; | |||||
} | |||||
size_t nr_inputs = inputs.size(); | |||||
for (size_t i = 0; i < nr_inputs; ++i) { | |||||
if (inputs[i].comp_node != rhs.inputs[i].comp_node) { | |||||
return false; | |||||
} | |||||
if (inputs[i].layout.dtype != rhs.inputs[i].layout.dtype) { | |||||
return false; | |||||
} | |||||
} | |||||
return extras == rhs.extras; | |||||
} | |||||
struct hash_t { | |||||
size_t operator()(const OpMethArgs& key) const { | |||||
return key.hash(); | |||||
} | |||||
}; | |||||
}; | |||||
template <typename... TExtraArgs> | |||||
inline size_t OpMethArgs<TExtraArgs...>::hash() const { | |||||
XXHash state; | |||||
size_t length = 0; | |||||
size_t data[1 + 2 * inputs.size() + sizeof...(TExtraArgs)]; | |||||
auto append = [&](size_t hash) { | |||||
data[length++] = hash; | |||||
}; | |||||
append(op->hash()); | |||||
for (auto &&i : inputs) { | |||||
append(mgb::hash(i.layout.dtype.handle())); | |||||
append(mgb::hash(i.comp_node)); | |||||
} | |||||
std::apply([&](auto&&... extras){ | |||||
(append(mgb::hash(extras)), ...); | |||||
}, extras); | |||||
mgb_assert(length == sizeof(data) / sizeof(size_t)); | |||||
state.update(data, sizeof(data)); | |||||
return state.digest(); | |||||
} | |||||
template <typename TValue, typename... TExtraArgs> | |||||
struct OpMethResultCache : std::unordered_map<OpMethArgs<TExtraArgs...>, TValue, typename OpMethArgs<TExtraArgs...>::hash_t>, CompNodeDepedentObject { | |||||
std::shared_ptr<void> on_comp_node_finalize() override { | |||||
static_cast<std::unordered_map<OpMethArgs<TExtraArgs...>, TValue, typename OpMethArgs<TExtraArgs...>::hash_t>*>(this)->clear(); | |||||
// clear(); | |||||
return {}; | |||||
} | |||||
using key_t = OpMethArgs<TExtraArgs...>; | |||||
}; | |||||
} // namespace imperative | |||||
} // namespace mgb |