Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 50 additions & 23 deletions swvo/io/omni/omni_high_res.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import calendar
import logging
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta, timezone
from typing import List, Optional, Tuple

Expand Down Expand Up @@ -46,6 +47,8 @@ class OMNIHighRes(BaseIO):

START_YEAR = 1981
LABEL = "omni"
PARALLEL_DOWNLOAD_THRESHOLD = 10
MAX_PARALLEL_DOWNLOADS = 10

def download_and_process(
self,
Expand Down Expand Up @@ -83,40 +86,64 @@ def download_and_process(

file_paths, time_intervals = self._get_processed_file_list(start_time, end_time, cadence_min)

download_tasks = []
for file_path, time_interval in zip(file_paths, time_intervals):
if file_path.exists() and not reprocess_files:
continue

# Create directory structure if it doesn't exist
file_path.parent.mkdir(parents=True, exist_ok=True)
download_tasks.append((file_path, time_interval))

tmp_path = file_path.with_suffix(file_path.suffix + ".tmp")
if len(download_tasks) > self.PARALLEL_DOWNLOAD_THRESHOLD:
max_workers = min(self.MAX_PARALLEL_DOWNLOADS, len(download_tasks))
logger.info(f"Downloading {len(download_tasks)} OMNI high resolution files with {max_workers} workers.")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(self._download_and_process_single_file, file_path, time_interval, cadence_min)
for file_path, time_interval in download_tasks
]
for future in as_completed(futures):
future.result()
return

try:
data = self._get_data_from_omni(
start=time_interval[0],
end=time_interval[1],
cadence=cadence_min,
)
for file_path, time_interval in download_tasks:
self._download_and_process_single_file(file_path, time_interval, cadence_min)

logger.debug("Processing file ...")
def _download_and_process_single_file(
self,
file_path,
time_interval: Tuple[datetime, datetime],
cadence_min: int,
) -> None:
"""Download and process one monthly OMNI High Resolution file."""

processed_df = self._process_single_month(data, original_end=time_interval[1], cadence_min=cadence_min)
# Create directory structure if it doesn't exist
file_path.parent.mkdir(parents=True, exist_ok=True)

# Do not save empty DataFrames — no data available for this interval
if processed_df.empty:
logger.warning(f"Skipping save for {file_path}: no data available for this interval.")
continue
tmp_path = file_path.with_suffix(file_path.suffix + ".tmp")

processed_df.to_csv(tmp_path, index=True, header=True)
tmp_path.replace(file_path)
try:
data = self._get_data_from_omni(
start=time_interval[0],
end=time_interval[1],
cadence=cadence_min,
)

except Exception as e:
logger.error(f"Failed to process {file_path}: {e}")
if tmp_path.exists():
tmp_path.unlink()
pass
continue
logger.debug("Processing file ...")

processed_df = self._process_single_month(data, original_end=time_interval[1], cadence_min=cadence_min)

# Do not save empty DataFrames — no data available for this interval
if processed_df.empty:
logger.warning(f"Skipping save for {file_path}: no data available for this interval.")
return

processed_df.to_csv(tmp_path, index=True, header=True)
tmp_path.replace(file_path)

except Exception as e:
logger.error(f"Failed to process {file_path}: {e}")
if tmp_path.exists():
tmp_path.unlink()

def read(
self,
Expand Down
18 changes: 15 additions & 3 deletions swvo/io/solar_wind/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,19 @@
from swvo.io.solar_wind.swift import SWSWIFTEnsemble as SWSWIFTEnsemble
from swvo.io.solar_wind.dscovr import DSCOVR as DSCOVR

# This has to be imported after the models to avoid a circular import
from swvo.io.solar_wind.read_solar_wind_from_multiple_models import (
AVERAGE_VALUES_TO_FILL: dict[str, float] = {
"bavg": 5.7501048842758955,
"bx_gsm": -0.0008639005912272984,
"by_gsm": -0.12753220183522668,
"bz_gsm": -0.10594003748277739,
"speed": 425.7842473380121,
"proton_density": 6.593453185227736,
"temperature": 91260.37300814023,
"pdyn": 2.1816079947051628,
"sym-h": -11.375495589373424,
}

# This has to be imported after the models and constants to avoid a circular import
from swvo.io.solar_wind.read_solar_wind_from_multiple_models import ( # noqa: E402
read_solar_wind_from_multiple_models as read_solar_wind_from_multiple_models,
) # noqa: I001
)
41 changes: 30 additions & 11 deletions swvo/io/solar_wind/read_solar_wind_from_multiple_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from scipy.interpolate import UnivariateSpline

from swvo.io.exceptions import ModelError
from swvo.io.solar_wind import DSCOVR, SWACE, SWOMNI, SWSWIFTEnsemble
from swvo.io.solar_wind import AVERAGE_VALUES_TO_FILL, DSCOVR, SWACE, SWOMNI, SWSWIFTEnsemble
from swvo.io.utils import (
any_nans,
construct_updated_data_frame,
Expand All @@ -39,6 +39,7 @@ def read_solar_wind_from_multiple_models( # noqa: PLR0913
*,
download: bool = False,
recurrence: bool = False,
fill_average: bool = False,
rec_model_order: list[DSCOVR | SWACE | SWOMNI] | None = None,
do_interpolation: bool = False,
) -> pd.DataFrame | list[pd.DataFrame]:
Expand All @@ -65,6 +66,9 @@ def read_solar_wind_from_multiple_models( # noqa: PLR0913
recurrence : bool, optional
If True, fill missing values using 27-day recurrence from historical models (OMNI, ACE, SWIFT).
Defaults to False.
fill_average : bool, optional
If True, keep the final dataframe through the requested end time for average-based filling.
Defaults to False.
rec_model_order : list[DSCOVR | SWACE | SWOMNI], optional
The order in which historical models will be used for 27-day recurrence filling.
Defaults to [DSCOVR, SWACE, SWOMNI].
Expand All @@ -87,6 +91,8 @@ def read_solar_wind_from_multiple_models( # noqa: PLR0913

assert reduce_ensemble in (None, "mean", "median"), "reduce_ensemble must be None, `mean` or `median`"

assert not (recurrence and fill_average), "Cannot use both recurrence and average filling at the same time"

if start_time > end_time:
msg = "start_time must be before end_time"
raise ValueError(msg)
Expand Down Expand Up @@ -149,23 +155,35 @@ def read_solar_wind_from_multiple_models( # noqa: PLR0913
if not any_nans(data_out):
break

# Apply 27-day recurrence if requested

if recurrence:
if rec_model_order is None:
rec_model_order = [m for m in model_order if isinstance(m, (DSCOVR, SWACE, SWOMNI))]
for i, df in enumerate(data_out):
if not df.empty:
data_out[i] = _recursive_fill_27d_historical(df, download, rec_model_order)

# Ensure continuous dataframe and handle SWIFT unavailability
data_out = _ensure_continuous_dataframe(
data_out,
start_time,
end_time,
historical_data_cutoff_time,
swift_data_available,
truncate=not (recurrence or fill_average),
)
# Apply 27-day recurrence if requested
if recurrence:
logger.info("Applying 27-day recurrence filling to missing values in historical data")
if rec_model_order is None:
rec_model_order = [m for m in model_order if isinstance(m, (DSCOVR, SWACE, SWOMNI))]
for i, df in enumerate(data_out):
if not df.empty:
data_out[i] = _recursive_fill_27d_historical(df, download, rec_model_order)
if fill_average:
logging.info("Filling future values with 10-year average values.")

for i, df in enumerate(data_out):
if not df.empty:
for col, avg_value in AVERAGE_VALUES_TO_FILL.items():
if col in df.columns:
future_mask = df.index > historical_data_cutoff_time
df.loc[future_mask, col] = avg_value
df.loc[future_mask, "model"] = "10_year_average_filled"
df.loc[future_mask, "file_name"] = "10_year_average_filled"
Comment thread
sahiljhawar marked this conversation as resolved.
Outdated
data_out[i] = df

if len(data_out) == 1:
data_out = data_out[0]
Expand Down Expand Up @@ -536,6 +554,7 @@ def _ensure_continuous_dataframe(
end_time: datetime,
historical_data_cutoff_time: datetime,
swift_data_available: bool,
truncate: bool = True,
) -> list[pd.DataFrame]:
"""
Ensure the dataframe is continuous from start to end time, handling gaps and SWIFT unavailability.
Expand Down Expand Up @@ -573,7 +592,7 @@ def _ensure_continuous_dataframe(
break

# Determine actual end time based on SWIFT availability
if (not swift_data_available or swift_data_all_nan) and historical_data_cutoff_time < end_time:
if ((not swift_data_available or swift_data_all_nan) and (historical_data_cutoff_time < end_time)) and truncate:
actual_end_time = historical_data_cutoff_time
logger.info(
f"Since SWIFT is not available for future dates, final dataframe truncated to {historical_data_cutoff_time}"
Expand Down
46 changes: 46 additions & 0 deletions tests/io/omni/test_omni_high_res.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
import os
import shutil
from concurrent.futures import Future
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import patch
Expand Down Expand Up @@ -101,6 +102,51 @@ def test_download_and_process_calls_get_data_per_month(self, omni_high_res, mock
omni_high_res.download_and_process(start_time, end_time)
assert omni_high_res._get_data_from_omni.call_count == 12

def test_download_and_process_uses_parallel_for_more_than_10_files(self, tmp_path, mocker):
omni_high_res = OMNIHighRes(data_dir=tmp_path)
start_time = datetime(2023, 1, 1, tzinfo=timezone.utc)
end_time = datetime(2023, 12, 31, tzinfo=timezone.utc)
executor_max_workers = []

class RecordingExecutor:
def __init__(self, max_workers=None):
executor_max_workers.append(max_workers)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
return False

def submit(self, fn, *args, **kwargs):
future = Future()
try:
future.set_result(fn(*args, **kwargs))
except Exception as exc:
future.set_exception(exc)
return future

mocker.patch("swvo.io.omni.omni_high_res.ThreadPoolExecutor", RecordingExecutor)
process_single_file = mocker.patch.object(omni_high_res, "_download_and_process_single_file")

omni_high_res.download_and_process(start_time, end_time)

assert executor_max_workers == [10]
assert process_single_file.call_count == 12

def test_download_and_process_stays_sequential_for_10_files(self, tmp_path, mocker):
omni_high_res = OMNIHighRes(data_dir=tmp_path)
start_time = datetime(2023, 1, 1, tzinfo=timezone.utc)
end_time = datetime(2023, 10, 31, tzinfo=timezone.utc)

executor = mocker.patch("swvo.io.omni.omni_high_res.ThreadPoolExecutor")
process_single_file = mocker.patch.object(omni_high_res, "_download_and_process_single_file")

omni_high_res.download_and_process(start_time, end_time)

executor.assert_not_called()
assert process_single_file.call_count == 10

def test_invalid_cadence(self, omni_high_res):
start_time = datetime(2022, 1, 1, tzinfo=timezone.utc)
end_time = datetime(2022, 12, 31, tzinfo=timezone.utc)
Expand Down
76 changes: 76 additions & 0 deletions tests/io/solar_wind/test_read_solar_wind_from_multiple_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import importlib
import os
from datetime import datetime, timedelta, timezone
from pathlib import Path
Expand All @@ -13,6 +14,7 @@

from swvo.io.exceptions import ModelError
from swvo.io.solar_wind import (
AVERAGE_VALUES_TO_FILL,
DSCOVR,
SWACE,
SWOMNI,
Expand All @@ -26,6 +28,7 @@

TEST_DIR = os.path.dirname(__file__)
DATA_DIR = Path(os.path.join(TEST_DIR, "data/"))
READ_SW_MODULE = importlib.import_module("swvo.io.solar_wind.read_solar_wind_from_multiple_models")


class TestReadSolarWindFromMultipleModels:
Expand Down Expand Up @@ -205,6 +208,79 @@ def test_27_day_recurrence_basic(self, sample_times, expected_columns):
assert data.index.is_monotonic_increasing
assert data.index.freq == "1min"

@pytest.mark.parametrize(("recurrence", "fill_average"), [(True, False), (False, True)])
def test_fill_modes_do_not_truncate_final_dataframe(self, monkeypatch, recurrence, fill_average):
start_time = datetime(2024, 11, 25, 0, 0, tzinfo=timezone.utc)
historical_data_cutoff_time = start_time + timedelta(minutes=5)
end_time = start_time + timedelta(minutes=10)
index = pd.date_range(start_time, end_time, freq="1min", tz="UTC")
data = pd.DataFrame(
{
"speed": [400.0] * 6 + [np.nan] * 5,
"model": ["omni"] * 6 + [None] * 5,
"file_name": ["test_file.txt"] * 6 + [None] * 5,
},
index=index,
)

monkeypatch.setattr(READ_SW_MODULE, "_read_from_model", lambda *args, **kwargs: data)
monkeypatch.setattr(READ_SW_MODULE, "_recursive_fill_27d_historical", lambda df, *_args: df)

result = read_solar_wind_from_multiple_models(
start_time=start_time,
end_time=end_time,
model_order=[SWOMNI(), SWSWIFTEnsemble()],
historical_data_cutoff_time=historical_data_cutoff_time,
Comment on lines +204 to +210
recurrence=recurrence,
fill_average=fill_average,
)

assert result.index.max() == end_time

def test_average_fill_uses_expected_values(self, monkeypatch):
start_time = datetime(2024, 11, 25, 0, 0, tzinfo=timezone.utc)
historical_data_cutoff_time = start_time + timedelta(minutes=5)
end_time = start_time + timedelta(minutes=10)
index = pd.date_range(start_time, end_time, freq="1min", tz="UTC")
average_values = AVERAGE_VALUES_TO_FILL
data = pd.DataFrame(
{
**{col: [1.0] * 6 + [np.nan] * 5 for col in average_values},
"model": ["omni"] * 6 + [None] * 5,
"file_name": ["test_file.txt"] * 6 + [None] * 5,
},
index=index,
)

monkeypatch.setattr(READ_SW_MODULE, "_read_from_model", lambda *args, **kwargs: data)

result = read_solar_wind_from_multiple_models(
start_time=start_time,
end_time=end_time,
model_order=[SWOMNI()],
historical_data_cutoff_time=historical_data_cutoff_time,
fill_average=True,
)

future_mask = result.index > historical_data_cutoff_time
assert result.index.max() == end_time
for col, avg_value in average_values.items():
assert result.loc[historical_data_cutoff_time, col] == 1.0
np.testing.assert_allclose(result.loc[future_mask, col].to_numpy(), avg_value)
assert (result.loc[future_mask, "model"] == "10_year_average_filled").all()
assert (result.loc[future_mask, "file_name"] == "10_year_average_filled").all()

def test_recurrence_and_average_fill_are_mutually_exclusive(self, sample_times):
with pytest.raises(AssertionError, match="Cannot use both recurrence and average filling"):
read_solar_wind_from_multiple_models(
start_time=sample_times["past_start"],
end_time=sample_times["future_end"],
model_order=[SWOMNI()],
historical_data_cutoff_time=sample_times["test_time_now"],
recurrence=True,
fill_average=True,
)

def test_3_hour_interpolation_with_recurrence(self, sample_times, expected_columns):
# Use a longer time range to increase chances of gaps that need interpolation
extended_start = sample_times["past_start"] - timedelta(days=2)
Expand Down
Loading