GitOrigin-RevId: 56d90be0e7
tags/v1.6.0-rc1
@@ -77,7 +77,7 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||||
std::shared_ptr<OptimizedBackwardGraphResult> ret; | std::shared_ptr<OptimizedBackwardGraphResult> ret; | ||||
auto bg = OpDef::make_backward_graph( | auto bg = OpDef::make_backward_graph( | ||||
*ctx.op, inputs, input_requires_grad, output_has_grad); | *ctx.op, inputs, input_requires_grad, output_has_grad); | ||||
if (!bg.backward.empty()) { | |||||
if (!bg.graph.empty()) { | |||||
ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | ||||
} | } | ||||
backward_graph_cache.emplace(key, ret); | backward_graph_cache.emplace(key, ret); | ||||
@@ -37,7 +37,7 @@ void init_imperative_rt(py::module m) { | |||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
const SmallVector<bool>& output_has_grad){ | const SmallVector<bool>& output_has_grad){ | ||||
auto result = OpDef::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | auto result = OpDef::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | ||||
return std::make_tuple("backward_graph", result.save_for_backward, result.input_has_grad); | |||||
return std::make_tuple("backward_graph", result.input_mask, result.output_mask); | |||||
}; | }; | ||||
m.def("make_backward_graph", make_backward_graph); | m.def("make_backward_graph", make_backward_graph); | ||||
} | } |
@@ -16,19 +16,19 @@ | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace imperative; | using namespace imperative; | ||||
OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src) | |||||
: input_has_grad(src.input_has_grad) { | |||||
if (src.backward.exprs.size() <= 1) { | |||||
OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const EncodedSubraph& src) | |||||
: input_has_grad(src.output_mask) { | |||||
if (src.graph.exprs.size() <= 1) { | |||||
// backward graph only contains a single op | // backward graph only contains a single op | ||||
backward = src.backward; | |||||
save_for_backward = src.save_for_backward; | |||||
backward = src.graph; | |||||
save_for_backward = src.input_mask; | |||||
return; | return; | ||||
} | } | ||||
save_for_backward.resize(src.save_for_backward.size(), false); | |||||
save_for_backward.resize(src.input_mask.size(), false); | |||||
auto&& graph = src.backward; | |||||
auto&& mask = src.save_for_backward; | |||||
size_t input_size = src.input_has_grad.size(); | |||||
auto&& graph = src.graph; | |||||
auto&& mask = src.input_mask; | |||||
size_t input_size = src.output_mask.size(); | |||||
size_t output_size = (mask.size() - input_size) / 2; | size_t output_size = (mask.size() - input_size) / 2; | ||||
mgb_assert(input_size + output_size * 2 == mask.size()); | mgb_assert(input_size + output_size * 2 == mask.size()); | ||||
@@ -80,7 +80,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> OpDef::infer_output_attrs_falli | |||||
return def.trait()->infer_output_attrs_fallible(def, inputs); | return def.trait()->infer_output_attrs_fallible(def, inputs); | ||||
} | } | ||||
BackwardGraphResult OpDef::make_backward_graph( | |||||
EncodedSubraph OpDef::make_backward_graph( | |||||
const OpDef& def, | const OpDef& def, | ||||
const SmallVector<LogicalTensorDesc>& inputs, | const SmallVector<LogicalTensorDesc>& inputs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
@@ -668,14 +668,14 @@ struct ProxyGraph::GradGraph { | |||||
cg::VarNode* grad; | cg::VarNode* grad; | ||||
}; | }; | ||||
BackwardGraphResult | |||||
EncodedSubraph | |||||
ProxyGraph::make_backward_graph( | ProxyGraph::make_backward_graph( | ||||
const OpDef& opdef, | const OpDef& opdef, | ||||
const SmallVector<LogicalTensorDesc>& input_descs, | const SmallVector<LogicalTensorDesc>& input_descs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
const SmallVector<bool>& output_has_grad) { | const SmallVector<bool>& output_has_grad) { | ||||
ThinHashMap<VarNode*, size_t> var2idx; | ThinHashMap<VarNode*, size_t> var2idx; | ||||
auto push = [&var2idx, cnt=0](VarNode* var) mutable { | |||||
auto push = [&var2idx, cnt=1](VarNode* var) mutable { //cnt is always greater non zero | |||||
auto&& ret = var2idx.emplace(var, cnt ++); | auto&& ret = var2idx.emplace(var, cnt ++); | ||||
mgb_assert(ret.second, "var %s has been already inserted", var->cname()); | mgb_assert(ret.second, "var %s has been already inserted", var->cname()); | ||||
return ret.first->second; | return ret.first->second; | ||||
@@ -702,8 +702,8 @@ ProxyGraph::make_backward_graph( | |||||
} | } | ||||
auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); | auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); | ||||
BackwardGraphResult result; | |||||
auto&& igraph = result.backward; | |||||
EncodedSubraph result; | |||||
auto&& igraph = result.graph; | |||||
size_t nr_backward_graph_inputs = 0; | size_t nr_backward_graph_inputs = 0; | ||||
auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, | auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, | ||||
@@ -735,7 +735,7 @@ ProxyGraph::make_backward_graph( | |||||
// set backward graph outputs | // set backward graph outputs | ||||
cg::DepOprIter iter{gen_expr}; | cg::DepOprIter iter{gen_expr}; | ||||
iter.set_visited(fwd); | iter.set_visited(fwd); | ||||
result.input_has_grad.resize(inputs.size()); | |||||
result.output_mask.resize(inputs.size()); | |||||
VarNodeArray output_grads_with_unused_var; | VarNodeArray output_grads_with_unused_var; | ||||
{ | { | ||||
@@ -760,6 +760,7 @@ ProxyGraph::make_backward_graph( | |||||
if (grad_results.valid()) { | if (grad_results.valid()) { | ||||
grad = grad_results.val()[i]; | grad = grad_results.val()[i]; | ||||
} else { | } else { | ||||
mgb_assert(gfunc, "could not find grad function"); | |||||
auto res = (*gfunc)(fwd, i, output_grads_with_unused_var); | auto res = (*gfunc)(fwd, i, output_grads_with_unused_var); | ||||
if (res.from_single()) { | if (res.from_single()) { | ||||
grad = res.single(); | grad = res.single(); | ||||
@@ -776,9 +777,9 @@ ProxyGraph::make_backward_graph( | |||||
fwd->dyn_typeinfo()->name, i); | fwd->dyn_typeinfo()->name, i); | ||||
iter.add(grad); | iter.add(grad); | ||||
igraph.outputs.push_back(var2idx.at(grad)); | igraph.outputs.push_back(var2idx.at(grad)); | ||||
result.input_has_grad[i] = true; | |||||
result.output_mask[i] = true; | |||||
} else { | } else { | ||||
result.input_has_grad[i] = false; | |||||
result.output_mask[i] = false; | |||||
} | } | ||||
} | } | ||||
if (igraph.outputs.empty()) { | if (igraph.outputs.empty()) { | ||||
@@ -787,15 +788,15 @@ ProxyGraph::make_backward_graph( | |||||
// set backward graph inputs | // set backward graph inputs | ||||
igraph.inputs.reserve(nr_backward_graph_inputs); | igraph.inputs.reserve(nr_backward_graph_inputs); | ||||
result.save_for_backward.reserve(nr_backward_graph_inputs); | |||||
result.input_mask.reserve(nr_backward_graph_inputs); | |||||
auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) { | auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) { | ||||
for (auto&& i: vars) { | for (auto&& i: vars) { | ||||
auto&& iter = var2idx.find(i); | auto&& iter = var2idx.find(i); | ||||
if (iter != var2idx.end()) { | if (iter != var2idx.end()) { | ||||
igraph.inputs.push_back(iter->second); | igraph.inputs.push_back(iter->second); | ||||
result.save_for_backward.push_back(true); | |||||
result.input_mask.push_back(true); | |||||
} else { | } else { | ||||
result.save_for_backward.push_back(false); | |||||
result.input_mask.push_back(false); | |||||
} | } | ||||
} | } | ||||
}; | }; | ||||
@@ -40,7 +40,7 @@ public: | |||||
const SmallVector<Tensor*>& outputs, | const SmallVector<Tensor*>& outputs, | ||||
const SmallVector<Tensor*>& workspace); | const SmallVector<Tensor*>& workspace); | ||||
BackwardGraphResult make_backward_graph( | |||||
EncodedSubraph make_backward_graph( | |||||
const OpDef& opdef, | const OpDef& opdef, | ||||
const SmallVector<LogicalTensorDesc>& input_descs, | const SmallVector<LogicalTensorDesc>& input_descs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
@@ -133,7 +133,7 @@ size_t get_backward_graph_hash_key(const OpDef& def, | |||||
return state.digest(); | return state.digest(); | ||||
} | } | ||||
struct BackwardGraphCache : std::unordered_map<size_t, BackwardGraphResult>, CompNodeDepedentObject { | |||||
struct BackwardGraphCache : std::unordered_map<size_t, EncodedSubraph>, CompNodeDepedentObject { | |||||
std::shared_ptr<void> on_comp_node_finalize() override { | std::shared_ptr<void> on_comp_node_finalize() override { | ||||
clear(); | clear(); | ||||
return {}; | return {}; | ||||
@@ -142,7 +142,7 @@ struct BackwardGraphCache : std::unordered_map<size_t, BackwardGraphResult>, Com | |||||
} // anonymous namespace | } // anonymous namespace | ||||
BackwardGraphResult | |||||
EncodedSubraph | |||||
make_backward_graph(const OpDef& def, | make_backward_graph(const OpDef& def, | ||||
const SmallVector<LogicalTensorDesc>& inputs, | const SmallVector<LogicalTensorDesc>& inputs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
@@ -101,5 +101,26 @@ void Subgraph::replace_vars( | |||||
} | } | ||||
} | } | ||||
std::string EncodedSubraph::repr() const { | |||||
std::string buffer; | |||||
buffer.push_back('|'); | |||||
for (size_t i = 0; i < input_mask.size(); ++i) { | |||||
buffer.push_back(input_mask[i] ? '#' : ' '); | |||||
} | |||||
buffer.push_back('|'); | |||||
buffer.push_back('\n'); | |||||
buffer.append(graph.repr()); | |||||
buffer.push_back('|'); | |||||
for (size_t i = 0; i < output_mask.size(); ++i) { | |||||
buffer.push_back(output_mask[i] ? '#' : ' '); | |||||
} | |||||
buffer.push_back('|'); | |||||
return buffer; | |||||
} | |||||
size_t EncodedSubraph::hash() const { | |||||
return std::hash<std::string>{}(repr()); | |||||
} | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb |
@@ -19,7 +19,7 @@ struct OptimizedBackwardGraphResult { | |||||
SmallVector<bool> save_for_backward; | SmallVector<bool> save_for_backward; | ||||
SmallVector<bool> input_has_grad; | SmallVector<bool> input_has_grad; | ||||
OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph); | |||||
OptimizedBackwardGraphResult(const EncodedSubraph& bgraph); | |||||
}; | }; | ||||
} // namespace mgb::imperative | } // namespace mgb::imperative |
@@ -29,12 +29,6 @@ enum DispatchMode { | |||||
using SharedOp = std::shared_ptr<OpDef>; | using SharedOp = std::shared_ptr<OpDef>; | ||||
struct BackwardGraphResult { | |||||
Subgraph backward; | |||||
SmallVector<bool> save_for_backward; | |||||
SmallVector<bool> input_has_grad; | |||||
}; | |||||
class OpDef : public Hashable, | class OpDef : public Hashable, | ||||
public NonCopyableObj, | public NonCopyableObj, | ||||
public std::enable_shared_from_this<OpDef> { | public std::enable_shared_from_this<OpDef> { | ||||
@@ -91,7 +85,7 @@ public: | |||||
const SmallVector<TensorPtr>& inputs_tensors, | const SmallVector<TensorPtr>& inputs_tensors, | ||||
const SmallVector<MemoryDesc>& inputs_mems); | const SmallVector<MemoryDesc>& inputs_mems); | ||||
static BackwardGraphResult make_backward_graph( | |||||
static EncodedSubraph make_backward_graph( | |||||
const OpDef& def, | const OpDef& def, | ||||
const SmallVector<LogicalTensorDesc>& inputs, | const SmallVector<LogicalTensorDesc>& inputs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
@@ -38,7 +38,7 @@ void exec(const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs, | const SmallVector<TensorPtr>& inputs, | ||||
const SmallVector<TensorPtr>& outputs); | const SmallVector<TensorPtr>& outputs); | ||||
BackwardGraphResult | |||||
EncodedSubraph | |||||
make_backward_graph(const OpDef& def, | make_backward_graph(const OpDef& def, | ||||
const SmallVector<LogicalTensorDesc>& inputs, | const SmallVector<LogicalTensorDesc>& inputs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
@@ -96,5 +96,185 @@ struct Subgraph { | |||||
bool operator==(const Subgraph& rhs) const; | bool operator==(const Subgraph& rhs) const; | ||||
}; | }; | ||||
struct EncodedSubraph { | |||||
Subgraph graph; | |||||
SmallVector<bool> input_mask; | |||||
SmallVector<bool> output_mask; | |||||
template <typename TContainer> | |||||
TContainer encode_inputs(TContainer inputs) const { | |||||
TContainer encoded_inputs; | |||||
size_t index = 0; | |||||
for (auto&& input : inputs) { | |||||
mgb_assert(index < input_mask.size(), "index out of range"); | |||||
if (input_mask[index++]) { | |||||
encoded_inputs.push_back(input); | |||||
} | |||||
} | |||||
mgb_assert(index == input_mask.size(), "mask size mismatch"); | |||||
return encoded_inputs; | |||||
} | |||||
template <typename TContainer> | |||||
TContainer encode_outputs(TContainer outputs) const { | |||||
TContainer encoded_outputs; | |||||
size_t index = 0; | |||||
for (auto&& output : outputs) { | |||||
mgb_assert(index < output_mask.size(), "index out of range"); | |||||
if (output_mask[index++]) { | |||||
encoded_outputs.push_back(output); | |||||
} | |||||
} | |||||
mgb_assert(index == output_mask.size(), "mask size mismatch"); | |||||
return encoded_outputs; | |||||
} | |||||
template <typename TContainer> | |||||
TContainer decode_outputs(TContainer outputs) const { | |||||
TContainer decoded_outputs; | |||||
size_t index = 0; | |||||
for (size_t i = 0; i < output_mask.size(); i++) { | |||||
mgb_assert(index < output_mask.size(), "index out of range"); | |||||
if (output_mask[i]) { | |||||
decoded_outputs.push_back(outputs[index++]); | |||||
} else { | |||||
decoded_outputs.emplace_back(); | |||||
} | |||||
} | |||||
mgb_assert(decoded_outputs.size() == output_mask.size(), | |||||
"mask size mismatch"); | |||||
return decoded_outputs; | |||||
} | |||||
static EncodedSubraph make(Subgraph graph) { | |||||
EncodedSubraph result; | |||||
result.input_mask = graph.gen_input_mask(); | |||||
result.output_mask = graph.gen_output_mask(); | |||||
graph.inputs = result.encode_inputs(graph.inputs); | |||||
graph.outputs = result.encode_outputs(graph.outputs); | |||||
result.graph = graph; | |||||
return result; | |||||
} | |||||
static EncodedSubraph make_single( | |||||
std::shared_ptr<OpDef> op, | |||||
SmallVector<bool> input_mask, | |||||
SmallVector<bool> output_mask) { | |||||
EncodedSubraph result; | |||||
result.input_mask = input_mask; | |||||
result.output_mask = output_mask; | |||||
Subgraph::var_t last_var = 0; | |||||
for (auto&& mask: input_mask) { | |||||
if (mask) { | |||||
result.graph.inputs.push_back(++last_var); | |||||
} | |||||
} | |||||
for (auto&& mask: output_mask) { | |||||
if (mask) { | |||||
result.graph.outputs.push_back(++last_var); | |||||
} | |||||
} | |||||
result.graph.exprs = {Subgraph::expr_t{op, result.graph.inputs, result.graph.outputs}}; | |||||
return result; | |||||
} | |||||
template <typename T, typename F, typename C> | |||||
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { | |||||
auto encoded_inputs = encode_inputs(input_vars); | |||||
auto encoded_outputs = graph.apply(encoded_inputs, std::forward<F>(f), | |||||
std::forward<C>(c)); | |||||
return decode_outputs(encoded_outputs); | |||||
} | |||||
std::string repr() const; | |||||
size_t hash() const; | |||||
}; | |||||
template <typename T> | |||||
class GradContext { | |||||
public: | |||||
using var_t = T; | |||||
using vars_t = SmallVector<var_t>; | |||||
using expr_t = Expr<T>; | |||||
private: | |||||
std::unordered_map<var_t, var_t> m_grads; | |||||
std::unordered_set<var_t> m_vars_require_grad; | |||||
std::function<var_t(var_t, var_t)> m_accumulator; | |||||
std::vector<expr_t> m_exprs; | |||||
public: | |||||
GradContext(std::function<var_t(var_t, var_t)> accumulator): m_accumulator{std::move(accumulator)}{} | |||||
SmallVector<bool> get_require_grads(vars_t dests) { | |||||
SmallVector<bool> mask; | |||||
for (auto&& dest: dests) { | |||||
mask.push_back(bool(m_vars_require_grad.count(dest))); | |||||
} | |||||
return mask; | |||||
} | |||||
SmallVector<bool> get_has_grads(vars_t dests) { | |||||
SmallVector<bool> mask; | |||||
for (auto&& dest: dests) { | |||||
mask.push_back(bool(m_grads.count(dest))); | |||||
} | |||||
return mask; | |||||
} | |||||
void mark_require_grads(vars_t dests) { | |||||
for (auto&& dest: dests) { | |||||
m_vars_require_grad.insert(dest); | |||||
} | |||||
} | |||||
var_t accumulate_grad(var_t dest, var_t grad) { | |||||
if (!m_grads.count(dest)) { | |||||
return m_grads[dest] = grad; | |||||
} else { | |||||
return m_grads[dest] = m_accumulator(m_grads[dest], grad); | |||||
} | |||||
} | |||||
void record_expr(std::shared_ptr<OpDef> op, vars_t inputs, vars_t outputs) { | |||||
bool require_grad = false; | |||||
for (auto&& input: inputs) { | |||||
if (m_vars_require_grad.count(input)) { | |||||
require_grad = true; | |||||
break; | |||||
} | |||||
} | |||||
if (require_grad) { | |||||
m_exprs.push_back({op, inputs, outputs}); | |||||
mark_require_grads(outputs); | |||||
} | |||||
} | |||||
template <typename TFunctor> | |||||
void backward(vars_t outputs, vars_t output_grads, TFunctor functor) { | |||||
size_t nr_outputs = outputs.size(); | |||||
for (size_t i = 0; i < nr_outputs; ++i) { | |||||
m_grads[outputs[i]] = output_grads[i]; | |||||
} | |||||
auto exprs = m_exprs; | |||||
std::reverse(exprs.begin(), exprs.end()); | |||||
for (const expr_t& expr: exprs) { | |||||
size_t nr_inputs = expr.inputs.size(); | |||||
vars_t input_grads = functor(expr, get_grads(expr.outputs)); | |||||
mgb_assert(input_grads.size() == nr_inputs, "input size mismatch"); | |||||
for (size_t i = 0; i < nr_inputs; ++i) { | |||||
if (input_grads[i] && m_vars_require_grad.count(expr.inputs[i])) { | |||||
accumulate_grad(expr.inputs[i], input_grads[i]); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
var_t get_grad(var_t dest) { | |||||
if (m_grads.count(dest)) { | |||||
return m_grads.at(dest); | |||||
} | |||||
return 0; | |||||
} | |||||
vars_t get_grads(vars_t dests) { | |||||
vars_t grads; | |||||
for (auto&& dest: dests) { | |||||
grads.push_back(get_grad(dest)); | |||||
} | |||||
return grads; | |||||
} | |||||
}; | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb |
@@ -22,22 +22,22 @@ using namespace cg; | |||||
using namespace imperative; | using namespace imperative; | ||||
template <typename T> | template <typename T> | ||||
T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, | |||||
T prepare_backward_graph_inputs(const EncodedSubraph& bg, const T& inputs, | |||||
const T& outputs, const T& grads) { | const T& outputs, const T& grads) { | ||||
T ret; | T ret; | ||||
size_t i = 0; | size_t i = 0; | ||||
for (auto&& t : inputs) { | for (auto&& t : inputs) { | ||||
if (bg.save_for_backward[i++]) { | |||||
if (bg.input_mask[i++]) { | |||||
ret.push_back(t); | ret.push_back(t); | ||||
} | } | ||||
} | } | ||||
for (auto&& t : outputs) { | for (auto&& t : outputs) { | ||||
if (bg.save_for_backward[i++]) { | |||||
if (bg.input_mask[i++]) { | |||||
ret.push_back(t); | ret.push_back(t); | ||||
} | } | ||||
} | } | ||||
for (auto&& t : grads) { | for (auto&& t : grads) { | ||||
if (bg.save_for_backward[i++]) { | |||||
if (bg.input_mask[i++]) { | |||||
ret.push_back(t); | ret.push_back(t); | ||||
} | } | ||||
} | } | ||||
@@ -45,10 +45,10 @@ T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, | |||||
} | } | ||||
template <typename T, typename U> | template <typename T, typename U> | ||||
T expand_grads(const U& bg, const T& outputs) { | |||||
T ret(bg.input_has_grad.size()); | |||||
for (size_t i = 0, j = 0; i < bg.input_has_grad.size(); ++i) { | |||||
if (bg.input_has_grad[i]) { | |||||
T expand_grads(const U& mask, const T& outputs) { | |||||
T ret(mask.size()); | |||||
for (size_t i = 0, j = 0; i < mask.size(); ++i) { | |||||
if (mask[i]) { | |||||
ret[i] = outputs[j++]; | ret[i] = outputs[j++]; | ||||
} | } | ||||
} | } | ||||
@@ -80,7 +80,7 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, | |||||
} | } | ||||
SmallVector<TensorPtr> apply_shared_on_physical_tensor( | SmallVector<TensorPtr> apply_shared_on_physical_tensor( | ||||
std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) { | |||||
std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs, size_t nr_outputs) { | |||||
return OpDef::apply_on_physical_tensor(*def, inputs); | return OpDef::apply_on_physical_tensor(*def, inputs); | ||||
} | } | ||||
@@ -104,8 +104,8 @@ TEST(TestImperative, BackwardGraphBasic) { | |||||
} | } | ||||
auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, | auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, | ||||
{true}); | {true}); | ||||
auto&& save_for_backward = result.save_for_backward; | |||||
auto&& input_has_grad = result.input_has_grad; | |||||
auto&& save_for_backward = result.input_mask; | |||||
auto&& input_has_grad = result.output_mask; | |||||
auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | ||||
inputs.push_back(outputs[0]); | inputs.push_back(outputs[0]); | ||||
@@ -124,7 +124,7 @@ TEST(TestImperative, BackwardGraphBasic) { | |||||
} | } | ||||
} | } | ||||
inputs.clear(); | inputs.clear(); | ||||
auto input_grads = result.backward.apply(backward_graph_inputs, | |||||
auto input_grads = result.graph.apply(backward_graph_inputs, | |||||
apply_shared_on_physical_tensor, | apply_shared_on_physical_tensor, | ||||
[&](auto&& x) { return x; }); | [&](auto&& x) { return x; }); | ||||
mgb_assert(input_grads.size() == input_has_grad.size()); | mgb_assert(input_grads.size() == input_has_grad.size()); | ||||
@@ -159,8 +159,8 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||||
input_descs.push_back({a->layout(), a->comp_node()}); | input_descs.push_back({a->layout(), a->comp_node()}); | ||||
auto result = | auto result = | ||||
OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); | OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); | ||||
auto&& save_for_backward = result.save_for_backward; | |||||
auto&& input_has_grad = result.input_has_grad; | |||||
auto&& save_for_backward = result.input_mask; | |||||
auto&& input_has_grad = result.output_mask; | |||||
auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | ||||
inputs.push_back(outputs[0]); | inputs.push_back(outputs[0]); | ||||
@@ -178,7 +178,7 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||||
} | } | ||||
} | } | ||||
inputs.clear(); | inputs.clear(); | ||||
auto input_grads = result.backward.apply(backward_graph_inputs, | |||||
auto input_grads = result.graph.apply(backward_graph_inputs, | |||||
apply_shared_on_physical_tensor, | apply_shared_on_physical_tensor, | ||||
[&](auto&& x) { return x; }); | [&](auto&& x) { return x; }); | ||||
mgb_assert(input_grads.size() == input_has_grad.size()); | mgb_assert(input_grads.size() == input_has_grad.size()); | ||||
@@ -245,7 +245,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||||
prepare_backward_graph_inputs<SmallVector<TensorPtr>>( | prepare_backward_graph_inputs<SmallVector<TensorPtr>>( | ||||
bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | ||||
auto grads = | auto grads = | ||||
expand_grads(bg, bg.backward.apply(backward_graph_inputs, | |||||
expand_grads(bg.output_mask, bg.graph.apply(backward_graph_inputs, | |||||
apply_shared_on_physical_tensor, | apply_shared_on_physical_tensor, | ||||
[&](auto&& x) { return x; })); | [&](auto&& x) { return x; })); | ||||
@@ -262,7 +262,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||||
prepare_optimized_backward_inputs<SmallVector<TensorPtr>>( | prepare_optimized_backward_inputs<SmallVector<TensorPtr>>( | ||||
obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | ||||
auto grads2 = expand_grads( | auto grads2 = expand_grads( | ||||
obg, | |||||
obg.input_has_grad, | |||||
obg.backward.apply(backward_inputs, apply_shared_on_physical_tensor, | obg.backward.apply(backward_inputs, apply_shared_on_physical_tensor, | ||||
[&](auto&& x) { return x; })); | [&](auto&& x) { return x; })); | ||||