GitOrigin-RevId: e1a1ebdf1c
tags/v1.9.0
@@ -7,6 +7,7 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import collections | import collections | ||||
import itertools | |||||
from typing import Iterable, Union | from typing import Iterable, Union | ||||
import numpy as np | import numpy as np | ||||
@@ -22,6 +23,7 @@ from .._imperative_rt.core2 import ( | |||||
) | ) | ||||
from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | ||||
from .._wrap import as_device | from .._wrap import as_device | ||||
from ..autodiff.grad import Function | |||||
from ..ops import builtin | from ..ops import builtin | ||||
from ..ops.special import Const | from ..ops.special import Const | ||||
from .amp import _high_prec_dtype, _low_prec_dtype | from .amp import _high_prec_dtype, _low_prec_dtype | ||||
@@ -197,8 +199,15 @@ def _normalize_axis( | |||||
_opr_map = { | _opr_map = { | ||||
("-", 1): builtin.Elemwise(mode="negate"), | ("-", 1): builtin.Elemwise(mode="negate"), | ||||
("abs", 1): builtin.Elemwise(mode="abs"), | |||||
("exp", 1): builtin.Elemwise(mode="exp"), | |||||
("log1p", 1): builtin.Elemwise(mode="log1p"), | |||||
("relu", 1): builtin.Elemwise(mode="relu"), | |||||
("cond_leq_mov", 3): builtin.Elemwise(mode="cond_leq_mov"), | |||||
("fma3", 3): builtin.Elemwise(mode="FUSE_MUL_ADD3"), | ("fma3", 3): builtin.Elemwise(mode="FUSE_MUL_ADD3"), | ||||
("fma4", 4): builtin.Elemwise(mode="FUSE_MUL_ADD4"), | ("fma4", 4): builtin.Elemwise(mode="FUSE_MUL_ADD4"), | ||||
("[?:]", 2): builtin.Subtensor(items=[(0, True, False, False, False)]), | |||||
("[:?]", 2): builtin.Subtensor(items=[(0, False, True, False, False)]), | |||||
} | } | ||||
for name, mode in [ | for name, mode in [ | ||||
@@ -209,15 +218,21 @@ for name, mode in [ | |||||
("//", "floor_div"), | ("//", "floor_div"), | ||||
("**", "pow"), | ("**", "pow"), | ||||
("max", "max"), | ("max", "max"), | ||||
("min", "min"), | |||||
("additive", "add"), | ("additive", "add"), | ||||
("exp", "EXP"), | ("exp", "EXP"), | ||||
("switch_gt0", "switch_gt0"), | |||||
("abs_grad", "abs_grad"), | |||||
]: | ]: | ||||
_opr_map[(name, 2)] = builtin.Elemwise(mode=mode) | _opr_map[(name, 2)] = builtin.Elemwise(mode=mode) | ||||
def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||||
def subgraph( | |||||
name, dtype, device, nr_inputs, gopt_level=None, jit_fusion=False, custom_grad=False | |||||
): | |||||
if device.physical_name.startswith("cpu"): | if device.physical_name.startswith("cpu"): | ||||
gopt_level = None # disable jit and compile | gopt_level = None # disable jit and compile | ||||
jit_fusion = False | |||||
def as_op(op, nargs): | def as_op(op, nargs): | ||||
if isinstance(op, str): | if isinstance(op, str): | ||||
@@ -241,14 +256,64 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||||
def apply_const(value, dtype=dtype, device=device): | def apply_const(value, dtype=dtype, device=device): | ||||
return builder.apply_const(value, dtype, device) | return builder.apply_const(value, dtype, device) | ||||
def build(builder, outputs, outputs_has_grad): | |||||
builder = type(builder)(builder) | |||||
builder.outputs(outputs) | |||||
builder.outputs_has_grad(outputs_has_grad) | |||||
if jit_fusion: | |||||
assert gopt_level is None | |||||
op = lambda: builder.jit_fuse() | |||||
elif gopt_level is None: | |||||
op = lambda: builder.get() | |||||
else: | |||||
op = lambda: builder.compile(gopt_level) | |||||
return op | |||||
inputs = [builder.input() for _ in range(nr_inputs)] | inputs = [builder.input() for _ in range(nr_inputs)] | ||||
outputs, outputs_has_grad = func(inputs, apply_expr, apply_const) | |||||
builder.outputs(outputs) | |||||
builder.outputs_has_grad(outputs_has_grad) | |||||
if gopt_level is None: | |||||
return lambda: builder.get() | |||||
if not custom_grad: | |||||
outputs, outputs_has_grad = func(inputs, apply_expr, apply_const) | |||||
return build(builder, outputs, outputs_has_grad) | |||||
else: | else: | ||||
return lambda: builder.compile(gopt_level) | |||||
gen = func(inputs, apply_expr, apply_const) | |||||
outputs = gen.send(None) | |||||
nr_outputs = len(outputs) | |||||
forward_fn = build(builder, outputs, [False] * nr_outputs) | |||||
output_grads = [builder.input() for _ in range(nr_outputs)] | |||||
input_grads = gen.send(output_grads) | |||||
assert len(input_grads) == nr_inputs | |||||
input_grads_mask = [input_grad is not None for input_grad in input_grads] | |||||
indices = [ | |||||
i - 1 if mask else None | |||||
for i, mask in zip( | |||||
itertools.accumulate(input_grads_mask), input_grads_mask | |||||
) | |||||
] | |||||
encoded_input_grads = [grad for grad in input_grads if grad is not None] | |||||
backward_fn = build( | |||||
builder, encoded_input_grads, [False] * len(encoded_input_grads) | |||||
) | |||||
class SubgraphOp(Function): | |||||
def __init__(self): | |||||
self.inputs = None | |||||
def forward(self, *inputs): | |||||
self.inputs = inputs | |||||
return apply(forward_fn(), *inputs) | |||||
def backward(self, *output_grads): | |||||
inputs = self.inputs | |||||
self.inputs = None | |||||
encoded_input_grads = apply(backward_fn(), *inputs, *output_grads) | |||||
input_grads = [ | |||||
encoded_input_grads[i] if i is not None else None | |||||
for i in indices | |||||
] | |||||
return input_grads | |||||
gen.close() | |||||
return SubgraphOp | |||||
return decorator | return decorator | ||||
@@ -274,15 +339,37 @@ def interpret_subgraph(func, dtype, device): | |||||
return Const(value, dtype=dtype, device=device)()[0] | return Const(value, dtype=dtype, device=device)()[0] | ||||
outputs, outputs_has_grad = func(args, apply_expr, apply_const) | outputs, outputs_has_grad = func(args, apply_expr, apply_const) | ||||
outputs = [ | |||||
output if has_grad else output.detach() | |||||
for output, has_grad in zip(outputs, outputs_has_grad) | |||||
] | |||||
return outputs | return outputs | ||||
return decorated_func | return decorated_func | ||||
def subgraph_fn(name, dtype, device, nr_inputs, gopt_level=None, interpret=False): | |||||
def subgraph_fn( | |||||
name, | |||||
dtype, | |||||
device, | |||||
nr_inputs, | |||||
gopt_level=None, | |||||
jit_fusion=False, | |||||
custom_grad=False, | |||||
*, | |||||
interpret=False | |||||
): | |||||
def decorator(func): | def decorator(func): | ||||
if not interpret: | if not interpret: | ||||
op = subgraph(name, dtype, device, nr_inputs, gopt_level=gopt_level)(func) | |||||
op = subgraph( | |||||
name, | |||||
dtype, | |||||
device, | |||||
nr_inputs, | |||||
gopt_level=gopt_level, | |||||
jit_fusion=jit_fusion, | |||||
custom_grad=custom_grad, | |||||
)(func) | |||||
return lambda *args: apply(op(), *args) | return lambda *args: apply(op(), *args) | ||||
else: | else: | ||||
return interpret_subgraph(func, dtype, device) | return interpret_subgraph(func, dtype, device) | ||||
@@ -33,6 +33,7 @@ from ..core._imperative_rt.ops import ( | |||||
ExternOpr, | ExternOpr, | ||||
RemoteRecv, | RemoteRecv, | ||||
RemoteSend, | RemoteSend, | ||||
set_jit_enabled, | |||||
) | ) | ||||
from ..core._trace_option import set_symbolic_shape | from ..core._trace_option import set_symbolic_shape | ||||
from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
@@ -711,12 +712,14 @@ class trace: | |||||
graph = G.Graph() | graph = G.Graph() | ||||
jit_enabled = set_jit_enabled(False) | |||||
dest_vars = self._trace.dump( | dest_vars = self._trace.dump( | ||||
graph, | graph, | ||||
input_bindings, | input_bindings, | ||||
[*zip(self._output_bindings, output_names)], | [*zip(self._output_bindings, output_names)], | ||||
prefer_input_names, | prefer_input_names, | ||||
) | ) | ||||
set_jit_enabled(jit_enabled) | |||||
# dest_vars = [i._node for i in dest_vars] | # dest_vars = [i._node for i in dest_vars] | ||||
@@ -577,21 +577,26 @@ void init_ops(py::module m) { | |||||
struct PySubgraphBuilder { | struct PySubgraphBuilder { | ||||
explicit PySubgraphBuilder(std::string name) : name{name} {} | explicit PySubgraphBuilder(std::string name) : name{name} {} | ||||
std::string name; | std::string name; | ||||
std::shared_ptr<Subgraph> graph_storage = std::make_shared<Subgraph>(); | |||||
std::shared_ptr<UniqueKey> graph_key = std::make_shared<UniqueKey>(); | |||||
Subgraph& graph = *graph_storage; | |||||
Subgraph graph; | |||||
mgb::SmallVector<bool> output_grad_mask; | mgb::SmallVector<bool> output_grad_mask; | ||||
Subgraph::var_t next_var = 1; | Subgraph::var_t next_var = 1; | ||||
std::shared_ptr<mgb::Hashable> key = nullptr; | |||||
std::shared_ptr<OpDef> build() const { | |||||
return SubgraphOp::make(name, graph_storage, output_grad_mask, graph_key); | |||||
std::shared_ptr<OpDef> build() { | |||||
if (key == nullptr) { | |||||
key = std::make_shared<UniqueKey>(); | |||||
} | |||||
return SubgraphOp::make( | |||||
name, std::make_shared<Subgraph>(graph), output_grad_mask, key); | |||||
} | } | ||||
}; | }; | ||||
py::class_<PySubgraphBuilder>(m, "SubgraphBuilder") | py::class_<PySubgraphBuilder>(m, "SubgraphBuilder") | ||||
.def(py::init<std::string>()) | .def(py::init<std::string>()) | ||||
.def(py::init<PySubgraphBuilder>()) | |||||
.def("input", | .def("input", | ||||
[](PySubgraphBuilder& self) { | [](PySubgraphBuilder& self) { | ||||
mgb_assert(self.key == nullptr); | |||||
auto var = self.next_var++; | auto var = self.next_var++; | ||||
self.graph.inputs.push_back(var); | self.graph.inputs.push_back(var); | ||||
return var; | return var; | ||||
@@ -599,6 +604,7 @@ void init_ops(py::module m) { | |||||
.def("apply", | .def("apply", | ||||
[](PySubgraphBuilder& self, std::shared_ptr<OpDef> op, | [](PySubgraphBuilder& self, std::shared_ptr<OpDef> op, | ||||
Subgraph::vars_t inputs, size_t nr_outputs) { | Subgraph::vars_t inputs, size_t nr_outputs) { | ||||
mgb_assert(self.key == nullptr); | |||||
Subgraph::vars_t outputs; | Subgraph::vars_t outputs; | ||||
for (size_t i = 0; i < nr_outputs; ++i) { | for (size_t i = 0; i < nr_outputs; ++i) { | ||||
outputs.push_back(self.next_var++); | outputs.push_back(self.next_var++); | ||||
@@ -609,6 +615,7 @@ void init_ops(py::module m) { | |||||
.def("apply_const", | .def("apply_const", | ||||
[](PySubgraphBuilder& self, py::object value, mgb::DType dtype, | [](PySubgraphBuilder& self, py::object value, mgb::DType dtype, | ||||
mgb::CompNode cn) { | mgb::CompNode cn) { | ||||
mgb_assert(self.key == nullptr); | |||||
auto var = self.next_var++; | auto var = self.next_var++; | ||||
mgb::HostTensorND hvalue(cn); | mgb::HostTensorND hvalue(cn); | ||||
npy::np2tensor( | npy::np2tensor( | ||||
@@ -619,11 +626,13 @@ void init_ops(py::module m) { | |||||
}) | }) | ||||
.def("outputs", | .def("outputs", | ||||
[](PySubgraphBuilder& self, Subgraph::vars_t outputs) { | [](PySubgraphBuilder& self, Subgraph::vars_t outputs) { | ||||
mgb_assert(self.key == nullptr); | |||||
self.graph.outputs = outputs; | self.graph.outputs = outputs; | ||||
self.output_grad_mask.resize(outputs.size(), true); | self.output_grad_mask.resize(outputs.size(), true); | ||||
}) | }) | ||||
.def("outputs_has_grad", | .def("outputs_has_grad", | ||||
[](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad) { | [](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad) { | ||||
mgb_assert(self.key == nullptr); | |||||
mgb_assert( | mgb_assert( | ||||
self.graph.outputs.size() == self.output_grad_mask.size()); | self.graph.outputs.size() == self.output_grad_mask.size()); | ||||
self.output_grad_mask = outputs_has_grad; | self.output_grad_mask = outputs_has_grad; | ||||
@@ -632,11 +641,18 @@ void init_ops(py::module m) { | |||||
[](PySubgraphBuilder& self) { | [](PySubgraphBuilder& self) { | ||||
return (std::shared_ptr<OpDef>)self.build(); | return (std::shared_ptr<OpDef>)self.build(); | ||||
}) | }) | ||||
.def("compile", [](PySubgraphBuilder& self, int gopt_level) { | |||||
.def("compile", | |||||
[](PySubgraphBuilder& self, int gopt_level) { | |||||
return (std::shared_ptr<OpDef>)CompiledOp::make( | |||||
self.build(), gopt_level); | |||||
}) | |||||
.def("jit_fuse", [](PySubgraphBuilder& self) { | |||||
return (std::shared_ptr<OpDef>)CompiledOp::make( | return (std::shared_ptr<OpDef>)CompiledOp::make( | ||||
self.build(), gopt_level); | |||||
JITFusionOp::make(self.build())); | |||||
}); | }); | ||||
m.def("set_jit_enabled", &JITFusionOp::set_enabled); | |||||
auto custom = submodule(m, "_custom"); | auto custom = submodule(m, "_custom"); | ||||
init_custom(custom); | init_custom(custom); | ||||
} | } | ||||
@@ -12,6 +12,7 @@ | |||||
#include "megbrain/imperative/transformations/scalar.h" | #include "megbrain/imperative/transformations/scalar.h" | ||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "megbrain/imperative/ops/utility.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
@@ -320,6 +321,24 @@ std::vector<ValueRef> inplace_add_rule( | |||||
} | } | ||||
} | } | ||||
template <typename T> | |||||
std::vector<ValueRef> subgraph_op_rule(const T& op, Span<ValueRef> inputs) { | |||||
// TODO: add flag instead of assume | |||||
bool all_scalar = true; | |||||
for (auto&& input : inputs) { | |||||
if (!input.is<ScalarValue>()) { | |||||
all_scalar = false; | |||||
} | |||||
} | |||||
auto outputs = imperative::apply(op, unwrap_inputs(inputs)); | |||||
if (all_scalar) { | |||||
for (auto& output : outputs) { | |||||
output = ScalarValue::make(output); | |||||
} | |||||
} | |||||
return outputs; | |||||
} | |||||
struct ScalarRuleRegistry { | struct ScalarRuleRegistry { | ||||
ScalarRuleRegistry() { | ScalarRuleRegistry() { | ||||
register_scalar_rule(elemwise_rule); | register_scalar_rule(elemwise_rule); | ||||
@@ -339,6 +358,8 @@ struct ScalarRuleRegistry { | |||||
register_scalar_rule(broadcast_rule); | register_scalar_rule(broadcast_rule); | ||||
register_scalar_rule(copy_rule); | register_scalar_rule(copy_rule); | ||||
register_scalar_rule(inplace_add_rule); | register_scalar_rule(inplace_add_rule); | ||||
register_scalar_rule(subgraph_op_rule<SubgraphOp>); | |||||
register_scalar_rule(subgraph_op_rule<CompiledOp>); | |||||
} | } | ||||
} _; | } _; | ||||
} // namespace | } // namespace | ||||