GitOrigin-RevId: 56d90be0e7
tags/v1.6.0-rc1
@@ -77,7 +77,7 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||
std::shared_ptr<OptimizedBackwardGraphResult> ret; | |||
auto bg = OpDef::make_backward_graph( | |||
*ctx.op, inputs, input_requires_grad, output_has_grad); | |||
if (!bg.backward.empty()) { | |||
if (!bg.graph.empty()) { | |||
ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | |||
} | |||
backward_graph_cache.emplace(key, ret); | |||
@@ -37,7 +37,7 @@ void init_imperative_rt(py::module m) { | |||
const SmallVector<bool>& input_requires_grad, | |||
const SmallVector<bool>& output_has_grad){ | |||
auto result = OpDef::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | |||
return std::make_tuple("backward_graph", result.save_for_backward, result.input_has_grad); | |||
return std::make_tuple("backward_graph", result.input_mask, result.output_mask); | |||
}; | |||
m.def("make_backward_graph", make_backward_graph); | |||
} |
@@ -16,19 +16,19 @@ | |||
using namespace mgb; | |||
using namespace imperative; | |||
OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src) | |||
: input_has_grad(src.input_has_grad) { | |||
if (src.backward.exprs.size() <= 1) { | |||
OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const EncodedSubraph& src) | |||
: input_has_grad(src.output_mask) { | |||
if (src.graph.exprs.size() <= 1) { | |||
// backward graph only contains a single op | |||
backward = src.backward; | |||
save_for_backward = src.save_for_backward; | |||
backward = src.graph; | |||
save_for_backward = src.input_mask; | |||
return; | |||
} | |||
save_for_backward.resize(src.save_for_backward.size(), false); | |||
save_for_backward.resize(src.input_mask.size(), false); | |||
auto&& graph = src.backward; | |||
auto&& mask = src.save_for_backward; | |||
size_t input_size = src.input_has_grad.size(); | |||
auto&& graph = src.graph; | |||
auto&& mask = src.input_mask; | |||
size_t input_size = src.output_mask.size(); | |||
size_t output_size = (mask.size() - input_size) / 2; | |||
mgb_assert(input_size + output_size * 2 == mask.size()); | |||
@@ -80,7 +80,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> OpDef::infer_output_attrs_falli | |||
return def.trait()->infer_output_attrs_fallible(def, inputs); | |||
} | |||
BackwardGraphResult OpDef::make_backward_graph( | |||
EncodedSubraph OpDef::make_backward_graph( | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs, | |||
const SmallVector<bool>& input_requires_grad, | |||
@@ -668,14 +668,14 @@ struct ProxyGraph::GradGraph { | |||
cg::VarNode* grad; | |||
}; | |||
BackwardGraphResult | |||
EncodedSubraph | |||
ProxyGraph::make_backward_graph( | |||
const OpDef& opdef, | |||
const SmallVector<LogicalTensorDesc>& input_descs, | |||
const SmallVector<bool>& input_requires_grad, | |||
const SmallVector<bool>& output_has_grad) { | |||
ThinHashMap<VarNode*, size_t> var2idx; | |||
auto push = [&var2idx, cnt=0](VarNode* var) mutable { | |||
auto push = [&var2idx, cnt=1](VarNode* var) mutable { //cnt is always greater non zero | |||
auto&& ret = var2idx.emplace(var, cnt ++); | |||
mgb_assert(ret.second, "var %s has been already inserted", var->cname()); | |||
return ret.first->second; | |||
@@ -702,8 +702,8 @@ ProxyGraph::make_backward_graph( | |||
} | |||
auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); | |||
BackwardGraphResult result; | |||
auto&& igraph = result.backward; | |||
EncodedSubraph result; | |||
auto&& igraph = result.graph; | |||
size_t nr_backward_graph_inputs = 0; | |||
auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, | |||
@@ -735,7 +735,7 @@ ProxyGraph::make_backward_graph( | |||
// set backward graph outputs | |||
cg::DepOprIter iter{gen_expr}; | |||
iter.set_visited(fwd); | |||
result.input_has_grad.resize(inputs.size()); | |||
result.output_mask.resize(inputs.size()); | |||
VarNodeArray output_grads_with_unused_var; | |||
{ | |||
@@ -760,6 +760,7 @@ ProxyGraph::make_backward_graph( | |||
if (grad_results.valid()) { | |||
grad = grad_results.val()[i]; | |||
} else { | |||
mgb_assert(gfunc, "could not find grad function"); | |||
auto res = (*gfunc)(fwd, i, output_grads_with_unused_var); | |||
if (res.from_single()) { | |||
grad = res.single(); | |||
@@ -776,9 +777,9 @@ ProxyGraph::make_backward_graph( | |||
fwd->dyn_typeinfo()->name, i); | |||
iter.add(grad); | |||
igraph.outputs.push_back(var2idx.at(grad)); | |||
result.input_has_grad[i] = true; | |||
result.output_mask[i] = true; | |||
} else { | |||
result.input_has_grad[i] = false; | |||
result.output_mask[i] = false; | |||
} | |||
} | |||
if (igraph.outputs.empty()) { | |||
@@ -787,15 +788,15 @@ ProxyGraph::make_backward_graph( | |||
// set backward graph inputs | |||
igraph.inputs.reserve(nr_backward_graph_inputs); | |||
result.save_for_backward.reserve(nr_backward_graph_inputs); | |||
result.input_mask.reserve(nr_backward_graph_inputs); | |||
auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) { | |||
for (auto&& i: vars) { | |||
auto&& iter = var2idx.find(i); | |||
if (iter != var2idx.end()) { | |||
igraph.inputs.push_back(iter->second); | |||
result.save_for_backward.push_back(true); | |||
result.input_mask.push_back(true); | |||
} else { | |||
result.save_for_backward.push_back(false); | |||
result.input_mask.push_back(false); | |||
} | |||
} | |||
}; | |||
@@ -40,7 +40,7 @@ public: | |||
const SmallVector<Tensor*>& outputs, | |||
const SmallVector<Tensor*>& workspace); | |||
BackwardGraphResult make_backward_graph( | |||
EncodedSubraph make_backward_graph( | |||
const OpDef& opdef, | |||
const SmallVector<LogicalTensorDesc>& input_descs, | |||
const SmallVector<bool>& input_requires_grad, | |||
@@ -133,7 +133,7 @@ size_t get_backward_graph_hash_key(const OpDef& def, | |||
return state.digest(); | |||
} | |||
struct BackwardGraphCache : std::unordered_map<size_t, BackwardGraphResult>, CompNodeDepedentObject { | |||
struct BackwardGraphCache : std::unordered_map<size_t, EncodedSubraph>, CompNodeDepedentObject { | |||
std::shared_ptr<void> on_comp_node_finalize() override { | |||
clear(); | |||
return {}; | |||
@@ -142,7 +142,7 @@ struct BackwardGraphCache : std::unordered_map<size_t, BackwardGraphResult>, Com | |||
} // anonymous namespace | |||
BackwardGraphResult | |||
EncodedSubraph | |||
make_backward_graph(const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs, | |||
const SmallVector<bool>& input_requires_grad, | |||
@@ -101,5 +101,26 @@ void Subgraph::replace_vars( | |||
} | |||
} | |||
std::string EncodedSubraph::repr() const { | |||
std::string buffer; | |||
buffer.push_back('|'); | |||
for (size_t i = 0; i < input_mask.size(); ++i) { | |||
buffer.push_back(input_mask[i] ? '#' : ' '); | |||
} | |||
buffer.push_back('|'); | |||
buffer.push_back('\n'); | |||
buffer.append(graph.repr()); | |||
buffer.push_back('|'); | |||
for (size_t i = 0; i < output_mask.size(); ++i) { | |||
buffer.push_back(output_mask[i] ? '#' : ' '); | |||
} | |||
buffer.push_back('|'); | |||
return buffer; | |||
} | |||
size_t EncodedSubraph::hash() const { | |||
return std::hash<std::string>{}(repr()); | |||
} | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -19,7 +19,7 @@ struct OptimizedBackwardGraphResult { | |||
SmallVector<bool> save_for_backward; | |||
SmallVector<bool> input_has_grad; | |||
OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph); | |||
OptimizedBackwardGraphResult(const EncodedSubraph& bgraph); | |||
}; | |||
} // namespace mgb::imperative |
@@ -29,12 +29,6 @@ enum DispatchMode { | |||
using SharedOp = std::shared_ptr<OpDef>; | |||
struct BackwardGraphResult { | |||
Subgraph backward; | |||
SmallVector<bool> save_for_backward; | |||
SmallVector<bool> input_has_grad; | |||
}; | |||
class OpDef : public Hashable, | |||
public NonCopyableObj, | |||
public std::enable_shared_from_this<OpDef> { | |||
@@ -91,7 +85,7 @@ public: | |||
const SmallVector<TensorPtr>& inputs_tensors, | |||
const SmallVector<MemoryDesc>& inputs_mems); | |||
static BackwardGraphResult make_backward_graph( | |||
static EncodedSubraph make_backward_graph( | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs, | |||
const SmallVector<bool>& input_requires_grad, | |||
@@ -38,7 +38,7 @@ void exec(const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs, | |||
const SmallVector<TensorPtr>& outputs); | |||
BackwardGraphResult | |||
EncodedSubraph | |||
make_backward_graph(const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs, | |||
const SmallVector<bool>& input_requires_grad, | |||
@@ -96,5 +96,185 @@ struct Subgraph { | |||
bool operator==(const Subgraph& rhs) const; | |||
}; | |||
struct EncodedSubraph { | |||
Subgraph graph; | |||
SmallVector<bool> input_mask; | |||
SmallVector<bool> output_mask; | |||
template <typename TContainer> | |||
TContainer encode_inputs(TContainer inputs) const { | |||
TContainer encoded_inputs; | |||
size_t index = 0; | |||
for (auto&& input : inputs) { | |||
mgb_assert(index < input_mask.size(), "index out of range"); | |||
if (input_mask[index++]) { | |||
encoded_inputs.push_back(input); | |||
} | |||
} | |||
mgb_assert(index == input_mask.size(), "mask size mismatch"); | |||
return encoded_inputs; | |||
} | |||
template <typename TContainer> | |||
TContainer encode_outputs(TContainer outputs) const { | |||
TContainer encoded_outputs; | |||
size_t index = 0; | |||
for (auto&& output : outputs) { | |||
mgb_assert(index < output_mask.size(), "index out of range"); | |||
if (output_mask[index++]) { | |||
encoded_outputs.push_back(output); | |||
} | |||
} | |||
mgb_assert(index == output_mask.size(), "mask size mismatch"); | |||
return encoded_outputs; | |||
} | |||
template <typename TContainer> | |||
TContainer decode_outputs(TContainer outputs) const { | |||
TContainer decoded_outputs; | |||
size_t index = 0; | |||
for (size_t i = 0; i < output_mask.size(); i++) { | |||
mgb_assert(index < output_mask.size(), "index out of range"); | |||
if (output_mask[i]) { | |||
decoded_outputs.push_back(outputs[index++]); | |||
} else { | |||
decoded_outputs.emplace_back(); | |||
} | |||
} | |||
mgb_assert(decoded_outputs.size() == output_mask.size(), | |||
"mask size mismatch"); | |||
return decoded_outputs; | |||
} | |||
static EncodedSubraph make(Subgraph graph) { | |||
EncodedSubraph result; | |||
result.input_mask = graph.gen_input_mask(); | |||
result.output_mask = graph.gen_output_mask(); | |||
graph.inputs = result.encode_inputs(graph.inputs); | |||
graph.outputs = result.encode_outputs(graph.outputs); | |||
result.graph = graph; | |||
return result; | |||
} | |||
static EncodedSubraph make_single( | |||
std::shared_ptr<OpDef> op, | |||
SmallVector<bool> input_mask, | |||
SmallVector<bool> output_mask) { | |||
EncodedSubraph result; | |||
result.input_mask = input_mask; | |||
result.output_mask = output_mask; | |||
Subgraph::var_t last_var = 0; | |||
for (auto&& mask: input_mask) { | |||
if (mask) { | |||
result.graph.inputs.push_back(++last_var); | |||
} | |||
} | |||
for (auto&& mask: output_mask) { | |||
if (mask) { | |||
result.graph.outputs.push_back(++last_var); | |||
} | |||
} | |||
result.graph.exprs = {Subgraph::expr_t{op, result.graph.inputs, result.graph.outputs}}; | |||
return result; | |||
} | |||
template <typename T, typename F, typename C> | |||
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { | |||
auto encoded_inputs = encode_inputs(input_vars); | |||
auto encoded_outputs = graph.apply(encoded_inputs, std::forward<F>(f), | |||
std::forward<C>(c)); | |||
return decode_outputs(encoded_outputs); | |||
} | |||
std::string repr() const; | |||
size_t hash() const; | |||
}; | |||
template <typename T> | |||
class GradContext { | |||
public: | |||
using var_t = T; | |||
using vars_t = SmallVector<var_t>; | |||
using expr_t = Expr<T>; | |||
private: | |||
std::unordered_map<var_t, var_t> m_grads; | |||
std::unordered_set<var_t> m_vars_require_grad; | |||
std::function<var_t(var_t, var_t)> m_accumulator; | |||
std::vector<expr_t> m_exprs; | |||
public: | |||
GradContext(std::function<var_t(var_t, var_t)> accumulator): m_accumulator{std::move(accumulator)}{} | |||
SmallVector<bool> get_require_grads(vars_t dests) { | |||
SmallVector<bool> mask; | |||
for (auto&& dest: dests) { | |||
mask.push_back(bool(m_vars_require_grad.count(dest))); | |||
} | |||
return mask; | |||
} | |||
SmallVector<bool> get_has_grads(vars_t dests) { | |||
SmallVector<bool> mask; | |||
for (auto&& dest: dests) { | |||
mask.push_back(bool(m_grads.count(dest))); | |||
} | |||
return mask; | |||
} | |||
void mark_require_grads(vars_t dests) { | |||
for (auto&& dest: dests) { | |||
m_vars_require_grad.insert(dest); | |||
} | |||
} | |||
var_t accumulate_grad(var_t dest, var_t grad) { | |||
if (!m_grads.count(dest)) { | |||
return m_grads[dest] = grad; | |||
} else { | |||
return m_grads[dest] = m_accumulator(m_grads[dest], grad); | |||
} | |||
} | |||
void record_expr(std::shared_ptr<OpDef> op, vars_t inputs, vars_t outputs) { | |||
bool require_grad = false; | |||
for (auto&& input: inputs) { | |||
if (m_vars_require_grad.count(input)) { | |||
require_grad = true; | |||
break; | |||
} | |||
} | |||
if (require_grad) { | |||
m_exprs.push_back({op, inputs, outputs}); | |||
mark_require_grads(outputs); | |||
} | |||
} | |||
template <typename TFunctor> | |||
void backward(vars_t outputs, vars_t output_grads, TFunctor functor) { | |||
size_t nr_outputs = outputs.size(); | |||
for (size_t i = 0; i < nr_outputs; ++i) { | |||
m_grads[outputs[i]] = output_grads[i]; | |||
} | |||
auto exprs = m_exprs; | |||
std::reverse(exprs.begin(), exprs.end()); | |||
for (const expr_t& expr: exprs) { | |||
size_t nr_inputs = expr.inputs.size(); | |||
vars_t input_grads = functor(expr, get_grads(expr.outputs)); | |||
mgb_assert(input_grads.size() == nr_inputs, "input size mismatch"); | |||
for (size_t i = 0; i < nr_inputs; ++i) { | |||
if (input_grads[i] && m_vars_require_grad.count(expr.inputs[i])) { | |||
accumulate_grad(expr.inputs[i], input_grads[i]); | |||
} | |||
} | |||
} | |||
} | |||
var_t get_grad(var_t dest) { | |||
if (m_grads.count(dest)) { | |||
return m_grads.at(dest); | |||
} | |||
return 0; | |||
} | |||
vars_t get_grads(vars_t dests) { | |||
vars_t grads; | |||
for (auto&& dest: dests) { | |||
grads.push_back(get_grad(dest)); | |||
} | |||
return grads; | |||
} | |||
}; | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -22,22 +22,22 @@ using namespace cg; | |||
using namespace imperative; | |||
template <typename T> | |||
T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, | |||
T prepare_backward_graph_inputs(const EncodedSubraph& bg, const T& inputs, | |||
const T& outputs, const T& grads) { | |||
T ret; | |||
size_t i = 0; | |||
for (auto&& t : inputs) { | |||
if (bg.save_for_backward[i++]) { | |||
if (bg.input_mask[i++]) { | |||
ret.push_back(t); | |||
} | |||
} | |||
for (auto&& t : outputs) { | |||
if (bg.save_for_backward[i++]) { | |||
if (bg.input_mask[i++]) { | |||
ret.push_back(t); | |||
} | |||
} | |||
for (auto&& t : grads) { | |||
if (bg.save_for_backward[i++]) { | |||
if (bg.input_mask[i++]) { | |||
ret.push_back(t); | |||
} | |||
} | |||
@@ -45,10 +45,10 @@ T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, | |||
} | |||
template <typename T, typename U> | |||
T expand_grads(const U& bg, const T& outputs) { | |||
T ret(bg.input_has_grad.size()); | |||
for (size_t i = 0, j = 0; i < bg.input_has_grad.size(); ++i) { | |||
if (bg.input_has_grad[i]) { | |||
T expand_grads(const U& mask, const T& outputs) { | |||
T ret(mask.size()); | |||
for (size_t i = 0, j = 0; i < mask.size(); ++i) { | |||
if (mask[i]) { | |||
ret[i] = outputs[j++]; | |||
} | |||
} | |||
@@ -80,7 +80,7 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, | |||
} | |||
SmallVector<TensorPtr> apply_shared_on_physical_tensor( | |||
std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) { | |||
std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs, size_t nr_outputs) { | |||
return OpDef::apply_on_physical_tensor(*def, inputs); | |||
} | |||
@@ -104,8 +104,8 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
} | |||
auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, | |||
{true}); | |||
auto&& save_for_backward = result.save_for_backward; | |||
auto&& input_has_grad = result.input_has_grad; | |||
auto&& save_for_backward = result.input_mask; | |||
auto&& input_has_grad = result.output_mask; | |||
auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | |||
inputs.push_back(outputs[0]); | |||
@@ -124,7 +124,7 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
} | |||
} | |||
inputs.clear(); | |||
auto input_grads = result.backward.apply(backward_graph_inputs, | |||
auto input_grads = result.graph.apply(backward_graph_inputs, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x) { return x; }); | |||
mgb_assert(input_grads.size() == input_has_grad.size()); | |||
@@ -159,8 +159,8 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
input_descs.push_back({a->layout(), a->comp_node()}); | |||
auto result = | |||
OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); | |||
auto&& save_for_backward = result.save_for_backward; | |||
auto&& input_has_grad = result.input_has_grad; | |||
auto&& save_for_backward = result.input_mask; | |||
auto&& input_has_grad = result.output_mask; | |||
auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | |||
inputs.push_back(outputs[0]); | |||
@@ -178,7 +178,7 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
} | |||
} | |||
inputs.clear(); | |||
auto input_grads = result.backward.apply(backward_graph_inputs, | |||
auto input_grads = result.graph.apply(backward_graph_inputs, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x) { return x; }); | |||
mgb_assert(input_grads.size() == input_has_grad.size()); | |||
@@ -245,7 +245,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
prepare_backward_graph_inputs<SmallVector<TensorPtr>>( | |||
bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
auto grads = | |||
expand_grads(bg, bg.backward.apply(backward_graph_inputs, | |||
expand_grads(bg.output_mask, bg.graph.apply(backward_graph_inputs, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x) { return x; })); | |||
@@ -262,7 +262,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
prepare_optimized_backward_inputs<SmallVector<TensorPtr>>( | |||
obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
auto grads2 = expand_grads( | |||
obg, | |||
obg.input_has_grad, | |||
obg.backward.apply(backward_inputs, apply_shared_on_physical_tensor, | |||
[&](auto&& x) { return x; })); | |||