From 6a60d2d8c839dadc1b621bc41788ac8dcbed809e Mon Sep 17 00:00:00 2001 From: liuluobin Date: Mon, 14 Sep 2020 09:56:53 +0800 Subject: [PATCH] Specified output_numpy to be True in function create_tuple_dict. Rename some files. --- examples/ai_fuzzer/lenet5_mnist_coverage.py | 4 +- examples/ai_fuzzer/lenet5_mnist_fuzzing.py | 4 +- .../black_box/mnist_attack_genetic.py | 2 +- .../model_attacks/black_box/mnist_attack_hsja.py | 2 +- .../model_attacks/black_box/mnist_attack_nes.py | 2 +- .../black_box/mnist_attack_pointwise.py | 2 +- .../model_attacks/black_box/mnist_attack_pso.py | 2 +- .../black_box/mnist_attack_salt_and_pepper.py | 2 +- .../model_attacks/white_box/mnist_attack_cw.py | 2 +- .../white_box/mnist_attack_deepfool.py | 2 +- .../model_attacks/white_box/mnist_attack_fgsm.py | 2 +- .../model_attacks/white_box/mnist_attack_jsma.py | 2 +- .../model_attacks/white_box/mnist_attack_lbfgs.py | 2 +- .../white_box/mnist_attack_mdi2fgsm.py | 2 +- .../model_attacks/white_box/mnist_attack_pgd.py | 2 +- .../model_defenses/mnist_defense_nad.py | 4 +- .../model_defenses/mnist_evaluation.py | 2 +- .../model_defenses/mnist_similarity_detector.py | 2 +- examples/privacy/README.md | 8 ++-- .../__init__.py | 0 .../eval.py | 0 .../example_vgg_cifar.py} | 0 .../train.py | 0 .../privacy/evaluation/membership_inference.py | 44 +++++++++++----------- .../adv_robustness/attacks/test_gradient_method.py | 9 ++++- 25 files changed, 54 insertions(+), 49 deletions(-) rename examples/privacy/{membership_inference_attack => membership_inference}/__init__.py (100%) rename examples/privacy/{membership_inference_attack => membership_inference}/eval.py (100%) rename examples/privacy/{membership_inference_attack/vgg_cifar_attack.py => membership_inference/example_vgg_cifar.py} (100%) rename examples/privacy/{membership_inference_attack => membership_inference}/train.py (100%) diff --git a/examples/ai_fuzzer/lenet5_mnist_coverage.py b/examples/ai_fuzzer/lenet5_mnist_coverage.py index 8b6234e..355c95b 100644 --- a/examples/ai_fuzzer/lenet5_mnist_coverage.py +++ b/examples/ai_fuzzer/lenet5_mnist_coverage.py @@ -42,7 +42,7 @@ def test_lenet_mnist_coverage(): batch_size = 32 ds = generate_mnist_dataset(data_list, batch_size, sparse=True) train_images = [] - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): images = data[0].astype(np.float32) train_images.append(images) train_images = np.concatenate(train_images, axis=0) @@ -57,7 +57,7 @@ def test_lenet_mnist_coverage(): ds = generate_mnist_dataset(data_list, batch_size, sparse=True) test_images = [] test_labels = [] - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): images = data[0].astype(np.float32) labels = data[1] test_images.append(images) diff --git a/examples/ai_fuzzer/lenet5_mnist_fuzzing.py b/examples/ai_fuzzer/lenet5_mnist_fuzzing.py index 3fb9e4a..b7f2a77 100644 --- a/examples/ai_fuzzer/lenet5_mnist_fuzzing.py +++ b/examples/ai_fuzzer/lenet5_mnist_fuzzing.py @@ -59,7 +59,7 @@ def test_lenet_mnist_fuzzing(): batch_size = 32 ds = generate_mnist_dataset(data_list, batch_size, sparse=False) train_images = [] - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): images = data[0].astype(np.float32) train_images.append(images) train_images = np.concatenate(train_images, axis=0) @@ -74,7 +74,7 @@ def test_lenet_mnist_fuzzing(): ds = generate_mnist_dataset(data_list, batch_size, sparse=False) test_images = [] test_labels = [] - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): images = data[0].astype(np.float32) labels = data[1] test_images.append(images) diff --git a/examples/model_security/model_attacks/black_box/mnist_attack_genetic.py b/examples/model_security/model_attacks/black_box/mnist_attack_genetic.py index be34e27..bce2974 100644 --- a/examples/model_security/model_attacks/black_box/mnist_attack_genetic.py +++ b/examples/model_security/model_attacks/black_box/mnist_attack_genetic.py @@ -67,7 +67,7 @@ def test_genetic_attack_on_mnist(): test_labels = [] predict_labels = [] i = 0 - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): i += 1 images = data[0].astype(np.float32) labels = data[1] diff --git a/examples/model_security/model_attacks/black_box/mnist_attack_hsja.py b/examples/model_security/model_attacks/black_box/mnist_attack_hsja.py index 7f02568..13e23e0 100644 --- a/examples/model_security/model_attacks/black_box/mnist_attack_hsja.py +++ b/examples/model_security/model_attacks/black_box/mnist_attack_hsja.py @@ -88,7 +88,7 @@ def test_hsja_mnist_attack(): test_labels = [] predict_labels = [] i = 0 - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): i += 1 images = data[0].astype(np.float32) labels = data[1] diff --git a/examples/model_security/model_attacks/black_box/mnist_attack_nes.py b/examples/model_security/model_attacks/black_box/mnist_attack_nes.py index bdde962..96ca217 100644 --- a/examples/model_security/model_attacks/black_box/mnist_attack_nes.py +++ b/examples/model_security/model_attacks/black_box/mnist_attack_nes.py @@ -98,7 +98,7 @@ def test_nes_mnist_attack(): test_labels = [] predict_labels = [] i = 0 - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): i += 1 images = data[0].astype(np.float32) labels = data[1] diff --git a/examples/model_security/model_attacks/black_box/mnist_attack_pointwise.py b/examples/model_security/model_attacks/black_box/mnist_attack_pointwise.py index febe94c..f835e58 100644 --- a/examples/model_security/model_attacks/black_box/mnist_attack_pointwise.py +++ b/examples/model_security/model_attacks/black_box/mnist_attack_pointwise.py @@ -68,7 +68,7 @@ def test_pointwise_attack_on_mnist(): test_labels = [] predict_labels = [] i = 0 - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): i += 1 images = data[0].astype(np.float32) labels = data[1] diff --git a/examples/model_security/model_attacks/black_box/mnist_attack_pso.py b/examples/model_security/model_attacks/black_box/mnist_attack_pso.py index 997cfa5..f98bfb2 100644 --- a/examples/model_security/model_attacks/black_box/mnist_attack_pso.py +++ b/examples/model_security/model_attacks/black_box/mnist_attack_pso.py @@ -67,7 +67,7 @@ def test_pso_attack_on_mnist(): test_labels = [] predict_labels = [] i = 0 - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): i += 1 images = data[0].astype(np.float32) labels = data[1] diff --git a/examples/model_security/model_attacks/black_box/mnist_attack_salt_and_pepper.py b/examples/model_security/model_attacks/black_box/mnist_attack_salt_and_pepper.py index ec813bc..60d4f5c 100644 --- a/examples/model_security/model_attacks/black_box/mnist_attack_salt_and_pepper.py +++ b/examples/model_security/model_attacks/black_box/mnist_attack_salt_and_pepper.py @@ -68,7 +68,7 @@ def test_salt_and_pepper_attack_on_mnist(): test_labels = [] predict_labels = [] i = 0 - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): i += 1 images = data[0].astype(np.float32) labels = data[1] diff --git a/examples/model_security/model_attacks/white_box/mnist_attack_cw.py b/examples/model_security/model_attacks/white_box/mnist_attack_cw.py index 65602c7..bef43dc 100644 --- a/examples/model_security/model_attacks/white_box/mnist_attack_cw.py +++ b/examples/model_security/model_attacks/white_box/mnist_attack_cw.py @@ -54,7 +54,7 @@ def test_carlini_wagner_attack(): test_labels = [] predict_labels = [] i = 0 - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): i += 1 images = data[0].astype(np.float32) labels = data[1] diff --git a/examples/model_security/model_attacks/white_box/mnist_attack_deepfool.py b/examples/model_security/model_attacks/white_box/mnist_attack_deepfool.py index 3fccbf1..4ae2adf 100644 --- a/examples/model_security/model_attacks/white_box/mnist_attack_deepfool.py +++ b/examples/model_security/model_attacks/white_box/mnist_attack_deepfool.py @@ -54,7 +54,7 @@ def test_deepfool_attack(): test_labels = [] predict_labels = [] i = 0 - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): i += 1 images = data[0].astype(np.float32) labels = data[1] diff --git a/examples/model_security/model_attacks/white_box/mnist_attack_fgsm.py b/examples/model_security/model_attacks/white_box/mnist_attack_fgsm.py index cbeeeeb..3aa76f4 100644 --- a/examples/model_security/model_attacks/white_box/mnist_attack_fgsm.py +++ b/examples/model_security/model_attacks/white_box/mnist_attack_fgsm.py @@ -55,7 +55,7 @@ def test_fast_gradient_sign_method(): test_labels = [] predict_labels = [] i = 0 - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): i += 1 images = data[0].astype(np.float32) labels = data[1] diff --git a/examples/model_security/model_attacks/white_box/mnist_attack_jsma.py b/examples/model_security/model_attacks/white_box/mnist_attack_jsma.py index 7f8f536..f8b3c01 100644 --- a/examples/model_security/model_attacks/white_box/mnist_attack_jsma.py +++ b/examples/model_security/model_attacks/white_box/mnist_attack_jsma.py @@ -54,7 +54,7 @@ def test_jsma_attack(): test_labels = [] predict_labels = [] i = 0 - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): i += 1 images = data[0].astype(np.float32) labels = data[1] diff --git a/examples/model_security/model_attacks/white_box/mnist_attack_lbfgs.py b/examples/model_security/model_attacks/white_box/mnist_attack_lbfgs.py index dd0a775..38f6943 100644 --- a/examples/model_security/model_attacks/white_box/mnist_attack_lbfgs.py +++ b/examples/model_security/model_attacks/white_box/mnist_attack_lbfgs.py @@ -55,7 +55,7 @@ def test_lbfgs_attack(): test_labels = [] predict_labels = [] i = 0 - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): i += 1 images = data[0].astype(np.float32) labels = data[1] diff --git a/examples/model_security/model_attacks/white_box/mnist_attack_mdi2fgsm.py b/examples/model_security/model_attacks/white_box/mnist_attack_mdi2fgsm.py index ed2334b..8390277 100644 --- a/examples/model_security/model_attacks/white_box/mnist_attack_mdi2fgsm.py +++ b/examples/model_security/model_attacks/white_box/mnist_attack_mdi2fgsm.py @@ -56,7 +56,7 @@ def test_momentum_diverse_input_iterative_method(): test_labels = [] predict_labels = [] i = 0 - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): i += 1 images = data[0].astype(np.float32) labels = data[1] diff --git a/examples/model_security/model_attacks/white_box/mnist_attack_pgd.py b/examples/model_security/model_attacks/white_box/mnist_attack_pgd.py index 11ea3f3..1aff061 100644 --- a/examples/model_security/model_attacks/white_box/mnist_attack_pgd.py +++ b/examples/model_security/model_attacks/white_box/mnist_attack_pgd.py @@ -55,7 +55,7 @@ def test_projected_gradient_descent_method(): test_labels = [] predict_labels = [] i = 0 - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): i += 1 images = data[0].astype(np.float32) labels = data[1] diff --git a/examples/model_security/model_defenses/mnist_defense_nad.py b/examples/model_security/model_defenses/mnist_defense_nad.py index 35b52af..3711c37 100644 --- a/examples/model_security/model_defenses/mnist_defense_nad.py +++ b/examples/model_security/model_defenses/mnist_defense_nad.py @@ -56,7 +56,7 @@ def test_nad_method(): batch_size=batch_size, repeat_size=1) inputs = [] labels = [] - for data in ds_test.create_tuple_iterator(): + for data in ds_test.create_tuple_iterator(output_numpy=True): inputs.append(data[0].astype(np.float32)) labels.append(data[1]) inputs = np.concatenate(inputs) @@ -99,7 +99,7 @@ def test_nad_method(): batch_size=batch_size, repeat_size=1) inputs_train = [] labels_train = [] - for data in ds_train.create_tuple_iterator(): + for data in ds_train.create_tuple_iterator(output_numpy=True): inputs_train.append(data[0].astype(np.float32)) labels_train.append(data[1]) inputs_train = np.concatenate(inputs_train) diff --git a/examples/model_security/model_defenses/mnist_evaluation.py b/examples/model_security/model_defenses/mnist_evaluation.py index a2131c8..3adc619 100644 --- a/examples/model_security/model_defenses/mnist_evaluation.py +++ b/examples/model_security/model_defenses/mnist_evaluation.py @@ -138,7 +138,7 @@ def test_defense_evaluation(): ds_test = generate_mnist_dataset(data_list, batch_size=batch_size) inputs = [] labels = [] - for data in ds_test.create_tuple_iterator(): + for data in ds_test.create_tuple_iterator(output_numpy=True): inputs.append(data[0].astype(np.float32)) labels.append(data[1]) inputs = np.concatenate(inputs).astype(np.float32) diff --git a/examples/model_security/model_defenses/mnist_similarity_detector.py b/examples/model_security/model_defenses/mnist_similarity_detector.py index c179052..99f1506 100644 --- a/examples/model_security/model_defenses/mnist_similarity_detector.py +++ b/examples/model_security/model_defenses/mnist_similarity_detector.py @@ -108,7 +108,7 @@ def test_similarity_detector(): true_labels = [] predict_labels = [] i = 0 - for data in ds.create_tuple_iterator(): + for data in ds.create_tuple_iterator(output_numpy=True): i += 1 images = data[0].astype(np.float32) labels = data[1] diff --git a/examples/privacy/README.md b/examples/privacy/README.md index f09392c..be94d59 100644 --- a/examples/privacy/README.md +++ b/examples/privacy/README.md @@ -24,10 +24,10 @@ With adaptive norm clip mechanism, the norm clip of the gradients would be chang $ cd examples/privacy/diff_privacy $ python lenet5_dp.py ``` -## 3. Membership inference attack -By this attack method, we could judge whether a sample is belongs to training dataset or not. +## 3. Membership inference evaluation +By this evaluation method, we could judge whether a sample is belongs to training dataset or not. ```sh $ cd examples/privacy/membership_inference_attack -$ python vgg_cifar_attack.py +$ python train.py --data_path home_path_to_cifar100 --ckpt_path ./ +$ python example_vgg_cifar.py --data_path home_path_to_cifar100 --pre_trained 0-100_781.ckpt ``` - diff --git a/examples/privacy/membership_inference_attack/__init__.py b/examples/privacy/membership_inference/__init__.py similarity index 100% rename from examples/privacy/membership_inference_attack/__init__.py rename to examples/privacy/membership_inference/__init__.py diff --git a/examples/privacy/membership_inference_attack/eval.py b/examples/privacy/membership_inference/eval.py similarity index 100% rename from examples/privacy/membership_inference_attack/eval.py rename to examples/privacy/membership_inference/eval.py diff --git a/examples/privacy/membership_inference_attack/vgg_cifar_attack.py b/examples/privacy/membership_inference/example_vgg_cifar.py similarity index 100% rename from examples/privacy/membership_inference_attack/vgg_cifar_attack.py rename to examples/privacy/membership_inference/example_vgg_cifar.py diff --git a/examples/privacy/membership_inference_attack/train.py b/examples/privacy/membership_inference/train.py similarity index 100% rename from examples/privacy/membership_inference_attack/train.py rename to examples/privacy/membership_inference/train.py diff --git a/mindarmour/privacy/evaluation/membership_inference.py b/mindarmour/privacy/evaluation/membership_inference.py index d395cb7..e7bc90c 100644 --- a/mindarmour/privacy/evaluation/membership_inference.py +++ b/mindarmour/privacy/evaluation/membership_inference.py @@ -43,7 +43,7 @@ def _eval_info(pred, truth, option): values are 'precision', 'accuracy' and 'recall'. Returns: - float32, Calculated evaluation results. + float32, calculated evaluation results. Raises: ValueError, size of parameter pred or truth is 0. @@ -80,7 +80,7 @@ def _softmax_cross_entropy(logits, labels): labels (numpy.ndarray): Numpy array of shape(N, ). Returns: - numpy.ndarray: Numpy array of shape(N, ), containing loss value for each vector in logits. + numpy.ndarray: numpy array of shape(N, ), containing loss value for each vector in logits. """ labels = np.eye(logits.shape[1])[labels].astype(np.int32) logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) @@ -111,11 +111,11 @@ class MembershipInference: >>> # test_1, test_2 are non-overlapping datasets from test dataset of target model. >>> # We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model. >>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) - >>> inference_model = MembershipInference(model, n_jobs=-1) + >>> attack_model = MembershipInference(model, n_jobs=-1) >>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] - >>> inference_model.train(train_1, test_1, config) + >>> attack_model.train(train_1, test_1, config) >>> metrics = ["precision", "recall", "accuracy"] - >>> result = inference_model.eval(train_2, test_2, metrics) + >>> result = attack_model.eval(train_2, test_2, metrics) Raises: TypeError: If type of model is not mindspore.train.Model. @@ -147,11 +147,11 @@ class MembershipInference: {"method": "lr", "params": {"C": np.logspace(-4, 2, 10)}}]. The support methods are knn, lr, mlp and rf, and the params of each method must within the range of changeable parameters. Tips of params implement - can be found in - https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html - https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html - https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html - https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPRegressor.html + can be found below: + `KNN`_, + `LR`_, + `RF`_, + `MLP`_. Raises: KeyError: If any config in attack_config doesn't have keys {"method", "params"}. @@ -179,7 +179,7 @@ class MembershipInference: must be in ["precision", "accuracy", "recall"]. Default: ["precision"]. Returns: - list, Each element contains an evaluation indicator for the attack model. + list, each element contains an evaluation indicator for the attack model. """ check_param_type("dataset_train", dataset_train, Dataset) check_param_type("dataset_test", dataset_test, Dataset) @@ -207,13 +207,13 @@ class MembershipInference: Generate corresponding loss_logits features and new label, and return after shuffle. Args: - dataset_train: The training set for the target model. - dataset_test: The test set for the target model. + dataset_train (mindspore.dataset): The train set for the target model. + dataset_test (mindspore.dataset): The test set for the target model. Returns: - - numpy.ndarray, Loss_logits features for each sample. Shape is (N, C). + - numpy.ndarray, loss_logits features for each sample. Shape is (N, C). N is the number of sample. C = 1 + dim(logits). - - numpy.ndarray, Labels for each sample, Shape is (N,). + - numpy.ndarray, labels for each sample, Shape is (N,). """ features_train, labels_train = self._generate(dataset_train, 1) features_test, labels_test = self._generate(dataset_test, 0) @@ -231,18 +231,18 @@ class MembershipInference: Return a loss_logits features and labels for training attack model. Args: - input_dataset (mindspore.dataset): The dataset to be generate. - label (int32): Whether input_dataset belongs to the target model. + input_dataset (mindspore.dataset): The dataset to be generated. + label (int): Whether input_dataset belongs to the target model. Returns: - - numpy.ndarray, Loss_logits features for each sample. Shape is (N, C). + - numpy.ndarray, loss_logits features for each sample. Shape is (N, C). N is the number of sample. C = 1 + dim(logits). - - numpy.ndarray, Labels for each sample, Shape is (N,). + - numpy.ndarray, labels for each sample, Shape is (N,). """ loss_logits = np.array([]) - for batch in input_dataset.create_dict_iterator(): - batch_data = Tensor(batch['image'], ms.float32) - batch_labels = batch['label'].astype(np.int32) + for batch in input_dataset.create_tuple_iterator(output_numpy=True): + batch_data = Tensor(batch[0], ms.float32) + batch_labels = batch[1].astype(np.int32) batch_logits = self._model.predict(batch_data).asnumpy() batch_loss = _softmax_cross_entropy(batch_logits, batch_labels) diff --git a/tests/ut/python/adv_robustness/attacks/test_gradient_method.py b/tests/ut/python/adv_robustness/attacks/test_gradient_method.py index b9fa9ad..8e6707c 100644 --- a/tests/ut/python/adv_robustness/attacks/test_gradient_method.py +++ b/tests/ut/python/adv_robustness/attacks/test_gradient_method.py @@ -29,8 +29,6 @@ from mindarmour.adv_robustness.attacks import RandomFastGradientMethod from mindarmour.adv_robustness.attacks import RandomFastGradientSignMethod from mindarmour.adv_robustness.attacks import RandomLeastLikelyClassMethod -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - # for user class Net(Cell): @@ -68,6 +66,7 @@ def test_fast_gradient_method(): """ Fast gradient method unit test. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) label = np.asarray([2], np.int32) label = np.eye(3)[label].astype(np.float32) @@ -128,6 +127,7 @@ def test_random_fast_gradient_method(): """ Random fast gradient method unit test. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) label = np.asarray([2], np.int32) label = np.eye(3)[label].astype(np.float32) @@ -149,6 +149,7 @@ def test_fast_gradient_sign_method(): """ Fast gradient sign method unit test. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) label = np.asarray([2], np.int32) label = np.eye(3)[label].astype(np.float32) @@ -170,6 +171,7 @@ def test_random_fast_gradient_sign_method(): """ Random fast gradient sign method unit test. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") input_np = np.random.random((1, 28)).astype(np.float32) label = np.asarray([2], np.int32) label = np.eye(28)[label].astype(np.float32) @@ -191,6 +193,7 @@ def test_least_likely_class_method(): """ Least likely class method unit test. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) label = np.asarray([2], np.int32) label = np.eye(3)[label].astype(np.float32) @@ -212,6 +215,7 @@ def test_random_least_likely_class_method(): """ Random least likely class method unit test. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) label = np.asarray([2], np.int32) label = np.eye(3)[label].astype(np.float32) @@ -233,6 +237,7 @@ def test_assert_error(): """ Random least likely class method unit test. """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") with pytest.raises(ValueError) as e: assert RandomLeastLikelyClassMethod(Net(), eps=0.05, alpha=0.21) assert str(e.value) == 'eps must be larger than alpha!'