Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about the value of cls token #186

Open
SELECT-FROM opened this issue Apr 21, 2024 · 0 comments
Open

Question about the value of cls token #186

SELECT-FROM opened this issue Apr 21, 2024 · 0 comments

Comments

@SELECT-FROM
Copy link

Thank for your amazing work. I have some questions about the value of cls token. During pretraining, the value of cls is pad_value(default is -2), while during the finetuning of integration, the value of cls is 0. Is there any special purpose in this design as the value of cls token is different between the pre training stage and the finetune stage?

scGPT/examples/pretrain.py

Lines 430 to 441 in 4068d67

def _map_append_cls(dataset: Dataset) -> Dataset:
logger.info(f"Rank {args.local_rank}: Appending <cls> to dataset")
dataset = dataset.map(
lambda example: {
"genes": [vocab["<cls>"]] + example["genes"],
"expressions": [args.pad_value] + example["expressions"],
},
# batched=True, # not using since then the map func needs to loop
num_proc=len(os.sched_getaffinity(0)),
)
return dataset

if append_cls:
genes = np.insert(genes, 0, cls_id)
values = np.insert(values, 0, 0)

During finetuning of batch integration, model work as self-supervised training. When masking the gene expression value, the value of the cls token may also be masked. But this situation will not occur during the pre training process. I want to know why the value of cls token is also likely to be masked in batch integration finetune. What is the reason for this design?

for i in range(len(values)):
row = values[i]
non_padding_idx = np.nonzero(row - pad_value)[0]
n_mask = int(len(non_padding_idx) * mask_ratio)
mask_idx = np.random.choice(non_padding_idx, n_mask, replace=False)
row[mask_idx] = mask_value

if keep_first_n_tokens > 0:
result_ = self._mask(
expressions[:, keep_first_n_tokens:],
keep_first_n_tokens=0,
)
return torch.cat([expressions[:, :keep_first_n_tokens], result_], dim=1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant