/
main.cpp
38 lines (31 loc) · 1005 Bytes
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include "Vocabulary.hpp"
#include "Utils.hpp"
int main(int argc, char** argv){
int wordVecDim = 50;
int paragraphVecDim = 50;
int contextSize = 5;
double learningRate = 0.025;
int numNegative = 5;
int minFreq = 10;
int iteration = 1;
int numThreads = 1;
double shrink = 0.0;
std::string input = "INPUT.txt";
std::string output = "OUTPUT";
Utils::procArg(argc, argv,
wordVecDim, paragraphVecDim, contextSize, learningRate, numNegative, minFreq, iteration, numThreads,
input, output);
Vocabulary voc(wordVecDim, contextSize, paragraphVecDim);
voc.read(input, minFreq);
shrink = learningRate/iteration;
for (int i = 0; i < iteration; ++i){
printf("Iteration %2d (current learning rate: %f)\n", i+1, learningRate);
voc.train(input, learningRate, shrink, numNegative, numThreads);
learningRate -= shrink;
}
voc.save(output+".bin");
//voc.wordKnn(10);
voc.outputParagraphVector(output+".pv");
voc.outputWordVector(output+".wv");
return 0;
}