Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callbacks for LDAMultiCore #3481

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from

Conversation

maciejskorski
Copy link

@maciejskorski maciejskorski commented Jun 20, 2023

This PR upgrades the multi-core implementation of LDA to use callbacks 馃挭.

Callbacks are critical for model evaluation in general, and have been requested in past for Gensim's model in particular 馃檹.

A usage example on News20 dataset:

from gensim.models import LdaMulticore
from gensim.models.callbacks import CoherenceMetric, PerplexityMetric
from gensim.models import LdaMulticore, LdaModel

callback1 = CoherenceMetric(corpus=mm_corpus, dictionary=dictionary, coherence='u_mass', title='u_mass')
callback2 = CoherenceMetric(corpus=mm_corpus, texts=docs_tokenized, dictionary=dictionary, coherence='c_v', title='c_v',)
lda = LdaMulticore(mm_corpus, id2word=dictionary, num_topics=20, passes=20, batch=False, callbacks=[callback1,callback2])

# evaluation

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

metrics = pd.DataFrame(lda.metrics)
metrics.reset_index(names=['epoch'], inplace=True)
metrics['epoch'] = metrics['epoch']+1

fig,ax1 = plt.subplots()
ln1=ax1.plot(metrics['epoch'],metrics['u_mass'],label='$U_{mass}$',color='tab:red')
ax1.set_xlabel('epoch')
ax1.set_ylabel('$U_{mass}$')
ax2 = ax1.twinx()
ln2 = ax2.plot(metrics['epoch'],metrics['c_v'],label='$C_v$',color='tab:blue')
ax2.set_ylabel('$C_v$')
lines = ln1+ln2
labels = [l.get_label() for l in lines]
ax2.legend(lines, labels, loc=0)
plt.show()

This illustrates the point of using callbacks: we know how many epochs are sufficient to converge 馃啋
image

Also, the doc string has been made more accurate:

        callbacks : list of :class:`~gensim.models.callbacks.Callback`
            Metric callbacks to log evaluation metrics of the model at every training epoch.

For a full example see this Kaggle notebook.

DISCLAIMER: this is a byproduct of the implementation for the purpose of a research paper.

@mpenkov
Copy link
Collaborator

mpenkov commented Apr 8, 2024

@maciejskorski Looks like some tests in your PR are failing. Are you able to fix them?

@mpenkov mpenkov added this to the Spring 2024 release milestone Apr 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants