You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import os
  2. from typing import Union, Dict
  3. def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
  4. """
  5. 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果
  6. {
  7. 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。
  8. 'test': 'xxx' # 可能有,也可能没有
  9. ...
  10. }
  11. 如果paths为不合法的,将直接进行raise相应的错误
  12. :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名
  13. 中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。
  14. :return:
  15. """
  16. if isinstance(paths, str):
  17. if os.path.isfile(paths):
  18. return {'train': paths}
  19. elif os.path.isdir(paths):
  20. filenames = os.listdir(paths)
  21. files = {}
  22. for filename in filenames:
  23. path_pair = None
  24. if 'train' in filename:
  25. path_pair = ('train', filename)
  26. if 'dev' in filename:
  27. if path_pair:
  28. raise Exception("File:{} in {} contains both `{}` and `dev`.".format(filename, paths, path_pair[0]))
  29. path_pair = ('dev', filename)
  30. if 'test' in filename:
  31. if path_pair:
  32. raise Exception("File:{} in {} contains both `{}` and `test`.".format(filename, paths, path_pair[0]))
  33. path_pair = ('test', filename)
  34. if path_pair:
  35. if path_pair[0] in files:
  36. raise RuntimeError(f"Multiple file under {paths} have '{path_pair[0]}' in their filename.")
  37. files[path_pair[0]] = os.path.join(paths, path_pair[1])
  38. return files
  39. else:
  40. raise FileNotFoundError(f"{paths} is not a valid file path.")
  41. elif isinstance(paths, dict):
  42. if paths:
  43. if 'train' not in paths:
  44. raise KeyError("You have to include `train` in your dict.")
  45. for key, value in paths.items():
  46. if isinstance(key, str) and isinstance(value, str):
  47. if not os.path.isfile(value):
  48. raise TypeError(f"{value} is not a valid file.")
  49. else:
  50. raise TypeError("All keys and values in paths should be str.")
  51. return paths
  52. else:
  53. raise ValueError("Empty paths is not allowed.")
  54. else:
  55. raise TypeError(f"paths only supports str and dict. not {type(paths)}.")
  56. def get_tokenizer():
  57. try:
  58. import spacy
  59. spacy.prefer_gpu()
  60. en = spacy.load('en')
  61. print('use spacy tokenizer')
  62. return lambda x: [w.text for w in en.tokenizer(x)]
  63. except Exception as e:
  64. print('use raw tokenizer')
  65. return lambda x: x.split()