-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
10 add evaluation pipeline #25
base: develop
Are you sure you want to change the base?
Conversation
…d losses across potential answers
|
It would be nice if the tests capture some of what we were trying to think through earlier today, e.g. check truth ratio of one of the dummy forget models is larger than the dummy fine-tuned model and similar. I haven't checked back through the tests so it might be that you've already done that. |
… outputs the metrics we want to track. Also hidden away/cleaned up evaluation scripts and moved most functions to the utils file
I've made some changes to my pull request now, I've added a function in
It outputs a dictionary containing:
This should be everything we want to track in the wandb. I've added some tests testing it, and I've moved the old scripts I wrote into a
|
…he max() function in table 1 of the tofu paper (it wasn't)
Just something minor I didn't explicitly point out in the above: the path for the base model truth ratios should currently be the relative path to where the forget truth ratios are stored. The all_eval script will calculate and save these, provided you give it the forget dataset. These are the only values that need to be stored locally for evaluate_model to run, everything else should be calculated within the function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't got my head around it fully yet, will look again next week. I left a few comments about things being hardcoded but not for everything - for this PR they can be kept hardcoded but if so we should make an issue with everything that's left outstanding/will need to be changed for future runs/experiments.
return f"Question: {question}\nAnswer: {answer}" | ||
|
||
|
||
class EvalQADataset(Dataset): | ||
def qa_formatter_autoregression(qa: tuple[str, str, int]) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe qa_formatter_blank
or similar (i.e. nothing added other than question and answer).
@@ -57,7 +58,7 @@ def get_data( | |||
return data | |||
|
|||
|
|||
def qa_formatter_basic(qa: tuple[str, str]) -> str: | |||
def qa_formatter_basic(qa: tuple[str, str, int]) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: Forgetting branch has refactored QA formatters.
def batch_formatter( | ||
self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: Data collators/padding choices.
perturbed_options = self.data.filter( | ||
lambda sample: sample["author_index"] == author_n | ||
and sample["question_index"] != question_n | ||
).shuffle(seed=self.rand_gen.seed()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've not gone through to check but is setting the seed like this here ok or will it be giving the same perturbed samples every time? i.e. is this function used multiple times during an evaluation run and should/does it give different perturbed rows each time?
forget_set.num_authors = TOFU_NUM_AUTHORS | ||
forget_set.q_per_author = TOFU_Q_PER_AUTHOR | ||
|
||
retain_set.num_authors = TOFU_NUM_AUTHORS | ||
retain_set.q_per_author = TOFU_Q_PER_AUTHOR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these still needed/used?
) -> dict[float, float, float, float, float]: | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) -> dict[float, float, float, float, float]: | |
""" | |
) -> dict[str, float]: | |
""" |
|
||
|
||
def all_eval( | ||
model: torch.nn.Module, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All model typehints maybe should be transformers.PreTrainedModel
(but could be they all work with torch
modules anyway)
model: torch.nn.Module, | |
model: transformers.PreTrainedModel, |
"all_losses": torch.zeros( | ||
(dataset.__len__(), n_perturbed + 1), dtype=torch.float64 | ||
), | ||
"truth_ratios": torch.zeros(dataset.__len__()), | ||
"rougeL_recall": torch.zeros(dataset.__len__()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use len
rather than calling __len__
directly
"all_losses": torch.zeros( | |
(dataset.__len__(), n_perturbed + 1), dtype=torch.float64 | |
), | |
"truth_ratios": torch.zeros(dataset.__len__()), | |
"rougeL_recall": torch.zeros(dataset.__len__()), | |
"all_losses": torch.zeros( | |
(len(dataset), n_perturbed + 1), dtype=torch.float64 | |
), | |
"truth_ratios": torch.zeros(len(dataset)), | |
"rougeL_recall": torch.zeros(len(dataset)), |
def get_analysis_values( | ||
model_dir: str, | ||
) -> dict[np.ndarray, np.ndarray, np.ndarray, torch.Tensor, torch.Tensor]: | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
""" | |
) -> dict[str, np.ndarray | torch.Tensor]: |
vals["forget_losses"] = np.loadtxt(model_dir + "/eval/forget/all_losses.txt") | ||
vals["retain_losses"] = np.loadtxt(model_dir + "/eval/retain/all_losses.txt") | ||
vals["rouge_scores"] = np.loadtxt(model_dir + "/eval/retain/rougeL_scores.txt") | ||
# we re-calculate the truth ratio, since torch calculated many as NaNs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we check for nans and log a warning somewhere if they appear?
Evaluation pipeline
Has a few utils files containing metrics and utility functions, and some scripts which perform evaluation on a selected model. Will briefly go over the evaluation scripts as they are at the moment, and the changes I made to the evaluation dataset class. The evaluation scripts can be run periodically throughout training to allow get a clearer picture of model performance as it is being trained.
quantitative_eval.py
Performs quantitative evaluation over a test set with a selected model, it compares ground truth inputs against perturbed inputs like in the paper. Unlike the paper we don't generate these, rather they are answers to different questions, randomly sampled from within the same author. In the future we can change this according to our work package. The script outputs the truth ratio values, and raw losses which are outputted as a numpy array for further processing. Currently if running as main these are saved to a
.np
file in a separate folder in the parent folder of that where the model weights are stored.qualitative_eval.py
This performs a qualitative evaluation of the model. Loops over the test data and generates an output answer according to the input question. These are both printed along with the target printed along with the target to allow qualitative comparison against the target answer.
EvalQADataset()
ChangesI made some changes to allow the quantitative evaluation script to work. Namely adding a batch formatter which when given a question, outputs input IDs, labels, and attention masks with appropriate padding for batch computation. Furthermore, a method which locates perturbed answers is added, which when given a question index will locate a random question pertaining the same author which can be used as a perturbed answer.