Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can not reproduce the evaluation results of small model on 6k multi-ref dataset #63

Open
liuslnlp opened this issue Feb 25, 2021 · 5 comments
Assignees

Comments

@liuslnlp
Copy link

I first extract contexts from test.refs.txt (6000 lines)

cat test.refs.txt | cut -f 1 > test.source

and extract multi ref files (use up to 15 per sample)

for (( i=2; i<=15; i++ ))
do
    cat test.refs.txt | cut -f $i > refs/ref_$i.txt
done

Then use the following script to predict the responses on 6k multi-ref dataset.

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from nltk import word_tokenize
from tqdm import tqdm, trange

model_path = '/path/to/DialoGPT-small'
file_path = '/path/to/test.source'
out_path = '/path/to/gpt_test.txt'

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.padding_side = "left"
SEP = tokenizer.eos_token
tokenizer.add_special_tokens({'pad_token': SEP})

model = AutoModelForCausalLM.from_pretrained(model_path)
model.eval()
batch_size = 64

# read context
lines = []
with open(file_path, encoding='utf-8') as f:
    for line in f:
        new_line = SEP.join(line.strip().split(' EOS ')[-5:]) + SEP
        lines.append(new_line)

preds = []

# predict
for i in trange(0, len(lines), batch_size):
    batchs = lines[i:i+batch_size]
    batch_encoding = tokenizer.batch_encode_plus(
        batchs,
        max_length=256,
        padding=True, truncation=True,
        return_tensors="pt",
    )
    input_ids = batch_encoding['input_ids']
    attention_mask = batch_encoding['attention_mask']
    dyn_seq_len = input_ids.shape[1]
    preds_ids = model.generate(input_ids, attention_mask=attention_mask, max_length=512, num_beams=1, pad_token_id=tokenizer.eos_token_id)
    preds_ids = preds_ids[:, dyn_seq_len:].tolist()
    batch_preds = [tokenizer.decode(ids, skip_special_tokens=True) for ids in preds_ids]
    preds.extend(batch_preds)

# write predictions
with open(out_path, 'w', encoding='utf-8') as f:
    for pred in preds:
        line = ' '.join(word_tokenize(pred)) + '\n'
        f.write(line)

But there is a big gap between the evaluation results and those described in the paper.

My evaluation results

NIST: [3.372, 3.7761, 3.8364, 3.8455]
BLEU: [0.4679, 0.1924, 0.0928, 0.0505]
METEOR: 0.10545417931305287
Entropy: [4.9949875062421425, 7.123308932861081, 8.000309028686685, 8.413536358302238]
Distinct: [0.0619184959030736, 0.22404933196300103]
avg_len: 13.811166666666667

Described in paper

Experiment NIST2 NIST4 BLEU2 BLEU4 METEOR ENT-4 DIST-1 DIST-2 Avg. Len
DialoGPT 117M 2.39 2.41 10.54% 1.55% 7.53% 10.78 8.60% 39.90% 12.8

Here are predictions of the first 20 test samples:

I 'm not fasting , I 'm fasting because I 'm fasting .
I 'm waiting for someone to say something stupid and then I can see it over a r iamverysmart
I 'm not sure if I should be excited or scared .
I 'm going to be a millionaire by the end of this .
I love this post and the art . Do I 40 love it ? Well it does come framed , and it 's so absurd ... idk I just might .
I 'm not sure I trust him .
I have a few of those . I 'll have to check out the other ones .
I 'm watching the Oilers game on TV .
How hard is it to play snooker ?
Deshaun Watson is playing tonight .
What was your time ?
Artie Burns
What 's a screwdriver ?
I 'm not sure if I 'm missing something , but I do n't get it .
I think it 's a title defense .
I 'm not sure if it 's free , but I 've been to a few parks and they 're pretty cool .
I 'm not sure what you 're trying to say .
I 'm not sure what you 're trying to say .
I have the most chromosomes .
John Wick .
@liuslnlp
Copy link
Author

This is the evaluation result of the medium model and the large model. It can be seen that the gap between NIST/BLEU/DIST and the official results is relatively large.

DialoGPT-medium

NIST: [3.6142, 4.1402, 4.2257, 4.2379]
BLEU: [0.5054, 0.2272, 0.1161, 0.0658]
METEOR: 0.11448456319410923
Entropy: [5.110969324425441, 7.4741025550415054, 8.487332812728265, 8.96638167676112]
Distinct: [0.063865246873529, 0.2401520577378657]
avg_len: 13.1005

DialoGPT-large

NIST: [3.9302, 4.5571, 4.6678, 4.6848]
BLEU: [0.5454, 0.2555, 0.1352, 0.0788]
METEOR: 0.11694036328599848
Entropy: [5.376659255260651, 8.038661195818934, 9.129731989024675, 9.630095839832428]
Distinct: [0.07617776246662647, 0.29050042408821036]
avg_len: 11.611

Official

Experiment NIST2 NIST4 BLEU2 BLEU4 METEOR ENT-4 DIST-1 DIST-2 Avg. Len
Human response 3.41 4.25 17.90% 7.48% 10.64% 11 14.50% 63.00% 13.1
DialoGPT 117M 2.39 2.41 10.54% 1.55% 7.53% 10.78 8.60% 39.90% 12.8
DialoGPT 345M 3 3.06 16.96% 4.56% 9.81% 9.13 6.80% 26.30% 12.2
DialoGPT 762M 2.84 2.9 18.66% 5.25% 9.66% 9.72 7.76% 29.93% 11.2

@marianafidalgo
Copy link

I first extract contexts from test.refs.txt (6000 lines)

cat test.refs.txt | cut -f 1 > test.source

and extract multi ref files (use up to 15 per sample)

for (( i=2; i<=15; i++ ))
do
    cat test.refs.txt | cut -f $i > refs/ref_$i.txt
done

Then use the following script to predict the responses on 6k multi-ref dataset.

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from nltk import word_tokenize
from tqdm import tqdm, trange

model_path = '/path/to/DialoGPT-small'
file_path = '/path/to/test.source'
out_path = '/path/to/gpt_test.txt'

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.padding_side = "left"
SEP = tokenizer.eos_token
tokenizer.add_special_tokens({'pad_token': SEP})

model = AutoModelForCausalLM.from_pretrained(model_path)
model.eval()
batch_size = 64

# read context
lines = []
with open(file_path, encoding='utf-8') as f:
    for line in f:
        new_line = SEP.join(line.strip().split(' EOS ')[-5:]) + SEP
        lines.append(new_line)

preds = []

# predict
for i in trange(0, len(lines), batch_size):
    batchs = lines[i:i+batch_size]
    batch_encoding = tokenizer.batch_encode_plus(
        batchs,
        max_length=256,
        padding=True, truncation=True,
        return_tensors="pt",
    )
    input_ids = batch_encoding['input_ids']
    attention_mask = batch_encoding['attention_mask']
    dyn_seq_len = input_ids.shape[1]
    preds_ids = model.generate(input_ids, attention_mask=attention_mask, max_length=512, num_beams=1, pad_token_id=tokenizer.eos_token_id)
    preds_ids = preds_ids[:, dyn_seq_len:].tolist()
    batch_preds = [tokenizer.decode(ids, skip_special_tokens=True) for ids in preds_ids]
    preds.extend(batch_preds)

# write predictions
with open(out_path, 'w', encoding='utf-8') as f:
    for pred in preds:
        line = ' '.join(word_tokenize(pred)) + '\n'
        f.write(line)

But there is a big gap between the evaluation results and those described in the paper.

My evaluation results

NIST: [3.372, 3.7761, 3.8364, 3.8455]
BLEU: [0.4679, 0.1924, 0.0928, 0.0505]
METEOR: 0.10545417931305287
Entropy: [4.9949875062421425, 7.123308932861081, 8.000309028686685, 8.413536358302238]
Distinct: [0.0619184959030736, 0.22404933196300103]
avg_len: 13.811166666666667

Described in paper

Experiment NIST2 NIST4 BLEU2 BLEU4 METEOR ENT-4 DIST-1 DIST-2 Avg. Len
DialoGPT 117M 2.39 2.41 10.54% 1.55% 7.53% 10.78 8.60% 39.90% 12.8
Here are predictions of the first 20 test samples:

I 'm not fasting , I 'm fasting because I 'm fasting .
I 'm waiting for someone to say something stupid and then I can see it over a r iamverysmart
I 'm not sure if I should be excited or scared .
I 'm going to be a millionaire by the end of this .
I love this post and the art . Do I 40 love it ? Well it does come framed , and it 's so absurd ... idk I just might .
I 'm not sure I trust him .
I have a few of those . I 'll have to check out the other ones .
I 'm watching the Oilers game on TV .
How hard is it to play snooker ?
Deshaun Watson is playing tonight .
What was your time ?
Artie Burns
What 's a screwdriver ?
I 'm not sure if I 'm missing something , but I do n't get it .
I think it 's a title defense .
I 'm not sure if it 's free , but I 've been to a few parks and they 're pretty cool .
I 'm not sure what you 're trying to say .
I 'm not sure what you 're trying to say .
I have the most chromosomes .
John Wick .

Hello!
Could you share the evaluation code you used, please?

@liuslnlp
Copy link
Author

The evaluation code is almost the same as the official.

#  Copyright (c) Microsoft Corporation. 
#  Licensed under the MIT license. 

import re
from collections import defaultdict
import argparse
from pathlib import Path

import os, time, subprocess, io, sys, re, argparse
import numpy as np

py_version = sys.version.split('.')[0]
if py_version == '2':
    open = io.open
else:
    unicode = str

def makedirs(fld):
    if not os.path.exists(fld):
        os.makedirs(fld)

cur_dir = str(Path(__file__).parent)


def str2bool(s):
    # to avoid issue like this: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
    if s.lower() in ['t','true','1','y']:
        return True
    elif s.lower() in ['f','false','0','n']:
        return False
    else:
        raise ValueError

def calc_nist_bleu(path_refs, path_hyp, fld_out='temp', n_lines=None):
    # call mteval-v14c.pl
    # ftp://jaguar.ncsl.nist.gov/mt/resources/mteval-v14c.pl
    # you may need to cpan install XML:Twig Sort:Naturally String:Util 

    makedirs(fld_out)

    if n_lines is None:
        n_lines = len(open(path_refs[0], encoding='utf-8').readlines())    
    # import pdb; pdb.set_trace()
    _write_xml([''], fld_out + '/src.xml', 'src', n_lines=n_lines)
    _write_xml([path_hyp], fld_out + '/hyp.xml', 'hyp')#, n_lines=n_lines)
    _write_xml(path_refs, fld_out + '/ref.xml', 'ref')#, n_lines=n_lines)

    time.sleep(1)
    cmd = [
        'perl',f'{cur_dir}/mteval-v14c.pl',
        '-s', '%s/src.xml'%fld_out,
        '-t', '%s/hyp.xml'%fld_out,
        '-r', '%s/ref.xml'%fld_out,
        ]
    process = subprocess.Popen(cmd, stdout=subprocess.PIPE)
    # import pdb; pdb.set_trace()
    output, error = process.communicate()

    lines = output.decode().split('\n')

    try:
        nist = lines[-6].strip('\r').split()[1:5]
        bleu = lines[-4].strip('\r').split()[1:5]
        return [float(x) for x in nist], [float(x) for x in bleu]

    except Exception:
        print('mteval-v14c.pl returns unexpected message')
        print('cmd = '+str(cmd))
        print(output.decode())
        print(error.decode())
        return [-1]*4, [-1]*4

    


def calc_cum_bleu(path_refs, path_hyp):
    # call multi-bleu.pl
    # https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/multi-bleu.perl
    # the 4-gram cum BLEU returned by this one should be very close to calc_nist_bleu
    # however multi-bleu.pl doesn't return cum BLEU of lower rank, so in nlp_metrics we preferr calc_nist_bleu
    # NOTE: this func doesn't support n_lines argument and output is not parsed yet

    process = subprocess.Popen(
            ['perl', f'{cur_dir}/multi-bleu.perl'] + path_refs, 
            stdout=subprocess.PIPE, 
            stdin=subprocess.PIPE
            )
    with open(path_hyp, encoding='utf-8') as f:
        lines = f.readlines()
    for line in lines:
        process.stdin.write(line.encode())
    output, error = process.communicate()
    return output.decode()


def calc_meteor(path_refs, path_hyp, fld_out='temp', n_lines=None, pretokenized=True):
    # Call METEOR code.
    # http://www.cs.cmu.edu/~alavie/METEOR/index.html

    makedirs(fld_out)
    path_merged_refs = fld_out + '/refs_merged.txt'
    _write_merged_refs(path_refs, path_merged_refs)
    cmd = [
            'java', '-Xmx1g',    # heapsize of 1G to avoid OutOfMemoryError
            '-jar', f'{cur_dir}/meteor-1.5/meteor-1.5.jar', 
            path_hyp, path_merged_refs, 
            '-r', '%i'%len(path_refs),     # refCount 
            '-l', 'en', '-norm'     # also supports language: cz de es fr ar
            ]
    # print(cmd)
    process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    output, error = process.communicate()
    for line in output.decode().split('\n'):
        if "Final score:" in line:
            return float(line.split()[-1])

    print('meteor-1.5.jar returns unexpected message')
    print("cmd = " + " ".join(cmd))
    print(output.decode())
    print(error.decode())
    return -1 


def calc_entropy(path_hyp, n_lines=None):
    # based on Yizhe Zhang's code
    etp_score = [0.0,0.0,0.0,0.0]
    counter = [defaultdict(int),defaultdict(int),defaultdict(int),defaultdict(int)]
    i = 0
    for line in open(path_hyp, encoding='utf-8'):
        i += 1
        words = line.strip('\n').split()
        for n in range(4):
            for idx in range(len(words)-n):
                ngram = ' '.join(words[idx:idx+n+1])
                counter[n][ngram] += 1
        if i == n_lines:
            break

    for n in range(4):
        total = sum(counter[n].values())
        for v in counter[n].values():
            etp_score[n] += - v /total * (np.log(v) - np.log(total))

    return etp_score


def calc_len(path, n_lines):
    l = []
    for line in open(path, encoding='utf8'):
        l.append(len(line.strip('\n').split()))
        if len(l) == n_lines:
            break
    return np.mean(l)


def calc_diversity(path_hyp):
    tokens = [0.0,0.0]
    types = [defaultdict(int),defaultdict(int)]
    for line in open(path_hyp, encoding='utf-8'):
        words = line.strip('\n').split()
        for n in range(2):
            for idx in range(len(words)-n):
                ngram = ' '.join(words[idx:idx+n+1])
                types[n][ngram] = 1
                tokens[n] += 1
    div1 = len(types[0].keys())/tokens[0]
    div2 = len(types[1].keys())/tokens[1]
    return [div1, div2]


def nlp_metrics(path_refs, path_hyp, fld_out='temp',  n_lines=None):
    nist, bleu = calc_nist_bleu(path_refs, path_hyp, fld_out, n_lines)
    meteor = calc_meteor(path_refs, path_hyp, fld_out, n_lines)
    entropy = calc_entropy(path_hyp, n_lines)
    div = calc_diversity(path_hyp)
    avg_len = calc_len(path_hyp, n_lines)
    return nist, bleu, meteor, entropy, div, avg_len


def _write_merged_refs(paths_in, path_out, n_lines=None):
    # prepare merged ref file for meteor-1.5.jar (calc_meteor)
    # lines[i][j] is the ref from i-th ref set for the j-th query

    lines = []
    for path_in in paths_in:
        lines.append([line.strip('\n') for line in open(path_in, encoding='utf-8')])

    with open(path_out, 'w', encoding='utf-8') as f:
        for j in range(len(lines[0])):
            for i in range(len(paths_in)):
                f.write(unicode(lines[i][j]) + "\n")



def _write_xml(paths_in, path_out, role, n_lines=None):
    # prepare .xml files for mteval-v14c.pl (calc_nist_bleu)
    # role = 'src', 'hyp' or 'ref'

    lines = [
        '<?xml version="1.0" encoding="UTF-8"?>',
        '<!DOCTYPE mteval SYSTEM "">',
        '<!-- generated by https://github.com/golsun/NLP-tools -->',
        '<!-- from: %s -->'%paths_in,
        '<!-- as inputs for ftp://jaguar.ncsl.nist.gov/mt/resources/mteval-v14c.pl -->',
        '<mteval>',
        ]

    for i_in, path_in in enumerate(paths_in):

        # header ----

        if role == 'src':
            lines.append('<srcset setid="unnamed" srclang="src">')
            set_ending = '</srcset>'
        elif role == 'hyp':
            lines.append('<tstset setid="unnamed" srclang="src" trglang="tgt" sysid="unnamed">')
            set_ending = '</tstset>'
        elif role == 'ref':
            lines.append('<refset setid="unnamed" srclang="src" trglang="tgt" refid="ref%i">'%i_in)
            set_ending = '</refset>'
        
        lines.append('<doc docid="unnamed" genre="unnamed">')

        # body -----

        if role == 'src':
            body = ['__src__'] * n_lines
        else:
            with open(path_in, 'r', encoding='utf-8') as f:
                body = f.readlines()
            if n_lines is not None:
                body = body[:n_lines]
        #for i in range(len(body)):
        i = 0
        for b in body:
            line = b.strip('\n')
            line = line.replace('&',' ').replace('<',' ')        # remove illegal xml char
            # if len(line) > 0:
            lines.append('<p><seg id="%i"> %s </seg></p>'%(i + 1, line))
            i += 1

        # ending -----

        lines.append('</doc>')
        if role == 'src':
            lines.append('</srcset>')
        elif role == 'hyp':
            lines.append('</tstset>')
        elif role == 'ref':
            lines.append('</refset>')

    lines.append('</mteval>')
    with open(path_out, 'w', encoding='utf-8') as f:
        f.write(unicode('\n'.join(lines)))

def dialogue_evaluation(hyp_file, ref_file, fld_out):
    nist, bleu, meteor, entropy, div, avg_len = nlp_metrics([ref_file], hyp_file, fld_out)
    results = {
        'NIST-2': nist[1],
        'NIST-4': nist[3],
        'BLEU-2': bleu[1],
        'BLEU-4': bleu[3],
        'METEOR': meteor,
        'Entropy-4': entropy[3],
        'Dist-1': div[0],
        'Dist-2': div[1],
        'avg_len': avg_len
    }
    return results

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--refs_dir', default=None)
    parser.add_argument('--ref_file', default=None)
    parser.add_argument('--hyp_file', required=True)
    parser.add_argument('--fld_out', required=True)

    args = parser.parse_args()
    if args.ref_file is not None:
        refs_files = [args.ref_file]
    else:
        refs_files = list(map(str, Path(args.refs_dir).glob('ref_*.txt')))
    print("references: ", refs_files)
    nist, bleu, meteor, entropy, div, avg_len = nlp_metrics(refs_files, args.hyp_file, args.fld_out)
    print("NIST:", nist)
    print("BLEU:", bleu)
    print("METEOR:", meteor)
    print("Entropy:", entropy)
    print("Distinct:", div)
    print("avg_len:", avg_len)
if __name__ == "__main__":
    main()

@marianafidalgo
Copy link

marianafidalgo commented May 31, 2021

Thank you so much.
And one more question: where can I find the test.refs.txt and test.refs.txt files?

@dreasysnail dreasysnail self-assigned this Jun 9, 2021
@Mayanksoni20
Copy link

Mayanksoni20 commented Nov 10, 2021

preds_ids = model.generate(input_ids, attention_mask=attention_mask, max_length=512, num_beams=1, pad_token_id=tokenizer.eos_token_id)

From DialoGPT paper,

Beam search (with
beam width 10) dramatically improves BLEU and
DIST scores, and marginally improves NIST and
METEOR.

The paper mentions that the results obtained are with beam width 10 and you ran the evaluation with beam width 1. Maybe trying generating responses with num_beams=10 and observe if there is any difference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants