Skip to content

Commit

Permalink
Fix clab#3: allow limiting optimization by dev uas tolerance
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhers committed Jul 28, 2015
1 parent 55ba994 commit 689b983
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions parser/lstm-parse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("lstm_input_dim", po::value<unsigned>()->default_value(60), "LSTM input dimension")
("train,t", "Should training be run?")
("maxit,M", po::value<unsigned>()->default_value(8000), "Maximum number of training iterations")
("tolerance", po::value<double>()->default_value(0.0), "Tolerance on dev uas for stopping training")
("words,w", po::value<string>(), "Pretrained word embeddings")
("help,h", "Help");
po::options_description dcmdline_options;
Expand Down Expand Up @@ -525,6 +526,8 @@ int main(int argc, char** argv) {
assert(unk_prob >= 0.); assert(unk_prob <= 1.);
const unsigned maxit = conf["maxit"].as<unsigned>();
cerr << "Maximum number of iterations: " << maxit << "\n";
const double tolerance = conf["tolerance"].as<double>();
cerr << "Optimization tolerance: " << tolerance << "\n";
ostringstream os;
os << "parser_" << (USE_POS ? "pos" : "nopos")
<< '_' << LAYERS
Expand Down Expand Up @@ -607,9 +610,12 @@ int main(int argc, char** argv) {
double llh = 0;
bool first = true;
unsigned iter = 0;
double uas = -1;
double prev_uas = -1;
time_t time_start = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
cerr << "TRAINING STARTED AT: " << put_time(localtime(&time_start), "%c %Z") << endl;
while(!requested_stop && iter < maxit) {
while(!requested_stop && iter < maxit &&
(uas < 0 || prev_uas < 0 || abs(prev_uas - uas) > tolerance)) {
for (unsigned sii = 0; sii < status_every_i_iterations; ++sii) {
if (si == corpus.nsentences) {
si = 0;
Expand Down Expand Up @@ -675,7 +681,9 @@ int main(int argc, char** argv) {
total_heads += sentence.size() - 1;
}
auto t_end = std::chrono::high_resolution_clock::now();
cerr << " **dev (iter=" << iter << " epoch=" << (tot_seen / corpus.nsentences) << ")\tllh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads / total_heads) << "\t[" << dev_size << " sents in " << std::chrono::duration<double, std::milli>(t_end-t_start).count() << " ms]" << endl;
prev_uas = uas;
uas = correct_heads / total_heads;
cerr << " **dev (iter=" << iter << " epoch=" << (tot_seen / corpus.nsentences) << ")\tllh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << uas << "\t[" << dev_size << " sents in " << std::chrono::duration<double, std::milli>(t_end-t_start).count() << " ms]" << endl;
if (correct_heads > best_correct_heads) {
best_correct_heads = correct_heads;
ofstream out(fname);
Expand All @@ -698,6 +706,8 @@ int main(int argc, char** argv) {
}
if (iter >= maxit) {
cerr << "\nMaximum number of iterations reached (" << iter << "), terminating optimization...\n";
} else if (!requested_stop) {
cerr << "\nScore tolerance reached (" << tolerance << "), terminating optimization...\n";
}
} // should do training?
if (true) { // do test evaluation
Expand Down

0 comments on commit 689b983

Please sign in to comment.