You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_server_test.py 12 kB

first commit Former-commit-id: 08bc23ba02cffbce3cf63962390a65459a132e48 [formerly 0795edd4834b9b7dc66db8d10d4cbaf42bbf82cb] [formerly b5010b42541add7e2ea2578bf2da537efc457757 [formerly a7ca09c2c34c4fc8b3d8e01fcfa08eeeb2cae99d]] [formerly 615058473a2177ca5b89e9edbb797f4c2a59c7e5 [formerly 743d8dfc6843c4c205051a8ab309fbb2116c895e] [formerly bb0ea98b1e14154ef464e2f7a16738705894e54b [formerly 960a69da74b81ef8093820e003f2d6c59a34974c]]] [formerly 2fa3be52c1b44665bc81a7cc7d4cea4bbf0d91d5 [formerly 2054589f0898627e0a17132fd9d4cc78efc91867] [formerly 3b53730e8a895e803dfdd6ca72bc05e17a4164c1 [formerly 8a2fa8ab7baf6686d21af1f322df46fd58c60e69]] [formerly 87d1e3a07a19d03c7d7c94d93ab4fa9f58dada7c [formerly f331916385a5afac1234854ee8d7f160f34b668f] [formerly 69fb3c78a483343f5071da4f7e2891b83a49dd18 [formerly 386086f05aa9487f65bce2ee54438acbdce57650]]]] Former-commit-id: a00aed8c934a6460c4d9ac902b9a74a3d6864697 [formerly 26fdeca29c2f07916d837883983ca2982056c78e] [formerly 0e3170d41a2f99ecf5c918183d361d4399d793bf [formerly 3c12ad4c88ac5192e0f5606ac0d88dd5bf8602dc]] [formerly d5894f84f2fd2e77a6913efdc5ae388cf1be0495 [formerly ad3e7bc670ff92c992730d29c9d3aa1598d844e8] [formerly 69fb3c78a483343f5071da4f7e2891b83a49dd18]] Former-commit-id: 3c19c9fae64f6106415fbc948a4dc613b9ee12f8 [formerly 467ddc0549c74bb007e8f01773bb6dc9103b417d] [formerly 5fa518345d958e2760e443b366883295de6d991c [formerly 3530e130b9fdb7280f638dbc2e785d2165ba82aa]] Former-commit-id: 9f5d473d42a435ec0d60149939d09be1acc25d92 [formerly be0b25c4ec2cde052a041baf0e11f774a158105d] Former-commit-id: 9eca71cb73ba9edccd70ac06a3b636b8d4093b04
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # from __future__ import print_function
  2. import argparse
  3. import os
  4. import pathlib
  5. from pprint import pprint
  6. import grpc
  7. from d3m import utils as d3m_utils, runtime as runtime_module
  8. from d3m.metadata import problem as problem_module
  9. from ta3ta2_api import core_pb2, core_pb2_grpc, value_pb2, utils
  10. from axolotl.utils import pipeline as pipeline_utils
  11. from axolotl.d3m_grpc import constants
  12. # with d3m_utils.silence():
  13. # d3m_index.load_all(blocklist=constants.PrimitivesList.BLACK_LIST)
  14. # primitives = [
  15. # 'd3m.primitives.datasets.DatasetToDataFrame',
  16. # 'd3m.primitives.data_transformation.denormalize.Common'
  17. # ]
  18. #
  19. # with d3m_utils.silence():
  20. # for primitive in primitives:
  21. # d3m_index.get_primitive(primitive)
  22. LENGTH = 60
  23. ALLOWED_VALUE_TYPES = ['DATASET_URI', 'CSV_URI', 'RAW']
  24. FULL_SPECIFIED_PIPELINE_PATH = 'modules/server/test_full_pipeline.json'
  25. PRE_SPECIFIED_PIPELINE_PATH = 'modules/server/test_placeholder.json'
  26. # PRE_SPECIFIED_PIPELINE_PATH = 'modules/server/test_placeholder_pipeline.json'
  27. def hello_request():
  28. request = core_pb2.HelloRequest()
  29. return request
  30. def list_primitives_request():
  31. request = core_pb2.ListPrimitivesRequest()
  32. return request
  33. def search_solutions_request(test_paths, specified_template=None):
  34. user_agent = "test_agent"
  35. version = core_pb2.DESCRIPTOR.GetOptions().Extensions[core_pb2.protocol_version]
  36. time_bound = 0.5
  37. priority = 10
  38. # allowed_value_types = [value_pb2.ValueType.Value(value) for value in ALLOWED_VALUE_TYPES]
  39. problem_description = utils.encode_problem_description(
  40. problem_module.Problem.load(test_paths['TRAIN']['problem'])
  41. )
  42. template = None
  43. if specified_template == 'FULL':
  44. with d3m_utils.silence():
  45. pipeline = pipeline_utils.load_pipeline(FULL_SPECIFIED_PIPELINE_PATH)
  46. template = utils.encode_pipeline_description(pipeline, ALLOWED_VALUE_TYPES, constants.Path.TEMP_STORAGE_ROOT)
  47. elif specified_template == 'PRE': # PRE for PREPROCESSING
  48. pipeline = runtime_module.get_pipeline(PRE_SPECIFIED_PIPELINE_PATH, load_all_primitives=False)
  49. template = utils.encode_pipeline_description(pipeline, ALLOWED_VALUE_TYPES, constants.Path.TEMP_STORAGE_ROOT)
  50. inputs = [
  51. value_pb2.Value(
  52. dataset_uri=test_paths['TRAIN']['dataset']
  53. )
  54. ]
  55. request = core_pb2.SearchSolutionsRequest(
  56. user_agent=user_agent,
  57. version=version,
  58. time_bound_search=time_bound,
  59. priority=priority,
  60. allowed_value_types=ALLOWED_VALUE_TYPES,
  61. problem=problem_description,
  62. template=template,
  63. inputs=inputs
  64. )
  65. return request
  66. def get_search_solution_results_request(search_id):
  67. request = core_pb2.GetSearchSolutionsResultsRequest(search_id=search_id)
  68. return request
  69. def fit_solution_request(solution_id, test_paths):
  70. inputs = [
  71. value_pb2.Value(
  72. dataset_uri=test_paths['TRAIN']['dataset']
  73. )
  74. ]
  75. expose_outputs = ['outputs.0']
  76. expose_value_types = ['CSV_URI']
  77. users = [
  78. core_pb2.SolutionRunUser(
  79. id='test_user',
  80. chosen=True,
  81. reason='just because'
  82. )
  83. ]
  84. request = core_pb2.FitSolutionRequest(
  85. solution_id=solution_id,
  86. inputs=inputs,
  87. expose_outputs=expose_outputs,
  88. expose_value_types=expose_value_types,
  89. users=users
  90. )
  91. return request
  92. def get_fit_solution_results_request(request_id):
  93. request = core_pb2.GetFitSolutionResultsRequest(
  94. request_id=request_id
  95. )
  96. return request
  97. def produce_solution_request(fitted_solution_id, test_paths):
  98. inputs = [
  99. value_pb2.Value(
  100. dataset_uri=test_paths['TEST']['dataset']
  101. )
  102. ]
  103. expose_outputs = ['outputs.0']
  104. expose_value_types = ['CSV_URI']
  105. users = [
  106. core_pb2.SolutionRunUser(
  107. id='test_user',
  108. chosen=True,
  109. reason='just because'
  110. )
  111. ]
  112. request = core_pb2.ProduceSolutionRequest(
  113. fitted_solution_id=fitted_solution_id,
  114. inputs=inputs,
  115. expose_outputs=expose_outputs,
  116. expose_value_types=expose_value_types,
  117. users=users
  118. )
  119. return request
  120. def get_produce_solution_results_request(request_id):
  121. request = core_pb2.GetProduceSolutionResultsRequest(
  122. request_id=request_id
  123. )
  124. return request
  125. def describe_solution_request(solution_id):
  126. request = core_pb2.DescribeSolutionRequest(
  127. solution_id=solution_id
  128. )
  129. return request
  130. def score_solution_request(solution_id, test_paths):
  131. inputs = [
  132. value_pb2.Value(
  133. dataset_uri=test_paths['SCORE']['dataset']
  134. )
  135. ]
  136. problem = problem_module.Problem.load(test_paths['SCORE']['problem'])
  137. performance_metrics = []
  138. for performance_metric in problem['problem'].get('performance_metrics', []):
  139. performance_metrics.append(utils.encode_performance_metric(performance_metric))
  140. # TODO add support for more evaluation methods
  141. users = []
  142. evaluation_method = 'K_FOLD'
  143. configuration = core_pb2.ScoringConfiguration(
  144. method=evaluation_method,
  145. folds=2,
  146. # train_test_ratio
  147. shuffle=True,
  148. random_seed=42,
  149. stratified=True,
  150. )
  151. request = core_pb2.ScoreSolutionRequest(
  152. solution_id=solution_id,
  153. inputs=inputs,
  154. performance_metrics=performance_metrics,
  155. users=users,
  156. configuration=configuration
  157. )
  158. return request
  159. def get_score_solution_request(solution_id):
  160. request = core_pb2.ScoreSolutionRequest(
  161. solution_id=solution_id
  162. )
  163. return request
  164. def solution_export_request(solution_id):
  165. rank = 0.1
  166. request = core_pb2.SolutionExportRequest(
  167. solution_id=solution_id,
  168. rank=rank
  169. )
  170. return request
  171. def end_search_solutions_request(search_id):
  172. request = core_pb2.EndSearchSolutionsRequest(search_id=search_id)
  173. return request
  174. def stop_search_solution_request(search_id):
  175. request = core_pb2.StopSearchSolutionsRequest(search_id=search_id)
  176. return request
  177. def run(test_paths, specified_template=None):
  178. channel = grpc.insecure_channel('localhost:45042')
  179. stub = core_pb2_grpc.CoreStub(channel)
  180. print_name('Hello')
  181. hello_r = stub.Hello(hello_request())
  182. pprint(hello_r)
  183. print_name('ListPrimitive')
  184. list_primitives_r = stub.ListPrimitives(list_primitives_request())
  185. for _primitive in list_primitives_r.primitives:
  186. print_space()
  187. pprint(_primitive)
  188. print_name('SearchSolution')
  189. search_solutions_r = stub.SearchSolutions(search_solutions_request(test_paths, specified_template))
  190. search_id = search_solutions_r.search_id
  191. pprint(search_solutions_r)
  192. print_name('GetSearchSolutionsResults')
  193. solution_id = None
  194. for get_search_solution_r in stub.GetSearchSolutionsResults(get_search_solution_results_request(search_id)):
  195. print_space()
  196. pprint(get_search_solution_r)
  197. if get_search_solution_r.solution_id:
  198. solution_id = get_search_solution_r.solution_id
  199. print_name('DescribeSolution')
  200. describe_solution_r = stub.DescribeSolution(describe_solution_request(solution_id))
  201. pprint(describe_solution_r)
  202. print_name('FitSolution')
  203. fit_solution_r = stub.FitSolution(fit_solution_request(solution_id, test_paths))
  204. fit_request_id = fit_solution_r.request_id
  205. pprint(fit_solution_r)
  206. print_name('GetFitSolutionResultsRequest')
  207. fitted_solution_id = None
  208. for get_git_solution_results_r in stub.GetFitSolutionResults(get_fit_solution_results_request(fit_request_id)):
  209. print_space()
  210. pprint(get_git_solution_results_r)
  211. fitted_solution_id = get_git_solution_results_r.fitted_solution_id
  212. print_name('ProduceSolutionRequest')
  213. produce_solution_r = stub.ProduceSolution(produce_solution_request(fitted_solution_id, test_paths))
  214. produce_request_id = produce_solution_r.request_id
  215. pprint(produce_solution_r)
  216. print_name('GetProduceSolutionResultsRequest')
  217. for get_produce_solution_results_r in stub.GetProduceSolutionResults(
  218. get_produce_solution_results_request(produce_request_id)):
  219. print_space()
  220. pprint(get_produce_solution_results_r)
  221. print_name('ScoreSolution')
  222. score_solution_r = stub.ScoreSolution(score_solution_request(solution_id, test_paths))
  223. score_request_id = score_solution_r.request_id
  224. pprint(score_solution_r)
  225. print_name('GetScoreSolutionResults')
  226. for score_solution_r in stub.GetScoreSolutionResults(get_score_solution_request(score_request_id)):
  227. print_space()
  228. pprint(score_solution_r)
  229. print_name('SolutionExport')
  230. solution_export_r = stub.SolutionExport(solution_export_request(solution_id))
  231. pprint(solution_export_r)
  232. print_name('StopSearchSolutions')
  233. stop_search_solution_r = stub.StopSearchSolutions(stop_search_solution_request(search_id))
  234. pprint(stop_search_solution_r)
  235. print_name('EndSearchSolutions')
  236. end_search_solutions_r = stub.EndSearchSolutions(end_search_solutions_request(search_id))
  237. pprint(end_search_solutions_r)
  238. def print_name(name):
  239. length = LENGTH
  240. free_space = length - len(name) - 2
  241. space = int(free_space / 2)
  242. name = '#' + ' ' * space + name + ' ' * space
  243. if free_space % 2 == 0:
  244. name = name + '#'
  245. else:
  246. name = name + ' #'
  247. print("#" * length)
  248. print(name)
  249. print("#" * length)
  250. def print_space():
  251. print('-' * LENGTH)
  252. def configure_parser(parser, *, skip_arguments=()):
  253. parser.add_argument(
  254. '-t', '--test-path', type=str, default="/D3M/internal_d3m/Winter_2018_tamuta2/datasets/26/",
  255. help="path of d3m dataset to test."
  256. )
  257. def get_problem_id(test_path):
  258. problem_description = problem_module.Problem.load(test_path)
  259. print(problem_description)
  260. problem_id = problem_description.get('id', None)
  261. return problem_id
  262. def get_paths(test_path):
  263. # Classification Score dataset path is (problem_SCORE, dataset_SCORE) not
  264. # However, regression and other Score dataset path is (problem_TEST, dataset_TEST)
  265. score_problem_relative_path = os.path.join(test_path, 'SCORE/problem_SCORE/problemDoc.json')
  266. score_dataset_relative_path = os.path.join(test_path, 'SCORE/dataset_SCORE/datasetDoc.json')
  267. if not os.path.exists(score_problem_relative_path) or not os.path.exists(score_dataset_relative_path):
  268. score_problem_relative_path = os.path.join(test_path, 'SCORE/problem_TEST/problemDoc.json')
  269. score_dataset_relative_path = os.path.join(test_path, 'SCORE/dataset_TEST/datasetDoc.json')
  270. test_paths = {
  271. 'TRAIN': {
  272. 'dataset': os.path.join(test_path, 'TRAIN/dataset_TRAIN/datasetDoc.json'),
  273. 'problem': pathlib.Path(
  274. os.path.abspath(os.path.join(test_path, 'TRAIN/problem_TRAIN/problemDoc.json'))).as_uri()
  275. },
  276. 'TEST': {
  277. 'dataset': os.path.join(test_path, 'TEST/dataset_TEST/datasetDoc.json'),
  278. 'problem': pathlib.Path(
  279. os.path.abspath(os.path.join(test_path, 'TEST/problem_TEST/problemDoc.json'))).as_uri()
  280. },
  281. 'SCORE': {
  282. 'dataset': os.path.join(test_path, score_dataset_relative_path),
  283. 'problem': pathlib.Path(os.path.abspath(score_problem_relative_path)).as_uri()
  284. },
  285. }
  286. return test_paths
  287. if __name__ == '__main__':
  288. # Creating parser
  289. parser = argparse.ArgumentParser(description="Test from command line")
  290. configure_parser(parser)
  291. arguments = parser.parse_args()
  292. # Getting test root path
  293. test_path = arguments.test_path
  294. # Getting test paths train/test/score
  295. test_paths = get_paths(test_path)
  296. # Getting problem id
  297. test_id = get_problem_id(test_paths['TEST']['problem'])
  298. print_name('Starting Test: ' + test_id)
  299. run(test_paths, None)
  300. print_name('Finishing Test: ' + test_id)

全栈的自动化机器学习系统,主要针对多变量时间序列数据的异常检测。TODS提供了详尽的用于构建基于机器学习的异常检测系统的模块,它们包括:数据处理(data processing),时间序列处理( time series processing),特征分析(feature analysis),检测算法(detection algorithms),和强化模块( reinforcement module)。这些模块所提供的功能包括常见的数据预处理、时间序列数据的平滑或变换,从时域或频域中抽取特征、多种多样的检测算

Contributors (1)