|
|
@@ -182,11 +182,11 @@ OP_TRAIT_REG(Identity, Identity) |
|
|
|
|
|
|
|
namespace { namespace subgraph { |
|
|
|
|
|
|
|
EncodedSubraph make_forward_graph(const OpDef& def, SmallVector<LogicalTensorDesc> inputs) { |
|
|
|
return EncodedSubraph::make(*def.cast_final_safe<SubgraphOp>().graph); |
|
|
|
EncodedSubgraph make_forward_graph(const OpDef& def, SmallVector<LogicalTensorDesc> inputs) { |
|
|
|
return EncodedSubgraph::make(*def.cast_final_safe<SubgraphOp>().graph); |
|
|
|
} |
|
|
|
|
|
|
|
EncodedSubraph make_backward_graph( |
|
|
|
EncodedSubgraph make_backward_graph( |
|
|
|
const OpDef& def, |
|
|
|
const SmallVector<LogicalTensorDesc>& inputs, |
|
|
|
const SmallVector<bool>& input_requires_grad, |
|
|
@@ -199,7 +199,7 @@ EncodedSubraph make_backward_graph( |
|
|
|
} |
|
|
|
} |
|
|
|
auto bgraph = subgraph_detail::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); |
|
|
|
return EncodedSubraph::make_single( |
|
|
|
return EncodedSubgraph::make_single( |
|
|
|
SubgraphOp::make(op.name + "Grad", |
|
|
|
std::make_shared<Subgraph>(bgraph.graph)), |
|
|
|
bgraph.input_mask, bgraph.output_mask); |
|
|
@@ -430,7 +430,7 @@ std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_de |
|
|
|
return {}; |
|
|
|
} |
|
|
|
|
|
|
|
EncodedSubraph make_backward_graph( |
|
|
|
EncodedSubgraph make_backward_graph( |
|
|
|
const OpDef& def, |
|
|
|
const SmallVector<LogicalTensorDesc>& inputs, |
|
|
|
const SmallVector<bool>& input_requires_grad, |
|
|
@@ -452,7 +452,7 @@ EncodedSubraph make_backward_graph( |
|
|
|
grad_outputs_has_grad, key); |
|
|
|
} |
|
|
|
auto compiled_op = CompiledOp::make(bgraph_op, op.gopt_level); |
|
|
|
auto encoded_graph = EncodedSubraph::make_single(compiled_op, backward_graph.input_mask, backward_graph.output_mask); |
|
|
|
auto encoded_graph = EncodedSubgraph::make_single(compiled_op, backward_graph.input_mask, backward_graph.output_mask); |
|
|
|
return encoded_graph; |
|
|
|
} |
|
|
|
|
|
|
|