/
Viterbi.java
501 lines (430 loc) · 13.8 KB
/
Viterbi.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
import java.util.HashMap;
import java.util.ArrayList;
import java.util.Set;
import java.util.Map;
import java.util.Iterator;
import java.io.File;
import java.util.Scanner;
import java.io.PrintWriter;
import java.io.BufferedReader;
import java.io.FileWriter;
import java.io.FileReader;
import java.io.FileNotFoundException;
import java.io.IOException;
/**
* As a class, <code>Viterbi</code> contains both the HMM model
* and the functions to run the Viterbi algorithm on it.
* <p>
* The class can train itself from probabilities provided by
* a dictionary, and then save and load the needed data into a
* file. There are two required files for training: a corpus tagset
* and the directory of the corpus.
* <p>
* The corpus tagset should be in the form of having
* each part of speech on its own line, each POS symbol exactly 1 "word"
* long followed by a tab and the real English english term for the POS
* such as <br>
* [POS symbol] \t [POS term]
* e.g.
* cs conjunction, subordinating
* <p>
* Each corpus text file should be a compete English text, with each word
* tagged with a part of speech. It should conform to the format set by
* the Brown corpus, i.e. in the form [word]/[part of speech].
* <p>
* The training file will be printed in the form <br>
* numPOS <br>
* numWords <br>
* "word1" index_POS1 log(probability POS1) index_POS2 log(probabilty POS2) ... <br>
* ... <br>
* "wordlast" index_POS1 log(probability POS1) index_POS1 log(probabilty POS2) ... <br>
* "POS1" log(probability POS1 following) log(probability POS2 following) ... log(probability POSlast following)<br>
* ... <br>
* "POS2" log(probability POS1 following) log(probability POS2 following) ... log(probability POSlast following)<br>
* where "word1", "POS1", etc are replaced by the actual words or
* names of parts of speech
* <p>
* In particular, the corpus directory may contain subdirectories,
* but should not contain non-corpus files.
* <p>
*/
public class Viterbi
{
/**
* The number of words.
*/
private int numWords;
/**
* The number of parts of speech.
*/
private int numPOS;
/**
* The logs of the emission probabilities of the model,
* represented as a HashMap of Strings (words) to arrays
* of probabilities.
*/
private HashMap<String, float[]> p_emission = new HashMap<String, float[]>();
/**
* The logs of the transmission probabilities of the model,
* representated as a two-dimensional array. Here,
* transmission[POS 1][POS 2] refers to the
* transmission probability for moving from
* POS 1 to POS 2.
*/
private float[][] p_transmission;
/**
* Constructor; initializes the probability table from the
* given saved training file. (See above for format of file.)
* @param tagset the file containing a list of parts of speech
* @param gtagset text file containing a legend for the tags; maps a simplified
* tag name to a comma-separated list of POSIndices
* @param datafile the name of the file of saved probability data
* @return none
*/
public Viterbi(String tagset, String gtagset, String datafile)
throws FileNotFoundException, WrongFormatException, POSNotFoundException
{
POS.loadFromFile (tagset, gtagset);
Scanner sc = null;
// deals with the actual data file
try
{
// tries to open the file
sc = new Scanner(new BufferedReader(new FileReader(datafile)));
// reads in the first two numerical values
this.numPOS = sc.nextInt();
sc.nextLine();
this.numWords = sc.nextInt();
sc.nextLine();
// basic check that the number of parts of speech matches up
if (this.numPOS != POS.numPOS())
throw new WrongFormatException("The training file does not seem to match the indicated tagset.");
// initializes the probability array
this.p_transmission = new float[this.numPOS][this.numPOS];
// reads file and loads emission probabilities
String word;
float[] probability;
String[] line;
int numEntries;
for (int i = 0; i < numWords; i++)
{
probability = new float[this.numPOS];
word = sc.next("\\S+");
sc.skip(" ");
line = sc.nextLine().split(" ");
numEntries = (line.length+1)/2;
// sets default probabilities
for (int j = 0; j < numPOS; j++)
{
probability[j] = Float.NEGATIVE_INFINITY;
}
// gets actual probabilities
for (int j = 0; j < numEntries; j+=2)
{
probability[Integer.parseInt(line[j])] = Float.parseFloat(line[j+1]);
}
// adds to hash map
this.p_emission.put(word, probability);
}
// reads file and loads transmission probabilities
for (int i = 0; i < numPOS; i++)
{
// gets POS and checks that it has the right index
word = sc.next("\\S+");
sc.skip(" ");
if (i != POS.getIndexBySymbol(word))
throw new WrongFormatException("The training file does not seem to match the indicated tagset.");
line = sc.nextLine().split(" ");
numEntries = (line.length+1)/2;
// sets default probabilities
for (int j = 0; j < numPOS; j++)
{
p_transmission[i][j] = Float.NEGATIVE_INFINITY;
}
// gets actual probabilities
for (int j = 0; j < numEntries; j+=2)
{
p_transmission[i][Integer.parseInt(line[j])] = Float.parseFloat(line[j+1]);
}
}
}
finally
{
if (sc != null)
sc.close();
}
}
/**
* Iterates through the corpus and calculates the frequencies of
* neighborings parts of speech and the frequences of each part
* of speech for each word.
* @param tagset the file containing the tagset, as defined in the class
description.
* @param gtagset text file containing a legend for the tags; maps a simplified
* tag name to a comma-separated list of POSIndices
* @param corpusDirectory the name of the directory containing
the corpus
* @param saveLocation where the probabilities are to be saved
* @return none
*/
public static void loadCorpusForTraining (String tagset, String gtagset,
String corpusDirectory,
String saveLocation)
throws IOException, FileNotFoundException,
WrongFormatException, POSNotFoundException
{
int numWords = 0;
// load all of the corpus tags from the corpus_tagset
POS.loadFromFile (tagset, gtagset);
int numPOS = POS.numPOS();
/* Hashmap of number of times each word appears
* in the training data for each part of speech;
* as a hashmap of Strings to integer arrays, with each integer
* representing a POS index.
*/
HashMap<String, int[]> word_to_pos = new HashMap<String, int[]>();
/* Two dimension of number of times each POS appears
* after a specific POS.
*/
int[][] pos_to_pos = new int[numPOS][numPOS];
int POSIndex = -1;
int lastPOSIndex = -1;
/* One dimension array of number of times each POS appears in the corpus */
int[] pos_frequencies = new int[numPOS];
// Find all corpus data files in directory
File dir = new File(corpusDirectory);
File[] fl = dir.listFiles();
// If none, error
if (fl == null)
throw new FileNotFoundException("Corpus directory was not valid.");
// Begin reading from corpus files
Scanner scanner;
for (int i = 0; i < fl.length; i++)
{
scanner = null;
try
{
scanner = new Scanner(new BufferedReader(new FileReader(fl[i])));
// scan through file
while (scanner.hasNext())
{
String s = scanner.next();
// figure out word/symbol combinations
int lastIndex = s.lastIndexOf("/");
String word = s.substring(0, lastIndex).toLowerCase();
String symbol = s.substring(lastIndex + 1).
replaceAll(POS.getIgnoreRegex(), "");
// get the index of the POS, if none, error
POSIndex = POS.getIndexBySymbol(symbol);
int[] arr;
// add to word_to_pos
if (word_to_pos.containsKey(word))
{
word_to_pos.get(word)[POSIndex]++;
}
else
{
arr = new int[numPOS];
arr[POSIndex]++;
word_to_pos.put (word, arr);
}
// add to pos_to_pos
if (lastPOSIndex < 0)
{
lastPOSIndex = POSIndex;
continue;
} else {
pos_to_pos[lastPOSIndex][POSIndex]++;
lastPOSIndex = POSIndex;
}
// add to pos_frequencies
pos_frequencies[POSIndex]++;
}
}
finally
{
// close file
if (scanner != null)
scanner.close();
}
}
numWords = word_to_pos.size();
PrintWriter saveFile = null;
try
{
// open file to be saved
saveFile = new PrintWriter(new FileWriter(saveLocation));
// write numPOS
saveFile.println(numPOS);
// write numWords
saveFile.println(numWords);
String line;
// write emission probabilities
Set<Map.Entry<String, int[]>> prob_e = word_to_pos.entrySet();
Iterator<Map.Entry<String, int[]>> it = prob_e.iterator();
Map.Entry<String, int[]> e;
int[] p;
for (int i = 0; i < numWords; i++)
{
e = it.next();
p = e.getValue();
line = e.getKey() + " ";
for (int j = 0; j < numPOS; j++)
{
if (p[j] > 0)
line += j + " " + Math.log((float)p[j]/pos_frequencies[j]) + " ";
}
saveFile.println(line);
}
// write transmission probabilities
for (int i = 0; i < numPOS; i++)
{
line = POS.getPOSbyIndex(i).getSymbol() + " ";
for (int j = 0; j < numPOS; j++)
{
if (pos_to_pos[i][j] > 0)
line += j + " " + Math.log((float)pos_to_pos[i][j]/pos_frequencies[i]) + " ";
}
saveFile.println(line);
}
}
finally
{
if (saveFile != null)
saveFile.close();
}
}
/**
* Takes a set of outputs and determines the most likely original
* state.
* @param results list of outputs
* @return list of (state, output) pairs
*/
public ArrayList<Pair<String, POS>> parse(ArrayList<String> results)
throws POSNotFoundException
{
int length = results.size(); // number of pieces of the sentence
float[][] probs = new float[numPOS][length]; // probability table
int[][] parts = new int[numPOS][length]; // parts of speech fo the words
// variables for looping below
float value; // current value
int index; // index of the previous word for the maximum of value
float[] emission; // tmp variable for emission probabilities
String word;
float vtmp;
boolean allneg; // if all the values for a part of speech are negative infinity
// iterates over all possible words
for (int i = 0; i < length; i++) // per word
{
// trims any zeros, decapalizes for matching
word = results.get(i).trim().toLowerCase();
// sets allneg back to true
allneg = true;
// if the dictionary does not contain the word, sets a default
if(! p_emission.containsKey(word))
{
index = 0;
value = Float.NEGATIVE_INFINITY;
// find index of largest value before this
if (i > 0)
{
index = maxInRowIndex(probs, i-1, numPOS);
value = probs[index][i-1];
// set all other probabilities really, find previous index
for (int j = 0; j < numPOS; j++)
{
probs[j][i] = value;
parts[j][i] = index;
}
}
else
{
for (int j = 0; j < numPOS; j++)
{
probs[j][i] = 0;
parts[j][i] = 0;
}
}
}
// if the word is in our dictionary
else
{
// get the table of emission values
emission = p_emission.get(word);
// loop over all possible POS assigments
for (int j = 0; j < numPOS; j++)
{
value = Float.NEGATIVE_INFINITY;
index = 0;
// loop over all possible POS assignments for the previous word
for (int k = 0; k < numPOS; k++)
{
if (i > 0)
vtmp = probs[k][i-1] + p_transmission[k][j] + emission[j];
else
vtmp = emission[j];
if (vtmp > value)
{
value = vtmp;
index = k;
}
}
// set values in tables
/*if (value > Float.NEGATIVE_INFINITY)
System.out.println(word + " " + value + " / ");*/
if (value > Float.NEGATIVE_INFINITY)
allneg = false;
parts[j][i] = index;
probs[j][i] = value;
}
// if all possible parts of speech have value negative infinity
if (allneg)
{
index = maxInRowIndex(probs, i-1, numPOS);
value = probs[index][i-1];
for (int j = 0; j < numPOS; j++)
{
probs[j][i] = value + emission[j];
parts[j][i] = index;
}
}
}
}
// find the most likely part of speech for the last word
int POSIndex = 0;
for (int i = 1; i < numPOS; i++)
if (probs[i][length-1] > probs[i-1][length-1])
POSIndex = i;
return create_for_parse(results, probs, parts, length-1, POSIndex);
}
/**
* Helper method for parse
* Find the index of the maximum value on a given column of an array
*/
private int maxInRowIndex(float[][] array, int index2, int size1)
{
float value = Float.NEGATIVE_INFINITY;
int index = 0;
for (int i = 0; i < size1; i++)
if (array[i][index2] > value)
{
index = i;
value = array[i][index2];
}
return index;
}
/**
* Helper method for parse.
*/
private ArrayList<Pair<String, POS>> create_for_parse(ArrayList<String> words, float[][] probs,
int[][] parts, int wIndex, int pIndex)
throws POSNotFoundException
{
if (wIndex < 0)
return new ArrayList<Pair<String,POS>>();
ArrayList<Pair<String,POS>> sentence =
create_for_parse(words, probs, parts, wIndex-1, parts[pIndex][wIndex]);
sentence.add(new Pair<String, POS>(words.get(wIndex), POS.getPOSbyIndex(pIndex)));
return sentence;
}
}