-
Notifications
You must be signed in to change notification settings - Fork 7
/
pargs.py
105 lines (83 loc) · 5.46 KB
/
pargs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
# !/usr/bin/python
"""
# @Time : 2020/5/1
# @Author : Yongrui Chen
# @File : pargs.py
# @Software: PyCharm
"""
import os
import torch
import argparse
def generation_pargs():
parser = argparse.ArgumentParser(description='AQG generation args')
parser.add_argument('--n_epochs', type=int, default=30)
parser.add_argument('--lr', type=float, default=5e-5)
parser.add_argument('--bs', type=int, default=16)
parser.add_argument('--clip_grad', type=float, default=0.6, help='gradient clipping')
parser.add_argument("--d_emb", default=300, type=int)
parser.add_argument("--d_h", default=256, type=int)
parser.add_argument("--n_lstm_layers", default=1, type=int)
parser.add_argument("--n_gnn_blocks", default=3, type=int)
parser.add_argument("--dropout", default=0.1, type=float)
parser.add_argument("--heads", default=4, type=int)
parser.add_argument('--not_birnn', action='store_false', dest='birnn')
parser.add_argument('--beam_size', default=5, type=int)
parser.add_argument('--readout', default='identity', choices=['identity', 'non_linear'])
parser.add_argument('--att_type', default='affine', choices=['dot_prod', 'affine'])
parser.add_argument("--max_num_op", default=20, type=int, help='maximum number of time steps used '
'in decoding')
parser.add_argument('--use_small', action='store_true', help='use small data', dest='use_small')
parser.add_argument('--not_shuffle', action='store_false', help='do not '
'shuffle training data', dest='shuffle')
parser.add_argument('--use_kb_constraint', action='store_true', dest='kb_constraint')
parser.add_argument('--no_cuda', action='store_false', help='do not use CUDA', dest='cuda')
parser.add_argument('--gpu', type=int, default=0, help='GPU device to use')
parser.add_argument('--word_normalize', action='store_true')
parser.add_argument('--train_data', type=str, default=os.path.abspath('./data/processed_train.pkl'))
parser.add_argument('--valid_data', type=str, default=os.path.abspath('./data/processed_valid.pkl'))
parser.add_argument('--test_data', type=str, default=os.path.abspath('./data/processed_test.pkl'))
parser.add_argument('--wo_vocab', type=str, default=os.path.abspath('./vocab/generation_word_vocab.pkl'))
parser.add_argument('--not_glove', action='store_false', help='do not use GloVe', dest='glove')
parser.add_argument('--glove_path', type=str, default=os.path.abspath(''))
parser.add_argument('--random_init_words', type=str,
default=os.path.abspath('./vocab/generation_random_init_words.json'))
parser.add_argument('--emb_cache', type=str, default=os.path.abspath('./vocab/generation_word_embeddings_cache.pt'))
parser.add_argument('--cpt', type=str, default='')
parser.add_argument('--kb_endpoint', type=str, default='')
args = parser.parse_args()
return args
def ranking_pargs():
parser = argparse.ArgumentParser(description='query ranking args')
parser.add_argument('--n_epochs', type=int, default=30)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--bs', type=int, default=16)
parser.add_argument('--clip_grad', type=float, default=0.6, help='gradient clipping')
parser.add_argument('--ns', type=int, default=30)
parser.add_argument('--margin', type=float, default=0.1)
parser.add_argument("--d_emb_wo", default=300, type=int)
parser.add_argument("--d_h_wo", default=256, type=int)
parser.add_argument("--n_lstm_layers", default=1, type=int)
parser.add_argument("--dropout", default=0.1, type=float)
parser.add_argument('--not_birnn', action='store_false', dest='birnn')
parser.add_argument('--use_small', action='store_true', help='use small data', dest='use_small')
parser.add_argument('--not_shuffle', action='store_false', help='do not '
'shuffle training data', dest='shuffle')
parser.add_argument('--no_cuda', action='store_false', help='do not use CUDA', dest='cuda')
parser.add_argument('--gpu', type=int, default=0, help='GPU device to use')
parser.add_argument('--word_normalize', action='store_true')
parser.add_argument('--train_data', type=str, default=os.path.abspath('../data/ranking_processed_train.pkl'))
parser.add_argument('--valid_data', type=str, default=os.path.abspath('../data/ranking_processed_valid.pkl'))
parser.add_argument('--test_data', type=str, default=os.path.abspath('../data/ranking_processed_test.pkl'))
parser.add_argument('--wo_vocab', type=str, default=os.path.abspath('../vocab/ranking_word_vocab.pkl'))
parser.add_argument('--cand_pool', type=str, default=os.path.abspath('../vocab/cand_pool.pkl'))
parser.add_argument('--not_glove', action='store_false', help='do not use GloVe', dest='glove')
parser.add_argument('--glove_path', type=str, default=os.path.abspath(''))
parser.add_argument('--random_init_words', type=str,
default=os.path.abspath('../vocab/ranking_random_init_words.json'))
parser.add_argument('--emb_cache', type=str,
default=os.path.abspath('../vocab/ranking_word_embeddings_cache.pt'))
parser.add_argument('--cpt', type=str, default='')
parser.add_argument('--kb_endpoint', type=str, default='')
args = parser.parse_args()
return args