Skip to content

Commit 126b309

Browse files
Merge pull request #127 from mahendra-918/feature/add-eia-data-script
feat: Add script to fetch US solar data from EIA (Issue #109)
2 parents 203d866 + 64b2a5b commit 126b309

2 files changed

Lines changed: 330 additions & 0 deletions

File tree

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import os
2+
import logging
3+
import requests
4+
import pandas as pd
5+
import xarray as xr
6+
from typing import Optional, List, Dict, Any
7+
8+
logger = logging.getLogger(__name__)
9+
10+
class EIAData:
11+
"""
12+
Class to handle interactions with the EIA API v2.
13+
"""
14+
def __init__(self, api_key: Optional[str] = None):
15+
self.api_key = api_key or os.getenv("EIA_API_KEY")
16+
if not self.api_key:
17+
logger.warning("EIA_API_KEY environment variable is not set. You must provide an API key to fetch data.")
18+
self.base_url = "https://api.eia.gov/v2"
19+
20+
def get_data(
21+
self,
22+
route: str,
23+
start_date: str,
24+
end_date: str,
25+
frequency: str = "hourly",
26+
data_cols: List[str] = ["value"],
27+
facets: Optional[Dict[str, Any]] = None,
28+
offset: int = 0,
29+
length: int = 5000,
30+
region: str = "US48",
31+
) -> Optional[pd.DataFrame]:
32+
"""
33+
Fetch data from the EIA API.
34+
35+
Args:
36+
route: API route (e.g. 'electricity/rto/daily-fuel-type-data')
37+
frequency: Data frequency (e.g. 'daily', 'hourly')
38+
start_date: Start date string
39+
end_date: End date string
40+
data_cols: List of data columns to retrieve
41+
facets: Dictionary of facets to filter by
42+
offset: Pagination offset
43+
length: Number of results to return
44+
region: Region identifier (default: "US48")
45+
46+
Returns:
47+
pd.DataFrame: Data returned from the API, or None if error/empty
48+
"""
49+
if not self.api_key:
50+
raise ValueError("API Key is missing")
51+
52+
if region:
53+
if facets is None:
54+
facets = {}
55+
if "respondent" not in facets and region == "US48":
56+
facets["respondent"] = ["US48"]
57+
58+
url = f"{self.base_url}/{route}/data"
59+
60+
params = {
61+
"api_key": self.api_key,
62+
"frequency": frequency,
63+
"start": start_date,
64+
"end": end_date,
65+
"offset": offset,
66+
"length": length,
67+
}
68+
69+
for i, col in enumerate(data_cols):
70+
params[f"data[{i}]"] = col
71+
72+
if facets:
73+
for key, value in facets.items():
74+
if isinstance(value, list):
75+
for i, v in enumerate(value):
76+
params[f"facets[{key}][{i}]"] = v
77+
else:
78+
params[f"facets[{key}][]"] = value
79+
80+
all_data = []
81+
82+
try:
83+
current_offset = offset
84+
while True:
85+
# Create a fresh copy of params for each request to avoid mutating history
86+
request_params = params.copy()
87+
request_params["offset"] = current_offset
88+
89+
logger.info(f"Fetching data from {url}, offset={current_offset}...")
90+
response = requests.get(url, params=request_params)
91+
response.raise_for_status()
92+
93+
payload = response.json()
94+
if "response" in payload and "data" in payload["response"]:
95+
data = payload["response"]["data"]
96+
if not data:
97+
logger.info("No more data returned from API.")
98+
break
99+
100+
all_data.extend(data)
101+
102+
if len(data) < length:
103+
break
104+
105+
current_offset += length
106+
else:
107+
logger.error(f"Unexpected API response format: {payload.keys()}")
108+
break
109+
110+
if not all_data:
111+
logger.warning("No data retrieved.")
112+
return None
113+
114+
return pd.DataFrame(all_data)
115+
116+
except requests.exceptions.RequestException as e:
117+
logger.error(f"Request failed: {e}")
118+
if 'response' in locals() and response is not None:
119+
logger.error(f"Response: {response.text}")
120+
return None
121+
122+
def get_dataset(
123+
self,
124+
route: str,
125+
start_date: str,
126+
end_date: str,
127+
frequency: str = "hourly",
128+
data_cols: List[str] = ["value"],
129+
facets: Optional[Dict[str, Any]] = None,
130+
region: str = "US48",
131+
) -> Optional[xr.Dataset]:
132+
"""
133+
Fetch data and convert to xarray Dataset compatible with ocf-data-sampler.
134+
135+
Args:
136+
route: API route
137+
start_date: Start date string
138+
end_date: End date string
139+
frequency: Data frequency
140+
data_cols: List of data columns
141+
facets: Dictionary of facets
142+
region: Region identifier
143+
144+
Returns:
145+
xr.Dataset: Dataset with datetime_gmt index, or None if no data
146+
"""
147+
df = self.get_data(
148+
route=route,
149+
start_date=start_date,
150+
end_date=end_date,
151+
frequency=frequency,
152+
data_cols=data_cols,
153+
facets=facets,
154+
region=region
155+
)
156+
157+
if df is None or df.empty:
158+
return None
159+
160+
# Process for ocf-data-sampler format
161+
if "period" in df.columns:
162+
df["datetime_gmt"] = pd.to_datetime(df["period"], utc=True)
163+
df = df.drop(columns=["period"])
164+
165+
index_cols = ["datetime_gmt"]
166+
if "respondent" in df.columns:
167+
index_cols.append("respondent")
168+
elif "region" in df.columns:
169+
index_cols.append("region")
170+
171+
if not df.index.is_unique:
172+
df = df.drop_duplicates(subset=index_cols)
173+
174+
df = df.set_index(index_cols)
175+
176+
ds = xr.Dataset.from_dataframe(df)
177+
178+
return ds
179+
180+
if __name__ == "__main__":
181+
# Basic test execution
182+
logging.basicConfig(level=logging.INFO)
183+
eia = EIAData()
184+
print("EIAData initialized. Set EIA_API_KEY and call get_data() to test.")

tests/test_eia.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import pytest
2+
import pandas as pd
3+
from unittest.mock import Mock, patch
4+
from open_data_pvnet.scripts.fetch_eia_data import EIAData
5+
6+
@pytest.fixture
7+
def mock_response():
8+
"""Fixture to mock a successful API response."""
9+
mock = Mock()
10+
mock.json.return_value = {
11+
"response": {
12+
"data": [
13+
{"period": "2023-01-01T00", "value": 100, "fueltype": "SUN"},
14+
{"period": "2023-01-01T01", "value": 150, "fueltype": "SUN"},
15+
]
16+
}
17+
}
18+
mock.raise_for_status.return_value = None
19+
return mock
20+
21+
def test_init_with_key():
22+
eia = EIAData(api_key="test_key")
23+
assert eia.api_key == "test_key"
24+
25+
def test_init_without_key(mocker):
26+
mocker.patch.dict("os.environ", {}, clear=True)
27+
eia = EIAData()
28+
assert eia.api_key is None
29+
30+
def test_get_data_success(mock_response):
31+
with patch("requests.get", return_value=mock_response) as mock_get:
32+
eia = EIAData(api_key="test_key")
33+
34+
df = eia.get_data(
35+
route="test/route",
36+
start_date="2023-01-01",
37+
end_date="2023-01-02",
38+
frequency="hourly",
39+
data_cols=["value"],
40+
facets={"fueltype": "SUN"}
41+
)
42+
43+
assert isinstance(df, pd.DataFrame)
44+
assert len(df) == 2
45+
assert "value" in df.columns
46+
47+
# Verify API call parameters
48+
mock_get.assert_called_once()
49+
args, kwargs = mock_get.call_args
50+
assert kwargs["params"]["api_key"] == "test_key"
51+
assert kwargs["params"]["facets[fueltype][]"] == "SUN"
52+
assert kwargs["params"]["data[0]"] == "value"
53+
54+
def test_get_data_missing_key():
55+
eia = EIAData(api_key=None)
56+
with pytest.raises(ValueError, match="API Key is missing"):
57+
eia.get_data("route", "start", "end", frequency="hourly")
58+
59+
def test_get_data_api_error():
60+
mock_resp = Mock()
61+
import requests
62+
mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError("API Error")
63+
64+
with patch("requests.get", return_value=mock_resp):
65+
eia = EIAData(api_key="test_key")
66+
df = eia.get_data("route", "start", "end", frequency="hourly")
67+
assert df is None
68+
69+
def test_get_data_empty_response():
70+
mock_resp = Mock()
71+
mock_resp.json.return_value = {"response": {"data": []}}
72+
mock_resp.raise_for_status.return_value = None
73+
74+
with patch("requests.get", return_value=mock_resp):
75+
eia = EIAData(api_key="test_key")
76+
df = eia.get_data("route", "start", "end", frequency="hourly")
77+
assert df is None
78+
79+
def test_get_data_with_region(mock_response):
80+
with patch("requests.get", return_value=mock_response) as mock_get:
81+
eia = EIAData(api_key="test_key")
82+
eia.get_data("route", "start", "end")
83+
84+
args, kwargs = mock_get.call_args
85+
assert kwargs["params"]["facets[respondent][0]"] == "US48"
86+
87+
def test_get_data_without_region(mock_response):
88+
with patch("requests.get", return_value=mock_response) as mock_get:
89+
eia = EIAData(api_key="test_key")
90+
eia.get_data("route", "start", "end", region=None)
91+
92+
args, kwargs = mock_get.call_args
93+
assert not any("facets[respondent]" in k for k in kwargs["params"].keys())
94+
95+
def test_get_dataset_success(mock_response):
96+
import xarray as xr
97+
with patch("requests.get", return_value=mock_response) as mock_get:
98+
eia = EIAData(api_key="test_key")
99+
100+
ds = eia.get_dataset(
101+
route="test/route",
102+
start_date="2023-01-01",
103+
end_date="2023-01-02"
104+
)
105+
106+
assert isinstance(ds, xr.Dataset)
107+
assert "datetime_gmt" in ds.coords or "datetime_gmt" in ds.indexes
108+
assert "value" in ds.data_vars
109+
assert len(ds.datetime_gmt) == 2
110+
111+
def test_get_data_pagination():
112+
page1 = {
113+
"response": {
114+
"data": [
115+
{"period": "2023-01-01T00", "value": 100},
116+
{"period": "2023-01-01T01", "value": 150},
117+
]
118+
}
119+
}
120+
page2 = {
121+
"response": {
122+
"data": [
123+
{"period": "2023-01-01T02", "value": 200},
124+
]
125+
}
126+
}
127+
128+
mock_resp1 = Mock()
129+
mock_resp1.json.return_value = page1
130+
mock_resp1.raise_for_status.return_value = None
131+
132+
mock_resp2 = Mock()
133+
mock_resp2.json.return_value = page2
134+
mock_resp2.raise_for_status.return_value = None
135+
136+
with patch("requests.get", side_effect=[mock_resp1, mock_resp2]) as mock_get:
137+
eia = EIAData(api_key="test_key")
138+
139+
df = eia.get_data("route", "start", "end", length=2)
140+
141+
assert len(df) == 3
142+
assert mock_get.call_count == 2
143+
144+
call_args_list = mock_get.call_args_list
145+
assert call_args_list[0][1]["params"]["offset"] == 0
146+
assert call_args_list[1][1]["params"]["offset"] == 2

0 commit comments

Comments
 (0)