Browse Source

feat(subgraph): subgraph builder supports jit and custom grad

GitOrigin-RevId: e1a1ebdf1c
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
2775f4580c
4 changed files with 143 additions and 16 deletions
  1. +96
    -9
      imperative/python/megengine/core/tensor/utils.py
  2. +3
    -0
      imperative/python/megengine/jit/tracing.py
  3. +23
    -7
      imperative/python/src/ops.cpp
  4. +21
    -0
      imperative/src/impl/transformations/scalar.cpp

+ 96
- 9
imperative/python/megengine/core/tensor/utils.py View File

@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
import itertools
from typing import Iterable, Union from typing import Iterable, Union


import numpy as np import numpy as np
@@ -22,6 +23,7 @@ from .._imperative_rt.core2 import (
) )
from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from .._wrap import as_device from .._wrap import as_device
from ..autodiff.grad import Function
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
from .amp import _high_prec_dtype, _low_prec_dtype from .amp import _high_prec_dtype, _low_prec_dtype
@@ -197,8 +199,15 @@ def _normalize_axis(


_opr_map = { _opr_map = {
("-", 1): builtin.Elemwise(mode="negate"), ("-", 1): builtin.Elemwise(mode="negate"),
("abs", 1): builtin.Elemwise(mode="abs"),
("exp", 1): builtin.Elemwise(mode="exp"),
("log1p", 1): builtin.Elemwise(mode="log1p"),
("relu", 1): builtin.Elemwise(mode="relu"),
("cond_leq_mov", 3): builtin.Elemwise(mode="cond_leq_mov"),
("fma3", 3): builtin.Elemwise(mode="FUSE_MUL_ADD3"), ("fma3", 3): builtin.Elemwise(mode="FUSE_MUL_ADD3"),
("fma4", 4): builtin.Elemwise(mode="FUSE_MUL_ADD4"), ("fma4", 4): builtin.Elemwise(mode="FUSE_MUL_ADD4"),
("[?:]", 2): builtin.Subtensor(items=[(0, True, False, False, False)]),
("[:?]", 2): builtin.Subtensor(items=[(0, False, True, False, False)]),
} }


for name, mode in [ for name, mode in [
@@ -209,15 +218,21 @@ for name, mode in [
("//", "floor_div"), ("//", "floor_div"),
("**", "pow"), ("**", "pow"),
("max", "max"), ("max", "max"),
("min", "min"),
("additive", "add"), ("additive", "add"),
("exp", "EXP"), ("exp", "EXP"),
("switch_gt0", "switch_gt0"),
("abs_grad", "abs_grad"),
]: ]:
_opr_map[(name, 2)] = builtin.Elemwise(mode=mode) _opr_map[(name, 2)] = builtin.Elemwise(mode=mode)




def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
def subgraph(
name, dtype, device, nr_inputs, gopt_level=None, jit_fusion=False, custom_grad=False
):
if device.physical_name.startswith("cpu"): if device.physical_name.startswith("cpu"):
gopt_level = None # disable jit and compile gopt_level = None # disable jit and compile
jit_fusion = False


def as_op(op, nargs): def as_op(op, nargs):
if isinstance(op, str): if isinstance(op, str):
@@ -241,14 +256,64 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None):
def apply_const(value, dtype=dtype, device=device): def apply_const(value, dtype=dtype, device=device):
return builder.apply_const(value, dtype, device) return builder.apply_const(value, dtype, device)


def build(builder, outputs, outputs_has_grad):
builder = type(builder)(builder)
builder.outputs(outputs)
builder.outputs_has_grad(outputs_has_grad)
if jit_fusion:
assert gopt_level is None
op = lambda: builder.jit_fuse()
elif gopt_level is None:
op = lambda: builder.get()
else:
op = lambda: builder.compile(gopt_level)
return op

inputs = [builder.input() for _ in range(nr_inputs)] inputs = [builder.input() for _ in range(nr_inputs)]
outputs, outputs_has_grad = func(inputs, apply_expr, apply_const)
builder.outputs(outputs)
builder.outputs_has_grad(outputs_has_grad)
if gopt_level is None:
return lambda: builder.get()
if not custom_grad:
outputs, outputs_has_grad = func(inputs, apply_expr, apply_const)
return build(builder, outputs, outputs_has_grad)
else: else:
return lambda: builder.compile(gopt_level)
gen = func(inputs, apply_expr, apply_const)
outputs = gen.send(None)
nr_outputs = len(outputs)
forward_fn = build(builder, outputs, [False] * nr_outputs)

output_grads = [builder.input() for _ in range(nr_outputs)]
input_grads = gen.send(output_grads)
assert len(input_grads) == nr_inputs
input_grads_mask = [input_grad is not None for input_grad in input_grads]
indices = [
i - 1 if mask else None
for i, mask in zip(
itertools.accumulate(input_grads_mask), input_grads_mask
)
]
encoded_input_grads = [grad for grad in input_grads if grad is not None]
backward_fn = build(
builder, encoded_input_grads, [False] * len(encoded_input_grads)
)

class SubgraphOp(Function):
def __init__(self):
self.inputs = None

def forward(self, *inputs):
self.inputs = inputs
return apply(forward_fn(), *inputs)

def backward(self, *output_grads):
inputs = self.inputs
self.inputs = None
encoded_input_grads = apply(backward_fn(), *inputs, *output_grads)
input_grads = [
encoded_input_grads[i] if i is not None else None
for i in indices
]
return input_grads

gen.close()
return SubgraphOp


return decorator return decorator


@@ -274,15 +339,37 @@ def interpret_subgraph(func, dtype, device):
return Const(value, dtype=dtype, device=device)()[0] return Const(value, dtype=dtype, device=device)()[0]


outputs, outputs_has_grad = func(args, apply_expr, apply_const) outputs, outputs_has_grad = func(args, apply_expr, apply_const)
outputs = [
output if has_grad else output.detach()
for output, has_grad in zip(outputs, outputs_has_grad)
]
return outputs return outputs


return decorated_func return decorated_func




def subgraph_fn(name, dtype, device, nr_inputs, gopt_level=None, interpret=False):
def subgraph_fn(
name,
dtype,
device,
nr_inputs,
gopt_level=None,
jit_fusion=False,
custom_grad=False,
*,
interpret=False
):
def decorator(func): def decorator(func):
if not interpret: if not interpret:
op = subgraph(name, dtype, device, nr_inputs, gopt_level=gopt_level)(func)
op = subgraph(
name,
dtype,
device,
nr_inputs,
gopt_level=gopt_level,
jit_fusion=jit_fusion,
custom_grad=custom_grad,
)(func)
return lambda *args: apply(op(), *args) return lambda *args: apply(op(), *args)
else: else:
return interpret_subgraph(func, dtype, device) return interpret_subgraph(func, dtype, device)


+ 3
- 0
imperative/python/megengine/jit/tracing.py View File

@@ -33,6 +33,7 @@ from ..core._imperative_rt.ops import (
ExternOpr, ExternOpr,
RemoteRecv, RemoteRecv,
RemoteSend, RemoteSend,
set_jit_enabled,
) )
from ..core._trace_option import set_symbolic_shape from ..core._trace_option import set_symbolic_shape
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
@@ -711,12 +712,14 @@ class trace:


graph = G.Graph() graph = G.Graph()


jit_enabled = set_jit_enabled(False)
dest_vars = self._trace.dump( dest_vars = self._trace.dump(
graph, graph,
input_bindings, input_bindings,
[*zip(self._output_bindings, output_names)], [*zip(self._output_bindings, output_names)],
prefer_input_names, prefer_input_names,
) )
set_jit_enabled(jit_enabled)


# dest_vars = [i._node for i in dest_vars] # dest_vars = [i._node for i in dest_vars]




+ 23
- 7
imperative/python/src/ops.cpp View File

@@ -577,21 +577,26 @@ void init_ops(py::module m) {
struct PySubgraphBuilder { struct PySubgraphBuilder {
explicit PySubgraphBuilder(std::string name) : name{name} {} explicit PySubgraphBuilder(std::string name) : name{name} {}
std::string name; std::string name;
std::shared_ptr<Subgraph> graph_storage = std::make_shared<Subgraph>();
std::shared_ptr<UniqueKey> graph_key = std::make_shared<UniqueKey>();
Subgraph& graph = *graph_storage;
Subgraph graph;
mgb::SmallVector<bool> output_grad_mask; mgb::SmallVector<bool> output_grad_mask;
Subgraph::var_t next_var = 1; Subgraph::var_t next_var = 1;
std::shared_ptr<mgb::Hashable> key = nullptr;


std::shared_ptr<OpDef> build() const {
return SubgraphOp::make(name, graph_storage, output_grad_mask, graph_key);
std::shared_ptr<OpDef> build() {
if (key == nullptr) {
key = std::make_shared<UniqueKey>();
}
return SubgraphOp::make(
name, std::make_shared<Subgraph>(graph), output_grad_mask, key);
} }
}; };


py::class_<PySubgraphBuilder>(m, "SubgraphBuilder") py::class_<PySubgraphBuilder>(m, "SubgraphBuilder")
.def(py::init<std::string>()) .def(py::init<std::string>())
.def(py::init<PySubgraphBuilder>())
.def("input", .def("input",
[](PySubgraphBuilder& self) { [](PySubgraphBuilder& self) {
mgb_assert(self.key == nullptr);
auto var = self.next_var++; auto var = self.next_var++;
self.graph.inputs.push_back(var); self.graph.inputs.push_back(var);
return var; return var;
@@ -599,6 +604,7 @@ void init_ops(py::module m) {
.def("apply", .def("apply",
[](PySubgraphBuilder& self, std::shared_ptr<OpDef> op, [](PySubgraphBuilder& self, std::shared_ptr<OpDef> op,
Subgraph::vars_t inputs, size_t nr_outputs) { Subgraph::vars_t inputs, size_t nr_outputs) {
mgb_assert(self.key == nullptr);
Subgraph::vars_t outputs; Subgraph::vars_t outputs;
for (size_t i = 0; i < nr_outputs; ++i) { for (size_t i = 0; i < nr_outputs; ++i) {
outputs.push_back(self.next_var++); outputs.push_back(self.next_var++);
@@ -609,6 +615,7 @@ void init_ops(py::module m) {
.def("apply_const", .def("apply_const",
[](PySubgraphBuilder& self, py::object value, mgb::DType dtype, [](PySubgraphBuilder& self, py::object value, mgb::DType dtype,
mgb::CompNode cn) { mgb::CompNode cn) {
mgb_assert(self.key == nullptr);
auto var = self.next_var++; auto var = self.next_var++;
mgb::HostTensorND hvalue(cn); mgb::HostTensorND hvalue(cn);
npy::np2tensor( npy::np2tensor(
@@ -619,11 +626,13 @@ void init_ops(py::module m) {
}) })
.def("outputs", .def("outputs",
[](PySubgraphBuilder& self, Subgraph::vars_t outputs) { [](PySubgraphBuilder& self, Subgraph::vars_t outputs) {
mgb_assert(self.key == nullptr);
self.graph.outputs = outputs; self.graph.outputs = outputs;
self.output_grad_mask.resize(outputs.size(), true); self.output_grad_mask.resize(outputs.size(), true);
}) })
.def("outputs_has_grad", .def("outputs_has_grad",
[](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad) { [](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad) {
mgb_assert(self.key == nullptr);
mgb_assert( mgb_assert(
self.graph.outputs.size() == self.output_grad_mask.size()); self.graph.outputs.size() == self.output_grad_mask.size());
self.output_grad_mask = outputs_has_grad; self.output_grad_mask = outputs_has_grad;
@@ -632,11 +641,18 @@ void init_ops(py::module m) {
[](PySubgraphBuilder& self) { [](PySubgraphBuilder& self) {
return (std::shared_ptr<OpDef>)self.build(); return (std::shared_ptr<OpDef>)self.build();
}) })
.def("compile", [](PySubgraphBuilder& self, int gopt_level) {
.def("compile",
[](PySubgraphBuilder& self, int gopt_level) {
return (std::shared_ptr<OpDef>)CompiledOp::make(
self.build(), gopt_level);
})
.def("jit_fuse", [](PySubgraphBuilder& self) {
return (std::shared_ptr<OpDef>)CompiledOp::make( return (std::shared_ptr<OpDef>)CompiledOp::make(
self.build(), gopt_level);
JITFusionOp::make(self.build()));
}); });


m.def("set_jit_enabled", &JITFusionOp::set_enabled);

auto custom = submodule(m, "_custom"); auto custom = submodule(m, "_custom");
init_custom(custom); init_custom(custom);
} }


+ 21
- 0
imperative/src/impl/transformations/scalar.cpp View File

@@ -12,6 +12,7 @@
#include "megbrain/imperative/transformations/scalar.h" #include "megbrain/imperative/transformations/scalar.h"


#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"


namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
@@ -320,6 +321,24 @@ std::vector<ValueRef> inplace_add_rule(
} }
} }


template <typename T>
std::vector<ValueRef> subgraph_op_rule(const T& op, Span<ValueRef> inputs) {
// TODO: add flag instead of assume
bool all_scalar = true;
for (auto&& input : inputs) {
if (!input.is<ScalarValue>()) {
all_scalar = false;
}
}
auto outputs = imperative::apply(op, unwrap_inputs(inputs));
if (all_scalar) {
for (auto& output : outputs) {
output = ScalarValue::make(output);
}
}
return outputs;
}

struct ScalarRuleRegistry { struct ScalarRuleRegistry {
ScalarRuleRegistry() { ScalarRuleRegistry() {
register_scalar_rule(elemwise_rule); register_scalar_rule(elemwise_rule);
@@ -339,6 +358,8 @@ struct ScalarRuleRegistry {
register_scalar_rule(broadcast_rule); register_scalar_rule(broadcast_rule);
register_scalar_rule(copy_rule); register_scalar_rule(copy_rule);
register_scalar_rule(inplace_add_rule); register_scalar_rule(inplace_add_rule);
register_scalar_rule(subgraph_op_rule<SubgraphOp>);
register_scalar_rule(subgraph_op_rule<CompiledOp>);
} }
} _; } _;
} // namespace } // namespace


Loading…
Cancel
Save