diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 495bcd223f6d2..c17ffc8255ffe 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -918,6 +918,7 @@ Pairwise metrics naive_bayes.GaussianNB naive_bayes.MultinomialNB naive_bayes.BernoulliNB + naive_bayes.PoissonNB .. _neighbors_ref: diff --git a/doc/modules/naive_bayes.rst b/doc/modules/naive_bayes.rst index 0e60479eddbff..afa2a8b1a7f7a 100644 --- a/doc/modules/naive_bayes.rst +++ b/doc/modules/naive_bayes.rst @@ -83,7 +83,8 @@ classification. The likelihood of the features is assumed to be Gaussian: .. math:: - P(x_i \mid y) &= \frac{1}{\sqrt{2\pi\sigma^2_y}} \exp\left(-\frac{(x_i - \mu_y)^2}{2\sigma^2_y}\right) + P(x_i \mid y) &= \frac{1}{\sqrt{2\pi\sigma^2_y}} + \exp\left(-\frac{(x_i - \mu_y)^2}{2\sigma^2_y}\right) The parameters :math:`\sigma_y` and :math:`\mu_y` are estimated using maximum likelihood. @@ -97,6 +98,7 @@ are estimated using maximum likelihood. ... % (iris.data.shape[0],(iris.target != y_pred).sum())) Number of mislabeled points out of a total 150 points : 6 + .. _multinomial_naive_bayes: Multinomial Naive Bayes @@ -178,6 +180,49 @@ It is advisable to evaluate both models, if time permits. 3rd Conf. on Email and Anti-Spam (CEAS). +.. _poisson_naive_bayes: + +Poisson Naive Bayes +------------------- + +:class:`PoissonNB` implements the naive Bayes algorithm for Poisson distributed +features. Poisson random variables typically arise when counting events of a +rate process with underlying rate :math:`\lambda` during a finite interval. + +The likelihood is given by: + +.. math:: + + P(x_i \mid \lambda) = \frac{\lambda^{x_i} e^{-\lambda}}{x_i!} + +and can be thought of as the limit of a Bernoulli process with per-trial +probability :math:`p = \lambda / n` as the number of trials `n` goes to +infinity. + +.. math:: + + P(x_i \mid \lambda) = \lim_{n \to \infty} \binom{n}{x_i} + \left(\frac{\lambda}{n}\right)^{x_i} + \left(1-\frac{\lambda}{n}\right)^{n-x_i} + +.. topic:: References: + + * T. D. Sanger (1994). + `Probability Density Estimation for the Interpretation of Neural Population Codes. + `_ + J Neurophys. 76(4):2790-3 + + * S. Kim, H. Seo and H. Rim. (2003) + `Poisson naive Bayes for text classification with feature weighting. + `_ + 6th Workshop on Information retrieval with Asian languages 11:33-40 + + * W. J. Ma, et Al. (2006). + `Bayesian inference with probabilistic population codes + `_ + Nat. Neurosci. 9:1432-1438 + + Out-of-core naive Bayes model fitting ------------------------------------- diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index fdd09cbad5a62..1423b770a7708 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -29,7 +29,7 @@ from .utils.multiclass import _check_partial_fit_first_call from .externals import six -__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB'] +__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB', 'PoissonNB'] class BaseNB(six.with_metaclass(ABCMeta, BaseEstimator, ClassifierMixin)): @@ -320,18 +320,113 @@ def _partial_fit(self, X, y, classes=None, _refit=False): def _joint_log_likelihood(self, X): X = check_array(X) - joint_log_likelihood = [] - for i in range(np.size(self.classes_)): + + joint_log_likelihood = np.zeros((np.shape(X)[0], + np.size(self.classes_))) + for i in range(len(self.classes_)): jointi = np.log(self.class_prior_[i]) n_ij = - 0.5 * np.sum(np.log(2. * np.pi * self.sigma_[i, :])) n_ij -= 0.5 * np.sum(((X - self.theta_[i, :]) ** 2) / (self.sigma_[i, :]), 1) - joint_log_likelihood.append(jointi + n_ij) + joint_log_likelihood[:, i] = jointi + n_ij - joint_log_likelihood = np.array(joint_log_likelihood).T return joint_log_likelihood +class PoissonNB(BaseNB): + """ + Poisson Naive Bayes (PoissonNB) + + Attributes + ---------- + class_prior_ : array, shape (n_classes,) + probability of each class. + + class_count_ : array, shape (n_classes,) + number of training samples observed in each class. + + lambda_ : array, shape (n_classes, n_features) + mean of each feature per class + + Examples + -------- + >>> import numpy as np + >>> X = np.array([[5, 2, 6, 1, 8, 1], [0, 0, 1, 3, 3, 1]]).T + >>> y = np.array([1, 1, 1, 2, 2, 2]) + >>> from sklearn.naive_bayes import PoissonNB + >>> clf = PoissonNB() + >>> clf.fit(X, y) + PoissonNB() + >>> print(clf.predict(X)) + [1 1 1 2 2 2] + + References + ---------- + T. D. Sanger. (1994) Probability Density Estimation for the Interpretation + of Neural Population Codes. J Neurophys. 76(4):2790-3 (Eq. 8) + http://jn.physiology.org/content/jn/76/4/2790.full.pdf + """ + + def fit(self, X, y): + """Fit Poisson Naive Bayes according to X, y + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training vectors, where n_samples is the number of samples + and n_features is the number of features. X expects non-negative + integers, although in practice non-integer values may also work. + Negative counts will result in a ValueError. + + y : array-like, shape (n_samples,) + Target values. + + Returns + ------- + self : object + Returns self. + """ + X, y = check_X_y(X, y) + PoissonNB._check_non_negative(X) + + n_samples, n_features = X.shape + + self.classes_ = unique_y = np.unique(y) + n_classes = unique_y.shape[0] + + epsilon = 1e-9 + + self.lambda_ = np.zeros((n_classes, n_features)) + self.class_prior_ = np.zeros(n_classes) + + for i, y_i in enumerate(unique_y): + Xi = X[y == y_i, :] + self.lambda_[i, :] = np.mean(Xi, axis=0) + epsilon + self.class_prior_[i] = float(Xi.shape[0]) / n_samples + + return self + + def _joint_log_likelihood(self, X): + X = check_array(X) + PoissonNB._check_non_negative(X) + + joint_log_likelihood = np.zeros((np.shape(X)[0], + np.size(self.classes_))) + + for i in range(len(self.classes_)): + jointi = np.log(self.class_prior_[i]) + n_ij = np.sum(X * np.log(self.lambda_[i, :]), axis=1) + n_ij -= np.sum(self.lambda_[i, :]) + joint_log_likelihood[:, i] = jointi + n_ij + + return joint_log_likelihood + + @staticmethod + def _check_non_negative(X): + if np.any(X < 0.): + raise ValueError("Input X must be non-negative") + + class BaseDiscreteNB(BaseNB): """Abstract base class for naive Bayes on discrete/categorical data diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 7244b85778a81..ccb781bd8d13a 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -148,9 +148,11 @@ def test_classifiers(): yield check_classifiers_pickle, name, Classifier # basic consistency testing yield check_classifiers_train, name, Classifier - if (name not in ["MultinomialNB", "LabelPropagation", "LabelSpreading"] + if (name not in ["MultinomialNB", "PoissonNB", + "LabelPropagation", "LabelSpreading"] # TODO some complication with -1 label - and name not in ["DecisionTreeClassifier", "ExtraTreeClassifier"]): + and name not in ["DecisionTreeClassifier", + "ExtraTreeClassifier"]): # We don't raise a warning in these classifiers, as # the column y interface is used by the forests. @@ -316,8 +318,8 @@ def test_root_import_all_completeness(): def test_sparsify_estimators(): - #Test if predict with sparsified estimators works. - #Tests regression, binary classification, and multi-class classification. + # Test if predict with sparsified estimators works. + # Tests regression, binary classification, and multi-class classification. estimators = all_estimators() # test regression and binary classification @@ -361,8 +363,7 @@ def test_non_transformer_estimators_n_iter(): # Tested in test_transformer_n_iter below elif name in CROSS_DECOMPOSITION or ( - name in ['LinearSVC', 'LogisticRegression'] - ): + name in ['LinearSVC', 'LogisticRegression']): continue else: diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index 7cb0715a574c1..fc66bfb84761a 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -11,10 +11,12 @@ from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_not_equal from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_greater -from sklearn.naive_bayes import GaussianNB, BernoulliNB, MultinomialNB +from sklearn.naive_bayes import GaussianNB, BernoulliNB, MultinomialNB, \ + PoissonNB # Data is just 6 separable points in the plane X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]) @@ -66,6 +68,30 @@ def test_discrete_prior(): clf.class_log_prior_, 8) +def test_poissonnb(): + clf = PoissonNB() + assert_raises(ValueError, clf.fit, -X2, y2) + + y_pred = clf.fit(X2, y2).predict(X2) + assert_array_equal(y_pred, y2) + + +def test_poissonnb_prior(): + """Test whether class priors are properly set. """ + Xp = X + np.array([2, 2]) + clf = PoissonNB().fit(Xp, y) + assert_array_almost_equal(np.array([3, 3]) / 6.0, + clf.class_prior_, 8) + clf.fit(X2, y2) + assert_array_almost_equal(clf.class_prior_.sum(), 1) + + # Verify that np.log(clf.predict_proba(X)) gives the same results as + # clf.predict_log_proba(X) + y_pred_proba = clf.predict_proba(X2) + y_pred_log_proba = clf.predict_log_proba(X2) + assert_array_almost_equal(np.log(y_pred_proba), y_pred_log_proba, 8) + + def test_mnnb(): """Test Multinomial Naive Bayes classification. @@ -153,7 +179,7 @@ def test_gnb_partial_fit(): def test_discretenb_pickle(): """Test picklability of discrete naive Bayes classifiers""" - for cls in [BernoulliNB, MultinomialNB, GaussianNB]: + for cls in [BernoulliNB, MultinomialNB, GaussianNB, PoissonNB]: clf = cls().fit(X2, y2) y_pred = clf.predict(X2) @@ -163,9 +189,7 @@ def test_discretenb_pickle(): assert_array_equal(y_pred, clf.predict(X2)) - if cls is not GaussianNB: - # TODO re-enable me when partial_fit is implemented for GaussianNB - + if cls is not PoissonNB: # Test pickling of estimator trained with partial_fit clf2 = cls().partial_fit(X2[:3], y2[:3], classes=np.unique(y2)) clf2.partial_fit(X2[3:], y2[3:]) @@ -177,7 +201,7 @@ def test_discretenb_pickle(): def test_input_check_fit(): """Test input checks for the fit method""" - for cls in [BernoulliNB, MultinomialNB, GaussianNB]: + for cls in [BernoulliNB, MultinomialNB, GaussianNB, PoissonNB]: # check shape consistency for number of samples at fit time assert_raises(ValueError, cls().fit, X2, y2[:-1]) @@ -212,10 +236,16 @@ def test_discretenb_predict_proba(): """Test discrete NB classes' probability scores""" # The 100s below distinguish Bernoulli from multinomial. - # FIXME: write a test to show this. X_bernoulli = [[1, 100, 0], [0, 1, 0], [0, 100, 1]] X_multinomial = [[0, 1], [1, 3], [4, 0]] + # Confirm that the 100s above distinguish Bernoulli from multinomial + y = [0, 0, 1] + cls_b = BernoulliNB().fit(X_bernoulli, y) + cls_m = MultinomialNB().fit(X_bernoulli, y) + assert_not_equal(cls_b.predict(X_bernoulli)[-1], + cls_m.predict(X_bernoulli)[-1]) + # test binary case (1-d output) y = [0, 0, 2] # 2 is regression test for binary case, 02e673 for cls, X in zip([BernoulliNB, MultinomialNB], @@ -348,3 +378,10 @@ def test_check_accuracy_on_digits(): scores = cross_val_score(GaussianNB(), X_3v8, y_3v8, cv=10) assert_greater(scores.mean(), 0.86) + + # Poisson NB + scores = cross_val_score(PoissonNB(), X, y, cv=10) + assert_greater(scores.mean(), 0.85) + + scores = cross_val_score(PoissonNB(), X_3v8, y_3v8, cv=10) + assert_greater(scores.mean(), 0.95) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 701a0a5efd27a..e65d5903e1f15 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -449,6 +449,8 @@ def check_classifiers_train(name, Classifier): classifier = Classifier() if name in ['BernoulliNB', 'MultinomialNB']: X -= X.min() + if name in ['PoissonNB']: + X = np.floor((10 * X) ** 2) # Forces positive integers set_fast_parameters(classifier) # raises error on malformed input for fit assert_raises(ValueError, classifier.fit, X, y[:-1]) @@ -461,7 +463,7 @@ def check_classifiers_train(name, Classifier): y_pred = classifier.predict(X) assert_equal(y_pred.shape, (n_samples,)) # training set performance - if name not in ['BernoulliNB', 'MultinomialNB']: + if name not in ['BernoulliNB', 'MultinomialNB', 'PoissonNB']: assert_greater(accuracy_score(y, y_pred), 0.85) # raises error on malformed input for predict @@ -506,7 +508,8 @@ def check_classifiers_input_shapes(name, Classifier): iris = load_iris() X, y = iris.data, iris.target X, y = shuffle(X, y, random_state=1) - X = StandardScaler().fit_transform(X) + if name is not 'PoissonNB': + X = StandardScaler().fit_transform(X) # catch deprecation warnings with warnings.catch_warnings(record=True): classifier = Classifier()