Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poisson naive Bayes classifier (PoissonNB) #3708

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Expand Up @@ -918,6 +918,7 @@ Pairwise metrics
naive_bayes.GaussianNB
naive_bayes.MultinomialNB
naive_bayes.BernoulliNB
naive_bayes.PoissonNB


.. _neighbors_ref:
Expand Down
47 changes: 46 additions & 1 deletion doc/modules/naive_bayes.rst
Expand Up @@ -83,7 +83,8 @@ classification. The likelihood of the features is assumed to be Gaussian:

.. math::

P(x_i \mid y) &= \frac{1}{\sqrt{2\pi\sigma^2_y}} \exp\left(-\frac{(x_i - \mu_y)^2}{2\sigma^2_y}\right)
P(x_i \mid y) &= \frac{1}{\sqrt{2\pi\sigma^2_y}}
\exp\left(-\frac{(x_i - \mu_y)^2}{2\sigma^2_y}\right)

The parameters :math:`\sigma_y` and :math:`\mu_y`
are estimated using maximum likelihood.
Expand All @@ -97,6 +98,7 @@ are estimated using maximum likelihood.
... % (iris.data.shape[0],(iris.target != y_pred).sum()))
Number of mislabeled points out of a total 150 points : 6


.. _multinomial_naive_bayes:

Multinomial Naive Bayes
Expand Down Expand Up @@ -178,6 +180,49 @@ It is advisable to evaluate both models, if time permits.
3rd Conf. on Email and Anti-Spam (CEAS).


.. _poisson_naive_bayes:

Poisson Naive Bayes
-------------------

:class:`PoissonNB` implements the naive Bayes algorithm for Poisson distributed
features. Poisson random variables typically arise when counting events of a
rate process with underlying rate :math:`\lambda` during a finite interval.

The likelihood is given by:

.. math::

P(x_i \mid \lambda) = \frac{\lambda^{x_i} e^{-\lambda}}{x_i!}

and can be thought of as the limit of a Bernoulli process with per-trial
probability :math:`p = \lambda / n` as the number of trials `n` goes to
infinity.

.. math::

P(x_i \mid \lambda) = \lim_{n \to \infty} \binom{n}{x_i}
\left(\frac{\lambda}{n}\right)^{x_i}
\left(1-\frac{\lambda}{n}\right)^{n-x_i}

.. topic:: References:

* T. D. Sanger (1994).
`Probability Density Estimation for the Interpretation of Neural Population Codes.
<http://jn.physiology.org/content/jn/76/4/2790.full.pdf>`_
J Neurophys. 76(4):2790-3

* S. Kim, H. Seo and H. Rim. (2003)
`Poisson naive Bayes for text classification with feature weighting.
<http://dl.acm.org/citation.cfm?id=1118940>`_
6th Workshop on Information retrieval with Asian languages 11:33-40

* W. J. Ma, et Al. (2006).
`Bayesian inference with probabilistic population codes
<http://psych.stanford.edu/~jlm/pdfs/Ma%20et%20al%20with%20figs.pdf>`_
Nat. Neurosci. 9:1432-1438


Out-of-core naive Bayes model fitting
-------------------------------------

Expand Down
105 changes: 100 additions & 5 deletions sklearn/naive_bayes.py
Expand Up @@ -29,7 +29,7 @@
from .utils.multiclass import _check_partial_fit_first_call
from .externals import six

__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB']
__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB', 'PoissonNB']


class BaseNB(six.with_metaclass(ABCMeta, BaseEstimator, ClassifierMixin)):
Expand Down Expand Up @@ -320,18 +320,113 @@ def _partial_fit(self, X, y, classes=None, _refit=False):

def _joint_log_likelihood(self, X):
X = check_array(X)
joint_log_likelihood = []
for i in range(np.size(self.classes_)):

joint_log_likelihood = np.zeros((np.shape(X)[0],
np.size(self.classes_)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.size -> len

for i in range(len(self.classes_)):
jointi = np.log(self.class_prior_[i])
n_ij = - 0.5 * np.sum(np.log(2. * np.pi * self.sigma_[i, :]))
n_ij -= 0.5 * np.sum(((X - self.theta_[i, :]) ** 2) /
(self.sigma_[i, :]), 1)
joint_log_likelihood.append(jointi + n_ij)
joint_log_likelihood[:, i] = jointi + n_ij

joint_log_likelihood = np.array(joint_log_likelihood).T
return joint_log_likelihood


class PoissonNB(BaseNB):
"""
Poisson Naive Bayes (PoissonNB)

Attributes
----------
class_prior_ : array, shape (n_classes,)
probability of each class.

class_count_ : array, shape (n_classes,)
number of training samples observed in each class.

lambda_ : array, shape (n_classes, n_features)
mean of each feature per class

Examples
--------
>>> import numpy as np
>>> X = np.array([[5, 2, 6, 1, 8, 1], [0, 0, 1, 3, 3, 1]]).T
>>> y = np.array([1, 1, 1, 2, 2, 2])
>>> from sklearn.naive_bayes import PoissonNB
>>> clf = PoissonNB()
>>> clf.fit(X, y)
PoissonNB()
>>> print(clf.predict(X))
[1 1 1 2 2 2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add here a reference paper which includes the derivation?


References
----------
T. D. Sanger. (1994) Probability Density Estimation for the Interpretation
of Neural Population Codes. J Neurophys. 76(4):2790-3 (Eq. 8)
http://jn.physiology.org/content/jn/76/4/2790.full.pdf
"""

def fit(self, X, y):
"""Fit Poisson Naive Bayes according to X, y

Parameters
----------
X : array-like, shape (n_samples, n_features)
Training vectors, where n_samples is the number of samples
and n_features is the number of features. X expects non-negative
integers, although in practice non-integer values may also work.
Negative counts will result in a ValueError.

y : array-like, shape (n_samples,)
Target values.

Returns
-------
self : object
Returns self.
"""
X, y = check_X_y(X, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should check that X entries are all positive and raise a meaningful error message if not

PoissonNB._check_non_negative(X)

n_samples, n_features = X.shape

self.classes_ = unique_y = np.unique(y)
n_classes = unique_y.shape[0]

epsilon = 1e-9

self.lambda_ = np.zeros((n_classes, n_features))
self.class_prior_ = np.zeros(n_classes)

for i, y_i in enumerate(unique_y):
Xi = X[y == y_i, :]
self.lambda_[i, :] = np.mean(Xi, axis=0) + epsilon
self.class_prior_[i] = float(Xi.shape[0]) / n_samples

return self

def _joint_log_likelihood(self, X):
X = check_array(X)
PoissonNB._check_non_negative(X)

joint_log_likelihood = np.zeros((np.shape(X)[0],
np.size(self.classes_)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.size -> len


for i in range(len(self.classes_)):
jointi = np.log(self.class_prior_[i])
n_ij = np.sum(X * np.log(self.lambda_[i, :]), axis=1)
n_ij -= np.sum(self.lambda_[i, :])
joint_log_likelihood[:, i] = jointi + n_ij

return joint_log_likelihood

@staticmethod
def _check_non_negative(X):
if np.any(X < 0.):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will not work with a sparse matrix (I think)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much of the fit and _joint_log_likelihood methods won't work with sparse matrices. Currently, either will throw an error at check_X_y. Unfortunately, array-like and sparse-matrix overload multiplication differently, so adding sparse matrix support will require detecting what was passed to _joint_log_likelihood and branching on it or adding a new function similar to safe_sparse_dot but with the array-like (element-wise) multiplication behavior to extmath.py.

I'm happy to do either, or hold off on sparse matrix support for the time being. Let me know which of those three you prefer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should not be much work to support sparse data so if you have the time
it would be great.

raise ValueError("Input X must be non-negative")


class BaseDiscreteNB(BaseNB):
"""Abstract base class for naive Bayes on discrete/categorical data

Expand Down
13 changes: 7 additions & 6 deletions sklearn/tests/test_common.py
Expand Up @@ -148,9 +148,11 @@ def test_classifiers():
yield check_classifiers_pickle, name, Classifier
# basic consistency testing
yield check_classifiers_train, name, Classifier
if (name not in ["MultinomialNB", "LabelPropagation", "LabelSpreading"]
if (name not in ["MultinomialNB", "PoissonNB",
"LabelPropagation", "LabelSpreading"]
# TODO some complication with -1 label
and name not in ["DecisionTreeClassifier", "ExtraTreeClassifier"]):
and name not in ["DecisionTreeClassifier",
"ExtraTreeClassifier"]):
# We don't raise a warning in these classifiers, as
# the column y interface is used by the forests.

Expand Down Expand Up @@ -316,8 +318,8 @@ def test_root_import_all_completeness():


def test_sparsify_estimators():
#Test if predict with sparsified estimators works.
#Tests regression, binary classification, and multi-class classification.
# Test if predict with sparsified estimators works.
# Tests regression, binary classification, and multi-class classification.
estimators = all_estimators()

# test regression and binary classification
Expand Down Expand Up @@ -361,8 +363,7 @@ def test_non_transformer_estimators_n_iter():

# Tested in test_transformer_n_iter below
elif name in CROSS_DECOMPOSITION or (
name in ['LinearSVC', 'LogisticRegression']
):
name in ['LinearSVC', 'LogisticRegression']):
continue

else:
Expand Down
51 changes: 44 additions & 7 deletions sklearn/tests/test_naive_bayes.py
Expand Up @@ -11,10 +11,12 @@
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_not_equal
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_greater

from sklearn.naive_bayes import GaussianNB, BernoulliNB, MultinomialNB
from sklearn.naive_bayes import GaussianNB, BernoulliNB, MultinomialNB, \
PoissonNB

# Data is just 6 separable points in the plane
X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]])
Expand Down Expand Up @@ -66,6 +68,30 @@ def test_discrete_prior():
clf.class_log_prior_, 8)


def test_poissonnb():
clf = PoissonNB()
assert_raises(ValueError, clf.fit, -X2, y2)

y_pred = clf.fit(X2, y2).predict(X2)
assert_array_equal(y_pred, y2)


def test_poissonnb_prior():
"""Test whether class priors are properly set. """
Xp = X + np.array([2, 2])
clf = PoissonNB().fit(Xp, y)
assert_array_almost_equal(np.array([3, 3]) / 6.0,
clf.class_prior_, 8)
clf.fit(X2, y2)
assert_array_almost_equal(clf.class_prior_.sum(), 1)

# Verify that np.log(clf.predict_proba(X)) gives the same results as
# clf.predict_log_proba(X)
y_pred_proba = clf.predict_proba(X2)
y_pred_log_proba = clf.predict_log_proba(X2)
assert_array_almost_equal(np.log(y_pred_proba), y_pred_log_proba, 8)


def test_mnnb():
"""Test Multinomial Naive Bayes classification.

Expand Down Expand Up @@ -153,7 +179,7 @@ def test_gnb_partial_fit():
def test_discretenb_pickle():
"""Test picklability of discrete naive Bayes classifiers"""

for cls in [BernoulliNB, MultinomialNB, GaussianNB]:
for cls in [BernoulliNB, MultinomialNB, GaussianNB, PoissonNB]:
clf = cls().fit(X2, y2)
y_pred = clf.predict(X2)

Expand All @@ -163,9 +189,7 @@ def test_discretenb_pickle():

assert_array_equal(y_pred, clf.predict(X2))

if cls is not GaussianNB:
# TODO re-enable me when partial_fit is implemented for GaussianNB

if cls is not PoissonNB:
# Test pickling of estimator trained with partial_fit
clf2 = cls().partial_fit(X2[:3], y2[:3], classes=np.unique(y2))
clf2.partial_fit(X2[3:], y2[3:])
Expand All @@ -177,7 +201,7 @@ def test_discretenb_pickle():

def test_input_check_fit():
"""Test input checks for the fit method"""
for cls in [BernoulliNB, MultinomialNB, GaussianNB]:
for cls in [BernoulliNB, MultinomialNB, GaussianNB, PoissonNB]:
# check shape consistency for number of samples at fit time
assert_raises(ValueError, cls().fit, X2, y2[:-1])

Expand Down Expand Up @@ -212,10 +236,16 @@ def test_discretenb_predict_proba():
"""Test discrete NB classes' probability scores"""

# The 100s below distinguish Bernoulli from multinomial.
# FIXME: write a test to show this.
X_bernoulli = [[1, 100, 0], [0, 1, 0], [0, 100, 1]]
X_multinomial = [[0, 1], [1, 3], [4, 0]]

# Confirm that the 100s above distinguish Bernoulli from multinomial
y = [0, 0, 1]
cls_b = BernoulliNB().fit(X_bernoulli, y)
cls_m = MultinomialNB().fit(X_bernoulli, y)
assert_not_equal(cls_b.predict(X_bernoulli)[-1],
cls_m.predict(X_bernoulli)[-1])

# test binary case (1-d output)
y = [0, 0, 2] # 2 is regression test for binary case, 02e673
for cls, X in zip([BernoulliNB, MultinomialNB],
Expand Down Expand Up @@ -348,3 +378,10 @@ def test_check_accuracy_on_digits():

scores = cross_val_score(GaussianNB(), X_3v8, y_3v8, cv=10)
assert_greater(scores.mean(), 0.86)

# Poisson NB
scores = cross_val_score(PoissonNB(), X, y, cv=10)
assert_greater(scores.mean(), 0.85)

scores = cross_val_score(PoissonNB(), X_3v8, y_3v8, cv=10)
assert_greater(scores.mean(), 0.95)
7 changes: 5 additions & 2 deletions sklearn/utils/estimator_checks.py
Expand Up @@ -449,6 +449,8 @@ def check_classifiers_train(name, Classifier):
classifier = Classifier()
if name in ['BernoulliNB', 'MultinomialNB']:
X -= X.min()
if name in ['PoissonNB']:
X = np.floor((10 * X) ** 2) # Forces positive integers
set_fast_parameters(classifier)
# raises error on malformed input for fit
assert_raises(ValueError, classifier.fit, X, y[:-1])
Expand All @@ -461,7 +463,7 @@ def check_classifiers_train(name, Classifier):
y_pred = classifier.predict(X)
assert_equal(y_pred.shape, (n_samples,))
# training set performance
if name not in ['BernoulliNB', 'MultinomialNB']:
if name not in ['BernoulliNB', 'MultinomialNB', 'PoissonNB']:
assert_greater(accuracy_score(y, y_pred), 0.85)

# raises error on malformed input for predict
Expand Down Expand Up @@ -506,7 +508,8 @@ def check_classifiers_input_shapes(name, Classifier):
iris = load_iris()
X, y = iris.data, iris.target
X, y = shuffle(X, y, random_state=1)
X = StandardScaler().fit_transform(X)
if name is not 'PoissonNB':
X = StandardScaler().fit_transform(X)
# catch deprecation warnings
with warnings.catch_warnings(record=True):
classifier = Classifier()
Expand Down