|
|
@@ -124,7 +124,14 @@ class MembershipInference: |
|
|
|
raise TypeError("Type of parameter 'test_train' must be Dataset, " |
|
|
|
"but got {}".format(type(dataset_train))) |
|
|
|
|
|
|
|
if not isinstance(attack_config, list): |
|
|
|
raise TypeError("Type of parameter 'attack_config' must be list, " |
|
|
|
"but got {}.".format(type(attack_config))) |
|
|
|
|
|
|
|
for config in attack_config: |
|
|
|
if not isinstance(config, dict): |
|
|
|
raise TypeError("Type of each config in 'attack_config' must be dict, " |
|
|
|
"but got {}.".format(type(config))) |
|
|
|
if {"params", "method"} != set(config.keys()): |
|
|
|
raise KeyError("Each config in attack_config must have keys 'method' and 'params', " |
|
|
|
"but your key value is {}.".format(set(config.keys()))) |
|
|
|