Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
e11b5e8
renamed forecaster name from blend to nl_blend
PavanRaghavendraKulkarni May 1, 2026
8de9294
nl_blend_adjuster
PavanRaghavendraKulkarni May 1, 2026
16e2bb8
lint error
PavanRaghavendraKulkarni May 1, 2026
ffe67b9
fix: add missing trailing newline to config.yaml
PavanRaghavendraKulkarni May 1, 2026
8159188
Merge branch 'main' into nl_blend_adjuster
PavanRaghavendraKulkarni May 1, 2026
1eeba51
added unit test cases
PavanRaghavendraKulkarni May 1, 2026
ab439a7
description enhancement
PavanRaghavendraKulkarni May 1, 2026
8de06a9
test: update test_run_blend_app_success to verify double execution of…
PavanRaghavendraKulkarni May 1, 2026
eabf121
suggested changes
PavanRaghavendraKulkarni May 1, 2026
f4260b0
comment fixes
PavanRaghavendraKulkarni May 1, 2026
d5281d3
lint fixes
PavanRaghavendraKulkarni May 1, 2026
86bb11f
use actual instead of makup removed redududnat tests
PavanRaghavendraKulkarni May 1, 2026
9a86a68
lint fixes
PavanRaghavendraKulkarni May 1, 2026
4d834cb
refactorting of blend app and test for rename columns
PavanRaghavendraKulkarni May 1, 2026
e9d316c
suggested fixes
PavanRaghavendraKulkarni May 4, 2026
e82788f
feat: implement regional blend pipeline using distinct candidate mode…
PavanRaghavendraKulkarni May 4, 2026
508289c
lint fixes
PavanRaghavendraKulkarni May 4, 2026
aec5e31
Merge branch 'main' into nl_blend_region
PavanRaghavendraKulkarni May 5, 2026
1e8d85b
filter by state ratther than prefix
PavanRaghavendraKulkarni May 5, 2026
e56f236
lint
PavanRaghavendraKulkarni May 5, 2026
836b1e1
fix tests
PavanRaghavendraKulkarni May 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 72 additions & 28 deletions site_forecast_app/blend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
build_forecast_value_objects,
)
from site_forecast_app.blend.init_times import load_nl_mae_scorecard
from site_forecast_app.blend.weights import get_blend_weights
from site_forecast_app.blend.weights import get_blend_weights, get_regional_blend_weights
from site_forecast_app.save.data_platform import (
create_forecaster_if_not_exists,
fetch_dp_location_map,
get_dataplatform_client,
)

Expand All @@ -33,12 +32,13 @@ async def run_blend_app() -> None:
1. Determine blend reference time (t0)
2. Fetch full location map from Data Platform
3. Load the MAE scorecard
4. Calculate blend weights and run blend for main models
5. Save main forecast under {forecaster_name}
6. If use_adjuster=True:
- Calculate blend weights and run blend for adjuster models
({model_name}_adjust) — full pipeline runs unchanged
- Save adjuster blend under {forecaster_name}_adjust
4. Calculate blend weights and run blend for national location
5. Save national forecast under {forecaster_name}
6. For each regional location (all non-national keys in the location map):
- Calculate regional blend weights and run blend
- Save under {forecaster_name}
7. If use_adjuster=True: repeat steps 4-6 using {model}_adjust forecasters
and save under {forecaster_name}_adjust
"""
_cfg = load_blend_config()
logger.info(
Expand All @@ -63,7 +63,9 @@ async def run_blend_app() -> None:
# -------------------------------------------------------------- #
logger.info("Fetching location map from Data Platform.")
try:
dp_loc_map = await fetch_dp_location_map(client)
resp = await client.list_locations(dp.ListLocationsRequest())
dp_locations = resp.locations
dp_loc_map = {loc.location_name: loc.location_uuid for loc in dp_locations}
if not dp_loc_map:
logger.error("Data Platform returned an empty location map. Cannot continue.")
return
Expand Down Expand Up @@ -113,6 +115,31 @@ async def run_blend_app() -> None:
forecaster_name=_cfg.forecaster_name,
)

# -------------------------------------------------------------- #
# Regional blends (all locations except national) #
# -------------------------------------------------------------- #
regional_locations = {
loc.location_name: loc.location_uuid
for loc in dp_locations
if loc.location_name != NL_NATIONAL_LOCATION_KEY
and loc.location_type == dp.LocationType.STATE
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've move dp.LocationType.STATE into the config, as for other countries, models this wont be the case

}
logger.info(
f"Running regional blend for {len(regional_locations)} region(s): "
f"{list(regional_locations.keys())}",
)
for location_key, location_uuid in regional_locations.items():
await _run_blend_pass(
client=client,
t0=t0,
location_uuid=location_uuid,
location_key=location_key,
df_mae=df_mae,
max_horizon=max_horizon,
forecaster_name=_cfg.forecaster_name,
use_regional_weights=True,
)

# -------------------------------------------------------------- #
# Adjuster blend (only if use_adjuster=True in config) #
# Weights are computed from the same module-level constants. #
Expand All @@ -132,6 +159,18 @@ async def run_blend_app() -> None:
forecaster_name=_cfg.adjuster_forecaster_name,
use_adjuster=True,
)
for location_key, location_uuid in regional_locations.items():
await _run_blend_pass(
client=client,
t0=t0,
location_uuid=location_uuid,
location_key=location_key,
df_mae=df_mae,
max_horizon=max_horizon,
forecaster_name=_cfg.adjuster_forecaster_name,
use_adjuster=True,
use_regional_weights=True,
)


def rename_columns_with_adjuster(weights_df: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -150,39 +189,44 @@ async def _run_blend_pass(
max_horizon: pd.Timedelta,
forecaster_name: str,
use_adjuster: bool = False,
use_regional_weights: bool = False,
) -> None:
"""Runs the full blend pipeline for one set of models and saves the result.

Shared by the main blend pass and the adjuster blend pass.
Shared by national and regional blend passes, and the adjuster variants.

Blend weights are always computed from the module-level constants in
``weights.py`` (NL_BACKUP_MODEL / NL_NATIONAL_CANDIDATE_MODELS).
When *use_adjuster* is True, the weight column names are renamed with an
``_adjust`` suffix before fetching, so that
:func:`get_blend_forecast_values_latest` fetches ``{model}_adjust``
forecasters from the Data Platform instead of the base model forecasters.
Blend weights are computed from the module-level constants in
``weights.py``. When *use_regional_weights* is True,
:func:`get_regional_blend_weights` is used (NL_REGIONAL_CANDIDATE_MODELS)
instead of the national candidate set.
When *use_adjuster* is True, weight column names are renamed with an
``_adjust`` suffix so that :func:`get_blend_forecast_values_latest`
fetches ``{model}_adjust`` forecasters from the Data Platform.

Args:
client: Active Data Platform gRPC client stub.
t0: Blend reference time (UTC).
location_uuid: DP location UUID to blend and save for.
location_key: Human-readable location identifier (for logging).
df_mae: (horizon x model) MAE scorecard.
max_horizon: Maximum scorecard horizon.
forecaster_name: Forecaster tag to save under.
use_adjuster: When True, fetches {model}_adjust forecasters and
saves under {forecaster_name} (caller sets the
correct adjuster forecaster name).
client: Active Data Platform gRPC client stub.
t0: Blend reference time (UTC).
location_uuid: DP location UUID to blend and save for.
location_key: Human-readable location identifier (for logging).
df_mae: (horizon x model) MAE scorecard.
max_horizon: Maximum scorecard horizon.
forecaster_name: Forecaster tag to save under.
use_adjuster: When True, fetches {model}_adjust forecasters.
use_regional_weights: When True, uses regional candidate models for
weight optimisation instead of national candidates.
"""
log_prefix = "adjuster" if use_adjuster else "blend"
weight_label = "regional" if use_regional_weights else "national"
logger.info(
f"[{log_prefix}] Starting blend pass for '{location_key}' "
f"[{log_prefix}] Starting {weight_label} blend pass for '{location_key}' "
f"(forecaster='{forecaster_name}', use_adjuster={use_adjuster})",
)

# Weights are always computed from the module-level constants.
# Regional locations use the regional candidate model set.
weight_fn = get_regional_blend_weights if use_regional_weights else get_blend_weights
try:
weights_df = await get_blend_weights(
weights_df = await weight_fn(
t0=t0,
location_uuid=location_uuid,
df_mae=df_mae,
Expand Down
1 change: 0 additions & 1 deletion site_forecast_app/blend/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ blend:
- nl_national_pv_ecmwf_sat_small

# Candidate models for regional blends (subset of national candidates).
# These values are not currently being used.
regional_candidate_models:
- nl_regional_48h_pv_ecmwf
- nl_regional_pv_ecmwf_mo_sat
Expand Down
33 changes: 33 additions & 0 deletions site_forecast_app/blend/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,4 +345,37 @@ async def get_blend_weights(
)


async def get_regional_blend_weights(
t0: pd.Timestamp,
location_uuid: str,
df_mae: pd.DataFrame,
max_horizon: pd.Timedelta,
client: dp.DataPlatformDataServiceStub,
) -> pd.DataFrame:
"""Produces a regional blend weight DataFrame for t0.

Identical pipeline to :func:`get_blend_weights` but uses
NL_REGIONAL_CANDIDATE_MODELS as the candidate set, which is typically
a subset of the national candidates (see config.yaml).

Args:
t0: Blend reference time (UTC, floored to 15 min).
location_uuid: Data Platform location UUID for the specific region.
df_mae: (horizon x model) MAE scorecard.
max_horizon: Maximum horizon in the scorecard.
client: Authenticated Data Platform gRPC client stub.

Returns:
Wide DataFrame indexed by absolute UTC target time.
Weights sum to 1.0 at every horizon.
Returns an empty DataFrame if the shifted MAE frame is empty.
"""
return await _compute_weights(
t0=t0,
location_uuid=location_uuid,
df_mae=df_mae,
max_horizon=max_horizon,
client=client,
candidate_models=NL_REGIONAL_CANDIDATE_MODELS,
label="Regional",
)
33 changes: 33 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import logging
import os
import random
import time
from importlib.metadata import version
from uuid import uuid4

import numpy as np
Expand All @@ -24,6 +26,7 @@
LocationSQL,
)
from sqlalchemy import create_engine
from testcontainers.core.container import DockerContainer
from testcontainers.postgres import PostgresContainer

from site_forecast_app.data.gencast import get_latest_6hr_init_time
Expand Down Expand Up @@ -627,3 +630,33 @@ def small_satellite_data(tmp_path_factory, init_timestamp):
def use_satellite():
"""Set use satellite env var"""
os.environ["USE_SATELLITE"] = "true"


@pytest.fixture(scope="module")
def dp_address():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how come you need this now? I thought you mocked most things out?

"""
Fixture to spin up a PostgreSQL container and Data Platform container for each test module.
Yields (host, port) for the Data Platform server.
"""

with PostgresContainer(
f"ghcr.io/openclimatefix/data-platform-pgdb:{version('dp_sdk')}",
username="postgres",
password="postgres", # noqa: S106
dbname="postgres",
env={"POSTGRES_HOST": "db"},
) as postgres:
database_url = postgres.get_connection_url()
database_url = database_url.replace("postgresql+psycopg2", "postgres")
database_url = database_url.replace("localhost", "host.docker.internal")

with DockerContainer(
image=f"ghcr.io/openclimatefix/data-platform:{version('dp_sdk')}",
env={"DATABASE_URL": database_url},
ports=[50051],
) as data_platform_server:
time.sleep(1) # Give some time for the server to start

port = data_platform_server.get_exposed_port(50051)
host = data_platform_server.get_container_host_ip()
yield host, port
Loading
Loading