@@ -22,8 +22,8 @@ import numpy as np | |||||
from PIL import Image | from PIL import Image | ||||
import lmdb | import lmdb | ||||
from mindspore.communication.management import get_rank, get_group_size | from mindspore.communication.management import get_rank, get_group_size | ||||
from src.model_utils.config import config | |||||
from src.util import CTCLabelConverter | |||||
from .model_utils.config import config | |||||
from .util import CTCLabelConverter | |||||
class NormalizePAD: | class NormalizePAD: | ||||
@@ -88,7 +88,7 @@ def generate_adv_iii5t_3000(lmdb_paths, lmdb_save_path, perturb_config): | |||||
print(f'num of samples in IIIT dataset: {len(filtered_index_list)}') | print(f'num of samples in IIIT dataset: {len(filtered_index_list)}') | ||||
client = Client("10.113.216.54:5500", "perturbation", "natural_perturbation") | |||||
client = Client("0.0.0.0:5500", "perturbation", "natural_perturbation") | |||||
start_time = time.time() | start_time = time.time() | ||||
result = client.infer(instances) | result = client.infer(instances) | ||||
end_time = time.time() | end_time = time.time() | ||||
@@ -101,9 +101,12 @@ def generate_adv_iii5t_3000(lmdb_paths, lmdb_save_path, perturb_config): | |||||
for i, index in enumerate(filtered_index_list): | for i, index in enumerate(filtered_index_list): | ||||
try: | try: | ||||
file_names = result[i]['file_names'].split(';') | file_names = result[i]['file_names'].split(';') | ||||
except: | |||||
except KeyError: | |||||
error_msg = result[i] | error_msg = result[i] | ||||
raise ValueError(error_msg) | |||||
msg = 'serving failed to generate the {}th image in origin dataset with ' \ | |||||
'error messages: {}'.format(i, error_msg) | |||||
print(KeyError(msg)) | |||||
continue | |||||
length = result[i]['file_length'].tolist() | length = result[i]['file_length'].tolist() | ||||
before = 0 | before = 0 | ||||
@@ -209,7 +209,7 @@ | |||||
3. 请求自然扰动serving服务,并保存serving返回的数据 | 3. 请求自然扰动serving服务,并保存serving返回的数据 | ||||
```python | ```python | ||||
client = Client("10.113.216.54:5500", "perturbation", "natural_perturbation") | |||||
client = Client("0.0.0.0:5500", "perturbation", "natural_perturbation") | |||||
start_time = time.time() | start_time = time.time() | ||||
result = client.infer(instances) | result = client.infer(instances) | ||||
end_time = time.time() | end_time = time.time() | ||||
@@ -47,6 +47,8 @@ def check_inputs(img, perturb_config, methods_number, outputs_number): | |||||
if not np.any(img): | if not np.any(img): | ||||
raise ValueError("img cannot be empty.") | raise ValueError("img cannot be empty.") | ||||
img = Image.open(BytesIO(img)) | img = Image.open(BytesIO(img)) | ||||
if img.mode == "L": | |||||
img = img.convert('RGB') | |||||
img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) | img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) | ||||
config = json.loads(perturb_config) | config = json.loads(perturb_config) | ||||