|
|
@@ -7,6 +7,7 @@ |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import collections |
|
|
|
import itertools |
|
|
|
from typing import Iterable, Union |
|
|
|
|
|
|
|
import numpy as np |
|
|
@@ -22,6 +23,7 @@ from .._imperative_rt.core2 import ( |
|
|
|
) |
|
|
|
from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder |
|
|
|
from .._wrap import as_device |
|
|
|
from ..autodiff.grad import Function |
|
|
|
from ..ops import builtin |
|
|
|
from ..ops.special import Const |
|
|
|
from .amp import _high_prec_dtype, _low_prec_dtype |
|
|
@@ -197,8 +199,15 @@ def _normalize_axis( |
|
|
|
|
|
|
|
_opr_map = { |
|
|
|
("-", 1): builtin.Elemwise(mode="negate"), |
|
|
|
("abs", 1): builtin.Elemwise(mode="abs"), |
|
|
|
("exp", 1): builtin.Elemwise(mode="exp"), |
|
|
|
("log1p", 1): builtin.Elemwise(mode="log1p"), |
|
|
|
("relu", 1): builtin.Elemwise(mode="relu"), |
|
|
|
("cond_leq_mov", 3): builtin.Elemwise(mode="cond_leq_mov"), |
|
|
|
("fma3", 3): builtin.Elemwise(mode="FUSE_MUL_ADD3"), |
|
|
|
("fma4", 4): builtin.Elemwise(mode="FUSE_MUL_ADD4"), |
|
|
|
("[?:]", 2): builtin.Subtensor(items=[(0, True, False, False, False)]), |
|
|
|
("[:?]", 2): builtin.Subtensor(items=[(0, False, True, False, False)]), |
|
|
|
} |
|
|
|
|
|
|
|
for name, mode in [ |
|
|
@@ -209,15 +218,21 @@ for name, mode in [ |
|
|
|
("//", "floor_div"), |
|
|
|
("**", "pow"), |
|
|
|
("max", "max"), |
|
|
|
("min", "min"), |
|
|
|
("additive", "add"), |
|
|
|
("exp", "EXP"), |
|
|
|
("switch_gt0", "switch_gt0"), |
|
|
|
("abs_grad", "abs_grad"), |
|
|
|
]: |
|
|
|
_opr_map[(name, 2)] = builtin.Elemwise(mode=mode) |
|
|
|
|
|
|
|
|
|
|
|
def subgraph(name, dtype, device, nr_inputs, gopt_level=None): |
|
|
|
def subgraph( |
|
|
|
name, dtype, device, nr_inputs, gopt_level=None, jit_fusion=False, custom_grad=False |
|
|
|
): |
|
|
|
if device.physical_name.startswith("cpu"): |
|
|
|
gopt_level = None # disable jit and compile |
|
|
|
jit_fusion = False |
|
|
|
|
|
|
|
def as_op(op, nargs): |
|
|
|
if isinstance(op, str): |
|
|
@@ -241,14 +256,64 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): |
|
|
|
def apply_const(value, dtype=dtype, device=device): |
|
|
|
return builder.apply_const(value, dtype, device) |
|
|
|
|
|
|
|
def build(builder, outputs, outputs_has_grad): |
|
|
|
builder = type(builder)(builder) |
|
|
|
builder.outputs(outputs) |
|
|
|
builder.outputs_has_grad(outputs_has_grad) |
|
|
|
if jit_fusion: |
|
|
|
assert gopt_level is None |
|
|
|
op = lambda: builder.jit_fuse() |
|
|
|
elif gopt_level is None: |
|
|
|
op = lambda: builder.get() |
|
|
|
else: |
|
|
|
op = lambda: builder.compile(gopt_level) |
|
|
|
return op |
|
|
|
|
|
|
|
inputs = [builder.input() for _ in range(nr_inputs)] |
|
|
|
outputs, outputs_has_grad = func(inputs, apply_expr, apply_const) |
|
|
|
builder.outputs(outputs) |
|
|
|
builder.outputs_has_grad(outputs_has_grad) |
|
|
|
if gopt_level is None: |
|
|
|
return lambda: builder.get() |
|
|
|
if not custom_grad: |
|
|
|
outputs, outputs_has_grad = func(inputs, apply_expr, apply_const) |
|
|
|
return build(builder, outputs, outputs_has_grad) |
|
|
|
else: |
|
|
|
return lambda: builder.compile(gopt_level) |
|
|
|
gen = func(inputs, apply_expr, apply_const) |
|
|
|
outputs = gen.send(None) |
|
|
|
nr_outputs = len(outputs) |
|
|
|
forward_fn = build(builder, outputs, [False] * nr_outputs) |
|
|
|
|
|
|
|
output_grads = [builder.input() for _ in range(nr_outputs)] |
|
|
|
input_grads = gen.send(output_grads) |
|
|
|
assert len(input_grads) == nr_inputs |
|
|
|
input_grads_mask = [input_grad is not None for input_grad in input_grads] |
|
|
|
indices = [ |
|
|
|
i - 1 if mask else None |
|
|
|
for i, mask in zip( |
|
|
|
itertools.accumulate(input_grads_mask), input_grads_mask |
|
|
|
) |
|
|
|
] |
|
|
|
encoded_input_grads = [grad for grad in input_grads if grad is not None] |
|
|
|
backward_fn = build( |
|
|
|
builder, encoded_input_grads, [False] * len(encoded_input_grads) |
|
|
|
) |
|
|
|
|
|
|
|
class SubgraphOp(Function): |
|
|
|
def __init__(self): |
|
|
|
self.inputs = None |
|
|
|
|
|
|
|
def forward(self, *inputs): |
|
|
|
self.inputs = inputs |
|
|
|
return apply(forward_fn(), *inputs) |
|
|
|
|
|
|
|
def backward(self, *output_grads): |
|
|
|
inputs = self.inputs |
|
|
|
self.inputs = None |
|
|
|
encoded_input_grads = apply(backward_fn(), *inputs, *output_grads) |
|
|
|
input_grads = [ |
|
|
|
encoded_input_grads[i] if i is not None else None |
|
|
|
for i in indices |
|
|
|
] |
|
|
|
return input_grads |
|
|
|
|
|
|
|
gen.close() |
|
|
|
return SubgraphOp |
|
|
|
|
|
|
|
return decorator |
|
|
|
|
|
|
@@ -274,15 +339,37 @@ def interpret_subgraph(func, dtype, device): |
|
|
|
return Const(value, dtype=dtype, device=device)()[0] |
|
|
|
|
|
|
|
outputs, outputs_has_grad = func(args, apply_expr, apply_const) |
|
|
|
outputs = [ |
|
|
|
output if has_grad else output.detach() |
|
|
|
for output, has_grad in zip(outputs, outputs_has_grad) |
|
|
|
] |
|
|
|
return outputs |
|
|
|
|
|
|
|
return decorated_func |
|
|
|
|
|
|
|
|
|
|
|
def subgraph_fn(name, dtype, device, nr_inputs, gopt_level=None, interpret=False): |
|
|
|
def subgraph_fn( |
|
|
|
name, |
|
|
|
dtype, |
|
|
|
device, |
|
|
|
nr_inputs, |
|
|
|
gopt_level=None, |
|
|
|
jit_fusion=False, |
|
|
|
custom_grad=False, |
|
|
|
*, |
|
|
|
interpret=False |
|
|
|
): |
|
|
|
def decorator(func): |
|
|
|
if not interpret: |
|
|
|
op = subgraph(name, dtype, device, nr_inputs, gopt_level=gopt_level)(func) |
|
|
|
op = subgraph( |
|
|
|
name, |
|
|
|
dtype, |
|
|
|
device, |
|
|
|
nr_inputs, |
|
|
|
gopt_level=gopt_level, |
|
|
|
jit_fusion=jit_fusion, |
|
|
|
custom_grad=custom_grad, |
|
|
|
)(func) |
|
|
|
return lambda *args: apply(op(), *args) |
|
|
|
else: |
|
|
|
return interpret_subgraph(func, dtype, device) |
|
|
|