Browse Source

perf(functional): rewrite serval elemwise ops with jit subgraph

GitOrigin-RevId: 26247e21d9
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
df3474ca1d
4 changed files with 225 additions and 17 deletions
  1. +181
    -6
      imperative/python/megengine/functional/nn.py
  2. +40
    -10
      imperative/python/megengine/functional/tensor.py
  3. +1
    -1
      imperative/python/test/unit/functional/test_elemwise.py
  4. +3
    -0
      imperative/python/test/unit/functional/test_functional.py

+ 181
- 6
imperative/python/megengine/functional/nn.py View File

@@ -36,6 +36,7 @@ from ..core.tensor.utils import (
convert_single_value, convert_single_value,
make_shape_tuple, make_shape_tuple,
subgraph, subgraph,
subgraph_fn,
) )
from ..device import get_default_device from ..device import get_default_device
from ..distributed import WORLD, is_distributed from ..distributed import WORLD, is_distributed
@@ -824,9 +825,37 @@ def sigmoid(x):
return _elwise(x, mode=Elemwise.Mode.SIGMOID) return _elwise(x, mode=Elemwise.Mode.SIGMOID)




@lru_cache(maxsize=None)
def _get_hsigmoid_op(dtype=None, device=None):
@subgraph_fn(
"Hsigmoid",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def hsigmoid(inputs, f, c):
(inp,) = inputs[0:1]
inp = f("+", inp, c(3))
max_0 = f("max", inp, c(0))
min_6 = f("min", max_0, c(6))
oup = f("/", min_6, c(6))
(oup_grad,) = yield (oup,)
inp_grad = f("/", oup_grad, c(6))
inp_grad = f("cond_leq_mov", max_0, c(6), inp_grad)
inp_grad = f("cond_leq_mov", c(0), inp, inp_grad)
yield (inp_grad,)

return hsigmoid


def hsigmoid(x): def hsigmoid(x):
r"""Element-wise `relu6(x + 3) / 6`.""" r"""Element-wise `relu6(x + 3) / 6`."""
return relu6(x + 3) / 6
hsigmoid = _get_hsigmoid_op(x.dtype, x.device)
(x,) = hsigmoid(x)
return x
# return relu6(x + 3) / 6




def relu(x): def relu(x):
@@ -834,9 +863,60 @@ def relu(x):
return _elwise(x, mode=Elemwise.Mode.RELU) return _elwise(x, mode=Elemwise.Mode.RELU)




@lru_cache(maxsize=None)
def _get_relu6_op(dtype=None, device=None):
@subgraph_fn(
"ReLU6",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def relu6(inputs, f, c):
(inp,) = inputs[0:1]
max_0 = f("max", inp, c(0))
min_6 = f("min", max_0, c(6))
oup = min_6
(oup_grad,) = yield (oup,)
inp_grad = f("cond_leq_mov", max_0, c(6), oup_grad)
inp_grad = f("cond_leq_mov", c(0), inp, inp_grad)
yield (inp_grad,)

return relu6


def relu6(x): def relu6(x):
r"""Element-wise `min(max(x, 0), 6)`.""" r"""Element-wise `min(max(x, 0), 6)`."""
return minimum(maximum(x, 0), 6)
relu6 = _get_relu6_op(x.dtype, x.device)
(x,) = relu6(x)
return x


@lru_cache(maxsize=None)
def _get_prelu_op(dtype=None, device=None):
@subgraph_fn(
"PReLU",
dtype=dtype,
device=device,
nr_inputs=2,
jit_fusion=True,
custom_grad=True,
)
def prelu(inputs, f, c):
(inp, weight) = inputs[0:2]
max_0 = f("max", inp, c(0))
min_0 = f("min", inp, c(0))
oup = f("fma3", min_0, weight, max_0)
(oup_grad,) = yield (oup,)
inp_grad_0 = f("cond_leq_mov", inp, c(0), oup_grad)
inp_grad_1 = f("*", oup_grad, weight)
inp_grad_1 = f("cond_leq_mov", c(0), inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
weight_grad = f("*", oup_grad, min_0)
yield (inp_grad, weight_grad)

return prelu




def prelu(inp: Tensor, weight: Tensor) -> Tensor: def prelu(inp: Tensor, weight: Tensor) -> Tensor:
@@ -844,7 +924,34 @@ def prelu(inp: Tensor, weight: Tensor) -> Tensor:


Refer to :class:`~.PReLU` for more information. Refer to :class:`~.PReLU` for more information.
""" """
return maximum(inp, 0) + weight * minimum(inp, 0)
prelu = _get_prelu_op(dtype=inp.dtype, device=inp.device)
(oup,) = prelu(inp, weight)
return oup


@lru_cache(maxsize=None)
def _get_leagk_relu_op(negative_slope, *, dtype=None, device=None):
@subgraph_fn(
"LeakyReLU",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def leakyReLU(inputs, f, c):
(inp,) = inputs[0:1]
max_0 = f("max", inp, c(0))
min_0 = f("min", inp, c(0))
oup = f("+", max_0, f("*", min_0, c(negative_slope)))
(oup_grad,) = yield (oup,)
inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad)
inp_grad_1 = f("*", oup_grad, c(negative_slope))
inp_grad_1 = f("cond_leq_mov", inp, c(negative_slope), inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)

return leakyReLU




def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
@@ -852,7 +959,9 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:


Refer to :class:`~.LeakyReLU` for more information. Refer to :class:`~.LeakyReLU` for more information.
""" """
return maximum(inp, 0) + negative_slope * minimum(inp, 0)
leakyReLU = _get_leagk_relu_op(negative_slope, dtype=inp.dtype, device=inp.device)
(oup,) = leakyReLU(inp)
return oup




def silu(x): def silu(x):
@@ -871,6 +980,36 @@ def gelu(x):
return _elwise(x, mode=Elemwise.Mode.GELU) return _elwise(x, mode=Elemwise.Mode.GELU)




@lru_cache(maxsize=None)
def _get_softplus_op(dtype=None, device=None):
@subgraph_fn(
"Softplus",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
# gopt_level=0,
custom_grad=True,
)
def softplus(inputs, f, c):
(inp,) = inputs[0:1]
neg_abs = f("-", f("abs", inp))
exp = f("exp", neg_abs)
oup = f("log1p", exp)
oup = f("+", oup, f("relu", inp))
(oup_grad,) = yield (oup,)
inp_grad_0 = f("switch_gt0", inp, oup_grad)
inp_grad_1 = oup_grad
inp_grad_1 = f("/", oup_grad, f("+", exp, c(1)))
inp_grad_1 = f("*", oup_grad, exp)
inp_grad_1 = f("-", inp_grad_1)
inp_grad_1 = f("abs_grad", inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)

return softplus


def softplus(inp: Tensor) -> Tensor: def softplus(inp: Tensor) -> Tensor:
r"""Applies the element-wise function: r"""Applies the element-wise function:


@@ -904,7 +1043,9 @@ def softplus(inp: Tensor) -> Tensor:


[0.0486 0.1269 0.3133 0.6931 1.3133 2.1269] [0.0486 0.1269 0.3133 0.6931 1.3133 2.1269]
""" """
return log1p(exp(-abs(inp))) + relu(inp)
softplus = _get_softplus_op(inp.dtype, inp.device)
(oup,) = softplus(inp)
return oup




def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
@@ -944,6 +1085,38 @@ def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
return inp - logsumexp(inp, axis, keepdims=True) return inp - logsumexp(inp, axis, keepdims=True)




@lru_cache(maxsize=None)
def _get_logsigmoid_op(dtype=None, device=None):
@subgraph_fn(
"LogSigmoid",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=True,
custom_grad=True,
)
def logsigmoid(inputs, f, c):
(inp,) = inputs[0:1]
neg_abs = f("-", f("abs", inp))
exp = f("exp", neg_abs)
oup = f("log1p", exp)
oup = f("+", oup, f("relu", f("-", inp)))
oup = f("-", oup)
(oup_grad,) = yield (oup,)
oup_grad = f("-", oup_grad)
inp_grad_0 = f("switch_gt0", inp, oup_grad)
inp_grad_0 = f("-", inp_grad_0)
inp_grad_1 = oup_grad
inp_grad_1 = f("/", oup_grad, f("+", exp, c(1)))
inp_grad_1 = f("*", oup_grad, exp)
inp_grad_1 = f("-", inp_grad_1)
inp_grad_1 = f("abs_grad", inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)

return logsigmoid


def logsigmoid(inp: Tensor) -> Tensor: def logsigmoid(inp: Tensor) -> Tensor:
r"""Applies the element-wise function: r"""Applies the element-wise function:


@@ -972,7 +1145,9 @@ def logsigmoid(inp: Tensor) -> Tensor:
[-5.0067 -4.0182 -3.0486 -2.1269 -1.3133 -0.6931 -0.3133 -0.1269 -0.0486 [-5.0067 -4.0182 -3.0486 -2.1269 -1.3133 -0.6931 -0.3133 -0.1269 -0.0486
-0.0181] -0.0181]
""" """
return -softplus(-inp)
logsigmoid = _get_logsigmoid_op(inp.dtype, inp.device)
(oup,) = logsigmoid(inp)
return oup




def logsumexp( def logsumexp(


+ 40
- 10
imperative/python/megengine/functional/tensor.py View File

@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import lru_cache
from typing import Iterable, Optional, Sequence, Tuple, Union from typing import Iterable, Optional, Sequence, Tuple, Union


import numpy as np import numpy as np
@@ -17,7 +18,14 @@ from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity from ..core.ops.builtin import Copy, Identity
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor.array_method import _broadcast, _remove_axis from ..core.tensor.array_method import _broadcast, _remove_axis
from ..core.tensor.utils import astensor1d, convert_inputs, get_device
from ..core.tensor.utils import (
astensor1d,
convert_inputs,
get_device,
isscalar,
setscalar,
subgraph_fn,
)
from ..device import get_default_device from ..device import get_default_device
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import ceil from .elemwise import ceil
@@ -731,6 +739,29 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
return inp return inp




@lru_cache(maxsize=None)
def _get_where_op(dtype=None, device=None):
@subgraph_fn(
"Where",
dtype=dtype,
device=device,
nr_inputs=3,
jit_fusion=True,
custom_grad=True,
)
def where(inputs, f, c):
(mask, x, y) = inputs[0:3]
oup = f("switch_gt0", mask, x)
ksam = f("-", c(1), mask)
oup = f("+", oup, f("switch_gt0", ksam, y))
(oup_grad,) = yield (oup,)
x_grad = f("switch_gt0", mask, oup_grad)
y_grad = f("switch_gt0", ksam, oup_grad)
yield (None, x_grad, y_grad)

return where


def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
r"""Selects elements either from Tensor x or Tensor y, according to mask. r"""Selects elements either from Tensor x or Tensor y, according to mask.


@@ -780,20 +811,19 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device)) raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))


dtype = dtype_promotion(x, y) dtype = dtype_promotion(x, y)
device = x.device

if x.dtype != dtype: if x.dtype != dtype:
x = x.astype(dtype) x = x.astype(dtype)
if y.dtype != dtype: if y.dtype != dtype:
y = y.astype(dtype) y = y.astype(dtype)
mask = mask.astype(dtype)


v0, index0 = cond_take(mask, x)
v1, index1 = cond_take(~mask, y)

out = concat([v0, v1])

out[index0] = v0
out[index1] = v1
out = out.reshape(x.shape)
return out
where = _get_where_op(dtype=dtype, device=device)
(oup,) = where(mask, x, y)
if isscalar(mask):
setscalar(oup)
return oup




def cond_take(mask: Tensor, x: Tensor) -> Tensor: def cond_take(mask: Tensor, x: Tensor) -> Tensor:


+ 1
- 1
imperative/python/test/unit/functional/test_elemwise.py View File

@@ -166,7 +166,7 @@ def test_hsigmoid():
x = np.random.randn(100).astype("float32") x = np.random.randn(100).astype("float32")
y_np = np.minimum(np.maximum(x + 3, 0), 6) / 6 y_np = np.minimum(np.maximum(x + 3, 0), 6) / 6
y_mge = F.hsigmoid(tensor(x)).numpy() y_mge = F.hsigmoid(tensor(x)).numpy()
np.testing.assert_equal(y_np, y_mge)
np.testing.assert_almost_equal(y_np, y_mge, decimal=6)




def test_logical_oprs(): def test_logical_oprs():


+ 3
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -27,6 +27,8 @@ from megengine.core.tensor.utils import make_shape_tuple
from megengine.device import get_device_count from megengine.device import get_device_count
from megengine.module import LayerNorm from megengine.module import LayerNorm


_assert_allclose = partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)



def test_where(): def test_where():
maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_) maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_)
@@ -627,6 +629,7 @@ def test_binary_cross_entropy():
{"input": [data1, label1], "output": expect1,}, {"input": [data1, label1], "output": expect1,},
{"input": [data2, label2], "output": expect2,}, {"input": [data2, label2], "output": expect2,},
] ]

opr_test(cases, F.nn.binary_cross_entropy, compare_fn=compare_fn) opr_test(cases, F.nn.binary_cross_entropy, compare_fn=compare_fn)


cases = [ cases = [


Loading…
Cancel
Save