Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output column in GeoParquet table with statistics (mean, count, min/max) of an input band #185

Open
weiji14 opened this issue Mar 18, 2024 · 0 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@weiji14
Copy link
Contributor

weiji14 commented Mar 18, 2024

Idea that came up during our regular meetings, on generalizing the cloud-cover percentage patch-level info (i.e. extending #168) to other bands/channels, so that someone could apply other filters based on certain columns with some statistics (mean, count, min/max, percentage, etc) derived from the input images. This would enable pre-filtering based on attributes when performing Similiarity Search.

Example:

embedding cloud_cover_percentage mean_elevation max_temperature bbox
[0.1, 0.4, ... x768] 20% 100m 25°C POLYGON(...)
[0.2, 0.5, ... x768] 30% 300m 20°C POLYGON(...)
[0.3, 0.6, ... x768] 40% 500m 15°C POLYGON(...)

This would involve generalizing the inference part of the code somehow, specifically the predict_step function here:

model/src/model_clay.py

Lines 855 to 921 in 0145e55

def predict_step(
self, batch: dict[str, torch.Tensor | list[str]], batch_idx: int
) -> gpd.GeoDataFrame:
"""
Logic for the neural network's prediction loop.
"""
# Get image, bounding box, EPSG code, and date inputs
# x: torch.Tensor = batch["pixels"] # image of shape (1, 13, 512, 512) # BCHW
bboxes: np.ndarray = batch["bbox"].cpu().__array__() # bounding boxes
epsgs: torch.Tensor = batch["epsg"] # coordinate reference systems as EPSG code
dates: list[str] = batch["date"] # dates, e.g. ['2022-12-12', '2022-12-12']
source_urls: list[str] = batch[ # URLs, e.g. ['s3://1.tif', 's3://2.tif']
"source_url"
]
# Forward encoder
self.model.encoder.mask_ratio = 0.0 # disable masking
outputs_encoder: dict = self.model.encoder(
datacube=batch # input (pixels, timestep, latlon)
)
# Get embeddings generated from encoder
# (encoded_unmasked_patches, _, _, _) = outputs_encoder
embeddings_raw: torch.Tensor = outputs_encoder[0]
assert embeddings_raw.shape == torch.Size(
[self.model.encoder.B, 1538, 768] # (batch_size, seq_length, hidden_size)
)
assert not torch.isnan(embeddings_raw).any() # ensure no NaNs in embedding
# Take the mean of the embeddings along the sequence_length dimension
# excluding the last two latlon_ and time_ embeddings, i.e. compute
# mean over patch embeddings only
embeddings_mean: torch.Tensor = embeddings_raw[:, :-2, :].mean(dim=1)
assert embeddings_mean.shape == torch.Size(
[self.model.encoder.B, 768] # (batch_size, hidden_size)
)
# Create table to store the embeddings with spatiotemporal metadata
unique_epsg_codes = set(int(epsg) for epsg in epsgs)
if len(unique_epsg_codes) == 1: # check that there's only 1 unique EPSG
epsg: int = batch["epsg"][0]
else:
raise NotImplementedError(
f"More than 1 EPSG code detected: {unique_epsg_codes}"
)
gdf = gpd.GeoDataFrame(
data={
"source_url": pd.Series(data=source_urls, dtype="string[pyarrow]"),
"date": pd.to_datetime(arg=dates, format="%Y-%m-%d").astype(
dtype="date32[day][pyarrow]"
),
"embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray(
embeddings_mean.cpu().detach().__array__()
),
},
geometry=shapely.box(
xmin=bboxes[:, 0],
ymin=bboxes[:, 1],
xmax=bboxes[:, 2],
ymax=bboxes[:, 3],
),
crs=f"EPSG:{epsg}",
)
gdf = gdf.to_crs(crs="OGC:CRS84") # reproject from UTM to lonlat coordinates
return gdf

Some changes might also need to happen on the DataLoader side, so that these statistical measures are passed through. Parking this as an idea for now.

@weiji14 weiji14 added enhancement New feature or request help wanted Extra attention is needed labels Mar 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant