Browse Source

image transform method of real world robustness test

tags/v1.8.0
ZhidanLiu 3 years ago
parent
commit
ffcbb88ff9
10 changed files with 1487 additions and 0 deletions
  1. +0
    -0
      examples/natural_robustness/__init__.py
  2. +176
    -0
      examples/natural_robustness/natural_robustness_example.py
  3. +18
    -0
      mindarmour/natural_robustness/__init__.py
  4. +37
    -0
      mindarmour/natural_robustness/image/__init__.py
  5. +193
    -0
      mindarmour/natural_robustness/image/blur.py
  6. +251
    -0
      mindarmour/natural_robustness/image/corruption.py
  7. +287
    -0
      mindarmour/natural_robustness/image/luminance.py
  8. +159
    -0
      mindarmour/natural_robustness/image/natural_perturb.py
  9. +365
    -0
      mindarmour/natural_robustness/image/transformation.py
  10. +1
    -0
      setup.py

+ 0
- 0
examples/natural_robustness/__init__.py View File


+ 176
- 0
examples/natural_robustness/natural_robustness_example.py View File

@@ -0,0 +1,176 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example for natural robustness methods."""

import numpy as np
import cv2

from mindarmour.natural_robustness.image import Translate, Curve, Perspective, Scale, Shear, Rotate, SaltAndPepperNoise, \
NaturalNoise, GaussianNoise, UniformNoise, MotionBlur, GaussianBlur, GradientBlur, Contrast, GradientLuminance


def test_perspective(image):
"""Test perspective."""
ori_pos = [[0, 0], [0, 800], [800, 0], [800, 800]]
dst_pos = [[50, 0], [0, 800], [780, 0], [800, 800]]
trans = Perspective(ori_pos, dst_pos)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()


def test_uniform_noise(image):
"""Test uniform noise."""
trans = UniformNoise(factor=0.1)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()


def test_gaussian_noise(image):
"""Test gaussian noise."""
trans = GaussianNoise(factor=0.1)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()


def test_contrast(image):
"""Test contrast."""
trans = Contrast(alpha=2, beta=0)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()


def test_gaussian_blur(image):
"""Test gaussian blur."""
trans = GaussianBlur(ksize=5)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()


def test_salt_and_pepper_noise(image):
"""Test salt and pepper noise."""
trans = SaltAndPepperNoise(factor=0.01)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()


def test_translate(image):
"""Test translate."""
trans = Translate(x_bias=0.1, y_bias=0.1)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()


def test_scale(image):
"""Test scale."""
trans = Scale(factor_x=0.7, factor_y=0.7)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()


def test_shear(image):
"""Test shear."""
trans = Shear(factor=0.2)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()


def test_rotate(image):
"""Test rotate."""
trans = Rotate(angle=20)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()


def test_curve(image):
"""Test curve."""
trans = Curve(curves=2, depth=1.5, mode='horizontal')
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()


def test_natural_noise(image):
"""Test natural noise."""
trans = NaturalNoise(ratio=0.0001, k_x_range=(1, 30), k_y_range=(1, 10), auto_param=True)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()


def test_gradient_luminance(image):
"""Test gradient luminance."""
height, width = image.shape[:2]
point = (height // 4, width // 2)
start = (255, 255, 255)
end = (0, 0, 0)
scope = 0.3
bright_rate = 0.4
trans = GradientLuminance(start, end, start_point=point, scope=scope, pattern='dark', bright_rate=bright_rate,
mode='horizontal')
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()


def test_motion_blur(image):
"""Test motion blur."""
angle = -10.5
i = 10
trans = MotionBlur(degree=i, angle=angle)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()


def test_gradient_blur(image):
"""Test gradient blur."""
number = 10
h, w = image.shape[:2]
point = (int(h / 5), int(w / 5))
center = False
trans = GradientBlur(point, number, center)
dst = trans(image)
cv2.imshow('dst', dst)
cv2.waitKey()


if __name__ == '__main__':
img = cv2.imread('1.jpeg')
img = np.array(img)
test_uniform_noise(img)
test_gaussian_noise(img)
test_motion_blur(img)
test_gradient_blur(img)
test_gradient_luminance(img) #####
test_natural_noise(img)
test_curve(img)
test_rotate(img)
test_shear(img)
test_scale(img)
test_translate(img)
test_salt_and_pepper_noise(img)
test_gaussian_blur(img)
test_contrast(img)
test_perspective(img)

+ 18
- 0
mindarmour/natural_robustness/__init__.py View File

@@ -0,0 +1,18 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Natural Robustness.

This module is a tool box to enhance model security and against natural perturb examples.
"""

+ 37
- 0
mindarmour/natural_robustness/image/__init__.py View File

@@ -0,0 +1,37 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This package include methods to generate natural perturbation samples.
"""

from .transformation import Translate, Scale, Shear, Rotate, Perspective, Curve
from .blur import GaussianBlur, MotionBlur, GradientBlur
from .luminance import Contrast, GradientLuminance
from .corruption import UniformNoise, GaussianNoise, SaltAndPepperNoise, NaturalNoise

__all__ = ['Translate',
'Scale',
'Shear',
'Rotate',
'Perspective',
'Curve',
'GaussianBlur',
'MotionBlur',
'GradientBlur',
'Contrast',
'GradientLuminance',
'UniformNoise',
'GaussianNoise',
'SaltAndPepperNoise',
'NaturalNoise']

+ 193
- 0
mindarmour/natural_robustness/image/blur.py View File

@@ -0,0 +1,193 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image Blur
"""

import numpy as np
import cv2

from mindarmour.natural_robustness.image.natural_perturb import _NaturalPerturb
from mindarmour.utils._check_param import check_param_multi_types, check_int_positive, check_param_type
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'Image Blur'


class GaussianBlur(_NaturalPerturb):
"""
Blurs the image using Gaussian blur filter.

Args:
ksize (int): Size of gaussian kernel, this value must be non-negnative.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> ksize = 5
>>> trans = GaussianBlur(ksize)
>>> dst = trans(img)
"""

def __init__(self, ksize=2, auto_param=False):
super(GaussianBlur, self).__init__()
ksize = check_int_positive('ksize', ksize)
if auto_param:
ksize = 2 * np.random.randint(0, 5) + 1
else:
ksize = 2 * ksize + 1
self.ksize = (ksize, ksize)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
new_img = cv2.GaussianBlur(image, self.ksize, 0)
new_img = self._original_format(new_img, chw, normalized, gray3dim)
return new_img.astype(ori_dtype)


class MotionBlur(_NaturalPerturb):
"""
Motion blur for a given image.

Args:
degree (int): Degree of blur. This value must be positive. Suggested value range in [1, 15].
angle: (union[float, int]): Direction of motion blur. Angle=0 means up and down motion blur. Angle is
counterclockwise.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> angle = 0
>>> degree = 5
>>> trans = MotionBlur(degree=degree, angle=angle)
>>> new_img = trans(img)
"""

def __init__(self, degree=5, angle=45, auto_param=False):
super(MotionBlur, self).__init__()
self.degree = check_int_positive('degree', degree)
self.degree = check_param_multi_types('degree', degree, [float, int])
auto_param = check_param_type('auto_param', auto_param, bool)
if auto_param:
self.degree = np.random.randint(1, 5)
self.angle = np.random.uniform(0, 360)
else:
self.angle = angle - 45

def __call__(self, image):
"""
Motion blur for a given image.

Args:
image (numpy.ndarray): Original image.

Returns:
numpy.ndarray, image after motion blur.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
matrix = cv2.getRotationMatrix2D((self.degree / 2, self.degree / 2), self.angle, 1)
motion_blur_kernel = np.diag(np.ones(self.degree))
motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, matrix, (self.degree, self.degree))
motion_blur_kernel = motion_blur_kernel / self.degree
blurred = cv2.filter2D(image, -1, motion_blur_kernel)
# convert to uint8
cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX)
blurred = self._original_format(blurred, chw, normalized, gray3dim)

return blurred.astype(ori_dtype)


class GradientBlur(_NaturalPerturb):
"""
Gradient blur.

Args:
point (union[tuple, list]): 2D coordinate of the Blur center point.
kernel_num (int): Number of blur kernels. Suggested value range in [1, 8].
center (bool): Blurred or clear at the center of a specified point.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('xx.png')
>>> img = np.array(img)
>>> number = 5
>>> h, w = img.shape[:2]
>>> point = (int(h / 5), int(w / 5))
>>> center = True
>>> trans = GradientBlur(point, number, center)
>>> new_img = trans(img)
"""

def __init__(self, point, kernel_num=3, center=True, auto_param=False):
super(GradientBlur).__init__()
point = check_param_multi_types('point', point, [list, tuple])
self.auto_param = check_param_type('auto_param', auto_param, bool)
self.point = tuple(point)
self.kernel_num = check_int_positive('kernel_num', kernel_num)
self.center = check_param_type('center', center, bool)

def _auto_param(self, h, w):
self.point = (int(np.random.uniform(0, h)), int(np.random.uniform(0, w)))
self.kernel_num = np.random.randint(1, 6)
self.center = np.random.choice([True, False])

def __call__(self, image):
"""
Args:
image (numpy.ndarray): Original image.

Returns:
numpy.ndarray, gradient blurred image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
w, h = image.shape[:2]
if self.auto_param:
self._auto_param(h, w)
mask = np.zeros(image.shape, dtype=np.uint8)
masks = []
radius = max(w - self.point[0], self.point[0], h - self.point[1], self.point[1])
radius = int(radius / self.kernel_num)
for i in range(self.kernel_num):
circle = cv2.circle(mask.copy(), self.point, radius * (1 + i), (1, 1, 1), -1)
masks.append(circle)
blurs = []
for i in range(3, 3 + 2 * self.kernel_num, 2):
ksize = (i, i)
blur = cv2.GaussianBlur(image, ksize, 0)
blurs.append(blur)

dst = image.copy()
if self.center:
for i in range(self.kernel_num):
dst = masks[i] * dst + (1 - masks[i]) * blurs[i]
else:
for i in range(self.kernel_num - 1, -1, -1):
dst = masks[i] * blurs[self.kernel_num - 1 - i] + (1 - masks[i]) * dst
dst = self._original_format(dst, chw, normalized, gray3dim)
return dst.astype(ori_dtype)

+ 251
- 0
mindarmour/natural_robustness/image/corruption.py View File

@@ -0,0 +1,251 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image corruption.
"""
import math
import numpy as np
import cv2

from mindarmour.natural_robustness.image.natural_perturb import _NaturalPerturb
from mindarmour.utils._check_param import check_param_multi_types, check_param_type
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'Image corruption'


class UniformNoise(_NaturalPerturb):
"""
Add uniform noise of an image.

Args:
factor (float): Noise density, the proportion of noise points per unit pixel area. Suggested value range in
[0.001, 0.15].
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> factor = 0.1
>>> trans = UniformNoise(factor)
>>> dst = trans(img)
"""

def __init__(self, factor=0.1, auto_param=False):
super(UniformNoise, self).__init__()
self.factor = check_param_multi_types('factor', factor, [int, float])
check_param_type('auto_param', auto_param, bool)
if auto_param:
self.factor = np.random.uniform(0, 0.15)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
low, high = (0, 255)
weight = self.factor * (high - low)
noise = np.random.uniform(-weight, weight, size=image.shape)
trans_image = np.clip(image + noise, low, high)
trans_image = self._original_format(trans_image, chw, normalized, gray3dim)

return trans_image.astype(ori_dtype)


class GaussianNoise(_NaturalPerturb):
"""
Add gaussian noise of an image.

Args:
factor (float): Noise density, the proportion of noise points per unit pixel area. Suggested value range in
[0.001, 0.15].
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> factor = 0.1
>>> trans = GaussianNoise(factor)
>>> dst = trans(img)
"""

def __init__(self, factor=0.1, auto_param=False):
super(GaussianNoise, self).__init__()
self.factor = check_param_multi_types('factor', factor, [int, float])
check_param_type('auto_param', auto_param, bool)
if auto_param:
self.factor = np.random.uniform(0, 0.15)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
low, high = (0, 255)
_, chw, normalized, gray3dim, image = self._check(image)
std = self.factor / math.sqrt(3) * (high - low)
noise = np.random.normal(scale=std, size=image.shape)
trans_image = np.clip(image + noise, low, high)
trans_image = self._original_format(trans_image, chw, normalized, gray3dim)
return trans_image.astype(ori_dtype)


class SaltAndPepperNoise(_NaturalPerturb):
"""
Add salt and pepper noise of an image.

Args:
factor (float): Noise density, the proportion of noise points per unit pixel area. Suggested value range in
[0.001, 0.15].
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> factor = 0.1
>>> trans = SaltAndPepperNoise(factor)
>>> dst = trans(img)
"""

def __init__(self, factor=0, auto_param=False):
super(SaltAndPepperNoise, self).__init__()
self.factor = check_param_multi_types('factor', factor, [int, float])
check_param_type('auto_param', auto_param, bool)
if auto_param:
self.factor = np.random.uniform(0, 0.15)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
low, high = (0, 255)
noise = np.random.uniform(low=-1, high=1, size=(image.shape[0], image.shape[1]))
trans_image = np.copy(image)
threshold = 1 - self.factor
trans_image[noise < -threshold] = low
trans_image[noise > threshold] = high
trans_image = self._original_format(trans_image, chw, normalized, gray3dim)
return trans_image.astype(ori_dtype)


class NaturalNoise(_NaturalPerturb):
"""
Add natural noise to an image.

Args:
ratio (float): Noise density, the proportion of noise blocks per unit pixel area. Suggested value range in
[0.00001, 0.001].
k_x_range (union[list, tuple]): Value range of the noise block length.
k_y_range (union[list, tuple]): Value range of the noise block width.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Examples:
>>> img = cv2.imread('xx.png')
>>> img = np.array(img)
>>> ratio = 0.0002
>>> k_x_range = (1, 5)
>>> k_y_range = (3, 25)
>>> trans = NaturalNoise(ratio, k_x_range, k_y_range)
>>> new_img = trans(img)
"""

def __init__(self, ratio=0.0002, k_x_range=(1, 5), k_y_range=(3, 25), auto_param=False):
super(NaturalNoise).__init__()
self.ratio = check_param_type('ratio', ratio, float)
k_x_range = check_param_multi_types('k_x_range', k_x_range, [list, tuple])
k_y_range = check_param_multi_types('k_y_range', k_y_range, [list, tuple])
self.k_x_range = tuple(k_x_range)
self.k_y_range = tuple(k_y_range)
self.auto_param = check_param_type('auto_param', auto_param, bool)

def __call__(self, image):
"""
Add natural noise to given image.

Args:
image (numpy.ndarray): Original image.

Returns:
numpy.ndarray, image with natural noise.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
randon_range = 100
w, h = image.shape[:2]
channel = len(np.shape(image))

if self.auto_param:
self.ratio = np.random.uniform(0, 0.001)
self.k_x_range = (1, 0.1 * w)
self.k_y_range = (1, 0.1 * h)

for _ in range(5):
if channel == 3:
noise = np.ones((w, h, 3), dtype=np.uint8) * 255
dst = np.ones((w, h, 3), dtype=np.uint8) * 255
else:
noise = np.ones((w, h), dtype=np.uint8) * 255
dst = np.ones((w, h), dtype=np.uint8) * 255

rate = self.ratio / 5
mask = np.random.uniform(size=(w, h)) < rate
noise[mask] = np.random.randint(0, randon_range)

k_x, k_y = np.random.randint(*self.k_x_range), np.random.randint(*self.k_y_range)
kernel = np.ones((k_x, k_y), np.uint8)
erode = cv2.erode(noise, kernel, iterations=1)
dst = erode * (erode < randon_range) + dst * (1 - erode < randon_range)
# Add black point
for _ in range(np.random.randint(math.ceil(k_x * k_y / 2))):
x = np.random.randint(-k_x, k_x)
y = np.random.randint(-k_y, k_y)
matrix = np.array([[1, 0, y], [0, 1, x]], dtype=np.float)
affine = cv2.warpAffine(noise, matrix, (h, w))
dst = affine * (affine < randon_range) + dst * (1 - affine < randon_range)
# Add white point
for _ in range(int(k_x * k_y / 2)):
x = np.random.randint(-k_x / 2 - 1, k_x / 2 + 1)
y = np.random.randint(-k_y / 2 - 1, k_y / 2 + 1)
matrix = np.array([[1, 0, y], [0, 1, x]], dtype=np.float)
affine = cv2.warpAffine(noise, matrix, (h, w))
white = affine < randon_range
dst[white] = 255

mask = dst < randon_range
dst = image * (1 - mask) + dst * mask
dst = self._original_format(dst, chw, normalized, gray3dim)

return dst.astype(ori_dtype)

+ 287
- 0
mindarmour/natural_robustness/image/luminance.py View File

@@ -0,0 +1,287 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image luminance.
"""
import math
import numpy as np
import cv2

from mindarmour.natural_robustness.image.natural_perturb import _NaturalPerturb
from mindarmour.utils._check_param import check_param_multi_types, check_param_in_range, check_param_type, \
check_value_non_negative
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'Image Luminance'


class Contrast(_NaturalPerturb):
"""
Contrast of an image.

Args:
alpha (Union[float, int]): Control the contrast of an image. :math:`out_image = in_image*alpha+beta`.
Suggested value range in [0.2, 2].
beta (Union[float, int]): Delta added to alpha. Default: 0.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> alpha = 0.1
>>> beta = 1
>>> trans = Contrast(alpha, beta)
>>> dst = trans(img)
"""

def __init__(self, alpha=1, beta=0, auto_param=False):
super(Contrast, self).__init__()
self.alpha = check_param_multi_types('factor', alpha, [int, float])
self.beta = check_param_multi_types('factor', beta, [int, float])
auto_param = check_param_type('auto_param', auto_param, bool)
if auto_param:
self.alpha = np.random.uniform(0.2, 2)
self.beta = np.random.uniform(-20, 20)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
dst = cv2.convertScaleAbs(image, alpha=self.alpha, beta=self.beta)
dst = self._original_format(dst, chw, normalized, gray3dim)
return dst.astype(ori_dtype)


def _circle_gradient_mask(img_src, color_start, color_end, scope=0.5, point=None):
"""
Generate circle gradient mask.

Args:
img_src (numpy.ndarray): Source image.
color_start (union([tuple, list])): Color of circle gradient center.
color_end (union([tuple, list])): Color of circle gradient edge.
scope (float): Range of the gradient. A larger value indicates a larger gradient range.
point (union([tuple, list]): Gradient center point.

Returns:
numpy.ndarray, gradients mask.
"""
if not isinstance(img_src, np.ndarray):
raise TypeError('`src` must be numpy.ndarray type, but got {0}.'.format(type(img_src)))

shape = img_src.shape
height, width = shape[:2]
rgb = False
if len(shape) == 3:
rgb = True
if point is None:
point = (height // 2, width // 2)
x, y = point

# upper left
bound_upper_left = math.ceil(math.sqrt(x ** 2 + y ** 2))
# upper right
bound_upper_right = math.ceil(math.sqrt(height ** 2 + (width - y) ** 2))
# lower left
bound_lower_left = math.ceil(math.sqrt((height - x) ** 2 + y ** 2))
# lower right
bound_lower_right = math.ceil(math.sqrt((height - x) ** 2 + (width - y) ** 2))

radius = max(bound_lower_left, bound_lower_right, bound_upper_left, bound_upper_right) * scope

img_grad = np.ones_like(img_src, dtype=np.uint8) * max(color_end)
# opencv use BGR format
grad_b = float(color_end[0] - color_start[0]) / radius
grad_g = float(color_end[1] - color_start[1]) / radius
grad_r = float(color_end[2] - color_start[2]) / radius

for i in range(height):
for j in range(width):
distance = math.ceil(math.sqrt((x - i) ** 2 + (y - j) ** 2))
if distance >= radius:
continue
if rgb:
img_grad[i, j, 0] = color_start[0] + distance * grad_b
img_grad[i, j, 1] = color_start[1] + distance * grad_g
img_grad[i, j, 2] = color_start[2] + distance * grad_r
else:
img_grad[i, j] = color_start[0] + distance * grad_b

return img_grad.astype(np.uint8)


def _line_gradient_mask(image, start_pos=None, start_color=(0, 0, 0), end_color=(255, 255, 255), mode='horizontal'):
"""
Generate liner gradient mask.

Args:
image (numpy.ndarray): Original image.
start_pos (union[tuple, list]): 2D coordinate of gradient center.
start_color (union([tuple, list])): Color of circle gradient center.
end_color (union([tuple, list])): Color of circle gradient edge.
mode (str): Direction of gradient. Optional value is 'vertical' or 'horizontal'.

Returns:
numpy.ndarray, gradients mask.
"""
shape = image.shape
h, w = shape[:2]
rgb = False
if len(shape) == 3:
rgb = True
if start_pos is None:
start_pos = 0.5
else:
if mode == 'horizontal':
if start_pos[0] > h:
start_pos = 1
else:
start_pos = start_pos[0] / h
else:
if start_pos[1] > w:
start_pos = 1
else:
start_pos = start_pos[1] / w
start_color = np.array(start_color)
end_color = np.array(end_color)
if mode == 'horizontal':
w_l = int(w * start_pos)
w_r = w - w_l
if w_l > w_r:
r_end_color = (end_color - start_color) / start_pos * (1 - start_pos) + start_color
left = np.linspace(end_color, start_color, w_l)
right = np.linspace(start_color, r_end_color, w_r)
else:
l_end_color = (end_color - start_color) / (1 - start_pos) * start_pos + start_color
left = np.linspace(l_end_color, start_color, w_l)
right = np.linspace(start_color, end_color, w_r)
line = np.concatenate((left, right), axis=0)
mask = np.reshape(np.tile(line, (h, 1)), (h, w, 3))
else:
# 'vertical'
h_t = int(h * start_pos)
h_b = h - h_t
if h_t > h_b:
b_end_color = (end_color - start_color) / start_pos * (1 - start_pos) + start_color
top = np.linspace(end_color, start_color, h_t)
bottom = np.linspace(start_color, b_end_color, h_b)
else:
t_end_color = (end_color - start_color) / (1 - start_pos) * start_pos + start_color
top = np.linspace(t_end_color, start_color, h_t)
bottom = np.linspace(start_color, end_color, h_b)
line = np.concatenate((top, bottom), axis=0)
mask = np.reshape(np.tile(line, (w, 1)), (w, h, 3))
mask = np.transpose(mask, [1, 0, 2])
if not rgb:
mask = mask[:, :, 0]
return mask.astype(np.uint8)


class GradientLuminance(_NaturalPerturb):
"""
Gradient adjusts the luminance of picture.

Args:
color_start (union[tuple, list]): Color of gradient center. Default:(0, 0, 0).
color_end (union[tuple, list]): Color of gradient edge. Default:(255, 255, 255).
start_point (union[tuple, list]): 2D coordinate of gradient center.
scope (float): Range of the gradient. A larger value indicates a larger gradient range. Default: 0.3.
pattern (str): Dark or light, this value must be in ['light', 'dark'].
bright_rate (float): Control brightness. A larger value indicates a larger gradient range. If parameter
'pattern' is 'light', Suggested value range in [0.1, 0.7], if parameter 'pattern' is 'dark', Suggested value
range in [0.1, 0.9].
mode (str): Gradient mode, value must be in ['circle', 'horizontal', 'vertical'].
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Examples:
>>> img = cv2.imread('x.png')
>>> height, width = img.shape[:2]
>>> point = (height // 4, width // 2)
>>> start = (255, 255, 255)
>>> end = (0, 0, 0)
>>> scope = 0.3
>>> pattern='light'
>>> bright_rate = 0.3
>>> trans = GradientLuminance(start, end, point, scope, pattern, bright_rate, mode='circle')
>>> img_new = trans(img)
"""

def __init__(self, color_start=(0, 0, 0), color_end=(255, 255, 255), start_point=(10, 10), scope=0.5,
pattern='light', bright_rate=0.3, mode='circle', auto_param=False):
super(GradientLuminance, self).__init__()
self.color_start = check_param_multi_types('color_start', color_start, [list, tuple])
self.color_end = check_param_multi_types('color_end', color_end, [list, tuple])
self.start_point = check_param_multi_types('start_point', start_point, [list, tuple])
self.scope = check_value_non_negative('scope', scope)
self.bright_rate = check_param_type('bright_rate', bright_rate, float)
self.bright_rate = check_param_in_range('bright_rate', bright_rate, 0, 1)
self.auto_param = check_param_type('auto_param', auto_param, bool)

if pattern in ['light', 'dark']:
self.pattern = pattern
else:
msg = "Value of param pattern must be in ['light', 'dark']"
LOGGER.error(TAG, msg)
raise ValueError(msg)
if mode in ['circle', 'horizontal', 'vertical']:
self.mode = mode
else:
msg = "Value of param mode must be in ['circle', 'horizontal', 'vertical']"
LOGGER.error(TAG, msg)
raise ValueError(msg)

def _set_auto_param(self, w, h):
self.color_start = (np.random.uniform(0, 255),) * 3
self.color_end = (np.random.uniform(0, 255),) * 3
self.start_point = (np.random.uniform(0, w), np.random.uniform(0, h))
self.scope = np.random.uniform(0, 1)
self.bright_rate = np.random.uniform(0.1, 0.9)
self.pattern = np.random.choice(['light', 'dark'])
self.mode = np.random.choice(['circle', 'horizontal', 'vertical'])

def __call__(self, image):
"""
Gradient adjusts the luminance of picture.

Args:
image (numpy.ndarray): Original image.

Returns:
numpy.ndarray, image with perlin noise.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
w, h = image.shape[:2]
if self.auto_param:
self._set_auto_param(w, h)
if self.mode == 'circle':
mask = _circle_gradient_mask(image, self.color_start, self.color_end, self.scope, self.start_point)
else:
mask = _line_gradient_mask(image, self.start_point, self.color_start, self.color_end, mode=self.mode)

if self.pattern == 'light':
img_new = cv2.addWeighted(image, 1, mask, self.bright_rate, 0.0)
else:
img_new = cv2.addWeighted(image, self.bright_rate, mask, 1 - self.bright_rate, 0.0)
img_new = self._original_format(img_new, chw, normalized, gray3dim)
return img_new.astype(ori_dtype)

+ 159
- 0
mindarmour/natural_robustness/image/natural_perturb.py View File

@@ -0,0 +1,159 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base class for image natural perturbation.
"""
import numpy as np

from mindspore.dataset.vision.py_transforms_util import is_numpy, hwc_to_chw
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'Image Transformation'


def _chw_to_hwc(img):
"""
Transpose the input image; shape (C, H, W) to shape (H, W, C).

Args:
img (numpy.ndarray): Image to be converted.

Returns:
img (numpy.ndarray), Converted image.
"""
if is_numpy(img):
return img.transpose(1, 2, 0).copy()
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def _is_hwc(img):
"""
Check if the input image is shape (H, W, C).

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is shape (H, W, C).
"""
if is_numpy(img):
img_shape = np.shape(img)
if img_shape[2] == 3 and img_shape[1] > 3 and img_shape[0] > 3:
return True
return False
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def _is_chw(img):
"""
Check if the input image is shape (H, W, C).

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is shape (H, W, C).
"""
if is_numpy(img):
img_shape = np.shape(img)
if img_shape[0] == 3 and img_shape[1] > 3 and img_shape[2] > 3:
return True
return False
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def _is_rgb(img):
"""
Check if the input image is RGB.

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is RGB.
"""
if is_numpy(img):
img_shape = np.shape(img)
if len(np.shape(img)) == 3 and (img_shape[0] == 3 or img_shape[2] == 3):
return True
return False
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img)))


def _is_normalized(img):
"""
Check if the input image is normalized between 0 to 1.

Args:
img (numpy.ndarray): Image to be checked.

Returns:
Bool, True if input is normalized between 0 to 1.
"""
if is_numpy(img):
minimal = np.min(img)
maximum = np.max(img)
if minimal >= 0 and maximum <= 1:
return True
return False
raise TypeError('img should be Numpy array. Got {}'.format(type(img)))


class _NaturalPerturb:
"""
The abstract base class for all image natural perturbation classes.
"""

def __init__(self):
pass

def _check(self, image):
""" Check image format. If input image is RGB and its shape
is (C, H, W), it will be transposed to (H, W, C). If the value
of the image is not normalized , it will be rescaled between 0 to 255."""
rgb = _is_rgb(image)
chw = False
gray3dim = False
normalized = _is_normalized(image)
if rgb:
chw = _is_chw(image)
if chw:
image = _chw_to_hwc(image)
else:
image = image
else:
if len(np.shape(image)) == 3:
gray3dim = True
image = image[0]
else:
image = image
if normalized:
image = image * 255
return rgb, chw, normalized, gray3dim, np.uint8(image)

def _original_format(self, image, chw, normalized, gray3dim):
""" Return image with original format. """
if not is_numpy(image):
image = np.array(image)
if chw:
image = hwc_to_chw(image)
if normalized:
image = image / 255
if gray3dim:
image = np.expand_dims(image, 0)
return image

def __call__(self, image):
pass

+ 365
- 0
mindarmour/natural_robustness/image/transformation.py View File

@@ -0,0 +1,365 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image transformation.
"""
import math
import numpy as np
import cv2

from mindarmour.natural_robustness.image.natural_perturb import _NaturalPerturb
from mindarmour.utils._check_param import check_param_multi_types, check_param_type, check_value_non_negative
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'Image Transformation'


class Translate(_NaturalPerturb):
"""
Translate an image.

Args:
x_bias (Union[int, float]): X-direction translation, x = x + x_bias*image_width. Suggested value range
in [-0.1, 0.1].
y_bias (Union[int, float]): Y-direction translation, y = y + y_bias*image_length. Suggested value range
in [-0.1, 0.1].
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> x_bias = 0.1
>>> y_bias = 0.1
>>> trans = Translate(x_bias, y_bias)
>>> dst = trans(img)
"""

def __init__(self, x_bias=0, y_bias=0, auto_param=False):
super(Translate, self).__init__()
self.x_bias = check_param_multi_types('x_bias', x_bias, [int, float])
self.y_bias = check_param_multi_types('y_bias', y_bias, [int, float])
if auto_param:
self.x_bias = np.random.uniform(-0.1, 0.1)
self.y_bias = np.random.uniform(-0.1, 0.1)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
h, w = image.shape[:2]
matrix = np.array([[1, 0, self.x_bias * w], [0, 1, self.y_bias * h]], dtype=np.float)
new_img = cv2.warpAffine(image, matrix, (w, h))
new_img = self._original_format(new_img, chw, normalized, gray3dim)
return new_img.astype(ori_dtype)


class Scale(_NaturalPerturb):
"""
Scale an image in the middle.

Args:
factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x. Suggested value range in [0.5, 1] and
abs(factor_y - factor_x) < 0.5.
factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y. Suggested value range in [0.5, 1] and
abs(factor_y - factor_x) < 0.5.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> factor_x = 0.7
>>> factor_y = 0.6
>>> trans = Scale(factor_x, factor_y)
>>> dst = trans(img)
"""

def __init__(self, factor_x=1, factor_y=1, auto_param=False):
super(Scale, self).__init__()
self.factor_x = check_param_multi_types('factor_x', factor_x, [int, float])
self.factor_y = check_param_multi_types('factor_y', factor_y, [int, float])
auto_param = check_param_type('auto_param', auto_param, bool)
if auto_param:
self.factor_x = np.random.uniform(0.5, 1)
self.factor_y = np.random.uniform(0.5, 1)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
h, w = image.shape[:2]
matrix = np.array([[self.factor_x, 0, 0], [0, self.factor_y, 0]], dtype=np.float)
new_img = cv2.warpAffine(image, matrix, (w, h))
new_img = self._original_format(new_img, chw, normalized, gray3dim)
return new_img.astype(ori_dtype)


class Shear(_NaturalPerturb):
"""
Shear an image, for each pixel (x, y) in the sheared image, the new value is taken from a position
(x+factor_x*y, factor_y*x+y) in the origin image. Then the sheared image will be rescaled to fit original size.

Args:
factor (Union[float, int]): Shear rate in shear direction. Suggested value range in [0.05, 0.5].
direction (str): Direction of deformation. Optional value is 'vertical' or 'horizontal'.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> factor = 0.2
>>> trans = Shear(factor, direction='horizontal')
>>> dst = trans(img)
"""

def __init__(self, factor=0.2, direction='horizontal', auto_param=False):
super(Shear, self).__init__()
self.factor = check_param_multi_types('factor', factor, [int, float])
if direction not in ['horizontal', 'vertical']:
msg = "'direction must be in ['horizontal', 'vertical'], but got {}".format(direction)
raise ValueError(msg)
self.direction = direction
auto_param = check_param_type('auto_params', auto_param, bool)
if auto_param:
self.factor = np.random.uniform(0.05, 0.5)
self.direction = np.random.choice(['horizontal', 'vertical'])

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
h, w = image.shape[:2]
if self.direction == 'horizontal':
matrix = np.array([[1, self.factor, 0], [0, 1, 0]], dtype=np.float)
nw = int(w + self.factor * h)
nh = h
else:
matrix = np.array([[1, 0, 0], [self.factor, 1, 0]], dtype=np.float)
nw = w
nh = int(h + self.factor * w)
new_img = cv2.warpAffine(image, matrix, (nw, nh))
new_img = cv2.resize(new_img, (w, h))
new_img = self._original_format(new_img, chw, normalized, gray3dim)
return new_img.astype(ori_dtype)


class Rotate(_NaturalPerturb):
"""
Rotate an image of counter clockwise around its center.

Args:
angle (Union[float, int]): Degrees of counter clockwise. Suggested value range in [-60, 60].
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> angle = 20
>>> trans = Rotate(angle)
>>> dst = trans(img)
"""

def __init__(self, angle=20, auto_param=False):
super(Rotate, self).__init__()
self.angle = check_param_multi_types('angle', angle, [int, float])
auto_param = check_param_type('auto_param', auto_param, bool)
if auto_param:
self.angle = np.random.uniform(0, 360)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, rotated image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
h, w = image.shape[:2]
center = (w // 2, h // 2)
matrix = cv2.getRotationMatrix2D(center, -self.angle, 1.0)
cos = np.abs(matrix[0, 0])
sin = np.abs(matrix[0, 1])

# Calculate new edge after rotated
nw = int((h * sin) + (w * cos))
nh = int((h * cos) + (w * sin))
# Adjust move distance of rotate matrix.
matrix[0, 2] += (nw / 2) - center[0]
matrix[1, 2] += (nh / 2) - center[1]
rotate = cv2.warpAffine(image, matrix, (nw, nh))
rotate = cv2.resize(rotate, (w, h))
new_img = self._original_format(rotate, chw, normalized, gray3dim)
return new_img.astype(ori_dtype)


class Perspective(_NaturalPerturb):
"""
Perform perspective transformation on a given picture.

Args:
ori_pos (list): Four points in original image.
dst_pos (list): The point coordinates of the 4 points in ori_pos after perspective transformation.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Example:
>>> img = cv2.imread('1.png')
>>> img = np.array(img)
>>> ori_pos = [[0, 0], [0, 800], [800, 0], [800, 800]]
>>> dst_pos = [[50, 0], [0, 800], [780, 0], [800, 800]]
>>> trans = Perspective(ori_pos, dst_pos)
>>> dst = trans(img)
"""

def __init__(self, ori_pos, dst_pos, auto_param=False):
super(Perspective, self).__init__()
ori_pos = check_param_type('ori_pos', ori_pos, list)
dst_pos = check_param_type('dst_pos', dst_pos, list)
self.ori_pos = np.float32(ori_pos)
self.dst_pos = np.float32(dst_pos)
self.auto_param = check_param_type('auto_param', auto_param, bool)

def _set_auto_param(self, w, h):
self.ori_pos = [[h * 0.25, w * 0.25], [h * 0.25, w * 0.75], [h * 0.75, w * 0.25], [h * 0.75, w * 0.75]]
self.dst_pos = [[np.random.uniform(0, h * 0.5), np.random.uniform(0, w * 0.5)],
[np.random.uniform(0, h * 0.5), np.random.uniform(w * 0.5, w)],
[np.random.uniform(h * 0.5, h), np.random.uniform(0, w * 0.5)],
[np.random.uniform(h * 0.5, h), np.random.uniform(w * 0.5, w)]]
self.ori_pos = np.float32(self.ori_pos)
self.dst_pos = np.float32(self.dst_pos)

def __call__(self, image):
"""
Transform the image.

Args:
image (numpy.ndarray): Original image to be transformed.

Returns:
numpy.ndarray, transformed image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
h, w = image.shape[:2]
if self.auto_param:
self._set_auto_param(w, h)
matrix = cv2.getPerspectiveTransform(self.ori_pos, self.dst_pos)
new_img = cv2.warpPerspective(image, matrix, (w, h))
new_img = self._original_format(new_img, chw, normalized, gray3dim)
return new_img.astype(ori_dtype)


class Curve(_NaturalPerturb):
"""
Curve picture using sin method.

Args:
curves (union[float, int]): Divide width to curves of `2*math.pi`, which means how many curve cycles. Suggested
value range in [0.1. 5].
depth (union[float, int]): Amplitude of sin method. Suggested value not exceed 1/10 of the length of the
picture.
mode (str): Direction of deformation. Optional value is 'vertical' or 'horizontal'.
auto_param (bool): Auto selected parameters. Selected parameters will preserve semantics of image.

Examples:
>>> img = cv2.imread('x.png')
>>> curves =1
>>> depth = 10
>>> trans = Curve(curves, depth, mode='vertical')
>>> img_new = trans(img)
"""

def __init__(self, curves=3, depth=10, mode='vertical', auto_param=False):
super(Curve).__init__()
self.curves = check_value_non_negative('curves', curves)
self.depth = check_value_non_negative('depth', depth)
if mode in ['vertical', 'horizontal']:
self.mode = mode
else:
msg = "Value of param mode must be in ['vertical', 'horizontal']"
LOGGER.error(TAG, msg)
raise ValueError(msg)
self.auto_param = check_param_type('auto_param', auto_param, bool)

def _set_auto_params(self, height, width):
if self.auto_param:
self.curves = np.random.uniform(1, 5)
self.mode = np.random.choice(['vertical', 'horizontal'])
if self.mode == 'vertical':
self.depth = np.random.uniform(1, 0.1 * width)
else:
self.depth = np.random.uniform(1, 0.1 * height)

def __call__(self, image):
"""
Curve picture using sin method.

Args:
image (numpy.ndarray): Original image.

Returns:
numpy.ndarray, curved image.
"""
ori_dtype = image.dtype
_, chw, normalized, gray3dim, image = self._check(image)
shape = image.shape
height, width = shape[:2]
if self.mode == 'vertical':
if len(shape) == 3:
image = np.transpose(image, [1, 0, 2])
else:
image = np.transpose(image, [1, 0])
src_x = np.zeros((height, width), np.float32)
src_y = np.zeros((height, width), np.float32)

for y in range(height):
for x in range(width):
src_x[y, x] = x
src_y[y, x] = y + self.depth * math.sin(x / (width / self.curves / 2 / math.pi))
img_new = cv2.remap(image, src_x, src_y, cv2.INTER_LINEAR)

if self.mode == 'vertical':
if len(shape) == 3:
img_new = np.transpose(img_new, [1, 0, 2])
else:
img_new = np.transpose(image, [1, 0])
new_img = self._original_format(img_new, chw, normalized, gray3dim)
return new_img.astype(ori_dtype)

+ 1
- 0
setup.py View File

@@ -132,6 +132,7 @@ setup(
'pillow >= 6.2.0',
'scikit-learn >= 0.23.1',
'easydict >= 1.9',
'opencv-python >= 4.1.2.30',
],
classifiers=[
'License :: OSI Approved :: Apache Software License'


Loading…
Cancel
Save