Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[P0] Intervention scheduling for generation #110

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

nathankim7
Copy link
Collaborator

@nathankim7 nathankim7 commented Feb 15, 2024

Description

Basic functionality for scheduling interventions to happen on positions not present in the prompt (i.e. generated tokens). Ideally should follow the same procedure for GRU.

Changelog:

  • timestep_selector, a list of length num_intv of boolean callbacks with signature Callable[[int, torch.Tensor], bool] can be passed to generate() calls. Each intervention calls its callback function with the current position to determine whether the intervention should operate on that position or not.
  • New handling of Nonevalues in unit locations: If Nones are specified at the batch dimension then interventions are not applied to those examples in the batch.
  • weird logic where _intervention_getter(), _intervention_setter() functions were being called with single interventions even though they were written to handle an array of intervention keys and return a list of handlers, has been removed
  • Efficiency and readability improvements in gather_neurons() and scatter_neurons()

Testing Done

  • Tests added: test_nulling_intervention, test_generation_with_source_intervened_prompt, test_dynamic_static_generation_intervention_parity, test_generation_noops
  • Tests fixed: test_with_subspace_negative, test_scatter_neurons_gpt2_attn_with_head_positive

Checklist:

  • My PR title strictly follows the format: [Your Priority] Your Title
  • I have attached the testing log above
  • I provide enough comments to my code
  • I have changed documentations
  • I have added tests for my changes

@nathankim7 nathankim7 changed the title Intervention scheduling for generation [P0] Intervention scheduling for generation Apr 18, 2024
pyvene/models/intervenable_base.py Show resolved Hide resolved
pyvene/models/intervenable_base.py Show resolved Hide resolved
if unit_locations is None:
# this means, we don't filter based on location at all.
return {"sources->base": ([None]*len(self.interventions), [None]*len(self.interventions))}

if self.mode == "parallel":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that self.mode does not control this logic block, what is the difference between wait_for_forward_with_parallel_intervention() and wait_for_forward_with_serial_intervention()? Is there still a need to separate these two?

intervention, module_hook = self.interventions[key]

def hook_callback(model, args, kwargs, output=None):
if self._is_generation:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I don't understand, could you explain the rationale of allowing the hook_callback to run when self._skip_forward is True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was dead code already iirc, since it's just getting passed here. Correct me if I'm wrong, but since the getter hook is used to gather source representations wouldn't it still need to run even if a generate() call skips intervening on the base (prompt)?

@@ -149,13 +178,12 @@ def test_with_subspace_negative(self):
Negative test case to check input length.
"""
intervenable = IntervenableModel(
self.test_subspace_intervention_link_config, self.mlp
self.test_negative_subspace_config, self.mlp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you replace this test_negative_subspace_config with test_subspace_intervention_link_config?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case was intended to test defining an intervention with subspace partitions that exceeded the dimension of the model. That is why test_subspace_intervention_link_config wasn't triggering an IndexError at all, since it and the inputs in this test case are both of dim 3. (It was passing in previous commits because of an entirely unrelated and problematic IndexError that should actually be fixed by this PR.)

Since changing the current config would break all the other tests in this file that rely on it, I decided to just copy it over to a new one.

@@ -102,15 +107,15 @@ def test_scatter_neurons_gpt2_batch_diff_fast_no_head_positive(self):
golden_output = tensor_input.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there is no fast path anymore, we can remove all fast_path tests, and remove the fast_path parameter in modeling_utils.py as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was curious about that, good to know we can remove it. I'd rather use a separate PR for that, though.

] = replacing_tensor_input[:, i]
else:
tensor_input[_batch_idx, unit_locations] = replacing_tensor_input
tensor_input[_batch_idx, unit_locations] = replacing_tensor_input[_batch_idx]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job! Removed the for loop in the assignment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants