/
iirc_joint_retrieval.jsonnet
74 lines (73 loc) · 2.23 KB
/
iirc_joint_retrieval.jsonnet
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
local bert_model = "bert-base-uncased";
local preprocessed_wiki_file_path = "data/iirc/preprocessed_context_articles.json";
local top_k_link_per_question = 3;
{
"dataset_reader" : {
"type": "iirc-joint-retrieval-reader",
"wiki_file_path": preprocessed_wiki_file_path,
# "cache_directory": "data/iirc/__cache__",
"transformer_model_name": bert_model,
"q_max_tokens": 64,
"c_max_tokens": 384,
"skip_invalid_examples": false,
"sent_n": 1,
"padding_sent_n": 1,
"stride": 1,
"neg_n": 7,
"include_main": false,
"add_ctx_sep": false,
"add_init_context": false,
"link_per_question": top_k_link_per_question,
},
"validation_dataset_reader" : {
"type": "iirc-joint-retrieval-reader",
"wiki_file_path": preprocessed_wiki_file_path,
# "cache_directory": "data/iirc/__cache__",
"transformer_model_name": bert_model,
"q_max_tokens": 64,
"c_max_tokens": 384,
"skip_invalid_examples": false,
"sent_n": 1,
"padding_sent_n": 1,
"stride": 1,
"neg_n": 7,
"max_neg_n": 500,
"include_main": false,
"add_ctx_sep": false,
"add_init_context": false,
"link_per_question": top_k_link_per_question,
},
"train_data_path": "data/iirc/preprocessed_iirc_tiny.json",
"validation_data_path": "data/iirc/preprocessed_iirc_tiny.json",
"vocabulary": {
"type": "empty",
},
"model": {
"type": "joint-retriever",
"transformer_model_name": bert_model,
"beam_size_link": top_k_link_per_question,
"beam_size_context": 5,
"print_trajectory": false,
"load_model_weights": false,
"link_predictor_weights_file": "",
"context_retriever_weights_file": "",
"use_joint_prob": false,
},
"data_loader": {
"batch_size": 2,
"shuffle": true
},
"validation_data_loader": {
"batch_size": 1,
"shuffle": false
},
"trainer": {
"optimizer": {
"type": "huggingface_adamw",
"lr": 1.0e-5
},
"num_epochs": 10,
"cuda_device": 0,
"validation_metric": "+jr_recall",
}
}