Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN with mock data #10

Open
BlinkDL opened this issue Jun 4, 2022 · 1 comment
Open

NaN with mock data #10

BlinkDL opened this issue Jun 4, 2022 · 1 comment

Comments

@BlinkDL
Copy link

BlinkDL commented Jun 4, 2022

Hi lucidrains,

Try this and it will NaN within 100 steps (latest Github code). The loss looks fine before NaN.

import torch
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True    
torch.backends.cudnn.benchmark = True

import random
import numpy as np
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

num_text_tokens = 10000
batch_sz = 12
text_seq_len = 256
visual_image_size = 256

# mock data

data_sz = 1000
all_text = torch.randint(0, num_text_tokens, (data_sz, text_seq_len)).cuda()
all_images = torch.randn(data_sz, 3, visual_image_size, visual_image_size).cuda()

text = torch.zeros((batch_sz, text_seq_len), dtype=torch.long).cuda()
images = torch.zeros((batch_sz, 3, visual_image_size, visual_image_size)).cuda()

##########################################################################################

import wandb
import datetime
wandb.init(project="Test", name=datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), save_code=False)

from x_clip import CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = num_text_tokens,
    text_enc_depth = 6,
    text_seq_len = text_seq_len,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = visual_image_size,
    visual_patch_size = 32,
    visual_heads = 8,
    use_all_token_embeds = False,           # whether to use fine-grained contrastive learning (FILIP)
    decoupled_contrastive_learning = True,  # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
    extra_latent_projection = True,         # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
    use_visual_ssl = True,                  # whether to do self supervised learning on iages
    visual_ssl_type = 'simclr',             # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
    use_mlm = False,                        # use masked language learning (MLM) on text (DeCLIP)
    text_ssl_loss_weight = 0.05,            # weight for text MLM loss
    image_ssl_loss_weight = 0.05            # weight for image self-supervised learning loss
).cuda()

optimizer = torch.optim.Adam(clip.parameters(), lr=1e-4, betas=(0.9, 0.99))

for step in range(999999):
    for i in range(batch_sz):
        data_id = random.randrange(0, data_sz - 1)
        text[i] = all_text[data_id]
        images[i] = all_images[data_id]

    loss = clip(
        text,
        images,
        freeze_image_encoder = False,   # whether to freeze image encoder if using a pretrained image net, proposed by LiT paper
        return_loss = True              # needs to be set to True to return contrastive loss
    )
    clip.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(clip.parameters(), 1.0)
    optimizer.step()

    now_loss = loss.item()
    wandb.log({"loss": now_loss}, step = step)
    print(step, now_loss)

    if 'nan' in str(now_loss):
        break
@lucidrains
Copy link
Owner

lucidrains commented Jun 23, 2022

@BlinkDL Hey Peng Bo! So I quickly checked the script and indeed it NaNs, but not if the visual_ssl is turned off

I suspect it has something to do with augmenting the randomly created images in the visual SSL, but not completely sure

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants