-
Notifications
You must be signed in to change notification settings - Fork 7
/
collate_fns.py
155 lines (104 loc) · 4.42 KB
/
collate_fns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import itertools
def collate_fn_emo(data):
def merge(sequences,N=None):
lengths = [len(seq) for seq in sequences]
if N == None: ## pads to the max length of the batch
N = max(lengths)
padded_seqs = torch.zeros(len(sequences),N).long()
attention_mask = torch.zeros(len(sequences),N).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
attention_mask[i,:end] = torch.ones(end).long()
return padded_seqs, attention_mask,lengths
data.sort(key=lambda x: len(x["cause"]), reverse=True) ## sort by source seq
item_info = {}
for key in data[0].keys():
item_info[key] = [d[key] for d in data]
## input
cause_batch,cause_attn_mask, cause_lengths = merge(item_info['cause'])
d={}
d["emotion"] = item_info["emotion"]
d["cause"] = cause_batch
d["cause_attn_mask"] = cause_attn_mask
return d
def collate_fn_w_aug_emo(data):
def merge(sequences,N=None):
lengths = [len(seq) for seq in sequences]
if N == None: ## pads to the max length of the batch
N = max(lengths)
padded_seqs = torch.zeros(len(sequences),N).long()
attention_mask = torch.zeros(len(sequences),N).long()
for i, seq in enumerate(sequences):
seq = torch.LongTensor(seq)
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
attention_mask[i,:end] = torch.ones(end).long()
return padded_seqs, attention_mask,lengths
## each data sample has two views
data.sort(key=lambda x: max(len(x["cause"][0]),len(x["cause"][1])), reverse=True) ## sort all the seq incl augmented
item_info = {}
for key in data[0].keys():
item_info[key] = [d[key] for d in data]
## unbinding the two views here as both views needs to be within the batch
flat = itertools.chain.from_iterable(item_info[key])
item_info[key] = list(flat)
## input
cause_batch,cause_attn_mask, cause_lengths = merge(item_info['cause'])
d={}
d["emotion"] = item_info["emotion"]
d["cause"] = cause_batch
d["cause_attn_mask"] = cause_attn_mask
return d
def collate_fn_sentiment(data):
def merge(sequences,N=None):
lengths = [len(seq) for seq in sequences]
if N == None: ## pads to the max length of the batch
N = max(lengths)
padded_seqs = torch.zeros(len(sequences),N).long()
attention_mask = torch.zeros(len(sequences),N).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
attention_mask[i,:end] = torch.ones(end).long()
return padded_seqs, attention_mask,lengths
data.sort(key=lambda x: len(x["review"]), reverse=True) ## sort by source seq
item_info = {}
for key in data[0].keys():
item_info[key] = [d[key] for d in data]
## input
review_batch,review_attn_mask, review_lengths = merge(item_info['review'])
d={}
d["sentiment"] = item_info["sentiment"]
d["review"] = review_batch
d["review_attn_mask"] = review_attn_mask
return d
def collate_fn_w_aug_sentiment(data):
def merge(sequences,N=None):
lengths = [len(seq) for seq in sequences]
if N == None: ## pads to the max length of the batch
N = max(lengths)
padded_seqs = torch.zeros(len(sequences),N).long()
attention_mask = torch.zeros(len(sequences),N).long()
for i, seq in enumerate(sequences):
seq = torch.LongTensor(seq)
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
attention_mask[i,:end] = torch.ones(end).long()
return padded_seqs, attention_mask,lengths
## each data sample has two views
data.sort(key=lambda x: max(len(x["review"][0]),len(x["review"][1])), reverse=True) ## sort all the seq incl augmented
item_info = {}
for key in data[0].keys():
item_info[key] = [d[key] for d in data]
## unbinding the two views here as both views needs to be within the batch
flat = itertools.chain.from_iterable(item_info[key])
item_info[key] = list(flat)
## input
review_batch,review_attn_mask, review_lengths = merge(item_info['review'])
d={}
d["sentiment"] = item_info["sentiment"]
d["review"] = review_batch
d["review_attn_mask"] = review_attn_mask
return d