@@ -83,12 +83,12 @@ def test_save_document_normal():
83
83
with tempfile .TemporaryDirectory () as tmpdirname :
84
84
csv_filename = os .path .join (tmpdirname , "test.csv" )
85
85
with open (csv_filename , 'w' ) as f :
86
- writer = csv .DictWriter (f , fieldnames = ["document_id" , "author_id" , "L1" , "english_proficiency" ])
86
+ # writer = csv.DictWriter(f, fieldnames=["document_id", "author_id", "L1", "english_proficiency"])
87
87
test_document , test_author = make_mocks ("1234" , "1234" , "hindi" )
88
88
m .get ("https://www.italki.com/api/notebook/1234" , text = json .dumps (test_document ))
89
89
m .get ("https://www.italki.com/api/user/1234" , text = json .dumps (test_author ))
90
- save_document ("1234" , tmpdirname , writer )
91
- assert open ( csv_filename ). read () == "1234,1234, hindi,0 \n "
90
+ doc = save_document ("1234" , tmpdirname )
91
+ assert doc == { "document_id" : test_document [ "data" ][ "id" ], "author_id" : test_author [ "data" ][ "id" ], "L1" : " hindi" , "english_proficiency" : 0 }
92
92
assert open (os .path .join (tmpdirname , "1234.txt" )).read () == test_document ["data" ]["content" ]
93
93
94
94
@@ -100,8 +100,8 @@ def test_save_document_404():
100
100
with open (csv_filename , 'w' ) as f :
101
101
writer = csv .DictWriter (f , fieldnames = ["document_id" , "author_id" , "L1" , "english_proficiency" ])
102
102
m .get ("https://www.italki.com/api/notebook/1234" , status_code = 404 )
103
- save_document ("1234" , tmpdirname , writer )
104
- assert open ( csv_filename ). read () == ""
103
+ doc = save_document ("1234" , tmpdirname )
104
+ assert doc is None
105
105
assert not os .path .isfile (os .path .join (tmpdirname , "1234.txt" ))
106
106
107
107
@@ -135,15 +135,16 @@ def test_recreate():
135
135
m .get ("https://www.italki.com/api/notebook/4" , text = json .dumps (test_document4 ))
136
136
main (SimpleNamespace (
137
137
command = "recreate" ,
138
- agents = 1 ,
138
+ num_agents = 1 ,
139
139
output_dir = os .path .join (tmpdirname , "output" ),
140
140
id_file = open (os .path .join (tmpdirname , "test_ids.txt" ))
141
141
))
142
142
assert open (os .path .join (tmpdirname , "output" , "1.txt" )).read () == test_document1 ["data" ]["content" ]
143
143
assert open (os .path .join (tmpdirname , "output" , "2.txt" )).read () == test_document2 ["data" ]["content" ]
144
144
assert open (os .path .join (tmpdirname , "output" , "3.txt" )).read () == test_document3 ["data" ]["content" ]
145
145
assert open (os .path .join (tmpdirname , "output" , "4.txt" )).read () == test_document4 ["data" ]["content" ]
146
- assert open (os .path .join (tmpdirname , "output" , "labels.train.csv" )).read () == "document_id,author_id,L1,english_proficiency\n 1,1234,hindi,0\n 4,12344,french,5\n "
146
+ print (open (os .path .join (tmpdirname , "output" , "labels.train.csv" )).readlines ())
147
+ assert set (open (os .path .join (tmpdirname , "output" , "labels.train.csv" )).readlines ()) == set (["document_id,author_id,L1,english_proficiency\n " , "1,1234,hindi,0\n " , "4,12344,french,5\n " ])
147
148
assert open (os .path .join (tmpdirname , "output" , "labels.test.csv" )).read () == "document_id,author_id,L1,english_proficiency\n 2,1234,hindi,0\n "
148
149
assert open (os .path .join (tmpdirname , "output" , "labels.dev.csv" )).read () == "document_id,author_id,L1,english_proficiency\n 3,1234,hindi,0\n "
149
150
0 commit comments