/
utils.py
93 lines (65 loc) · 2.73 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
This script provides necessary functions that help the evaluation in eval_qags.py
"""
from typing import List, Dict, Tuple
import json
import rouge_score.scoring
from scipy.stats import pearsonr, spearmanr
from rouge_score import rouge_scorer
from tqdm import tqdm
from mosestokenizer import MosesDetokenizer
def most_common_list_element(lst: List):
"""
find the most common elements from a list, from ['yes', 'no', 'yes'] will return "yes"
"""
return max(set(lst), key=lst.count)
def get_scorer(fast: bool = False) -> rouge_scorer.RougeScorer:
# Skip LCS computation for 10x speedup during debugging.
if fast:
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2"], use_stemmer=True)
else:
scorer = rouge_scorer.RougeScorer(
["rouge1", "rouge2", "rougeL"], use_stemmer=True
)
return scorer
def generate_rouge_scores(
generated_summary: str, source_text: str
) -> Dict[str, rouge_score.scoring.Score]:
rouge_score = get_scorer().score(source_text, generated_summary)
return rouge_score
def get_qag_whole_summary_sents(sample: dict) -> str:
"""
combine summaries sentences to a complete summary
"""
summary = ""
for sentence_dic in sample["summary_sentences"]:
summary = summary + sentence_dic["sentence"] + " "
return summary.rstrip()
def calculate_correlation_score(lst1: List[float], lst2: List[float]) -> Tuple[float]:
pearson_corr, pearson_p_value = pearsonr(lst1, lst2)
spearman_corr, spearman_p_value = spearmanr(lst1, lst2)
print(
f"pearson correlation is {pearson_corr} with pearson_p_value {pearson_p_value};\n"
f"spearman correlation is {spearman_corr} with spearman_p_value {spearman_p_value}"
)
return pearson_corr, pearson_p_value, spearman_corr, spearman_p_value
def detokenize(text: str) -> str:
detokenizer = MosesDetokenizer("en")
words = text.split(" ")
return detokenizer(words)
def get_src_sys_lines_for_BART(
samples: List[dict], json_file_name: str
) -> Tuple[List[str], List[str]]:
if json_file_name == "qags-cnndm.jsonl" or json_file_name == "qags-xsum.jsonl":
src_lines: List[str] = [sample["article"] for sample in tqdm(samples)]
sys_lines: List[str] = [
get_qag_whole_summary_sents(sample) for sample in tqdm(samples)
]
if json_file_name == "summeval.jsonl":
src_lines: List[str] = [sample["text"] for sample in tqdm(samples)]
sys_lines: List[str] = [sample["decoded"] for sample in tqdm(samples)]
src_lines = [detokenize(line) for line in src_lines]
sys_lines = [detokenize(line) for line in sys_lines]
return src_lines, sys_lines
def save_data(data, arg):
json.dump(data, open(arg.output, "w"))