Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MultiWorkerMirroredStrategy distributed training strategy for dynamic embeddings #365

Open
sivukhin opened this issue Oct 31, 2023 · 10 comments
Assignees
Labels
enhancement New feature or request question Further information is requested

Comments

@sivukhin
Copy link

sivukhin commented Oct 31, 2023

I tried to explore available approaches for distributed training of large-scale recommendation models with huge embedding tables and tried to use TFRA DynamicEmbedding combined with MultiWorkerMirroredStrategy.

  • Target task is to train simple two-tower model over online stream of events on multiple CPU workers (model is pretty simple - so no need to train on GPU).
  • On the first sight, MultiWorkerMirroredStrategy can suite my needs because model will have very small volume of parameters apart from the embeddings - so we can replicate them across all workers

It seems like current implementation struggle with MultiWorkerMirroredStrategy. My attempts to make it works failed with following error:

    ValueError: `colocate_vars_with` must only be passed a variable created in this tf.distribute.Strategy.scope(), not: <tf.Variable 'DynamicEmbedding/user-embedding-shadow:0' shape=(0, 64) dtype=float32, numpy=array([], shape=(0, 64), dtype=float32)>

I tried to launch following training code on 2 workers with following commands:

TF_CONFIG='{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }' python3 main.py &
TF_CONFIG='{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 1} }' python3 main.py &
Source code
import dataclasses
from typing import Dict

import tensorflow as tf
import tensorflow_datasets as tfds
# tensorflow_recommenders_addons does some patching on TensorFlow, so it MUST be imported after importing TF
import tensorflow_recommenders as tfrs
import tensorflow_recommenders_addons as tfra
from tensorflow_recommenders_addons import dynamic_embedding as de

redis_config = tfra.dynamic_embedding.RedisTableConfig(redis_config_abs_dir="redis.config")
redis_creator = tfra.dynamic_embedding.RedisTableCreator(redis_config)
batch_size = 4096
seed = 2023


@dataclasses.dataclass(frozen=True)
class TrainingDatasets:
    train_ds: tf.data.Dataset
    validation_ds: tf.data.Dataset


@dataclasses.dataclass(frozen=True)
class RetrievalDatasets:
    training_datasets: TrainingDatasets
    candidate_dataset: tf.data.Dataset


def create_datasets():
    def split_train_validation_datasets(ratings_dataset: tf.data.Dataset) -> TrainingDatasets:
        train_size = int(len(ratings_dataset) * 0.9)
        validation_size = len(ratings_dataset) - train_size
        print(f"Train size: {train_size}")
        print(f"Validation size: {validation_size}")

        shuffled_dataset = ratings_dataset.shuffle(buffer_size=5 * batch_size, seed=seed)
        train_ds = shuffled_dataset.skip(validation_size).shuffle(buffer_size=10 * batch_size).apply(lambda dataset: dataset.padded_batch(batch_size))
        validation_ds = shuffled_dataset.take(validation_size).apply(lambda dataset: dataset.padded_batch(batch_size))

        return TrainingDatasets(train_ds=train_ds, validation_ds=validation_ds)

    ratings_dataset = tfds.load("movielens/1m-ratings", split="train")
    movies_dataset = tfds.load("movielens/1m-movies", split="train").map(lambda x: x["movie_title"])

    for item in ratings_dataset.take(3):
        print(item)

    for item in movies_dataset.take(3):
        print(item)

    training_datasets = split_train_validation_datasets(ratings_dataset)
    return RetrievalDatasets(training_datasets=training_datasets, candidate_dataset=movies_dataset.padded_batch(batch_size))

def train_multi_worker():
    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    datasets = create_datasets()
    train_ds = strategy.experimental_distribute_dataset(datasets.training_datasets.train_ds)

    with strategy.scope() as scope:
        class TwoTowerModel(tfrs.Model):
            def __init__(self, user_model: tf.keras.Model, item_model: tf.keras.Model, task: tfrs.tasks.Retrieval):
                super().__init__()
                self.user_model = user_model
                self.item_model = item_model
                self.task = task

            def compute_loss(self, features: Dict[str, tf.Tensor], training=False) -> tf.Tensor:
                user_embeddings = self.user_model(features["user_id"])
                movie_embeddings = self.item_model(features["movie_title"])
                return self.task(user_embeddings, movie_embeddings)

        def create_de_two_tower_model(candidate_dataset: tf.data.Dataset) -> tf.keras.Model:
            user_model = tf.keras.Sequential([
                de.keras.layers.Embedding(
                    embedding_size=64,
                    key_dtype=tf.string,
                    initializer=tf.random_uniform_initializer(),
                    init_capacity=100_000,
                    restrict_policy=de.FrequencyRestrictPolicy,
                    name="user-embedding",
                    kv_creator=redis_creator,
                    distribute_strategy=strategy
                ),
                tf.keras.layers.Dense(64, activation="gelu"),
                tf.keras.layers.Dense(32),
                tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))
            ], name='user_model')

            item_model = tf.keras.models.Sequential([
                de.keras.layers.Embedding(
                    embedding_size=64,
                    key_dtype=tf.string,
                    initializer=tf.random_uniform_initializer(),
                    init_capacity=100_000,
                    restrict_policy=de.FrequencyRestrictPolicy,
                    name="movie-embedding",
                    kv_creator=redis_creator,
                    distribute_strategy=strategy
                ),
                tf.keras.layers.Dense(64, activation="gelu"),
                tf.keras.layers.Dense(32),
                tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))
            ], name='movie_model')

            current_model = TwoTowerModel(user_model, item_model, task=tfrs.tasks.Retrieval(
                metrics=tfrs.metrics.FactorizedTopK(candidate_dataset.map(item_model))
            ))
            current_optimizer = de.DynamicEmbeddingOptimizer(tf.keras.optimizers.Adam())
            return current_model, current_optimizer

        model, optimizer = create_de_two_tower_model(datasets.candidate_dataset)
        model.compile()
    history = model.fit(train_ds, epochs=1, steps_per_epoch=10)
    print(history)


if __name__ == '__main__':
    train_multi_worker()
Redis configuration
{
  "redis_connection_mode": 2,
  "redis_master_name": "master",
  "redis_host_ip": [
    "127.0.0.1"
  ],
  "redis_host_port": [
    6379
  ],
  "redis_user": "default",
  "redis_password": "",
  "redis_db": 0,
  "redis_read_access_slave": false,
  "redis_connect_keep_alive": false,
  "redis_connect_timeout": 1000,
  "redis_socket_timeout": 1000,
  "redis_conn_pool_size": 20,
  "redis_wait_timeout": 100000000,
  "redis_connection_lifetime": 100,
  "redis_sentinel_user": "default",
  "redis_sentinel_password": "",
  "redis_sentinel_connect_timeout": 1000,
  "redis_sentinel_socket_timeout": 1000,
  "storage_slice_import": 2,
  "storage_slice": 2,
  "using_hash_storage_slice": false,
  "keys_sending_size": 1024,
  "using_md5_prefix_name": false,
  "redis_hash_tags_hypodispersion": true,
  "model_tag_import": "test",
  "redis_hash_tags_import": [
    "{1}",
    "{2}"
  ],
  "model_tag_runtime": "movielens.v6",
  "redis_hash_tags_runtime": [
    "{1}",
    "{2}"
  ],
  "expire_model_tag_in_seconds": 604800,
  "table_store_mode": 2,
  "model_lib_abs_dir": "/tmp/"
}

Relevant information

  • Are you willing to contribute it not sure (it's must be hard to add it if this is not well supported yet)
  • Are you willing to maintain it going forward? no
  • Is there a relevant academic paper? (if so, where): no
  • Is there already an implementation in another framework? (if so, where): no
  • Was it part of tf.contrib? (if so, where): no

Which API type would this fall under (layer, metric, optimizer, etc.)

  • model.fit

Who will benefit with this feature?

  • This will allow to launch distributed training over multiple workers with large external dynamic embedding table
@rhdong
Copy link
Member

rhdong commented Oct 31, 2023

Hi @sivukhin, thank you for the feedback! We will give a resolution after the discussion. Thank you!

@MoFHeka
Copy link
Contributor

MoFHeka commented Nov 1, 2023

Hi @sivukhin, because of resource lock of TF, the MirroredStrategy for TFRA multi-table is not efficient. We recommend using Horovod for distributed training.
https://github.com/tensorflow/recommenders-addons/blob/master/docs/api_docs/tfra/dynamic_embedding/keras/layers/HvdAllToAllEmbedding.md


https://github.com/tensorflow/recommenders-addons/blob/master/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py

Or you could have helped us improve the code so that each MirroredStrategy worker created its own DEVariale object to hold its own table, and interacted with communication operators just like HvdAllToAllEmbedding as well.

@sivukhin
Copy link
Author

sivukhin commented Nov 1, 2023

@MoFHeka, thanks for quick reply!
I will try to use Horovod, sure (I'm not familiar with it but looks like this is mature library which everyone used :-) )

For now I only encountered the problem that latest version of TFRA (0.6.0 on PyPI) doesn't have HvdAllToAllEmbedding. Do you plan to release fresh version of library (it seems like horovod support was added only recently)

@rhdong
Copy link
Member

rhdong commented Nov 1, 2023

@MoFHeka, thanks for quick reply! I will try to use Horovod, sure (I'm not familiar with it but looks like this is mature library which everyone used :-) )

For now I only encountered the problem that latest version of TFRA (0.6.0 on PyPI) doesn't have HvdAllToAllEmbedding. Do you plan to release fresh version of library (it seems like horovod support was added only recently)

Yes, it is not a released feature, but you can try to install it from the source by following the guidance: https://github.com/tensorflow/recommenders-addons#installing-from-source. It is easy to do. If there is any problem, you can be helped here.
BTW the next release will also be published soon.

@sivukhin
Copy link
Author

sivukhin commented Nov 1, 2023

Yes, thanks!
I managed to install fresh TFRA version from sources. I created simple Dockerfile to automate these actions (maybe it will be helpful for someone): https://gist.github.com/sivukhin/da17615df0628a58e4680f7ab48ad8a2

Did HvdAllToAllEmbedding supports training on CPU with Redis kv_creator?
I tried to replace Embedding with HvdAllToAllEmbedding & run training in horovod (horovodrun -np 2 python train.py) in my simple example, but got following error:

[1,1]<stderr>: Expected shape [1682,64] for value, got [1001,64]
[1,1]<stderr>: 	 [[{{node Adam/Adam/update_5/None_lookup_table_insert_2/TFRA>RedisTableInsert}}]] [Op:__inference_train_function_2884]
Error stack trace
[1,0]<stderr>: 2023-11-01 16:22:45.510991: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at redis_table_op.cc:1502 : INVALID_ARGUMENT: Expected shape [1682,64] for value, got [998,64]
[1,0]<stderr>: 2023-11-01 16:22:45.511084: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at redis_table_op.cc:1502 : INVALID_ARGUMENT: Expected shape [1682,64] for value, got [998,64]
[1,0]<stderr>: 2023-11-01 16:22:45.511114: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at redis_table_op.cc:1502 : INVALID_ARGUMENT: Expected shape [1682,64] for value, got [998,64]
  1/100 [..............................] - ETA: 9:31 - factorized_top_k/top_1_categorical_accuracy: 7.3242e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0046 - factorized_top_k/top_10_categorical_accuracy: 0.0059 - factorized_top_k/top_50_categorical_accuracy: 0.0320 - factorized_top_k/top_100_categorical_accuracy: 0.0649 - loss: 34113.6  1/100 [..............................] - ETA: 9:32 - factorized_top_k/top_1_categorical_accuracy: 2.4414e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0029 - factorized_top_k/top_10_categorical_accuracy: 0.0054 - factorized_top_k/top_50_categorical_accuracy: 0.0337 - factorized_top_k/top_100_categorical_accuracy: 0.0637 - loss: 34138.2383 - regularization_loss: 0.0000e+00 - total_loss: 34138.23832023-11-01 16:22:45.511462: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at redis_table_op.cc:1502 : INVALID_ARGUMENT: Expected shape [1682,64] for value, got [1001,64]
[1,1]<stderr>: Traceback (most recent call last):
[1,1]<stderr>:   File "/work/train.py", line 135, in <module>
[1,0]<stderr>: Traceback (most recent call last):
[1,0]<stderr>:   File "/work/train.py", line 135, in <module>
[1,1]<stderr>:     train_multi_worker()
[1,1]<stderr>:   File "/work/train.py", line 130, in train_multi_worker
[1,1]<stderr>:     history = model.fit(datasets.training_datasets.train_ds, epochs=1, steps_per_epoch=100, callbacks=callbacks_list, verbose=1)
[1,1]<stderr>:   File "/usr/local/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
[1,1]<stderr>:     raise e.with_traceback(filtered_tb) from None
[1,1]<stderr>:   File "/usr/local/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
[1,0]<stderr>:     train_multi_worker()
[1,0]<stderr>:   File "/work/train.py", line 130, in train_multi_worker
[1,0]<stderr>:     history = model.fit(datasets.training_datasets.train_ds, epochs=1, steps_per_epoch=100, callbacks=callbacks_list, verbose=1)
[1,0]<stderr>:   File "/usr/local/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
[1,0]<stderr>:     raise e.with_traceback(filtered_tb) from None
[1,0]<stderr>:   File "/usr/local/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
[1,0]<stderr>:     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
[1,1]<stderr>:     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
[1,0]<stderr>: tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:
[1,0]<stderr>: 
[1,0]<stderr>: Expected shape [1682,64] for value, got [998,64]
[1,0]<stderr>: 	 [[{{node Adam/Adam/update_5/None_lookup_table_insert_3/TFRA>RedisTableInsert}}]] [Op:__inference_train_function_2884]
[1,1]<stderr>: tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:
[1,1]<stderr>: 
[1,1]<stderr>: Expected shape [1682,64] for value, got [1001,64]
[1,1]<stderr>: 	 [[{{node Adam/Adam/update_5/None_lookup_table_insert_2/TFRA>RedisTableInsert}}]] [Op:__inference_train_function_2884]
Training source code
import dataclasses
import os
from typing import Dict

import horovod.tensorflow as hvd
import tensorflow as tf
import tensorflow_datasets as tfds
# tensorflow_recommenders_addons does some patching on TensorFlow, so it MUST be imported after importing TF
import tensorflow_recommenders as tfrs
import tensorflow_recommenders_addons as tfra
from tensorflow_recommenders_addons import dynamic_embedding as de

hvd.init()

redis_config = tfra.dynamic_embedding.RedisTableConfig(redis_config_abs_dir="redis.config")
redis_creator = tfra.dynamic_embedding.RedisTableCreator(redis_config)
cuckoo_creator = de.CuckooHashTableCreator(saver=de.FileSystemSaver(proc_size=1, proc_rank=0))

batch_size = 4096
seed = 2023


@dataclasses.dataclass(frozen=True)
class TrainingDatasets:
    train_ds: tf.data.Dataset
    validation_ds: tf.data.Dataset


@dataclasses.dataclass(frozen=True)
class RetrievalDatasets:
    training_datasets: TrainingDatasets
    candidate_dataset: tf.data.Dataset


def create_datasets():
    def split_train_validation_datasets(ratings_dataset: tf.data.Dataset) -> TrainingDatasets:
        train_size = int(len(ratings_dataset) * 0.9)
        validation_size = len(ratings_dataset) - train_size
        print(f"Train size: {train_size}")
        print(f"Validation size: {validation_size}")

        shuffled_dataset = ratings_dataset.shuffle(buffer_size=5 * batch_size, seed=seed)
        train_ds = shuffled_dataset.skip(validation_size).shuffle(buffer_size=10 * batch_size).apply(lambda dataset: dataset.padded_batch(batch_size))
        validation_ds = shuffled_dataset.take(validation_size).apply(lambda dataset: dataset.padded_batch(batch_size))

        return TrainingDatasets(train_ds=train_ds, validation_ds=validation_ds)

    ratings_dataset = tfds.load("movielens/100k-ratings", split="train").map(lambda x: {
        'user_id': tf.strings.to_number(x["user_id"], tf.int64),
        'movie_id': tf.strings.to_number(x["movie_id"], tf.int64)
    })
    movies_dataset = tfds.load("movielens/100k-movies", split="train").map(lambda x: tf.strings.to_number(x["movie_id"], tf.int64))

    for item in ratings_dataset.take(3):
        print(item)

    for item in movies_dataset.take(3):
        print(item)

    training_datasets = split_train_validation_datasets(ratings_dataset)
    return RetrievalDatasets(training_datasets=training_datasets, candidate_dataset=movies_dataset.padded_batch(batch_size))


class TwoTowerModel(tfrs.Model):
    def __init__(self, user_model: tf.keras.Model, item_model: tf.keras.Model, task: tfrs.tasks.Retrieval):
        super().__init__()
        self.user_model = user_model
        self.item_model = item_model
        self.task = task

    def compute_loss(self, features: Dict[str, tf.Tensor], training=False) -> tf.Tensor:
        user_embeddings = self.user_model(features["user_id"])
        movie_embeddings = self.item_model(features["movie_id"])
        return self.task(user_embeddings, movie_embeddings)


def build_two_tower_model(candidate_dataset: tf.data.Dataset) -> tf.keras.Model:
    user_model = tf.keras.Sequential([
        de.keras.layers.HvdAllToAllEmbedding(
            embedding_size=64,
            key_dtype=tf.int64,
            value_dtype=tf.float32,
            initializer=tf.random_uniform_initializer(),
            init_capacity=100_000,
            restrict_policy=de.FrequencyRestrictPolicy,
            name="user-embedding",
            kv_creator=redis_creator,
        ),
        tf.keras.layers.Dense(64, activation="gelu"),
        tf.keras.layers.Dense(32),
        tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))
    ], name='user_model')

    item_model = tf.keras.models.Sequential([
        de.keras.layers.HvdAllToAllEmbedding(
            embedding_size=64,
            key_dtype=tf.int64,
            value_dtype=tf.float32,
            initializer=tf.random_uniform_initializer(),
            init_capacity=100_000,
            restrict_policy=de.FrequencyRestrictPolicy,
            name="movie-embedding",
            kv_creator=redis_creator,
        ),
        tf.keras.layers.Dense(64, activation="gelu"),
        tf.keras.layers.Dense(32),
        tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))
    ], name='movie_model')

    model = TwoTowerModel(user_model, item_model, task=tfrs.tasks.Retrieval(
        metrics=tfrs.metrics.FactorizedTopK(candidate_dataset.map(item_model))
    ))
    optimize = de.DynamicEmbeddingOptimizer(tf.keras.optimizers.Adam())
    model.compile(optimizer=optimize)
    return model


def train_multi_worker():
    datasets = create_datasets()

    model_dir = f'model_dir_{hvd.rank()}'
    print(f'model_dir: {model_dir}')
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=model_dir)
    broadcast_callback = de.keras.callbacks.DEHvdBroadcastGlobalVariablesCallback(root_rank=0)
    checkpoint_callback = de.keras.callbacks.DEHvdModelCheckpoint(
        filepath=model_dir + '/weights_epoch{epoch:03d}_loss{loss:.4f}',
        options=tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
    )
    callbacks_list = [tensorboard_callback, broadcast_callback, checkpoint_callback]
    model = build_two_tower_model(datasets.candidate_dataset)
    history = model.fit(datasets.training_datasets.train_ds, epochs=1, steps_per_epoch=100, callbacks=callbacks_list, verbose=1)
    print(history)


if __name__ == '__main__':
    train_multi_worker()

I also tried to launch demo from movielens-1m-keras-with-horovod but didn't succeed at first time (the demo is bit complicated - I don't need it's full functionality + I didn't setup GPU properly - so more likely the major problem on my side).

@MoFHeka
Copy link
Contributor

MoFHeka commented Nov 1, 2023

Of course HvdAllToAllEmbedding supports training on CPU.

I ran your code successfully with CUDA_VISIBLE_DEVICES=-1 horovodrun -np 2 python hvd_two_tower_test.py, which using both redis_creator and cuckoo_creator.

Also, if the error that your GPU doesn't work is from horovod's all2all operator, it may have been caused by a third package, which I suspect is tensorflow_recommenders. Because I also failed to run your code on the GPU.

One more thing, if you want to train on CPU, parameter server is your best choice. MirrorStrategy is much more efficient on a single multi-GPU machine.

Training source code
import dataclasses
import os
from typing import Dict

import horovod.tensorflow as hvd
import tensorflow as tf
import tensorflow_datasets as tfds
# tensorflow_recommenders_addons does some patching on TensorFlow, so it MUST be imported after importing TF
import tensorflow_recommenders as tfrs
import tensorflow_recommenders_addons as tfra
from tensorflow_recommenders_addons import dynamic_embedding as de

hvd.init()

redis_config = tfra.dynamic_embedding.RedisTableConfig(redis_config_abs_dir="redis.config")
redis_creator = tfra.dynamic_embedding.RedisTableCreator(redis_config)
cuckoo_creator = de.CuckooHashTableCreator(saver=de.FileSystemSaver(proc_size=hvd.size(), proc_rank=hvd.rank()))

batch_size = 4096
seed = 2023


@dataclasses.dataclass(frozen=True)
class TrainingDatasets:
    train_ds: tf.data.Dataset
    validation_ds: tf.data.Dataset


@dataclasses.dataclass(frozen=True)
class RetrievalDatasets:
    training_datasets: TrainingDatasets
    candidate_dataset: tf.data.Dataset


def create_datasets():
    def split_train_validation_datasets(ratings_dataset: tf.data.Dataset) -> TrainingDatasets:
        train_size = int(len(ratings_dataset) * 0.9)
        validation_size = len(ratings_dataset) - train_size
        print(f"Train size: {train_size}")
        print(f"Validation size: {validation_size}")

        shuffled_dataset = ratings_dataset.shuffle(buffer_size=5 * batch_size, seed=seed)
        train_ds = shuffled_dataset.skip(validation_size).shuffle(buffer_size=10 * batch_size).apply(lambda dataset: dataset.padded_batch(batch_size))
        validation_ds = shuffled_dataset.take(validation_size).apply(lambda dataset: dataset.padded_batch(batch_size))

        return TrainingDatasets(train_ds=train_ds, validation_ds=validation_ds)

    ratings_dataset =tfds.load("movielens/100k-ratings",
                      split="train",
                      data_dir=".",
                      download=False).map(lambda x: {
        'user_id': tf.strings.to_number(x["user_id"], tf.int64),
        'movie_id': tf.strings.to_number(x["movie_id"], tf.int64)
    })
    movies_dataset = tfds.load("movielens/100k-ratings",
                      split="train",
                      data_dir=".",
                      download=False).map(lambda x: tf.strings.to_number(x["movie_id"], tf.int64))

    for item in ratings_dataset.take(3):
        print(item)

    for item in movies_dataset.take(3):
        print(item)

    training_datasets = split_train_validation_datasets(ratings_dataset)
    return RetrievalDatasets(training_datasets=training_datasets, candidate_dataset=movies_dataset.padded_batch(batch_size))


class TwoTowerModel(tfrs.Model):
    def __init__(self, user_model: tf.keras.Model, item_model: tf.keras.Model, task: tfrs.tasks.Retrieval):
        super().__init__()
        self.user_model = user_model
        self.item_model = item_model
        self.task = task

    def compute_loss(self, features: Dict[str, tf.Tensor], training=False) -> tf.Tensor:
        user_embeddings = self.user_model(features["user_id"])
        movie_embeddings = self.item_model(features["movie_id"])
        return self.task(user_embeddings, movie_embeddings)


def build_two_tower_model(candidate_dataset: tf.data.Dataset) -> tf.keras.Model:
    user_model = tf.keras.Sequential([
        de.keras.layers.HvdAllToAllEmbedding(
            embedding_size=64,
            key_dtype=tf.int64,
            value_dtype=tf.float32,
            initializer=tf.random_uniform_initializer(),
            init_capacity=100_000,
            restrict_policy=de.FrequencyRestrictPolicy,
            name="user-embedding",
            devices=['CPU'],
            kv_creator=cuckoo_creator,
        ),
        tf.keras.layers.Dense(64, activation="gelu"),
        tf.keras.layers.Dense(32),
        tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))
    ], name='user_model')

    item_model = tf.keras.models.Sequential([
        de.keras.layers.HvdAllToAllEmbedding(
            embedding_size=64,
            key_dtype=tf.int64,
            value_dtype=tf.float32,
            initializer=tf.random_uniform_initializer(),
            init_capacity=100_000,
            restrict_policy=de.FrequencyRestrictPolicy,
            name="movie-embedding",
            devices=['CPU'],
            kv_creator=cuckoo_creator,
        ),
        tf.keras.layers.Dense(64, activation="gelu"),
        tf.keras.layers.Dense(32),
        tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))
    ], name='movie_model')

    model = TwoTowerModel(user_model, item_model, task=tfrs.tasks.Retrieval(
        metrics=tfrs.metrics.FactorizedTopK(candidate_dataset.map(item_model))
    ))
    optimize = de.DynamicEmbeddingOptimizer(tf.keras.optimizers.Adam())
    model.compile(optimizer=optimize)
    return model


def train_multi_worker():
    datasets = create_datasets()

    model_dir = f'model_dir_{hvd.rank()}'
    print(f'model_dir: {model_dir}')
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=model_dir)
    broadcast_callback = de.keras.callbacks.DEHvdBroadcastGlobalVariablesCallback(root_rank=0)
    checkpoint_callback = de.keras.callbacks.DEHvdModelCheckpoint(
        filepath=model_dir + '/weights_epoch{epoch:03d}_loss{loss:.4f}',
        options=tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
    )
    callbacks_list = [tensorboard_callback, broadcast_callback, checkpoint_callback]
    model = build_two_tower_model(datasets.candidate_dataset)
    history = model.fit(datasets.training_datasets.train_ds, epochs=1, steps_per_epoch=100, callbacks=callbacks_list, verbose=1)
    print(history)


if __name__ == '__main__':
    train_multi_worker()

@sivukhin
Copy link
Author

sivukhin commented Nov 1, 2023

Hm..ok. I still got this weird error about inconsistent shapes (with Redis & Cuckoo) - but maybe need to dig more into it...

UPD: I looked more closely on your sample code and found difference in HvdAllToAllEmbedding - I forgot to set devices key for it. With devices=['CPU'] horovod started to work!

One more thing, if you want to train on CPU, parameter server is your best choice. MirrorStrategy is much more efficient on a single multi-GPU machine.

Why parameter server is better for CPU? We have very simple model with very few weights (apart from embedding table). I thought that multi-worker strategy will be more efficient as it will require only rare communication between workers in order to accumulate updated gradients.

With parameter server I just not sure what will be stored on them... If all dense weights will be there - it seems like this can create huge communication overhead, no?

My initial thought were that I can just train dense weights independently on multiple workers (to provide high throughput) and use Redis as an external storage for embedding table. In my head this setup will imply following communication for single worker:

  1. Worker will accumulate gradients for embeddings and periodically apply them to the Redis (with bp_v2 feature enabled)
  2. After couple batches of data will be processed on every worker - they will communicate between each other in order to update dense parameters

If there is a way to control frequency of sync between worker and Redis & frequency of inter-worker communication - I thought that this scheme can work for pretty high load scenarios (with low frequency of syncs we will trade convergence rate for throughput - which looks fine for me at the moment)...

@MoFHeka
Copy link
Contributor

MoFHeka commented Nov 1, 2023

Ring-AllReduce vs Parameter Server

The lower communication time overhead of multi-worker strategy is based on synchronous training. If many CPU nodes are trained asynchronously with a small batch size, parameter server can complete the training of all samples faster under a specific cluster size.

Semi-Synchronous Training = Ring-AllReduce + Parameter Server

Another method is semi-synchronous training, the parameters of the dense layer are synchronized by horovod, but the parameters of the embedding are trained asynchronously by PS. You can refer to: semi-synchronous training with TF1 API. Although this demo uses the TF1 API, the principles used in TF2 are similar.
semi-sync

Semi-Synchronous Training with Redis

Redis is used as a serving, although you can definitely use it as a alternative solution for training purposes. If you want to use Redis Embedding in horovod synchronization training, use the normal Embedding layer instead of HvdAllToAllEmbedding. In addition, enabling bp_v2 may improve the model convergence effect(not guaranteed), and the bp_v2 function of redis requires another compilation of Redis module.

@MoFHeka MoFHeka pinned this issue Nov 1, 2023
@MoFHeka MoFHeka added enhancement New feature or request question Further information is requested labels Nov 1, 2023
@sivukhin
Copy link
Author

sivukhin commented Nov 2, 2023

Thanks @MoFHeka, got it!
One last question from my side - does recommenders-addons plans to support more fresh versions of tensorflow in future library releases?

@MoFHeka
Copy link
Contributor

MoFHeka commented Nov 3, 2023

@sivukhin
For now, it will continue to integrate and be compatible with the latest version of Tensorflow, but this is a lot of work. So it would be great if you could also contribute to the TFRA code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants