-
-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy path_db_helpers.py
More file actions
242 lines (193 loc) · 8.23 KB
/
_db_helpers.py
File metadata and controls
242 lines (193 loc) · 8.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
"""Helper functions to interact with the database.
Those represent the SQL queries used to communicate with the database.
Typically these helpers will query the database and return pydantic objects representing the data.
Some are still using sqlalchemy legacy style, but the latest ones are using the new 2.0-friendly
style.
"""
import datetime as dt
import uuid
from collections import defaultdict
from typing import Any, Optional
import sqlalchemy as sa
import structlog
from pvsite_datamodel.read.generation import get_pv_generation_by_sites
from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL, InverterSQL, SiteSQL
from sqlalchemy.orm import Session, aliased
from .pydantic_models import (
Forecast,
MultiplePVActual,
PVActualValue,
PVSiteMetadata,
SiteForecastValues,
)
logger = structlog.stdlib.get_logger()
# Sqlalchemy rows are tricky to type: we use this to make the code more readable.
Row = Any
def _get_forecasts_for_horizon(
session: Session,
site_uuids: list[str],
start_utc: dt.datetime,
end_utc: dt.datetime,
horizon_minutes: int,
) -> list[Row]:
"""Get the forecasts for given sites for a given horizon."""
stmt = (
sa.select(ForecastSQL, ForecastValueSQL)
# We need a DISTINCT ON statement in cases where we have run two forecasts for the same
# time. In practice this shouldn't happen often.
.distinct(ForecastSQL.site_uuid, ForecastSQL.timestamp_utc)
.select_from(ForecastSQL)
.join(ForecastValueSQL)
.where(ForecastSQL.site_uuid.in_(site_uuids))
# Also filtering on `timestamp_utc` makes the query faster.
.where(ForecastSQL.timestamp_utc >= start_utc - dt.timedelta(minutes=horizon_minutes))
.where(ForecastSQL.timestamp_utc < end_utc)
.where(ForecastValueSQL.horizon_minutes == horizon_minutes)
.where(ForecastValueSQL.start_utc >= start_utc)
.where(ForecastValueSQL.start_utc < end_utc)
.order_by(ForecastSQL.site_uuid, ForecastSQL.timestamp_utc)
)
return list(session.execute(stmt))
def _get_inverters_by_site(session: Session, site_uuid: str) -> list[Row]:
query = session.query(InverterSQL).filter(InverterSQL.site_uuid == site_uuid)
return query.all()
def _get_latest_forecast_by_sites(
session: Session, site_uuids: list[str], start_utc: Optional[dt.datetime] = None
) -> list[Row]:
"""Get the latest forecast for given site uuids."""
# Get the latest forecast for each site.
subquery = (
session.query(ForecastSQL)
.distinct(ForecastSQL.site_uuid)
.filter(ForecastSQL.site_uuid.in_([uuid.UUID(su) for su in site_uuids]))
.order_by(
ForecastSQL.site_uuid,
ForecastSQL.timestamp_utc.desc(),
)
).subquery()
forecast_subq = aliased(ForecastSQL, subquery, name="ForecastSQL")
# Join the forecast values.
query = session.query(forecast_subq, ForecastValueSQL)
query = query.join(ForecastValueSQL)
# only get future forecast values. This solves the case when a forecast is made 1 day a go,
# but since then, no new forecast have been made
if start_utc is not None:
query = query.filter(ForecastValueSQL.start_utc >= start_utc)
query.order_by(forecast_subq.timestamp_utc, ForecastValueSQL.start_utc)
return query.all()
def _forecast_rows_to_pydantic(rows: list[Row]) -> list[Forecast]:
"""Make a list of `(ForecastSQL, ForecastValueSQL)` rows into our pydantic `Forecast`
objects.
Note that we remove duplicate ForecastValueSQL when found.
"""
# Per-site metadata.
data: dict[str, dict[str, Any]] = defaultdict(dict)
# Per-site forecast values.
values: dict[str, list[SiteForecastValues]] = defaultdict(list)
# Per-site *set* of ForecastValueSQL.forecast_value_uuid to be able to filter out duplicates.
# This is useful in particular because our latest forecast and past forecasts will overlap in
# the middle.
fv_uuids: dict[str, set[uuid.UUID]] = defaultdict(set)
for row in rows:
site_uuid = str(row.ForecastSQL.site_uuid)
if site_uuid not in data:
data[site_uuid]["site_uuid"] = site_uuid
data[site_uuid]["forecast_uuid"] = str(row.ForecastSQL.forecast_uuid)
data[site_uuid]["forecast_creation_datetime"] = row.ForecastSQL.timestamp_utc
data[site_uuid]["forecast_version"] = row.ForecastSQL.forecast_version
fv_uuid = row.ForecastValueSQL.forecast_value_uuid
if fv_uuid not in fv_uuids[site_uuid]:
values[site_uuid].append(
SiteForecastValues(
target_datetime_utc=row.ForecastValueSQL.start_utc,
expected_generation_kw=row.ForecastValueSQL.forecast_power_kw,
)
)
fv_uuids[site_uuid].add(fv_uuid)
return [
Forecast(
forecast_values=values[site_uuid],
**data[site_uuid],
)
for site_uuid in data.keys()
]
def get_forecasts_by_sites(
session: Session,
site_uuids: list[str],
start_utc: dt.datetime,
horizon_minutes: int,
) -> list[Forecast]:
"""Combination of the latest forecast and the past forecasts, for given sites.
This is what we show in the UI.
"""
logger.info(f"Getting forecast for {len(site_uuids)} sites")
end_utc = dt.datetime.utcnow()
rows_past = _get_forecasts_for_horizon(
session,
site_uuids=site_uuids,
start_utc=start_utc,
end_utc=end_utc,
horizon_minutes=horizon_minutes,
)
logger.debug("Found %s past forecasts", len(rows_past))
rows_future = _get_latest_forecast_by_sites(
session=session, site_uuids=site_uuids, start_utc=start_utc
)
logger.debug("Found %s future forecasts", len(rows_future))
logger.debug("Formatting forecasts to pydantic objects")
forecasts = _forecast_rows_to_pydantic(rows_past + rows_future)
logger.debug("Formatting forecasts to pydantic objects: done")
return forecasts
def get_generation_by_sites(
session: Session, site_uuids: list[str], start_utc: dt.datetime
) -> list[MultiplePVActual]:
"""Get the generation since yesterday (midnight) for a list of sites."""
logger.info(f"Getting generation for {len(site_uuids)} sites")
rows = get_pv_generation_by_sites(
session=session, start_utc=start_utc, site_uuids=[uuid.UUID(su) for su in site_uuids]
)
# Go through the rows and split the data by site.
pv_actual_values_per_site: dict[str, list[PVActualValue]] = defaultdict(list)
# TODO can we speed this up?
logger.info("Formatting generation 1")
for row in rows:
site_uuid = str(row.site_uuid)
pv_actual_values_per_site[site_uuid].append(
PVActualValue(
datetime_utc=row.start_utc,
actual_generation_kw=row.generation_power_kw,
)
)
logger.info("Formatting generation 2")
multiple_pv_actuals = [
MultiplePVActual(site_uuid=site_uuid, pv_actual_values=pv_actual_values)
for site_uuid, pv_actual_values in pv_actual_values_per_site.items()
]
logger.debug("Getting generation for {len(site_uuids)} sites: done")
return multiple_pv_actuals
def get_sites_by_uuids(session: Session, site_uuids: list[str]) -> list[PVSiteMetadata]:
sites = session.query(SiteSQL).where(SiteSQL.site_uuid.in_(site_uuids)).all()
pydantic_sites = [site_to_pydantic(site) for site in sites]
return pydantic_sites
def site_to_pydantic(site: SiteSQL) -> PVSiteMetadata:
"""Converts a SiteSQL object into a PVSiteMetadata object."""
pv_site = PVSiteMetadata(
site_uuid=str(site.site_uuid),
client_name=site.client.client_name,
client_site_id=site.client_site_id,
client_site_name=site.client_site_name,
region=site.region,
dno=site.dno,
gsp=site.gsp,
latitude=site.latitude,
longitude=site.longitude,
installed_capacity_kw=site.capacity_kw,
created_utc=site.created_utc,
)
return pv_site
def does_site_exist(session: Session, site_uuid: str) -> bool:
"""Checks if a site exists."""
return (
session.execute(sa.select(SiteSQL).where(SiteSQL.site_uuid == site_uuid)).one_or_none()
is not None
)