Skip to content

Commit 4090d98

Browse files
committed
added code
1 parent 221a76a commit 4090d98

27 files changed

+3413
-24
lines changed

README.md

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,83 @@
1-
# Project
1+
# Hierarchical cross-entropy loss improves atlas-scale single-cell annotation models
22

3-
> This repo has been populated by an initial template to help get you started. Please
4-
> make sure to update the content to build a great experience for community-building.
3+
[![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE)
54

6-
As the maintainer of this project, please make a few updates:
5+
This repository contains the code used for "Hierarchical cross-entropy loss improves atlas-scale single-cell annotation models". The paper is available on [bioRxiv](https://doi.org/10.1101/2025.04.22.1234567).
76

8-
- Improving this README.MD file to provide a great experience
9-
- Updating SUPPORT.MD with content about this project's support experience
10-
- Understanding the security reporting process in SECURITY.MD
11-
- Remove this section from the README
7+
## Repository Information
8+
This repository is partially derived from the [scTab study](https://github.com/theislab/scTab). We have extended and modified the original codebase to implement the hierarchical cross-entropy loss and the experiments described in the paper.
129

13-
## Contributing
10+
## Training Data
11+
The model training uses the CELLxGENE census version "2023-05-15" preprocessed by [scTab](https://github.com/theislab/scTab), which must be downloaded manually from [this link](https://pklab.med.harvard.edu/felix/data/merlin_cxg_2023_05_15_sf-log1p.tar.gz).
1412

15-
This project welcomes contributions and suggestions. Most contributions require you to agree to a
16-
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
17-
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
13+
## Evaluation Data
14+
For model evaluation, we use the CELLxGENE census version "2023-12-15" as referenced in the paper. This census version is automatically fetched by the code directly from the [CELLxGENE](https://cellxgene.cziscience.com/) portal when needed.
1815

19-
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
20-
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
21-
provided by the bot. You will only need to do this once across all repos using our CLA.
16+
## Hierarchical Cross-Entropy Loss
17+
The hierarchical cross-entropy loss leverages inherent hierarchical structures within classification problems to improve model performance. Unlike standard cross-entropy which treats each class independently, this loss function accounts for inclusion relationships between classes. Here we provide a standalone implementation that can be applied to any hierarchical classification task, regardless of the domain or model architecture.
2218

23-
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
24-
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
25-
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
19+
### Reachability Matrix
20+
The function relies on a **reachability matrix** that encodes the hierarchical structure as a directed acyclic graph (DAG). In this matrix:
21+
- Element (i,j) equals 1 if class j is reachable from class i (meaning j is either i itself or j is a subclass of i in the hierarchy)
22+
- Element (i,j) equals 0 otherwise
2623

27-
## Trademarks
24+
For example, consider this simple hierarchical structure:
25+
```
26+
A
27+
↙ ↘
28+
B C
29+
↙ ↘ ↙
30+
D E
31+
```
2832

29-
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
30-
trademarks or logos is subject to and must follow
31-
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
32-
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
33-
Any use of third-party trademarks or logos are subject to those third-party's policies.
33+
The corresponding reachability matrix would be:
34+
```
35+
A B C D E
36+
A | 1 1 1 1 1
37+
B | 0 1 0 1 1
38+
C | 0 0 1 0 1
39+
D | 0 0 0 1 0
40+
E | 0 0 0 0 1
41+
```
42+
43+
The reachability relation encoded in this matrix is a partial order and has the following mathematical properties:
44+
- **Reflexive**: Every class is reachable from itself (diagonal elements are 1)
45+
- **Antisymmetric**: If class i can reach j and j can reach i, then i equals j
46+
- **Transitive**: If class i can reach j and j can reach k, then i can reach k
47+
48+
### Implementation
49+
```python
50+
def hierarchical_cross_entropy_loss(logits, targets, reachability_matrix, weight=None):
51+
"""
52+
Hierarchical Cross-Entropy loss
53+
54+
Args:
55+
logits: Raw model predictions (batch_size, num_classes)
56+
targets: Ground truth class indices (batch_size)
57+
reachability_matrix: Matrix encoding hierarchical relationships (num_classes, num_classes)
58+
weight: Optional class weights
59+
60+
Returns:
61+
Hierarchical Cross-Entropy loss value
62+
"""
63+
# Convert logits to probabilities using softmax
64+
cell_type_probs = torch.softmax(logits, dim=-1)
65+
66+
# Propagate probabilities through the hierarchy using the reachability matrix
67+
cell_type_probs = torch.matmul(cell_type_probs, reachability_matrix.T)
68+
69+
# Apply log transform (with numerical stability term) for NLL loss calculation
70+
cell_type_probs = torch.log(
71+
cell_type_probs + torch.tensor(1e-6, device=cell_type_probs.device)
72+
)
73+
74+
# Calculate negative log-likelihood loss with optional class weights
75+
hce_loss = F.nll_loss(cell_type_probs, targets, weight=weight)
76+
return hce_loss
77+
```
78+
79+
## Contact
80+
For questions or issues, please contact davide.dascenzo.work@gmail.com or davide.dascenzo@unimi.it (likely not active from 2026).
81+
82+
## Citation
83+
If you use this code or method in your research, please consider citing the following [paper](https://doi.org/10.1101/2025.04.22.1234567).

model_evaluation/checkpoint_list.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
/data/sebacultrera/merlin_cxg_2023_05_15_sf-log1p/tb_logs/cxg_2023_05_15_linear_hierarchical_loss/default/version_0/checkpoints/val_f1_macro_epoch=1_val_f1_macro=0.802.ckpt
2+
/data/sebacultrera/merlin_cxg_2023_05_15_sf-log1p/tb_logs/cxg_2023_05_15_mlp_hierarchical_loss/default/version_0/checkpoints/val_f1_macro_epoch=1_val_f1_macro=0.798.ckpt
3+
/data/sebacultrera/merlin_cxg_2023_05_15_sf-log1p/tb_logs/cxg_2023_05_15_tabnet_hierarchical_loss/default/version_0/checkpoints/val_f1_macro_epoch=0_val_f1_macro=0.789.ckpt

model_evaluation/model_evaluation.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import argparse
2+
import sys
3+
import os
4+
sys.path.append('../scTab')
5+
sys.path.append('../model_evaluation')
6+
import pandas as pd
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
from utils import (
10+
data_preparation,
11+
run_model,
12+
print_clf_report_per_class
13+
)
14+
15+
def parse_args():
16+
"""Parse command line arguments."""
17+
parser = argparse.ArgumentParser(description='Evaluate cell type classification models')
18+
19+
# Data paths
20+
parser.add_argument('--dataset_ids', type=str, help='Dataset ID or "diff_2023-05-15"')
21+
parser.add_argument('--features_file', type=str, required=True,
22+
help='Path to features.parquet')
23+
parser.add_argument('--var_file', type=str, required=True,
24+
help='Path to var.parquet')
25+
parser.add_argument('--cell_type_mapping_file', type=str, required=True,
26+
help='Path to cell_type.parquet')
27+
parser.add_argument('--cell_type_hierarchy_file', type=str, required=True,
28+
help='Path to child_matrix.npy')
29+
30+
# Model paths and configuration
31+
parser.add_argument('--model_type', type=str, required=True,
32+
choices=['tabnet', 'linear', 'mlp', 'celltypist'],
33+
help='Type of model to evaluate')
34+
parser.add_argument('--checkpoint_path', type=str, required=True,
35+
help='Path to model checkpoint')
36+
parser.add_argument('--hparams_file', type=str,
37+
help='Path to hyperparameters file (not needed for CellTypist)')
38+
39+
# Output configuration
40+
parser.add_argument('--output_dir', type=str, default='evaluation_results',
41+
help='Directory to save evaluation results')
42+
parser.add_argument('--census_version', type=str, default='2023-05-15',
43+
help='CellXGene census version')
44+
parser.add_argument('--force_download', action='store_true',
45+
help='Force re-download of data')
46+
47+
# Add output root argument
48+
parser.add_argument('--output_root', type=str, required=True,
49+
help='Root directory for storing AnnData chunks and results')
50+
51+
return parser.parse_args()
52+
53+
def save_results(clf_report_overall, clf_report_per_class, y_probs, y_pred, y_true, metadata, cell_type_mapping, args):
54+
"""Save evaluation results to files."""
55+
os.makedirs(args.output_dir, exist_ok=True)
56+
57+
# Save overall metrics
58+
overall_path = os.path.join(args.output_dir, f'{args.model_type}_overall_metrics.csv')
59+
clf_report_overall.to_csv(overall_path)
60+
print(f"\nOverall metrics saved to: {overall_path}")
61+
print("\nOverall Results:")
62+
print(clf_report_overall)
63+
64+
# Save per-class metrics
65+
per_class_path = os.path.join(args.output_dir, f'{args.model_type}_per_class_metrics.csv')
66+
clf_report_per_class.to_csv(per_class_path)
67+
print(f"\nPer-class metrics saved to: {per_class_path}")
68+
69+
# Generate and save visualization
70+
plt.figure(figsize=(20, 10))
71+
print_clf_report_per_class(
72+
clf_report_per_class,
73+
args.cell_type_mapping_file,
74+
title=f'{args.model_type.capitalize()} Performance by Cell Type'
75+
)
76+
plot_path = os.path.join(args.output_dir, f'{args.model_type}_performance_plot.png')
77+
plt.savefig(plot_path, bbox_inches='tight', dpi=300)
78+
plt.close()
79+
print(f"\nPerformance plot saved to: {plot_path}")
80+
81+
# Create and save detailed results dataframe
82+
print("\nCreating detailed results dataframe...")
83+
84+
# Convert numeric indices to cell type labels
85+
cell_type_mapping_df = pd.read_parquet(args.cell_type_mapping_file)
86+
cell_type_mapping_dict = dict(zip(range(len(cell_type_mapping_df)), cell_type_mapping_df['label']))
87+
88+
# Create the detailed results dataframe
89+
detailed_df = pd.DataFrame({
90+
'y_true': [cell_type_mapping_dict[idx] for idx in y_true],
91+
'y_pred': [cell_type_mapping_dict[idx] for idx in y_pred]
92+
})
93+
94+
# Add other columns from metadata
95+
detailed_df = pd.concat([detailed_df, metadata.reset_index(drop=True)], axis=1)
96+
97+
# Add probabilities as a column
98+
detailed_df['y_probs'] = list(y_probs)
99+
100+
# Save the detailed results
101+
detailed_path = os.path.join(args.output_dir, f'{args.model_type}_detailed_results.parquet')
102+
detailed_df.to_parquet(detailed_path, index=False)
103+
print(f"\nDetailed results saved to: {detailed_path}")
104+
105+
def main():
106+
"""Main execution function."""
107+
args = parse_args()
108+
109+
print(f"\nPreparing data for {args.model_type} model evaluation...")
110+
output_folder, genes, cell_mapping = data_preparation(
111+
args.dataset_ids,
112+
args.features_file,
113+
args.var_file,
114+
args.cell_type_mapping_file,
115+
census_version=args.census_version,
116+
force_download=args.force_download,
117+
output_root=args.output_root
118+
)
119+
120+
print(f"\nRunning evaluation for {args.model_type} model...")
121+
results = run_model(
122+
args.model_type,
123+
args.checkpoint_path,
124+
args.hparams_file,
125+
args.cell_type_hierarchy_file,
126+
genes,
127+
cell_mapping,
128+
output_folder
129+
)
130+
131+
# Unpack results (now including probabilities and additional metadata)
132+
clf_report_overall, clf_report_per_class, y_probs, y_pred, y_true, metadata = results
133+
134+
save_results(clf_report_overall, clf_report_per_class, y_probs, y_pred, y_true, metadata, cell_mapping, args)
135+
136+
if __name__ == '__main__':
137+
main()
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#!/bin/bash
2+
3+
# Common data paths
4+
DATA_ROOT="/data/sebacultrera/merlin_cxg_2023_05_15_sf-log1p"
5+
6+
# Function to evaluate a single checkpoint
7+
evaluate_checkpoint() {
8+
CHECKPOINT_PATH="$1"
9+
10+
# Dynamically extract log and model folders from the checkpoint file path
11+
REL_PATH=${CHECKPOINT_PATH#${DATA_ROOT}/}
12+
LOG_DIR=$(echo "$REL_PATH" | cut -d'/' -f1)
13+
MODEL_NAME=$(echo "$REL_PATH" | cut -d'/' -f2)
14+
DEFAULT_DIR=$(echo "$REL_PATH" | cut -d'/' -f3)
15+
VERSION_DIR=$(echo "$REL_PATH" | cut -d'/' -f4)
16+
MODEL_PATH="${DATA_ROOT}/${LOG_DIR}/${MODEL_NAME}/${DEFAULT_DIR}/${VERSION_DIR}"
17+
18+
CHECKPOINT_NAME=$(basename "$CHECKPOINT_PATH" .ckpt)
19+
20+
# Determine model type based on checkpoint path
21+
if [[ $CHECKPOINT_PATH == *"mlp"* ]]; then
22+
MODEL_TYPE="mlp"
23+
elif [[ $CHECKPOINT_PATH == *"tabnet"* ]]; then
24+
MODEL_TYPE="tabnet"
25+
elif [[ $CHECKPOINT_PATH == *"linear"* ]]; then
26+
MODEL_TYPE="linear"
27+
else
28+
echo "Unknown model type in path: $CHECKPOINT_PATH"
29+
return 1
30+
fi
31+
32+
HPARAMS_PATH="${MODEL_PATH}/hparams.yaml"
33+
OUTPUT_DIR="${MODEL_PATH}/checkpoints/${CHECKPOINT_NAME}"
34+
35+
# Create output directory
36+
mkdir -p "$OUTPUT_DIR"
37+
38+
# Run evaluation
39+
python model_evaluation.py \
40+
--dataset_ids "diff_2023-05-15" \
41+
--features_file "/home/sebacultrera/label_smoothing_celltype/label_smoothing_celltype/scTab/notebooks/store_creation/features.parquet" \
42+
--var_file "${DATA_ROOT}/var.parquet" \
43+
--cell_type_mapping_file "${DATA_ROOT}/categorical_lookup/cell_type.parquet" \
44+
--cell_type_hierarchy_file "${DATA_ROOT}/cell_type_hierarchy/child_matrix.npy" \
45+
--model_type "${MODEL_TYPE}" \
46+
--checkpoint_path "${CHECKPOINT_PATH}" \
47+
--hparams_file "${HPARAMS_PATH}" \
48+
--output_dir "${OUTPUT_DIR}" \
49+
--output_root "${DATA_ROOT}" \
50+
--census_version "2023-12-15" \
51+
> "${OUTPUT_DIR}/eval.out" 2> "${OUTPUT_DIR}/eval.err"
52+
}
53+
54+
export -f evaluate_checkpoint
55+
56+
find_checkpoints() {
57+
local dir="$1"
58+
find "$dir" -name "*.ckpt"
59+
}
60+
61+
# Check if checkpoint list file is provided
62+
if [ $# -ne 1 ]; then
63+
echo "Usage: $0 <checkpoint_list_file>"
64+
exit 1
65+
fi
66+
67+
CHECKPOINT_LIST="$1"
68+
MAX_PARALLEL=8 # Maximum number of parallel processes
69+
running=0 # Counter for running processes
70+
71+
# Read checkpoints and run evaluations
72+
while IFS= read -r path; do
73+
# Skip empty lines
74+
[ -z "$path" ] && continue
75+
76+
if [ -d "$path" ]; then
77+
# If path is a directory, find all checkpoints
78+
while IFS= read -r checkpoint; do
79+
# Wait if we've reached max parallel processes
80+
if [ $running -ge $MAX_PARALLEL ]; then
81+
wait -n
82+
running=$((running - 1))
83+
fi
84+
85+
evaluate_checkpoint "$checkpoint" &
86+
running=$((running + 1))
87+
done < <(find_checkpoints "$path")
88+
else
89+
# Handle single checkpoint file
90+
if [ $running -ge $MAX_PARALLEL ]; then
91+
wait -n
92+
running=$((running - 1))
93+
fi
94+
95+
evaluate_checkpoint "$path" &
96+
running=$((running + 1))
97+
fi
98+
done < "$CHECKPOINT_LIST"
99+
100+
# Wait for remaining processes to finish
101+
wait

0 commit comments

Comments
 (0)