Skip to content

Commit

Permalink
Merge pull request nltk#1837 from alvations/develop
Browse files Browse the repository at this point in the history
Cleaning up RTE classification code
  • Loading branch information
stevenbird committed Sep 28, 2017
2 parents d98eef0 + 81319eb commit 0477ceb
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 83 deletions.
135 changes: 52 additions & 83 deletions nltk/classify/rte_classify.py
Expand Up @@ -18,28 +18,9 @@
"""
from __future__ import print_function

import nltk
from nltk.classify.util import accuracy

def ne(token):
"""
This just assumes that words in all caps or titles are
named entities.
:type token: str
"""
if token.istitle() or token.isupper():
return True
return False

def lemmatize(word):
"""
Use morphy from WordNet to find the base form of verbs.
"""
lemma = nltk.corpus.wordnet.morphy(word, pos=nltk.corpus.wordnet.VERB)
if lemma is not None:
return lemma
return word
from nltk.tokenize import RegexpTokenizer
from nltk.classify.util import accuracy, check_megam_config
from nltk.classify.maxent import MaxentClassifier

class RTEFeatureExtractor(object):
"""
Expand All @@ -60,7 +41,6 @@ def __init__(self, rtepair, stop=True, use_lemmatize=False):
'denied'])
# Try to tokenize so that abbreviations, monetary amounts, email
# addresses, URLs are single tokens.
from nltk.tokenize import RegexpTokenizer
tokenizer = RegexpTokenizer('[\w.@:/]+|\w+|\$[\d.]+')

#Get the set of word types for text and hypothesis
Expand All @@ -70,8 +50,8 @@ def __init__(self, rtepair, stop=True, use_lemmatize=False):
self.hyp_words = set(self.hyp_tokens)

if use_lemmatize:
self.text_words = set(lemmatize(token) for token in self.text_tokens)
self.hyp_words = set(lemmatize(token) for token in self.hyp_tokens)
self.text_words = set(self._lemmatize(token) for token in self.text_tokens)
self.hyp_words = set(self._lemmatize(token) for token in self.hyp_tokens)

if self.stop:
self.text_words = self.text_words - self.stopwords
Expand All @@ -89,7 +69,7 @@ def overlap(self, toktype, debug=False):
:param toktype: distinguish Named Entities from ordinary words
:type toktype: 'ne' or 'word'
"""
ne_overlap = set(token for token in self._overlap if ne(token))
ne_overlap = set(token for token in self._overlap if self._ne(token))
if toktype == 'ne':
if debug:
print("ne overlap", ne_overlap)
Expand All @@ -108,14 +88,36 @@ def hyp_extra(self, toktype, debug=True):
:param toktype: distinguish Named Entities from ordinary words
:type toktype: 'ne' or 'word'
"""
ne_extra = set(token for token in self._hyp_extra if ne(token))
ne_extra = set(token for token in self._hyp_extra if self._ne(token))
if toktype == 'ne':
return ne_extra
elif toktype == 'word':
return self._hyp_extra - ne_extra
else:
raise ValueError("Type not recognized: '%s'" % toktype)

@staticmethod
def _ne(token):
"""
This just assumes that words in all caps or titles are
named entities.
:type token: str
"""
if token.istitle() or token.isupper():
return True
return False

@staticmethod
def _lemmatize(word):
"""
Use morphy from WordNet to find the base form of verbs.
"""
lemma = nltk.corpus.wordnet.morphy(word, pos=nltk.corpus.wordnet.VERB)
if lemma is not None:
return lemma
return word


def rte_features(rtepair):
extractor = RTEFeatureExtractor(rtepair)
Expand All @@ -130,62 +132,29 @@ def rte_features(rtepair):
return features


def rte_classifier(trainer, features=rte_features):
"""
Classify RTEPairs
"""
train = ((pair, pair.value) for pair in
nltk.corpus.rte.pairs(['rte1_dev.xml', 'rte2_dev.xml',
'rte3_dev.xml']))
test = ((pair, pair.value) for pair in
nltk.corpus.rte.pairs(['rte1_test.xml', 'rte2_test.xml',
'rte3_test.xml']))

# Train up a classifier.
print('Training classifier...')
classifier = trainer([(features(pair), label) for (pair, label) in train])
def rte_featurize(rte_pairs):
return [(rte_features(pair), pair.value) for pair in rte_pairs]

# Run the classifier on the test data.

def rte_classifier(algorithm):
from nltk.corpus import rte as rte_corpus
train_set = rte_corpus.pairs(['rte1_dev.xml', 'rte2_dev.xml', 'rte3_dev.xml'])
test_set = rte_corpus.pairs(['rte1_test.xml', 'rte2_test.xml', 'rte3_test.xml'])
featurized_train_set = rte_featurize(train_set)
featurized_test_set = rte_featurize(test_set)
# Train the classifier
print('Training classifier...')
if algorithm in ['megam', 'BFGS']: # MEGAM based algorithms.
# Ensure that MEGAM is configured first.
check_megam_config()
clf = lambda x: MaxentClassifier.train(featurized_train_set, algorithm)
elif algorithm in ['GIS', 'IIS']: # Use default GIS/IIS MaxEnt algorithm
clf = MaxentClassifier.train(featurized_train_set, algorithm)
else:
err_msg = str("RTEClassifier only supports these algorithms:\n "
"'megam', 'BFGS', 'GIS', 'IIS'.\n")
raise Exception(err_msg)
print('Testing classifier...')
acc = accuracy(classifier, [(features(pair), label)
for (pair, label) in test])
acc = accuracy(clf, featurized_test_set)
print('Accuracy: %6.4f' % acc)

# Return the classifier
return classifier


def demo_features():
pairs = nltk.corpus.rte.pairs(['rte1_dev.xml'])[:6]
for pair in pairs:
print()
for key in sorted(rte_features(pair)):
print("%-15s => %s" % (key, rte_features(pair)[key]))


def demo_feature_extractor():
rtepair = nltk.corpus.rte.pairs(['rte3_dev.xml'])[33]
extractor = RTEFeatureExtractor(rtepair)
print(extractor.hyp_words)
print(extractor.overlap('word'))
print(extractor.overlap('ne'))
print(extractor.hyp_extra('word'))


def demo():
import nltk
try:
nltk.config_megam('/usr/local/bin/megam')
trainer = lambda x: nltk.MaxentClassifier.train(x, 'megam')
except ValueError:
try:
trainer = lambda x: nltk.MaxentClassifier.train(x, 'BFGS')
except ValueError:
trainer = nltk.MaxentClassifier.train
nltk.classify.rte_classifier(trainer)

if __name__ == '__main__':
demo_features()
demo_feature_extractor()
demo()

return clf
12 changes: 12 additions & 0 deletions nltk/classify/util.py
Expand Up @@ -310,3 +310,15 @@ def wsd_demo(trainer, word, features, n=1000):
# Return the classifier
return classifier



def check_megam_config(self):
"""
Checks whether the MEGAM binary is configured.
"""
try:
_megam_bin
except NameError:
err_msg = str("Please configure your megam binary first, e.g.\n"
">>> nltk.config_megam('/usr/bin/local/megam')")
raise NameError(err_msg)
86 changes: 86 additions & 0 deletions nltk/test/unit/test_rte_classify.py
@@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, unicode_literals

import unittest

from nltk.corpus import rte as rte_corpus
from nltk.classify.rte_classify import RTEFeatureExtractor, rte_features, rte_classifier

expected_from_rte_feature_extration = """
alwayson => True
ne_hyp_extra => 0
ne_overlap => 1
neg_hyp => 0
neg_txt => 0
word_hyp_extra => 3
word_overlap => 3
alwayson => True
ne_hyp_extra => 0
ne_overlap => 1
neg_hyp => 0
neg_txt => 0
word_hyp_extra => 2
word_overlap => 1
alwayson => True
ne_hyp_extra => 1
ne_overlap => 1
neg_hyp => 0
neg_txt => 0
word_hyp_extra => 1
word_overlap => 2
alwayson => True
ne_hyp_extra => 1
ne_overlap => 0
neg_hyp => 0
neg_txt => 0
word_hyp_extra => 6
word_overlap => 2
alwayson => True
ne_hyp_extra => 1
ne_overlap => 0
neg_hyp => 0
neg_txt => 0
word_hyp_extra => 4
word_overlap => 0
alwayson => True
ne_hyp_extra => 1
ne_overlap => 0
neg_hyp => 0
neg_txt => 0
word_hyp_extra => 3
word_overlap => 1
"""


class RTEClassifierTest(unittest.TestCase):
# Test the feature extraction method.
def test_rte_feature_extraction(self):
pairs = rte_corpus.pairs(['rte1_dev.xml'])[:6]
test_output = ["%-15s => %s" % (key, rte_features(pair)[key])
for pair in pairs for key in sorted(rte_features(pair))]
expected_output = expected_from_rte_feature_extration.strip().split('\n')
# Remove null strings.
expected_output = list(filter(None, expected_output))
self.assertEqual(test_output, expected_output)
# Test the RTEFeatureExtractor object.
def test_feature_extractor_object(self):
rtepair = rte_corpus.pairs(['rte3_dev.xml'])[33]
extractor = RTEFeatureExtractor(rtepair)
self.assertEqual(extractor.hyp_words, {'member', 'China', 'SCO.'})
self.assertEqual(extractor.overlap('word'), set())
self.assertEqual(extractor.overlap('ne'), {'China'})
self.assertEqual(extractor.hyp_extra('word'), {'member'})
# Test the RTE classifier training.
def test_rte_classification_without_megam(self):
clf = rte_classifier('IIS')
clf = rte_classifier('GIS')
@unittest.skip("Skipping tests with dependencies on MEGAM")
def test_rte_classification_with_megam(self):
nltk.config_megam('/usr/local/bin/megam')
clf = rte_classifier('megam')
clf = rte_classifier('BFGS')

0 comments on commit 0477ceb

Please sign in to comment.