You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset.py 3.7 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test dataset module"""
  16. import pytest
  17. from easydict import EasyDict as edict
  18. from mindelec.data import Dataset
  19. from mindelec.geometry import create_config_from_edict
  20. from mindelec.geometry import Disk, Rectangle, TimeDomain, GeometryWithTime
  21. from mindelec.data import BoundaryBC, BoundaryIC
  22. from config import ds_config, src_sampling_config, no_src_sampling_config, bc_sampling_config
  23. ic_bc_config = edict({
  24. 'domain': edict({
  25. 'random_sampling': False,
  26. 'size': [10, 20],
  27. }),
  28. 'BC': edict({
  29. 'random_sampling': True,
  30. 'size': 10,
  31. 'with_normal': True,
  32. }),
  33. 'IC': edict({
  34. 'random_sampling': True,
  35. 'size': 10,
  36. }),
  37. 'time': edict({
  38. 'random_sampling': False,
  39. 'size': 10,
  40. })
  41. })
  42. @pytest.mark.level0
  43. @pytest.mark.platform_arm_ascend_training
  44. @pytest.mark.platform_x86_ascend_training
  45. @pytest.mark.env_onecard
  46. def test_dataset_allnone():
  47. with pytest.raises(ValueError):
  48. Dataset()
  49. @pytest.mark.level0
  50. @pytest.mark.platform_arm_ascend_training
  51. @pytest.mark.platform_x86_ascend_training
  52. @pytest.mark.env_onecard
  53. def test_dataset():
  54. """test dataset"""
  55. disk = Disk("src", (0.0, 0.0), 0.2)
  56. rectangle = Rectangle("rect", (-1, -1), (1, 1))
  57. diff = rectangle - disk
  58. time = TimeDomain("time", 0.0, 1.0)
  59. # check datalist
  60. rect_with_time = GeometryWithTime(rectangle, time)
  61. rect_with_time.set_sampling_config(create_config_from_edict(ic_bc_config))
  62. bc = BoundaryBC(rect_with_time)
  63. ic = BoundaryIC(rect_with_time)
  64. dataset = Dataset(dataset_list=bc)
  65. dataset.set_constraint_type("Equation")
  66. c_type1 = {bc: "Equation", ic: "Equation"}
  67. with pytest.raises(ValueError):
  68. dataset.set_constraint_type(c_type1)
  69. no_src_region = GeometryWithTime(diff, time)
  70. no_src_region.set_name("no_src")
  71. no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config))
  72. src_region = GeometryWithTime(disk, time)
  73. src_region.set_name("src")
  74. src_region.set_sampling_config(create_config_from_edict(src_sampling_config))
  75. boundary = GeometryWithTime(rectangle, time)
  76. boundary.set_name("bc")
  77. boundary.set_sampling_config(create_config_from_edict(bc_sampling_config))
  78. geom_dict = ['1', '2']
  79. with pytest.raises(TypeError):
  80. Dataset(geom_dict)
  81. geom_dict = {src_region: ["test"]}
  82. with pytest.raises(KeyError):
  83. Dataset(geom_dict)
  84. geom_dict = {src_region: ["domain", "IC"],
  85. no_src_region: ["domain", "IC"],
  86. boundary: ["BC"]}
  87. dataset = Dataset(geom_dict)
  88. with pytest.raises(ValueError):
  89. print(dataset[0])
  90. with pytest.raises(ValueError):
  91. len(dataset)
  92. with pytest.raises(ValueError):
  93. dataset.get_columns_list()
  94. with pytest.raises(ValueError):
  95. dataset.create_dataset(batch_size=ds_config.train.batch_size,
  96. shuffle=ds_config.train.shuffle,
  97. prebatched_data=True,
  98. drop_remainder=False)