Skip to content

Commit efa74a4

Browse files
committed
Support for "lagged"
1 parent a46a7f1 commit efa74a4

File tree

2 files changed

+89
-52
lines changed

2 files changed

+89
-52
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,5 @@ cython_debug/
165165
*.swp
166166
*.npy
167167
*.download
168+
?
169+
?.*

ai_models/model.py

Lines changed: 87 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# granted to it by virtue of its status as an intergovernmental organisation
66
# nor does it submit to any jurisdiction.
77

8+
import datetime
89
import logging
910
import os
1011
import sys
@@ -19,80 +20,76 @@
1920
LOG = logging.getLogger(__name__)
2021

2122

22-
class MarsInput:
23+
class RequestBasedInput:
2324
def __init__(self, owner, **kwargs):
2425
self.owner = owner
2526

2627
@cached_property
2728
def fields_sfc(self):
28-
LOG.info("Loading surface fields from MARS")
29-
request = dict(
30-
date=self.owner.date,
31-
time=self.owner.time,
32-
param=self.owner.param_sfc,
33-
grid=self.owner.grid,
34-
area=self.owner.area,
35-
levtype="sfc",
29+
LOG.info(f"Loading surface fields from {self.WHERE}")
30+
return cml.load_source(
31+
"multi",
32+
[
33+
self.sfc_load_source(
34+
date=date,
35+
time=time,
36+
param=self.owner.param_sfc,
37+
grid=self.owner.grid,
38+
area=self.owner.area,
39+
)
40+
for date, time in self.owner.datetimes()
41+
],
3642
)
37-
return cml.load_source("mars", request)
3843

3944
@cached_property
4045
def fields_pl(self):
41-
LOG.info("Loading pressure fields from MARS")
46+
LOG.info(f"Loading pressure fields from {self.WHERE}")
4247
param, level = self.owner.param_level_pl
43-
request = dict(
44-
date=self.owner.date,
45-
time=self.owner.time,
46-
param=param,
47-
level=level,
48-
grid=self.owner.grid,
49-
area=self.owner.area,
50-
levtype="pl",
48+
return cml.load_source(
49+
"multi",
50+
[
51+
self.pl_load_source(
52+
date=date,
53+
time=time,
54+
param=param,
55+
level=level,
56+
grid=self.owner.grid,
57+
area=self.owner.area,
58+
)
59+
for date, time in self.owner.datetimes()
60+
],
5161
)
52-
return cml.load_source("mars", request)
5362

5463
@cached_property
5564
def all_fields(self):
5665
return self.fields_sfc + self.fields_pl
5766

5867

59-
class CdsInput:
68+
class MarsInput(RequestBasedInput):
69+
WHERE = "MARS"
70+
6071
def __init__(self, owner, **kwargs):
6172
self.owner = owner
6273

63-
@cached_property
64-
def fields_sfc(self):
65-
LOG.info("Loading surface fields from the CDS")
66-
request = dict(
67-
product_type="reanalysis",
68-
date=self.owner.date,
69-
time=self.owner.time,
70-
param=self.owner.param_sfc,
71-
grid=self.owner.grid,
72-
area=self.owner.area,
73-
levtype="sfc",
74-
)
75-
return cml.load_source("cds", "reanalysis-era5-single-levels", request)
74+
def pl_load_source(self, **kwargs):
75+
kwargs["levtype"] = "pl"
76+
logging.debug("load source mars %s", kwargs)
77+
return cml.load_source("mars", kwargs)
7678

77-
@cached_property
78-
def fields_pl(self):
79-
LOG.info("Loading pressure fields from the CDS")
80-
param, level = self.owner.param_level_pl
81-
request = dict(
82-
product_type="reanalysis",
83-
date=self.owner.date,
84-
time=self.owner.time,
85-
param=param,
86-
level=level,
87-
grid=self.owner.grid,
88-
area=self.owner.area,
89-
levtype="pl",
90-
)
91-
return cml.load_source("cds", "reanalysis-era5-pressure-levels", request)
79+
def sfc_load_source(self, **kwargs):
80+
kwargs["levtype"] = "sfc"
81+
logging.debug("load source mars %s", kwargs)
82+
return cml.load_source("mars", kwargs)
9283

93-
@cached_property
94-
def all_fields(self):
95-
return self.fields_sfc + self.fields_pl
84+
85+
class CdsInput(RequestBasedInput):
86+
WHERE = "CDS"
87+
88+
def pl_load_source(self, **kwargs):
89+
return cml.load_source("cds", "reanalysis-era5-pressure-levels", kwargs)
90+
91+
def sfc_load_source(self, **kwargs):
92+
return cml.load_source("cds", "reanalysis-era5-single-levels", kwargs)
9693

9794

9895
class FileInput:
@@ -200,6 +197,7 @@ def __exit__(self, *args):
200197

201198

202199
class Model:
200+
lagged = False
203201
assets_extra_dir = None
204202

205203
def __init__(self, input, output, download_assets, **kwargs):
@@ -305,6 +303,43 @@ def timer(self, title):
305303
def stepper(self, step):
306304
return Stepper(step, self.lead_time)
307305

306+
def datetimes(self):
307+
date = self.date
308+
assert isinstance(date, int)
309+
if date <= 0:
310+
date = datetime.datetime.utcnow() + datetime.timedelta(days=date)
311+
date = date.year * 10000 + date.month * 100 + date.day
312+
313+
time = self.time
314+
assert isinstance(time, int)
315+
if time < 100:
316+
time *= 100
317+
assert time in (0, 600, 1200, 1800), time
318+
319+
lagged = self.lagged
320+
if not lagged:
321+
lagged = [0]
322+
323+
full = datetime.datetime(
324+
date // 10000,
325+
date % 10000 // 100,
326+
date % 100,
327+
time // 100,
328+
time % 100,
329+
)
330+
331+
result = []
332+
for lag in lagged:
333+
date = full + datetime.timedelta(hours=lag)
334+
result.append(
335+
(
336+
date.year * 10000 + date.month * 100 + date.day,
337+
date.hour,
338+
),
339+
)
340+
341+
return result
342+
308343
def print_fields(self):
309344
param, level = self.param_level_pl
310345
print("Grid:", self.grid)

0 commit comments

Comments
 (0)