Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No predictions in inference. #20

Open
Ajithbalakrishnan opened this issue Jan 2, 2023 · 16 comments
Open

No predictions in inference. #20

Ajithbalakrishnan opened this issue Jan 2, 2023 · 16 comments

Comments

@Ajithbalakrishnan
Copy link

I have trained the CORD dataset as per the "example.yaml" file. F1 scores seem to be excellent (with the CRF network).
But when I was trying to create the predictions, It was not predicting anything.
Can you provide an example of OCR API? Currently, I am using a custom Paddleocr flask server to get the OCR results and I convert the outputs to the required format that you have mentioned in the script.

If possible please share the OCR script. or the exact format that the module needs.

@ZeningLin
Copy link
Owner

Hello,

Sorry for my late reply.

The OCR I used is an private engine of IntSig, which is for in-house development only. The public version of the API is TextIn and it is not free. PaddleOCR is a good choice. To use other APIs, some modification is required. Here's how the OCR pipeline works:

The following code calls the OCR API and obtain the result. You may refer to the API documents you used and modify the request operations. In my codes, the OCR API requires the image bytes and returns the result in json formate, so the image_bytes is passed to the server as the data argument in line130, and the result will be jsonified if the status code is 200.

def ocr_extraction(image_bytes, ocr_url: str, parse_mode: str) -> Dict:
PARSE_MODE_DICT = {
"eng_line": ocr_parsing_eng_line,
"eng_word": ocr_parsing_eng_word,
"chn_char": ocr_parsing_chn_char,
"chn_ltp": ocr_parsing_chn_ltp,
}
headers = {"Content-Type": "application/octet-stream", "accept": "application/json"}
api_return_result = dict()
api_return_result["code"] = -1
try:
res = requests.post(url=ocr_url, data=image_bytes, headers=headers)
if res.status_code == 200:
api_return_result = res.json()
except Exception as e:
print(f"[ERROR] ocr engine failed, {e}")
return PARSE_MODE_DICT[parse_mode](api_return_result)

The OCR raw result will then be passed to the result parsing function, which is shown in PARSE_MODE_DICT. Here four parsing pipeline is provided, corresponds to different language and different modeling level. The parsing function will return the status code(optional), text list, and coordinate list. You may create your own parsing function to convert the raw result to the format required. For example, if we want to create an English word-level input (you can refer to the ocr_parsing_eng_word function), we should firstly obtain all the text & bboxes recognized by the OCR engine, then split them into word-level result. The words and word-level coordinates will be appended to the return_text_list and return_coor_list as the return. Converted results will be passed to generate_batch function for the remain preprocessing steps.

You may check your codes to see whether the parsing function generates the correct result or not.

If you have any further questions, please feel free to contact me.

@Ajithbalakrishnan
Copy link
Author

Hi, Thanks for the replay.
I have created an API using Paddle and I have created a wrapper to create data in a list format.

Hence the return function from the wrapper will be in the below format.

return_text_list :  ['****Thank You.Please Come Again.****', '***COPY***', '0.00', '000000111', ]     
return_coor_list : [[44, 863, 44, 885], [140, 96, 140, 117], [320, 675, 320, 693], [25, 549, 25, 566]]
ocr_corpus :  [1008, 1008, 1008, 1008, 4067]
seg_indices : [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5]

But the results are still empty. I am not able to understand where am failing.

@ZeningLin
Copy link
Owner

Hi,

Something seems weird in your returned data. The ocr_corpus refers to the token ids converted by the tokenizer, and the seg_indices corresponds to the segment index that a token falls in. Hence the length of ocr_corpus should be the same as seg_indices.

In your example, the text string ****Thank You.Please Come Again.**** should be firstly tokenized by the tokenizer into tokens, then be converted to numbers and be appended to the ocr_corpus list. The seg_indices will be also generated during a traversal of all the tokens. The whole process are shown below:

ocr_tokens = []
seg_indices = []
return_text_list_ = []
return_coor_list_ = []
seg_index = 0
for text, coor in zip(return_text_list, return_coor_list):
if text == "" or text.isspace():
continue
curr_tokens = tokenizer.tokenize(text)
if len(curr_tokens) == 0:
continue
return_text_list_.append(text)
return_coor_list_.append(coor)
for i in range(len(curr_tokens)):
ocr_tokens.append(curr_tokens[i])
seg_indices.append(seg_index)
seg_index += 1
ocr_corpus = tokenizer.convert_tokens_to_ids(ocr_tokens)
ocr_corpus = torch.tensor(ocr_corpus, dtype=torch.long, device=device)
mask = torch.ones(ocr_corpus.shape, dtype=torch.int, device=device)

I think that you can modify your wrapper, let it only return the return_text_list and return_coor_list, then pass these variables to the generat_batch function mentioned in the above code block and let it handle all the following steps.

@Ajithbalakrishnan
Copy link
Author

Ajithbalakrishnan commented Jan 10, 2023

Sorry, Actually the shared data is just a part of the actual data. I am sharing the full version here.

return_text_list :  ['****Thank You.Please Come Again.****', '***COPY***', '0.00', '000000111', '1193.00', '15/01/201911:0516AM', '193.00', '193.00 SR', '81750 MASAI,JOHOR', ':', 'Address', 'Amount', 'Approval Code:000', 'BANDAR SERI ALAM,', 'Bill To', 'Cashier', 'Date', 'Description', 'Email:ng@ojcgroup.com', 'Goods Sold Are Not Returnable & Refundable', 'Invoice NoPEGIV-1030765', 'KINGS SAFETY SHOES KWD 805', 'NG CHUAN MIN', 'NO2&4,JALAN BAYU 4', 'OJCMARKETING SDN BHD', 'Qty.....Price.', 'Qty:1', 'ROCNO:538358-H', 'Round Amt:', 'Sales Persor:FATIN', 'TAX INVOICE', 'THE PEAK QUARRYWORKS', 'TOTAL:', 'Tel:07-3882218Fax07-3888218', 'Total Exclude GST', 'Total GST@6%', 'Total Inclusive GST:', 'VISA CARD', 'X0004318', 'tan chay yee']     
return_coor_list : [[44, 863, 44, 885], [140, 96, 140, 117], [320, 675, 320, 693], [25, 549, 25, 566], [198, 549, 198, 567], [122, 348, 122, 365], [302, 730, 302, 747], [302, 550, 302, 567], [123, 209, 123, 225], [123, 449, 123, 459], [27, 444, 27, 462], [297, 529, 297, 547], [88, 772, 88, 788], [126, 187, 126, 205], [27, 412, 27, 430], [27, 369, 27, 387], [26, 345, 26, 368], [28, 529, 28, 547], [114, 250, 114, 267], [36, 848, 36, 862], [26, 326, 26, 342], [26, 576, 26, 593], [121, 369, 121, 386], [110, 166, 110, 184], [91, 121, 91, 138], [183, 527, 183, 548], [28, 614, 28, 633], [130, 144, 130, 161], [161, 672, 161, 696], [27, 391, 27, 408], [147, 291, 147, 311], [122, 412, 122, 429], [173, 698, 173, 720], [76, 229, 76, 246], [107, 613, 107, 630], [125, 635, 125, 652], [102, 655, 102, 671], [145, 729, 145, 746], [93, 753, 93, 766], [103, 27, 103, 62]]
seg_indices : [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 14, 14, 15, 15, 16, 17, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 35, 35, 35, 35, 35, 35, 36, 36, 36, 36, 36, 37, 37, 38, 38, 38, 38, 38, 39, 39, 39, 39, 39]
ocr_tokens : ['*', '*', '*', '*', 'thank', 'you', '.', 'please', 'come', 'again', '.', '*', '*', '*', '*', '*', '*', '*', 'copy', '*', '*', '*', '0', '.', '00', '000', '##00', '##01', '##11', '119', '##3', '.', '00', '15', '/', '01', '/', '2019', '##11', ':', '05', '##16', '##am', '193', '.', '00', '193', '.', '00', 'sr', '81', '##75', '##0', 'mas', '##ai', ',', 'johor', ':', 'address', 'amount', 'approval', 'code', ':', '000', 'banda', '##r', 'ser', '##i', 'alam', ',', 'bill', 'to', 'cash', '##ier', 'date', 'description', 'email', ':', 'ng', '@', 'o', '##j', '##c', '##group', '.', 'com', 'goods', 'sold', 'are', 'not', 'return', '##able', '&', 'ref', '##unda', '##ble', 'in', '##vo', '##ice', 'nope', '##gi', '##v', '-', '103', '##0', '##7', '##65', 'kings', 'safety', 'shoes', 'kw', '##d', '80', '##5', 'ng', 'chu', '##an', 'min', 'no', '##2', '&', '4', ',', 'jalan', 'bay', '##u', '4', 'o', '##j', '##cm', '##ark', '##eti', '##ng', 'sd', '##n', 'b', '##hd', 'q', '##ty', '.', '.', '.', '.', '.', 'price', '.', 'q', '##ty', ':', '1', 'roc', '##no', ':', '53', '##8', '##35', '##8', '-', 'h', 'round', 'am', '##t', ':', 'sales', 'per', '##sor', ':', 'fat', '##in', 'tax', 'in', '##vo', '##ice', 'the', 'peak', 'quarry', '##works', 'total', ':', 'tel', ':', '07', '-', '38', '##8', '##22', '##18', '##fa', '##x', '##0', '##7', '-', '38', '##8', '##8', '##21', '##8', 'total', 'exclude', 'gs', '##t', 'total', 'gs', '##t', '@', '6', '%', 'total', 'inclusive', 'gs', '##t', ':', 'visa', 'card', 'x', '##00', '##0', '##43', '##18', 'tan', 'cha', '##y', 'ye', '##e']
ocr_corpus :  [1008, 1008, 1008, 1008, 4067, 2017, 1012, 3531, 2272, 2153, 1012, 1008, 1008, 1008, 1008, 1008, 1008, 1008, 6100, 1008, 1008, 1008, 1014, 1012, 4002, 2199, 8889, 24096, 14526, 13285, 2509, 1012, 4002, 2321, 1013, 5890, 1013, 10476, 14526, 1024, 5709, 16048, 3286, 19984, 1012, 4002, 19984, 1012, 4002, 5034, 6282, 23352, 2692, 16137, 4886, 1010, 25268, 1024, 4769, 3815, 6226, 3642, 1024, 2199, 24112, 2099, 14262, 2072, 26234, 1010, 3021, 2000, 5356, 3771, 3058, 6412, 10373, 1024, 12835, 1030, 1051, 3501, 2278, 17058, 1012, 4012, 5350, 2853, 2024, 2025, 2709, 3085, 1004, 25416, 18426, 3468, 1999, 6767, 6610, 16780, 5856, 2615, 1011, 9800, 2692, 2581, 26187, 5465, 3808, 6007, 6448, 2094, 3770, 2629, 12835, 14684, 2319, 8117, 2053, 2475, 1004, 1018, 1010, 28410, 3016, 2226, 1018, 1051, 3501, 27487, 17007, 20624, 3070, 17371, 2078, 1038, 14945, 1053, 3723, 1012, 1012, 1012, 1012, 1012, 3976, 1012, 1053, 3723, 1024, 1015, 21326, 3630, 1024, 5187, 2620, 19481, 2620, 1011, 1044, 2461, 2572, 2102, 1024, 4341, 2566, 21748, 1024, 6638, 2378, 4171, 1999, 6767, 6610, 1996, 4672, 12907, 9316, 2561, 1024, 10093, 1024, 5718, 1011, 4229, 2620, 19317, 15136, 7011, 2595, 2692, 2581, 1011, 4229, 2620, 2620, 17465, 2620, 2561, 23329, 28177, 2102, 2561, 28177, 2102, 1030, 1020, 1003, 2561, 18678, 28177, 2102, 1024, 9425, 4003, 1060, 8889, 2692, 23777, 15136, 9092, 15775, 2100, 6300, 2063]

Length of each list is given below

return_text_list :  40     
return_coor_list : 40
seg_indices : 224
ocr_tokens : 224
ocr_corpus :  224

@ZeningLin
Copy link
Owner

The returns of your wrapper seems correct this time.

Now we should try to figure out which part of the pipeline failed, and I need more information. Try to add some breakpoints in the model's inference function to see whether it returns something or not. If debugger doesn't work in a flask service, just remove the flask part and run the model's inference function seperately.

@Ajithbalakrishnan
Copy link
Author

Ajithbalakrishnan commented Jan 19, 2023

Hi, Sorry for the delay in response.
This is the current status now. Even while training on std dataset (SROIE) with the below configuration

bert_version: "bert-base-uncased"
backbone: "resnet_34_fpn_pretrained"
layer_mode: "multi" 
image_max_size: 512
classifier_mode: "full"
eval_mode: "seq_and_str" 

The rest of the hyperparameters are unchanged.
The model is giving below avg metrics even after 200 epochs at batch size 2. The max F1 score is 0.3 and most of the time results were empty.
Then I tried with other configurations which are given below,

classifier_mode: "crf"
eval_mode: "seqeval" 

Here I got a nice F1 score (~0.91). But most of the time outputs from test images were empty only.
I believe the Paddle OCR API doesn't have problems with inference, since its works with some fields.

Your model implementation seems to be very promising. I guess there are some issues in the pipeline.
Please share your comments regarding the above problem.

Thanks,
AB

@ZeningLin
Copy link
Owner

Hello,

It seems that you set the classifier to "full" mode, which has the same architecture described in the paper. I also found that this setting gives poor results, as I have mentioned in section Adaptation and Exploration in This Implementation -> 3. Field Type Classification Head in README.

Thus, I implemented a simplified classifier, which directly apply multi-category classification using a single linear layer, and it gives good results. Just change the classifier_mode to simp and see whether it works or not.

The crf classifier has a similar architecture as the simp classifier, which applies BIO tagging using a single multi-category classifier with context information. In my experiments, it also gives good results.

For the empty output problem during inference, there might be several reasons. Could you please share me your inference configuration file? I will also try to re-train the model and debug the pipeline these days.

@ZeningLin
Copy link
Owner

Hello~

A pre-train weight is available at google drive, try to load it using the simp classifier mode and see whether it gives results or not.

@Ajithbalakrishnan
Copy link
Author

I am attaching my config file here. Please let me know your comments.

comment: " SROIE "


########################################## distributed training stuff ##################################
device: 'cuda'
syncBN: True
amp: True


##################################### training and optimizer hyper-parameters ##########################
start_epoch: 0
end_epoch: 200
batch_size: 6

optimizer_cnn_hyp: # SGD -> CNN
 learning_rate: 0.005
 min_learning_rate: 0.00001
 warm_up_epoches: 0
 warm_up_init_lr: 0.00001
 momentum: 0.9
 weight_decay: 0.005
 min_weight_decay: 0.005

optimizer_bert_hyp: # AdamW -> BERT
 learning_rate: 0.00005
 min_learning_rate: 0.0000001
 warm_up_epoches: 0
 warm_up_init_lr: 0.0000001
 beta1: 0.9
 beta2: 0.999
 epsilon: 0.00000001
 weight_decay: 0.01
 min_weight_decay: 0.01


loss_weights: 


############################################## OHEM trick config #######################################
num_hard_positive_main_1: 16    # number of hard positive samples in field-type-classification head loss 1
num_hard_negative_main_1: 16    # number of hard negative samples in field-type-classification head loss 1
num_hard_positive_main_2: 32    # number of hard positive samples in field-type-classification head loss 2
num_hard_negative_main_2: 32    # number of hard negative samples in field-type-classification head loss 2
loss_aux_sample_list:           # list of number of sampled in semantic segmentation head loss 1
 - 256
 - 512
 - 256
num_hard_positive_aux: 256      # number of hard positive samples in semantic segmentation head loss 2
num_hard_negative_aux: 256      # num--configber of hard negative samples in semantic segmentation head loss 2
ohem_random: True               # apply random sampling before OHEM or not


######################################### model structure config #######################################
classifier_mode: "simp"                   # classifier mode, "simp", "full" or "crf"
eval_mode: "seq_and_str"                      # type of evaluation tool, 
                                          # "seqeval" uses seqeval package, calculates token-level result, works for all tag mode
                                          # "strcmp" joins the result and compare the final strings (official SROIE eval method)
                                          # "seq_and_str" uses both of "seqeval" and "strcmp"

tag_mode: "B"                             # tagging mode, "B" for direct prediction, "BIO" for BIO prediction
bert_version: "bert-base-uncased"
#"bert-base-uncased"         # [preferred]SROIE & FUNSD
#bert_version: "roberta-base"            # SROIE & FUNSD
# bert_version: "bert-base-chinese"       # EPHOIE
backbone: "resnet_34_fpn_pretrained"
#"resnet_34_fpn_pretrained" 
#"resnet_18_fpn_pretrained"      # type of CNN backbone
grid_mode: "mean"                         # [does not need to change]mode of aggregating token features
early_fusion_downsampling_ratio: 8        # [does not need to change]
roi_shape: 7                              # [does not need to change]
p_fuse_downsampling_ratio: 4              # [does not need to change]
roi_align_output_reshape: False           # [does not need to change]
late_fusion_fuse_embedding_channel: 1024  # [does not need to change]
layer_mode: "single"                      # type of classifier, single for single layer perceptron, multi for MLP
loss_control_lambda: 1                    # set 0 to discard the auxiliary semantic segmentation head
add_pos_neg: True                         # use an additional positive-negative classifier in the simp mode


############################################# saving stuff #############################################
save_top: "/media/AB/4TBHDD1/vi-bertgrid/training-1"                    # dir to save weights of top performance models
save_log: "./log/"                        # dir to save logs

############################################ load check points #########################################
weights: ""

########################################### dataset loading stuff ######################################
num_workers: 4
data_root: "./Dataset/SROIE_DATA"                  # root of the raw dataset

## SROIE
num_classes: 5
image_mean:
 - 0.9248
 - 0.9224
 - 0.9215
image_std:
 - 0.1532
 - 0.1545
 - 0.1536
image_min_size:
 - 320
 - 416
 - 512
 - 608
 - 704 
image_max_size: 704
test_image_min_size: 704

ocr_url: sample_url
parse_mode: sample_parse_mode


@Ajithbalakrishnan
Copy link
Author

Hello~

A pre-train weight is available at google drive, try to load it using the simp classifier mode and see whether it gives results or not.

I tried this model. The results seem fine (Checked qualitatively). I can extract data(not all the fields, FN still exists).

@ZeningLin
Copy link
Owner

Hello~

Sorry for my delayed response. I found that the weights arg in your config file is empty, which will not load any checkpoint and will use the randomly initialized model. Have you filled it with the path to your pretrained weight during inference?

@ZeningLin
Copy link
Owner

In my experiments, I also found that the model gives poor results on the total field (you can see the epoch evaluation metric of each category in the terminal output, which is stored in the log folder), since a receipt in SROIE may contain multiple strings similar to the total entity, and this algorithm mess it up. But the model works well on the other three key fields.

@Ajithbalakrishnan
Copy link
Author

Ajithbalakrishnan commented Jan 29, 2023

Hello~

Sorry for my delayed response. I found that the weights arg in your config file is empty, which will not load any checkpoint and will use the randomly initialized model. Have you filled it with the path to your pretrained weight during inference?

Hi, The config file is configured for training. While doing inference, I have given the checkpoint abs dir in weights argument.

@ZeningLin
Copy link
Owner

Hello,

Does the model give results in inference mode when using the pretrained weights I provide?

@Ajithbalakrishnan
Copy link
Author

Hello,

Does the model give results in inference mode when using the pretrained weights I provide?

Yes, It does. But like you said, It gives poor performance for numeric fields like total. Also, I got some predictions for my trained model. But the predictions was bad. TP seems to be good, but FN is too much.

@ZeningLin
Copy link
Owner

I think that the configuration mismatch may cause the error in your case. What kind of architecture did you use when traning the model, crf, simp or full? When doing inference, the classifier_mode and tag_mode in the inference config file should be consistent with the pretrained weights. Check whether this kind of mismatch occured.

The poor results in total field is the shortcoming of ViBERTgrid model and hard to optimize, since it uses a grid encoding, which may confuse neighboring features. Using a weighted loss that gives more attention to the total field may slightly imporve the performance, but still far from satisfaction. I assume that using a larger size of input image may help, but it will take up more time during training and inference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants