Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: Implementing Persistent Speaker Embeddings Across Conversations #227

Open
DmitriyG228 opened this issue Jan 3, 2024 · 3 comments
Labels
feature New feature or request

Comments

@DmitriyG228
Copy link

Feature Description

I propose the addition of a feature to the DIART project that allows for the persistence and reuse of speaker embeddings across multiple conversations. I am willing to contribute into this feature.

Expected Benefit

It would be particularly useful in scenarios where the identification of speakers is necessary over time accross multiple conversations

Implementation Feasibility

Given the complexity of the speaker embeddings obtained during a conversation, I seek guidance on the technical feasibility of this feature. Specifically, I'm interested in understanding whether the current architecture and design of DIART can support the persistence of speaker embeddings across conversations.

Suggested Integration Points

Could you provide insights on which parts of the DIART codebase would be most relevant for integrating this mechanism? Any pointers or suggestions on how to approach this enhancement would be greatly appreciated.

Additional Context

I have reviewed the paper implemented by DIART and believe that, although challenging, this feature could be a feasible and valuable addition.

I am eager to contribute to this aspect of the project and align it with DIART's overall goals and design.

Thank you for considering this feature request and for any guidance you can provide.

@DmitriyG228
Copy link
Author

DmitriyG228 commented Jan 3, 2024

My own solution is the following:

patch OnlineSpeakerClustering with:


    def get_speaker_id_to_centroid_mapping(self) -> Dict[int, np.ndarray]:
        """Returns the mapping of speaker IDs to their centroids."""
        if self.centers is None:
            return {}

        speaker_id_to_centroid = {}
        for g_spk in self.active_centers:
            speaker_id_to_centroid[g_spk] = self.centers[g_spk].tolist()
        return speaker_id_to_centroid

SpeakerDiarization:

def __call__()
    ### existing code
    for wav, seg, emb in zip(waveforms, segmentations, embeddings):
        ### existing code
        speaker_id_to_centroid_mapping = self.clustering.get_speaker_id_to_centroid_mapping()
        outputs.append((agg_prediction, agg_waveform,speaker_id_to_centroid_mapping)) #####    
    

class RedisWriter(Observer):
    def __init__(self, uri: Text, redis_client, patch_collar: float = 0.05):
        super().__init__()
        self.uri = uri
        self.redis_client = redis_client
        self.conversation_id = uri  # Assuming URI as a unique identifier for the conversation
        self.patch_collar = patch_collar

    def on_next(self, value: Union[Tuple, Annotation]):
        if isinstance(value, tuple):
            prediction, _, centroids = value    
            # Process each segment in the prediction
            for segment, _, label in prediction.itertracks(yield_label=True):
                # Update last centroids for each speaker

                # Write data to Redis queues
                diarization_data = {
                    'start': segment.start,
                    'end': segment.end,
                    'speaker_id': label,
                    'centroids': centroids
                }
                self.redis_client.rpush(f'diarization_{self.conversation_id}', json.dumps(diarization_data))

        else:
            prediction = value

    def on_error(self, error: Exception):
        # Handle error (optional)
        pass

    def on_completed(self):
        # Handle completion (optional)
        pass

I run this the following way:

from diart import SpeakerDiarization
from diart.sources import FileAudioSource  # Import the class for file audio source
from diart.inference import StreamingInference
from diart.sinks import RTTMWriter,FileWriter, RedisWriter

# Initialize the speaker diarization pipeline
pipeline = SpeakerDiarization()

sample_rate = 16000 
file_source = FileAudioSource(audio_file_path,sample_rate)  # Use FileAudioSource

# Create a StreamingInference instance with the file source
inference = StreamingInference(pipeline, file_source, do_plot=False)
inference.attach_observers(RedisWriter(file_source.uri, redis_client)) # instead of RTTMWriter

# Run the inference
prediction = inference()

the above is writing output to redis queue with global speaker embeddings

note that it's probably suboptimal to save centroids with every iteration, as they quickly converge to equal

I would appreciate your feedback!

@juanmc2005 juanmc2005 added the feature New feature or request label Jan 4, 2024
@juanmc2005
Copy link
Owner

Hey @DmitriyG228! Thanks for this feature request, your implementation looks very ncie! I would only change some minor things. For example, I would prefer not to have a speaker id mapping mechanism in the clustering block. The speaker ids are already numbered according to their centroid if I'm not mistaken (e.g. speaker_0 == centroid 0). However, if we decide to include a mapping structure (I'm willing to be persuaded on this cause I see some advantages), I'd prefer to put it in SpeakerDiarization as part of the pipeline state.

Apart from that, I also really like the idea of a RedisWriter! Could you open a PR with your code so we can discuss the details there?

I would prefer not to add unnecessary dependencies to diart. I would implement the RedisWriter to throw an error if redis isn't installed. For an example of this you can check the imports in models.py.

Thank you!

@DmitriyG228
Copy link
Author

Hey @juanmc2005, thanks for your feedback, please find the PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants