Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't run TextAttack on PSO recipe using CLI #743

Open
spneshaei opened this issue Jun 22, 2023 · 2 comments
Open

Can't run TextAttack on PSO recipe using CLI #743

spneshaei opened this issue Jun 22, 2023 · 2 comments
Labels
bug Something isn't working dependencies

Comments

@spneshaei
Copy link

Describe the bug
The command-line code for generating adversarial examples using the PSOZang recipe results in "KeyError: 'pos'" error the moment it wants to generate the adversarial examples. The same method works for other recipes well.

To Reproduce
Steps to reproduce the behavior:

  1. Run a command like the following, for example for a given model and dataset (it has the same error when tried on other datasets and models as well, also for different num-exampless, and the only way to prevent the error is to use a different recipe): textattack attack --recipe pso --num-examples 470 --model ./outputs/2023-06-19-15-18-19-224515/best_model/ --dataset-from-huggingface rotten_tomatoes --dataset-split test
  2. See error KeyError: 'pos'

Expected behavior
Adversarial examples should be generated, as it is the case for other recipes (e.g. textfooler).

Screenshots or Traceback

textattack: Loading datasets dataset rotten_tomatoes, split test.
textattack: Downloading https://textattack.s3.amazonaws.com/transformations/hownet/word_candidates_sense.pkl.
100% 8.39M/8.39M [00:00<00:00, 13.8MB/s]
textattack: Copying /root/.cache/textattack/tmpt00fxsj3.zip to /root/.cache/textattack/transformations/hownet/word_candidates_sense.pkl.
textattack: Successfully saved transformations/hownet/word_candidates_sense.pkl to cache.
textattack: Unknown if model of class <class 'transformers.models.distilbert.modeling_distilbert.DistilBertForSequenceClassification'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.
  0% 0/1000 [00:00<?, ?it/s]
Downloading pytorch_model.bin:   0% 0.00/72.9M [00:00<?, ?B/s]
Downloading pytorch_model.bin:  14% 10.5M/72.9M [00:00<00:03, 17.6MB/s]
Downloading pytorch_model.bin:  29% 21.0M/72.9M [00:00<00:01, 31.4MB/s]
Downloading pytorch_model.bin:  43% 31.5M/72.9M [00:00<00:01, 38.6MB/s]
Downloading pytorch_model.bin:  58% 41.9M/72.9M [00:01<00:00, 43.4MB/s]
Downloading pytorch_model.bin:  72% 52.4M/72.9M [00:01<00:00, 46.1MB/s]
Downloading pytorch_model.bin:  86% 62.9M/72.9M [00:01<00:00, 44.1MB/s]
Downloading pytorch_model.bin: 100% 72.9M/72.9M [00:01<00:00, 37.6MB/s]
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /usr/local/bin/textattack:8 in <module>                                      │
│                                                                              │
│   5 from textattack.commands.textattack_cli import main                      │
│   6 if _name_ == '__main__':                                               │
│   7 │   sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])     │
│ ❱ 8 │   sys.exit(main())                                                     │
│   9                                                                          │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/commands/textattack_cli.p │
│ y:50 in main                                                                 │
│                                                                              │
│   47 │   # Run                                                               │
│   48 │   func = args.func                                                    │
│   49 │   del args.func                                                       │
│ ❱ 50 │   func.run(args)                                                      │
│   51                                                                         │
│   52                                                                         │
│   53 if _name_ == "__main__":                                              │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/commands/attack_command.p │
│ y:36 in run                                                                  │
│                                                                              │
│   33 │   │   │   │   attack_args, model_wrapper                              │
│   34 │   │   │   )                                                           │
│   35 │   │   │   attacker = Attacker(attack, dataset, attack_args)           │
│ ❱ 36 │   │   │   attacker.attack_dataset()                                   │
│   37 │                                                                       │
│   38 │   @staticmethod                                                       │
│   39 │   def register_subcommand(main_parser: ArgumentParser):               │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/attacker.py:441 in        │
│ attack_dataset                                                               │
│                                                                              │
│   438 │   │   │   │   )                                                      │
│   439 │   │   │   self._attack_parallel()                                    │
│   440 │   │   else:                                                          │
│ ❱ 441 │   │   │   self._attack()                                             │
│   442 │   │                                                                  │
│   443 │   │   if self.attack_args.silent:                                    │
│   444 │   │   │   logger.setLevel(logging.INFO)                              │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/attacker.py:170 in        │
│ _attack                                                                      │
│                                                                              │
│   167 │   │   │   try:                                                       │
│   168 │   │   │   │   result = self.attack.attack(example, ground_truth_outp │
│   169 │   │   │   except Exception as e:                                     │
│ ❱ 170 │   │   │   │   raise e                                                │
│   171 │   │   │   if (                                                       │
│   172 │   │   │   │   isinstance(result, SkippedAttackResult) and self.attac │
│   173 │   │   │   ) or (                                                     │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/attacker.py:168 in        │
│ _attack                                                                      │
│                                                                              │
│   165 │   │   │   if self.dataset.label_names is not None:                   │
│   166 │   │   │   │   example.attack_attrs["label_names"] = self.dataset.lab │
│   167 │   │   │   try:                                                       │
│ ❱ 168 │   │   │   │   result = self.attack.attack(example, ground_truth_outp │
│   169 │   │   │   except Exception as e:                                     │
│   170 │   │   │   │   raise e                                                │
│   171 │   │   │   if (                                                       │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/attack.py:448 in attack   │
│                                                                              │
│   445 │   │   if goal_function_result.goal_status == GoalFunctionResultStatu │
│   446 │   │   │   return SkippedAttackResult(goal_function_result)           │
│   447 │   │   else:                                                          │
│ ❱ 448 │   │   │   result = self._attack(goal_function_result)                │
│   449 │   │   │   return result                                              │
│   450 │                                                                      │
│   451 │   def __repr__(self):                                                │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/attack.py:396 in _attack  │
│                                                                              │
│   393 │   │   │   A ``SuccessfulAttackResult``, ``FailedAttackResult``,      │
│   394 │   │   │   │   or ``MaximizedAttackResult``.                          │
│   395 │   │   """                                                            │
│ ❱ 396 │   │   final_result = self.search_method(initial_result)              │
│   397 │   │   self.clear_cache()                                             │
│   398 │   │   if final_result.goal_status == GoalFunctionResultStatus.SUCCEE │
│   399 │   │   │   result = SuccessfulAttackResult(                           │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/search_methods/search_met │
│ hod.py:36 in _call_                                                        │
│                                                                              │
│   33 │   │   │   │   "Search Method must have access to filter_transformatio │
│   34 │   │   │   )                                                           │
│   35 │   │                                                                   │
│ ❱ 36 │   │   result = self.perform_search(initial_result)                    │
│   37 │   │   # ensure that the number of queries for this GoalFunctionResult │
│   38 │   │   result.num_queries = self.goal_function.num_queries             │
│   39 │   │   return result                                                   │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/search_methods/particle_s │
│ warm_optimization.py:217 in perform_search                                   │
│                                                                              │
│   214 │                                                                      │
│   215 │   def perform_search(self, initial_result):                          │
│   216 │   │   self._search_over = False                                      │
│ ❱ 217 │   │   population = self._initialize_population(initial_result, self. │
│   218 │   │   # Initialize  up velocities of each word for each population   │
│   219 │   │   v_init = np.random.uniform(-self.v_max, self.v_max, self.pop_s │
│   220 │   │   velocities = np.array(                                         │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/search_methods/particle_s │
│ warm_optimization.py:203 in _initialize_population                           │
│                                                                              │
│   200 │   │   Returns:                                                       │
│   201 │   │   │   population as `list[PopulationMember]`                     │
│   202 │   │   """                                                            │
│ ❱ 203 │   │   best_neighbors, prob_list = self._get_best_neighbors(          │
│   204 │   │   │   initial_result, initial_result                             │
│   205 │   │   )                                                              │
│   206 │   │   population = []                                                │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/search_methods/particle_s │
│ warm_optimization.py:160 in _get_best_neighbors                              │
│                                                                              │
│   157 │   │   """                                                            │
│   158 │   │   current_text = current_result.attacked_text                    │
│   159 │   │   neighbors_list = [[] for _ in range(len(current_text.words))]  │
│ ❱ 160 │   │   transformed_texts = self.get_transformations(                  │
│   161 │   │   │   current_text, original_text=original_result.attacked_text  │
│   162 │   │   )                                                              │
│   163 │   │   for transformed_text in transformed_texts:                     │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/attack.py:303 in          │
│ get_transformations                                                          │
│                                                                              │
│   300 │   │   │   │   ]                                                      │
│   301 │   │   │   │   transformed_texts = list(self.transformation_cache[cac │
│   302 │   │   │   else:                                                      │
│ ❱ 303 │   │   │   │   transformed_texts = self._get_transformations_uncached │
│   304 │   │   │   │   │   current_text, original_text, **kwargs              │
│   305 │   │   │   │   )                                                      │
│   306 │   │   │   │   if utils.hashable(cache_key):                          │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/attack.py:271 in          │
│ _get_transformations_uncached                                                │
│                                                                              │
│   268 │   │   Returns:                                                       │
│   269 │   │   │   A filtered list of transformations where each transformati │
│   270 │   │   """                                                            │
│ ❱ 271 │   │   transformed_texts = self.transformation(                       │
│   272 │   │   │   current_text,                                              │
│   273 │   │   │   pre_transformation_constraints=self.pre_transformation_con │
│   274 │   │   │   **kwargs,                                                  │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/transformations/transform │
│ ation.py:57 in _call_                                                      │
│                                                                              │
│   54 │   │   if return_indices:                                              │
│   55 │   │   │   return indices_to_modify                                    │
│   56 │   │                                                                   │
│ ❱ 57 │   │   transformed_texts = self._get_transformations(current_text, ind │
│   58 │   │   for text in transformed_texts:                                  │
│   59 │   │   │   text.attack_attrs["last_transformation"] = self             │
│   60 │   │   return transformed_texts                                        │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/transformations/word_swap │
│ s/word_swap_hownet.py:68 in _get_transformations                             │
│                                                                              │
│    65 │   │   transformed_texts = []                                         │
│    66 │   │   for i in indices_to_modify:                                    │
│    67 │   │   │   word_to_replace = current_text.words[i]                    │
│ ❱  68 │   │   │   word_to_replace_pos = current_text.pos_of_word_index(i)    │
│    69 │   │   │   replacement_words = self._get_replacement_words(           │
│    70 │   │   │   │   word_to_replace, word_to_replace_pos                   │
│    71 │   │   │   )                                                          │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/shared/attacked_text.py:1 │
│ 41 in pos_of_word_index                                                      │
│                                                                              │
│   138 │   │   │   )                                                          │
│   139 │   │   │   textattack.shared.utils.flair_tag(sentence)                │
│   140 │   │   │   self._pos_tags = sentence                                  │
│ ❱ 141 │   │   flair_word_list, flair_pos_list = textattack.shared.utils.zip_ │
│   142 │   │   │   self._pos_tags                                             │
│   143 │   │   )                                                              │
│   144                                                                        │
│                                                                              │
│ /usr/local/lib/python3.10/dist-packages/textattack/shared/utils/strings.py:2 │
│ 37 in zip_flair_result                                                       │
│                                                                              │
│   234 │   for token in tokens:                                               │
│   235 │   │   word_list.append(token.text)                                   │
│   236 │   │   if "pos" in tag_type:                                          │
│ ❱ 237 │   │   │   pos_list.append(token.annotation_layers["pos"][0]._value)  │
│   238 │   │   elif tag_type == "ner":                                        │
│   239 │   │   │   pos_list.append(token.get_label("ner"))                    │
│   240                                                                        │
╰──────────────────────────────────────────────────────────────────────────────╯
KeyError: 'pos'
  0% 0/1000 [00:05<?, ?it/s]

System Information:

  • Platform: Google Colab
  • Library versions: transformers==4.29.2
  • Textattack version: 0.3.8
@jxmorris12 jxmorris12 added bug Something isn't working dependencies labels Jul 11, 2023
@jxmorris12
Copy link
Collaborator

Seems like an update in the flair dependency! Can you try to downgrade flair for now and see if it works? Long-term we should fix to a certain version of flair or update our usage to match the most recent version.

@colinpannikkat
Copy link

I am running into the same issue as OP. I downgraded to flair==0.12.1 to no avail. I also tried flair==0.12.0, and the issue still persists. I can't downgrade any further without running into dependency issues with other packages. Any other possible reasons for this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working dependencies
Projects
None yet
Development

No branches or pull requests

3 participants