From 9b7411904457ab6834786ee36003faf0fdace467 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Tue, 23 Jun 2026 11:57:59 -0500 Subject: [PATCH] =?UTF-8?q?refactor(ogc):=20reuse=20+=20unify=20the=20OGC?= =?UTF-8?q?=20engine=20=E2=80=94=20pager,=20aggregation,=20ambient=20state?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reuse and unify the OGC engine's HTTP pagination / aggregation / error-recovery / ambient-state plumbing instead of carrying parallel implementations. Net source reduction of ~66 LOC, behavior-preserving, plus one rate-limit correctness fix. wateruse reuse: - wateruse drives the engine's generic `_paginate` (with an injected `raise_for_status` for the NWDC `{detail}` envelope), `_run_sync` (anyio portal, Jupyter-safe), and `_combine_chunk_frames` / `_combine_chunk_responses` aggregators — replacing a hand-rolled pager, thread bridge, and bespoke aggregation. `_resolve_locations` becomes a `_LOCATION_BUILDERS` table dispatch, dropping a 3-way if/elif and a duplicated selector enumeration. Engine unification: - `planning._merge_response`: one low-level "fold N responses into one" behind both pagination (`_paginate`) and chunked/fan-out aggregation (`_combine_chunk_responses`), replacing two near-duplicate implementations; deletes `engine._aggregate_paginated_response`. - `utils.Ambient[T]`: a generic ContextVar-with-scope class collapsing each per-call ambient (`_row_cap`, `_ogc_base_url`, `_dialect`, the chunker's `_chunked_client`) from a var + hand-written `@contextmanager` setter pair into one declaration. `with _x(value):` call sites unchanged; readers shorten to `_x.get()`. - `_paginate`'s verbatim per-page progress block deduped into a `report_page` closure. - `_combine_chunk_responses`: dropped a dead single-response branch. - `_QUOTA_HEADER` moved to the base `planning` module — dedups the literal and fixes a layering inversion (planning had hard-coded it, unable to import from chunking). - `_cql2_param`: CQL2 filter list built as a comprehension. - `engine._check_id_format`: inlined into its only caller; dead re-export dropped. Rate-limit correctness fix: - `x-ratelimit-remaining` now reports the LOWEST value any concurrent sub-request saw (the quota actually left after a fan-out), via a shared `_lowest_remaining`, instead of the last-by-index — fixing a latent inaccuracy in the OGC chunker too. Behavior-preserving (live-verified); offline OGC/wateruse/utils/progress suites green; ruff + mypy --strict clean. Co-Authored-By: Claude Opus 4.8 (1M context) Claude-Session: https://claude.ai/code/session_01Sjb14HkwuCydKSKMsaXsgd --- dataretrieval/ogc/chunking.py | 51 ++----- dataretrieval/ogc/engine.py | 220 ++++++++++--------------------- dataretrieval/ogc/planning.py | 84 +++++++++--- dataretrieval/utils.py | 40 +++++- dataretrieval/waterdata/utils.py | 2 - dataretrieval/wateruse.py | 211 +++++++++++++---------------- tests/waterdata_chunking_test.py | 23 +++- tests/waterdata_test.py | 4 +- tests/waterdata_utils_test.py | 2 +- tests/wateruse_test.py | 14 +- 10 files changed, 297 insertions(+), 354 deletions(-) diff --git a/dataretrieval/ogc/chunking.py b/dataretrieval/ogc/chunking.py index b7fff335..00d2766c 100644 --- a/dataretrieval/ogc/chunking.py +++ b/dataretrieval/ogc/chunking.py @@ -69,15 +69,14 @@ import functools import os from collections.abc import Callable, Iterator -from contextlib import contextmanager -from contextvars import ContextVar, copy_context +from contextvars import copy_context from typing import Any, cast import httpx import pandas as pd from anyio.from_thread import start_blocking_portal -from dataretrieval.utils import HTTPX_DEFAULTS +from dataretrieval.utils import HTTPX_DEFAULTS, Ambient from . import progress as _progress from .interruptions import ( @@ -106,9 +105,6 @@ _OGC_URL_BYTE_LIMIT = 8000 -# Response header USGS uses to advertise remaining hourly quota. -_QUOTA_HEADER = "x-ratelimit-remaining" - # Fan-out concurrency cap, read at call time (not import) so test # ``monkeypatch.setenv`` applies. Value grammar in :func:`_read_concurrency_env`; # the concurrency model is in the module docstring. @@ -152,38 +148,11 @@ def _read_concurrency_env() -> int | None: return value -# Shared per-call ``httpx.AsyncClient``, published via :func:`_publish` -# during ``ChunkedCall._run`` so paginated-loop helpers (``_walk_pages``) -# reuse the same connection pool across every sub-request. ``None`` -# outside a chunked call — paginated helpers then open their own -# short-lived client. -_chunked_client: ContextVar[httpx.AsyncClient | None] = ContextVar( - "_chunked_client", default=None -) - - -@contextmanager -def _publish(client: httpx.AsyncClient) -> Iterator[None]: - """ - Publish ``client`` on the ``_chunked_client`` ContextVar so the - paginated-loop helpers can borrow it via :func:`get_active_client` - for the duration of the ``with`` block. - - Parameters - ---------- - client : httpx.AsyncClient - The client to publish. - - Yields - ------ - None - Yields once, for the duration of the bind. - """ - token = _chunked_client.set(client) - try: - yield - finally: - _chunked_client.reset(token) +# Shared per-call ``httpx.AsyncClient``, scoped via ``with _chunked_client(c):`` +# during ``ChunkedCall._run`` so paginated-loop helpers (``_walk_pages``) reuse +# the same connection pool across every sub-request. ``None`` outside a chunked +# call — paginated helpers then open their own short-lived client. +_chunked_client: Ambient[httpx.AsyncClient | None] = Ambient("_chunked_client", None) def get_active_client() -> httpx.AsyncClient | None: @@ -197,8 +166,8 @@ def get_active_client() -> httpx.AsyncClient | None: Returns ------- httpx.AsyncClient or None - The client published via :func:`_publish` if currently inside a - :class:`ChunkedCall` run; ``None`` otherwise. + The client scoped via ``with _chunked_client(...)`` if currently inside + a :class:`ChunkedCall` run; ``None`` otherwise. """ return _chunked_client.get() @@ -541,7 +510,7 @@ async def _run(self, max_concurrent: int | None) -> tuple[pd.DataFrame, Any]: ) async with httpx.AsyncClient(limits=limits, **HTTPX_DEFAULTS) as client: - with _publish(client): + with _chunked_client(client): reporter = _progress.current() if reporter is not None: reporter.set_chunks(self.plan.total) diff --git a/dataretrieval/ogc/engine.py b/dataretrieval/ogc/engine.py index a806425c..8d67f486 100644 --- a/dataretrieval/ogc/engine.py +++ b/dataretrieval/ogc/engine.py @@ -24,7 +24,6 @@ from __future__ import annotations -import copy import functools import json import logging @@ -35,13 +34,10 @@ Awaitable, Callable, Iterable, - Iterator, Mapping, ) -from contextlib import asynccontextmanager, contextmanager -from contextvars import ContextVar +from contextlib import asynccontextmanager from dataclasses import dataclass, field -from datetime import timedelta from typing import Any, TypeVar, cast import httpx @@ -51,16 +47,14 @@ from dataretrieval.exceptions import DataRetrievalError from dataretrieval.ogc import chunking from dataretrieval.ogc import progress as _progress -from dataretrieval.ogc.chunking import ( - _QUOTA_HEADER, - get_active_client, -) +from dataretrieval.ogc.chunking import get_active_client from dataretrieval.ogc.dates import _DATE_RANGE_PARAMS, _format_api_dates from dataretrieval.ogc.errors import _paginated_failure_message, _raise_for_non_200 -from dataretrieval.ogc.planning import _safe_elapsed +from dataretrieval.ogc.planning import _QUOTA_HEADER, _merge_response, _safe_elapsed from dataretrieval.ogc.shaping import GEOPANDAS, _finalize_ogc, _get_resp_data from dataretrieval.utils import ( HTTPX_DEFAULTS, + Ambient, BaseMetadata, _default_headers, _get, @@ -237,12 +231,13 @@ def _cql2_param(args: dict[str, Any]) -> str: which roughly doubles how many monitoring-location ids fit in one sub-request and so halves the chunk count for large id lists. """ - filters = [] - for key, values in args.items(): - filters.append({"op": "in", "args": [{"property": key}, values]}) - - query = {"op": "and", "args": filters} - + query = { + "op": "and", + "args": [ + {"op": "in", "args": [{"property": key}, values]} + for key, values in args.items() + ], + } return json.dumps(query, separators=(",", ":")) @@ -275,7 +270,7 @@ def _check_ogc_requests(endpoint: str, req_type: str = "queryables") -> dict[str """ if req_type not in ("queryables", "schema"): raise ValueError(f"req_type must be 'queryables' or 'schema', got {req_type!r}") - url = f"{_ogc_base_url_var.get()}/collections/{endpoint}/{req_type}" + url = f"{_ogc_base_url.get()}/collections/{endpoint}/{req_type}" resp = _get(url, headers=_default_headers(), **HTTPX_DEFAULTS) _raise_for_non_200(resp) # ``Response.json`` is typed ``Any``; the OGC queryables/schema endpoints @@ -355,8 +350,8 @@ def _construct_api_requests( ----- - Date/time parameters are automatically formatted to ISO8601. """ - service_url = f"{_ogc_base_url_var.get()}/collections/{service}/items" - dialect = _dialect_var.get() + service_url = f"{_ogc_base_url.get()}/collections/{service}/items" + dialect = _dialect.get() # Format date/time parameters to ISO8601 first — both routing paths need it. for key in _DATE_RANGE_PARAMS: @@ -454,7 +449,7 @@ def _construct_cql_request( httpx.Request A POST request with ``Content-Type: application/query-cql-json``. """ - service_url = f"{_ogc_base_url_var.get()}/collections/{service}/items" + service_url = f"{_ogc_base_url.get()}/collections/{service}/items" params = _ogc_query_params( {}, properties=properties, @@ -586,108 +581,25 @@ async def _client_for( yield new -def _aggregate_paginated_response( - initial: httpx.Response, - last: httpx.Response, - total_elapsed: timedelta, -) -> httpx.Response: - """ - Build a single response covering a paginated call. - - Returns a shallow copy of ``initial`` with ``.headers`` set to the - LAST page's (so downstream sees current ``x-ratelimit-remaining``) - and ``.elapsed`` set to total wall-clock. The canonical - ``initial.url`` is preserved (it's the user's original query). - Both ``initial`` and ``last`` are left unmutated, mirroring the - convention of - :func:`dataretrieval.ogc.planning._combine_chunk_responses`. - - Parameters - ---------- - initial : httpx.Response - First-page response (the canonical one for ``md.url``). - last : httpx.Response - Last-page response — supplies the headers to copy over. - total_elapsed : datetime.timedelta - Cumulative wall-clock across every page, including ``initial``. - - Returns - ------- - httpx.Response - A shallow copy of ``initial`` with ``.headers`` set to a fresh - ``httpx.Headers`` and ``.elapsed`` set to the cumulative - wall-clock. ``initial.headers`` / ``initial.elapsed`` are - never mutated, so callers holding a pre-pagination reference - still see the original first-page values. - """ - final = copy.copy(initial) - final.headers = httpx.Headers(last.headers) - final.elapsed = total_elapsed - return final - - _Cursor = TypeVar("_Cursor") -# Optional cap on the total rows a single paginated call accumulates before it -# stops following ``next`` links. ``None`` (the default the data getters use) -# means "no cap — fetch the whole series". Set via :func:`_row_cap` so the deep -# ``_paginate`` loop can honor it without threading the value through the -# generic chunker; this mirrors the ``_progress`` ambient-reporter pattern. -_row_cap_var: ContextVar[int | None] = ContextVar("ogc_row_cap", default=None) - - -@contextmanager -def _row_cap(max_rows: int | None) -> Iterator[None]: - """Cap the rows any :func:`_paginate` under this context will - accumulate (``None`` = uncapped). Used by :func:`get_reference_table` - to preview large tables without downloading every page.""" - token = _row_cap_var.set(max_rows) - try: - yield - finally: - _row_cap_var.reset(token) - +# Ambient per-call state the generic chunker would otherwise have to thread +# through to the deep request builder / paginate loop. Each is read with +# ``.get()`` and scoped with ``with _x(value):``; the defaults leave every +# existing getter unaffected. (Mirrors the ``_progress`` ambient-reporter.) -# OGC base URL for the active request. ``get_ogc_data`` sets it per call so the -# shared request builder (:func:`_construct_api_requests`) can target either the -# main Water Data API or the NGWMN sub-API without threading the value through -# the generic chunker; this mirrors the ``_row_cap`` ambient pattern. The -# default is the main API, so every existing getter is unaffected. -_ogc_base_url_var: ContextVar[str] = ContextVar("ogc_base_url", default=OGC_API_URL) +# Optional cap on the rows one paginated call accumulates before it stops +# following ``next`` links (``None`` = uncapped). Set by :func:`get_reference_table` +# to preview large tables without downloading every page. +_row_cap: Ambient[int | None] = Ambient("ogc_row_cap", None) +# OGC base URL the shared request builder (:func:`_construct_api_requests`) +# targets — the main Water Data API or, for NGWMN collections, their own base. +_ogc_base_url: Ambient[str] = Ambient("ogc_base_url", OGC_API_URL) -@contextmanager -def _ogc_base_url(base_url: str) -> Iterator[None]: - """Point :func:`_construct_api_requests` (and the chunk planner that calls - it) at ``base_url`` for the duration of the block. Used by - :func:`get_ogc_data` to serve NGWMN collections from their own OGC base.""" - token = _ogc_base_url_var.set(base_url) - try: - yield - finally: - _ogc_base_url_var.reset(token) - - -# Per-call OGC dialect (which services need POST/CQL2, which use date-only time -# args). ``get_ogc_data`` sets it so the shared request builder -# (:func:`_construct_api_requests`) can adapt to the active API without -# threading the value through the generic chunker; this mirrors the -# ``_ogc_base_url`` ambient pattern. The default is a plain OGC API. -_dialect_var: ContextVar[OgcDialect] = ContextVar( - "ogc_dialect", default=_DEFAULT_DIALECT -) - - -@contextmanager -def _dialect(dialect: OgcDialect) -> Iterator[None]: - """Make ``dialect`` the active :class:`OgcDialect` that - :func:`_construct_api_requests` reads for CQL2-vs-GET routing and - date-only formatting, for the duration of the block.""" - token = _dialect_var.set(dialect) - try: - yield - finally: - _dialect_var.reset(token) +# Per-call OGC dialect the request builder reads for CQL2-vs-GET routing and +# date-only formatting (default: a plain OGC API). +_dialect: Ambient[OgcDialect] = Ambient("ogc_dialect", _DEFAULT_DIALECT) async def _paginate( @@ -696,6 +608,7 @@ async def _paginate( parse_response: Callable[[httpx.Response], tuple[pd.DataFrame, _Cursor | None]], follow_up: Callable[[_Cursor, httpx.AsyncClient], Awaitable[httpx.Response]], client: httpx.AsyncClient | None = None, + raise_for_status: Callable[[httpx.Response], None] = _raise_for_non_200, ) -> tuple[pd.DataFrame, httpx.Response]: """ Drive a paginated request to completion over an @@ -726,6 +639,10 @@ async def _paginate( Caller-borrowed client. ``None`` (default) means use the chunker's shared client (if inside a chunked call) or open a temporary one. + raise_for_status : callable, optional + ``resp -> None``; raises the typed error for a non-OK response. + Defaults to :func:`_raise_for_non_200` (the OGC ``{code, description}`` + envelope); wateruse passes its own to surface the NWDC ``detail``. Returns ------- @@ -756,9 +673,19 @@ async def _paginate( """ logger.debug("Requesting: %s", initial_req.url) reporter = _progress.current() + + def report_page(page: httpx.Response, frame: pd.DataFrame) -> None: + """Tick the ambient progress reporter (a no-op when unset) for one page.""" + if reporter is not None: + reporter.set_rate_remaining( + page.headers.get(_QUOTA_HEADER), + limit=page.headers.get("x-ratelimit-limit"), + ) + reporter.add_page(rows=len(frame)) + async with _client_for(client) as sess: resp = await sess.send(initial_req) - _raise_for_non_200(resp) + raise_for_status(resp) initial_response = resp total_elapsed = _safe_elapsed(resp) @@ -775,28 +702,25 @@ async def _paginate( # Stop following ``next`` links once the optional row cap is reached # (see :func:`_row_cap`); ``None`` means uncapped. The concatenation # is sliced to the cap below so a final over-budget page can't exceed it. - cap = _row_cap_var.get() + cap = _row_cap.get() nrows = len(df) - if reporter is not None: - reporter.set_rate_remaining( - resp.headers.get(_QUOTA_HEADER), - limit=resp.headers.get("x-ratelimit-limit"), - ) - reporter.add_page(rows=len(df)) - while cursor is not None and (cap is None or nrows < cap): + # Guard a non-advancing or cyclic cursor (a server bug that would + # otherwise loop forever). OGC's next-URLs are unique, so this never + # fires for them; the Link-header pagers (e.g. wateruse) rely on it. + seen: set[Any] = set() + report_page(resp, df) + while ( + cursor is not None and cursor not in seen and (cap is None or nrows < cap) + ): + seen.add(cursor) try: resp = await follow_up(cursor, sess) - _raise_for_non_200(resp) + raise_for_status(resp) df, cursor = parse_response(resp) dfs.append(df) nrows += len(df) total_elapsed += _safe_elapsed(resp) - if reporter is not None: - reporter.set_rate_remaining( - resp.headers.get(_QUOTA_HEADER), - limit=resp.headers.get("x-ratelimit-limit"), - ) - reporter.add_page(rows=len(df)) + report_page(resp, df) except Exception as e: # noqa: BLE001 logger.warning( "Request failed at cursor %r. Data download interrupted.", @@ -804,12 +728,13 @@ async def _paginate( ) raise DataRetrievalError(_paginated_failure_message(len(dfs), e)) from e - # Aggregate headers / elapsed onto a COPY of the initial - # response so the user's caller never sees an in-place - # mutation of the response object they may have inspected - # mid-pagination via a hook or test fixture. - final_response = _aggregate_paginated_response( - initial_response, resp, total_elapsed + # Fold the pages onto a COPY of the initial response so a caller that + # inspected it mid-pagination (a hook, a test fixture) never sees an + # in-place mutation. ``resp`` is the last page, whose headers carry the + # current ``x-ratelimit-remaining`` (monotonic, so the last page is the + # most depleted) — the same low-level merge the fan-out aggregation uses. + final_response = _merge_response( + initial_response, headers_from=resp, elapsed=total_elapsed ) result = pd.concat(dfs, ignore_index=True) if cap is not None: @@ -1047,7 +972,7 @@ def _run_sync( # through raw; mid-pagination failures are already typed. # Report the base URL actually targeted (NGWMN/sibling APIs # set their own via ``_ogc_base_url``), not a hardcoded host. - raise _network_error(_ogc_base_url_var.get(), exc) from exc + raise _network_error(_ogc_base_url.get(), exc) from exc # ``AGENCY-ID``: a hyphen-separated agency prefix and local id. The local id @@ -1168,19 +1093,14 @@ def _check_monitoring_location_id( if value is None: return None for item in (value,) if isinstance(value, str) else value: - _check_id_format(item) + if not _MONITORING_LOCATION_ID_RE.fullmatch(item): + raise ValueError( + f"Invalid monitoring_location_id: {item!r}. " + f"Expected 'AGENCY-ID' format, e.g., 'USGS-01646500'." + ) return value -def _check_id_format(value: str) -> None: - """Raise ``ValueError`` if ``value`` is not in ``AGENCY-ID`` format.""" - if not _MONITORING_LOCATION_ID_RE.fullmatch(value): - raise ValueError( - f"Invalid monitoring_location_id: {value!r}. " - f"Expected 'AGENCY-ID' format, e.g., 'USGS-01646500'." - ) - - def _get_args( local_vars: dict[str, Any], exclude: set[str] | None = None, diff --git a/dataretrieval/ogc/planning.py b/dataretrieval/ogc/planning.py index 23828e64..191bfafa 100644 --- a/dataretrieval/ogc/planning.py +++ b/dataretrieval/ogc/planning.py @@ -563,6 +563,56 @@ def _combine_chunk_frames(frames: list[pd.DataFrame]) -> pd.DataFrame: return combined +# Response header USGS uses to advertise remaining hourly quota. Lives in this +# base module so every layer (planning's ``_lowest_remaining``, the engine's +# per-page progress) reads it from one place rather than hard-coding the string. +_QUOTA_HEADER = "x-ratelimit-remaining" + + +def _lowest_remaining(responses: list[httpx.Response]) -> httpx.Response: + """The response reporting the lowest ``x-ratelimit-remaining``. + + The rate-limit counter decreases monotonically within a window, so the + smallest value any sub-request saw is the most-current "quota left after + this call" — the right thing to surface. Under concurrent fan-out the + last response *by index* need not be the one the server processed last, so + pick the minimum (falling back to the last response if none report it). + """ + best: httpx.Response | None = None + best_remaining: int | None = None + for response in responses: + try: + remaining = int(response.headers[_QUOTA_HEADER]) + except (KeyError, ValueError): + continue + if best_remaining is None or remaining < best_remaining: + best, best_remaining = response, remaining + return best if best is not None else responses[-1] + + +def _merge_response( + base: httpx.Response, + *, + headers_from: httpx.Response, + elapsed: timedelta, + url: str | httpx.URL | None = None, +) -> httpx.Response: + """Fold several responses into one: a shallow copy of ``base`` whose + ``.headers`` are rebuilt as a fresh ``httpx.Headers`` from ``headers_from``, + ``.elapsed`` set to ``elapsed``, and ``.url`` overridden when ``url`` is + given. ``base`` and ``headers_from`` are never mutated, and the fresh + ``httpx.Headers`` means downstream mutations don't back-propagate into any + underlying response — so callers may re-fold idempotently. This is the one + low-level merge behind both pagination (:func:`_paginate`) and the chunked / + fan-out aggregation (:func:`_combine_chunk_responses`).""" + merged = copy.copy(base) + merged.headers = httpx.Headers(headers_from.headers) + merged.elapsed = elapsed + if url is not None: + _set_response_url(merged, url) + return merged + + def _combine_chunk_responses( responses: list[httpx.Response], canonical_url: str | None ) -> httpx.Response: @@ -570,8 +620,9 @@ def _combine_chunk_responses( Fold per-sub-request responses into a single aggregated response. For a multi-response input, returns a shallow copy of - ``responses[0]`` with ``.headers`` set to the last response's (so - ``x-ratelimit-remaining`` reflects current state), ``.elapsed`` set + ``responses[0]`` with ``.headers`` set to those of the most-depleted + response (lowest ``x-ratelimit-remaining`` — the quota actually left + after the fan-out; see :func:`_lowest_remaining`), ``.elapsed`` set to total wall-clock across every response, and ``.url`` set to the canonical original-query URL (when supplied) so ``BaseMetadata`` reflects the user's full request rather than the first chunk. @@ -605,20 +656,15 @@ def _combine_chunk_responses( if len(responses) == 1 and canonical_url is None: return responses[0] - # ``copy.copy`` lets repeated calls re-sum elapsed from scratch - # rather than re-mutating ``responses[0]`` in place. The headers - # dict is then rewrapped in a fresh ``httpx.Headers`` so the - # aggregate's headers don't share identity with — or leak mutations - # back into — any underlying response on ``ChunkedCall._chunks``. - head = copy.copy(responses[0]) - if len(responses) > 1: - head.headers = httpx.Headers(responses[-1].headers) - head.elapsed = sum( - (_safe_elapsed(r) for r in responses[1:]), - start=_safe_elapsed(responses[0]), - ) - else: - head.headers = httpx.Headers(responses[0].headers) - if canonical_url is not None: - _set_response_url(head, canonical_url) - return head + # Headers come from the most-depleted response (lowest quota left after a + # concurrent fan-out; ``_lowest_remaining`` returns the lone response as-is + # for a single-element list). ``_merge_response`` re-sums elapsed onto a + # fresh copy, so repeated calls (e.g. via ``ChunkedCall.partial_response`` + # during resume) stay idempotent. + elapsed = sum((_safe_elapsed(r) for r in responses), start=timedelta()) + return _merge_response( + responses[0], + headers_from=_lowest_remaining(responses), + elapsed=elapsed, + url=canonical_url, + ) diff --git a/dataretrieval/utils.py b/dataretrieval/utils.py index 20f38710..0fe1dcae 100644 --- a/dataretrieval/utils.py +++ b/dataretrieval/utils.py @@ -6,8 +6,10 @@ import os import warnings -from collections.abc import Callable, Iterable -from typing import Any +from collections.abc import Callable, Iterable, Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any, Generic, TypeVar import httpx import pandas as pd @@ -29,6 +31,40 @@ "timeout": httpx.Timeout(60.0, connect=10.0), } +_T = TypeVar("_T") + + +class Ambient(Generic[_T]): + """A :class:`~contextvars.ContextVar` paired with a scoping contextmanager. + + Bundles the var and its set/reset-token dance into one object, so an ambient + value needs a single declaration instead of a ``var`` + setter-function pair. + Read the current value with :meth:`get`; set it for a ``with`` block by + *calling* the instance — the previous value is restored on exit (and can't + leak into a later call the way a hand-written ``try/finally`` can when its + ``reset`` is dropped):: + + _base_url = Ambient("ogc_base_url", DEFAULT) + with _base_url(other): # scoped to the block + _base_url.get() # -> other + """ + + def __init__(self, name: str, default: _T) -> None: + self._var: ContextVar[_T] = ContextVar(name, default=default) + + def get(self) -> _T: + """The current value — the default outside any active scope.""" + return self._var.get() + + @contextmanager + def __call__(self, value: _T) -> Iterator[None]: + """Set the value for the duration of the ``with`` block.""" + token = self._var.set(value) + try: + yield + finally: + self._var.reset(token) + def _default_headers() -> dict[str, str]: """Build the default HTTP headers for a USGS web-API request. diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index c876913a..77f2ea99 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -32,7 +32,6 @@ OGC_API_URL, OgcDialect, _as_str_list, - _check_id_format, _check_monitoring_location_id, _check_ogc_requests, _construct_api_requests, @@ -305,7 +304,6 @@ def _check_profiles( "_OUTPUT_ID_BY_SERVICE", "_arrange_cols", "_as_str_list", - "_check_id_format", "_check_monitoring_location_id", "_check_ogc_requests", "_check_profiles", diff --git a/dataretrieval/wateruse.py b/dataretrieval/wateruse.py index b4d4e968..b71baebe 100644 --- a/dataretrieval/wateruse.py +++ b/dataretrieval/wateruse.py @@ -10,10 +10,14 @@ Unlike the main Water Data getters (:mod:`dataretrieval.waterdata`) and NGWMN (:mod:`dataretrieval.ngwmn`), the NWDC is a plain CSV REST service rather than -an OGC API Features collection, so this module talks to it directly instead of -delegating to the shared OGC engine. It still follows the same conventions: -shared request headers (:func:`~dataretrieval.utils._default_headers`), -the typed :class:`~dataretrieval.exceptions.DataRetrievalError` taxonomy, and a +an OGC API Features collection. This module supplies the NWDC-specific bits — +request building, CSV parsing, the ``Link``-header cursor, and the ``{detail}`` +error envelope — but reuses the OGC engine's generic, API-agnostic pagination +and sync-from-async plumbing (:func:`~dataretrieval.ogc.engine._paginate` and +:func:`~dataretrieval.ogc.engine._run_sync`) rather than re-implementing it. It +follows the same conventions: shared request headers +(:func:`~dataretrieval.utils._default_headers`), the typed +:class:`~dataretrieval.exceptions.DataRetrievalError` taxonomy, and a ``(DataFrame, BaseMetadata)`` return. See https://api.water.usgs.gov/docs/nwaa-data/ for the API reference and @@ -40,11 +44,8 @@ from __future__ import annotations import asyncio -import copy import io -import logging -from collections.abc import Iterable -from concurrent.futures import ThreadPoolExecutor +from collections.abc import Callable, Iterable from typing import Any import httpx @@ -52,17 +53,16 @@ from dataretrieval.codes.states import to_state from dataretrieval.exceptions import DataRetrievalError +from dataretrieval.ogc.engine import _ogc_base_url, _paginate, _run_sync +from dataretrieval.ogc.planning import _combine_chunk_frames, _combine_chunk_responses from dataretrieval.utils import ( HTTPX_DEFAULTS, BaseMetadata, _default_headers, - _network_error, _raise_for_status, to_str, ) -logger = logging.getLogger(__name__) - WATERUSE_URL = "https://api.water.usgs.gov/nwaa-data/data" #: Water-use models (categories) served by the NWDC. The catalog at @@ -218,22 +218,43 @@ def get_wateruse( base_params = {k: v for k, v in base_params.items() if v is not None} # The NWDC queries one location per request, so fan a multi-value selector - # out into a request per location (concurrently — see ``_fan_out``) and - # concatenate the results. - locations = _resolve_locations(state, county, huc) - # Drive the async fan-out from a worker thread so it is safe even when - # called inside an already-running event loop (e.g. a Jupyter notebook), - # where a bare ``asyncio.run`` would raise. - with ThreadPoolExecutor(max_workers=1) as pool: - df, response = pool.submit( - lambda: asyncio.run(_fan_out(locations, base_params, ssl_check)) - ).result() + # out into one request per location, each paginated by the OGC engine's + # shared pager (``_paginate``), and concatenate the results. + headers = _default_headers() + requests = [ + httpx.Request( + "GET", + WATERUSE_URL, + params={**base_params, "location": location}, + headers=headers, + ) + for location in _resolve_locations(state, county, huc) + ] + # ``_run_sync`` drives the async fan-out via an anyio portal, so it is safe + # even inside an already-running event loop (e.g. a Jupyter notebook); + # ``_ogc_base_url`` sets the host reported in any connection-error message. + with _ogc_base_url(WATERUSE_URL): + df, response = _run_sync( + lambda: _fan_out(requests, headers, ssl_check), service="wateruse" + ) return df, BaseMetadata(response) # Valid HUC code lengths (digits) → the hydrologic-unit level they query. _HUC_LENGTHS = (2, 4, 6, 8, 10, 12) +# Maps each selector to the NWDC ``location=:`` value(s) it produces. +# A value may be a single code or a list; ``_as_list`` normalizes both (``state`` +# additionally normalizes to the two-letter postal code, and ``to_state`` may +# itself return a scalar or list, which ``_as_list`` flattens the same way). +# Since NWDC takes one location per request, a list value fans out — one request +# per location (see :func:`_fan_out`). +_LOCATION_BUILDERS: dict[str, Callable[[Any], list[str]]] = { + "state": lambda v: [f"stateCd:{c}" for c in _as_list(to_state(v, to="postal"))], + "county": lambda v: [f"countyCd:{_validate_county(c)}" for c in _as_list(v)], + "huc": lambda v: [f"huc{len(c)}:{c}" for c in map(_validate_huc, _as_list(v))], +} + def _resolve_locations( state: str | int | Iterable[str | int] | None, @@ -248,28 +269,18 @@ def _resolve_locations( ``huc`` code's length selects its level (``huc2`` … ``huc12``). Returns one location string per value — the caller issues one request per location. """ - provided = [ - name + selected = { + name: value for name, value in (("state", state), ("county", county), ("huc", huc)) if value is not None - ] - if len(provided) != 1: + } + if len(selected) != 1: raise ValueError( "Specify exactly one of state, county, or huc " - f"(got: {', '.join(provided) or 'none'})." + f"(got: {', '.join(selected) or 'none'})." ) - - if state is not None: - # to_state returns a str (scalar) or list[str] (iterable); _as_list - # normalizes both, keeping this branch the same shape as county/huc. - locations = [ - f"stateCd:{code}" for code in _as_list(to_state(state, to="postal")) - ] - elif county is not None: - locations = [f"countyCd:{_validate_county(c)}" for c in _as_list(county)] - else: - locations = [f"huc{len(c)}:{c}" for c in map(_validate_huc, _as_list(huc))] - + [(name, value)] = selected.items() + locations = _LOCATION_BUILDERS[name](value) if not locations: raise ValueError( "The chosen location selector is empty; pass at least one value." @@ -309,66 +320,52 @@ def _validate_huc(value: object) -> str: async def _fan_out( - locations: list[str], base_params: dict[str, Any], ssl_check: bool + requests: list[httpx.Request], headers: dict[str, str], ssl_check: bool ) -> tuple[pd.DataFrame, httpx.Response]: - """Fetch every location concurrently over one shared async client. - - Each location is an independent paginated request; concurrency is bounded by - a semaphore at :data:`MAX_CONCURRENT_REQUESTS`, and ``asyncio.gather`` - preserves input order so the concatenation is deterministic. The single - shared :class:`httpx.AsyncClient` keeps connections alive across pages and - locations. + """Fetch every request (each paginated) concurrently over one shared client. + + Each request is paginated by the engine's + :func:`~dataretrieval.ogc.engine._paginate` with NWDC strategies: parse a CSV + page and read its ``Link`` header cursor (``parse``), follow that cursor + (``follow``), and raise the typed error carrying the NWDC ``detail`` + (``raise_for_status``). Concurrency is bounded by a semaphore at + :data:`MAX_CONCURRENT_REQUESTS`, and ``asyncio.gather`` preserves input + order, so the concatenation is deterministic. The shared + :class:`httpx.AsyncClient` keeps connections alive across pages and requests. """ - headers = _default_headers() - semaphore = asyncio.Semaphore(max(1, MAX_CONCURRENT_REQUESTS)) - - async with httpx.AsyncClient(verify=ssl_check, **HTTPX_DEFAULTS) as client: - async def _one(location: str) -> tuple[pd.DataFrame, list[httpx.Response]]: - async with semaphore: - return await _fetch_location(client, location, base_params, headers) + def parse(response: httpx.Response) -> tuple[pd.DataFrame, str | None]: + return _read_csv_page(response), _next_page_url(response) - results = await asyncio.gather(*(_one(loc) for loc in locations)) + async def follow(cursor: str, sess: httpx.AsyncClient) -> httpx.Response: + return await sess.get(cursor, headers=headers) - frames = [frame for frame, _ in results] - responses = [resp for _, page_responses in results for resp in page_responses] - df = frames[0] if len(frames) == 1 else pd.concat(frames, ignore_index=True) - return df, _aggregate_responses(responses) - - -async def _fetch_location( - client: httpx.AsyncClient, - location: str, - base_params: dict[str, Any], - headers: dict[str, str], -) -> tuple[pd.DataFrame, list[httpx.Response]]: - """Fetch and concatenate every page for one location over ``client``. - - The NWDC paginates large areas with an RFC 8288 ``Link: <...>; rel="next"`` - header (the cursor is a ``skip`` offset). The first request carries the - query params; each subsequent page is a fully-formed URL requested bare. The - ``seen`` set guards against a non-advancing or cyclic cursor (a server bug - that would otherwise loop forever, accumulating frames until OOM). - """ - frames: list[pd.DataFrame] = [] - responses: list[httpx.Response] = [] - seen: set[str] = set() - url: str | None = WATERUSE_URL - params: dict[str, Any] | None = {**base_params, "location": location} - while url is not None and url not in seen: - seen.add(url) - try: - response = await client.get(url, params=params, headers=headers) - except httpx.TransportError as exc: - raise _network_error(url, exc) from exc + def raise_for_status(response: httpx.Response) -> None: _raise_for_status(response, detail_from=_nwdc_error_detail) - logger.debug("Requested water-use page: %s", response.url) - responses.append(response) - frames.append(_read_csv_page(response)) - url, params = _next_page_url(response), None - df = frames[0] if len(frames) == 1 else pd.concat(frames, ignore_index=True) - return df, responses + async with httpx.AsyncClient(verify=ssl_check, **HTTPX_DEFAULTS) as client: + semaphore = asyncio.Semaphore(max(1, MAX_CONCURRENT_REQUESTS)) + + async def _one(request: httpx.Request) -> tuple[pd.DataFrame, httpx.Response]: + async with semaphore: + return await _paginate( + request, + parse_response=parse, + follow_up=follow, + client=client, + raise_for_status=raise_for_status, + ) + + results = await asyncio.gather(*(_one(req) for req in requests)) + + # Reuse the engine's combine helpers: drop empty frames and concat, and fold + # the per-location responses into one (lowest-remaining rate-limit headers + + # cumulative elapsed), keeping the first request's URL as the query identity. + frames = [frame for frame, _ in results] + responses = [resp for _, resp in results] + return _combine_chunk_frames(frames), _combine_chunk_responses( + responses, str(requests[0].url) + ) def _read_csv_page(response: httpx.Response) -> pd.DataFrame: @@ -384,38 +381,6 @@ def _read_csv_page(response: httpx.Response) -> pd.DataFrame: ) from exc -def _aggregate_responses(responses: list[httpx.Response]) -> httpx.Response: - """Fold the per-page, per-location responses into one for metadata. - - Keeps the first request's URL (the query identity) but surfaces the *final* - rate-limit headers — those of the response that saw the lowest - ``x-ratelimit-remaining``, i.e. the quota left after the whole fan-out — and - the cumulative elapsed time. A single response is returned unchanged. - """ - first = responses[0] - if len(responses) == 1: - return first - final = copy.copy(first) - final.headers = httpx.Headers(_most_depleted(responses).headers) - final.elapsed = sum((r.elapsed for r in responses[1:]), start=first.elapsed) - return final - - -def _most_depleted(responses: list[httpx.Response]) -> httpx.Response: - """The response reporting the lowest ``x-ratelimit-remaining`` (the latest - server-side view of the quota), or the last response if none report it.""" - best: httpx.Response | None = None - best_remaining: int | None = None - for response in responses: - try: - remaining = int(response.headers["x-ratelimit-remaining"]) - except (KeyError, ValueError): - continue - if best_remaining is None or remaining < best_remaining: - best, best_remaining = response, remaining - return best if best is not None else responses[-1] - - def _next_page_url(response: httpx.Response) -> str | None: """Return the absolute URL of the next page, or None if this is the last. diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py index 8895c54a..72426d2b 100644 --- a/tests/waterdata_chunking_test.py +++ b/tests/waterdata_chunking_test.py @@ -44,7 +44,6 @@ from dataretrieval.ogc import chunking as _chunking from dataretrieval.ogc import retry as _retry_mod from dataretrieval.ogc.chunking import ( - _QUOTA_HEADER, ChunkedCall, _chunked_client, get_active_client, @@ -59,6 +58,7 @@ _LIST_SEP, _NEVER_CHUNK, _OR_SEP, + _QUOTA_HEADER, ChunkPlan, _combine_chunk_frames, _combine_chunk_responses, @@ -1044,7 +1044,8 @@ def test_combine_chunk_responses_returns_independent_headers(): ) head = _combine_chunk_responses([r0, r1], canonical_url=None) - # Aggregate carries the last chunk's headers... + # Aggregate carries a chunk's headers (here the last, as the fallback when + # neither reports a rate limit)... assert head.headers["X-Foo"] == "1" # ...but mutating the aggregate must not back-propagate. head.headers["X-Trace-Id"] = "abc" @@ -1052,6 +1053,24 @@ def test_combine_chunk_responses_returns_independent_headers(): assert "X-Trace-Id" not in r0.headers +def test_combine_chunk_responses_surfaces_lowest_remaining(): + """``x-ratelimit-remaining`` reports the LOWEST any sub-request saw — the + quota actually left after the fan-out — not the last-by-index, which under + concurrency need not be the response the server processed last.""" + r0 = mock.Mock( + elapsed=datetime.timedelta(seconds=0.1), + headers={"x-ratelimit-remaining": "5"}, # lowest, but first by index + url="u0", + ) + r1 = mock.Mock( + elapsed=datetime.timedelta(seconds=0.2), + headers={"x-ratelimit-remaining": "99"}, # last by index, but higher + url="u1", + ) + head = _combine_chunk_responses([r0, r1], canonical_url=None) + assert head.headers["x-ratelimit-remaining"] == "5" + + def test_paginate_terminates_on_empty_string_cursor(): """``_paginate``'s loop predicate is ``while cursor is not None``. Parse-response wrappers in ``_walk_pages`` / ``stats.get_data`` diff --git a/tests/waterdata_test.py b/tests/waterdata_test.py index b48859c4..50d8663c 100644 --- a/tests/waterdata_test.py +++ b/tests/waterdata_test.py @@ -882,8 +882,8 @@ def test_get_daily_malformed_id_raises(self): def test_per_item_format_check_in_list(self): """The AGENCY-ID format check runs on EVERY element of an iterable, not just the first. Regression guard against a - future ``_check_id_format`` loop that bails after one valid - item or only checks the head.""" + future ``_check_monitoring_location_id`` loop that bails after one + valid item or only checks the head.""" with pytest.raises(ValueError, match="Invalid monitoring_location_id"): _check_monitoring_location_id(["USGS-01646500", "badformat"]) diff --git a/tests/waterdata_utils_test.py b/tests/waterdata_utils_test.py index 4d568d1f..b6fd2984 100644 --- a/tests/waterdata_utils_test.py +++ b/tests/waterdata_utils_test.py @@ -379,7 +379,7 @@ def test_next_req_url_stops_when_no_features(): def test_walk_pages_does_not_mutate_initial_response(): """The aggregated response returned from ``_walk_pages`` is built - via ``_aggregate_paginated_response``, which returns a fresh copy. + via ``_merge_response``, which returns a fresh copy. Any caller that inspected ``initial_response.headers`` / ``.elapsed`` before pagination completed (a Session response hook, a logging middleware) must continue to see the original first-page diff --git a/tests/wateruse_test.py b/tests/wateruse_test.py index 9838da92..756fa4e1 100644 --- a/tests/wateruse_test.py +++ b/tests/wateruse_test.py @@ -286,18 +286,8 @@ def test_fan_out_surfaces_final_rate_limit_header(httpx_mock): assert md.header["x-ratelimit-remaining"] == "850" -def test_most_depleted_picks_lowest_remaining(): - responses = [ - httpx.Response(200, headers={"x-ratelimit-remaining": "900"}), - httpx.Response(200, headers={"x-ratelimit-remaining": "850"}), - httpx.Response(200, headers={"x-ratelimit-remaining": "875"}), - ] - assert wateruse._most_depleted(responses) is responses[1] - - -def test_most_depleted_falls_back_to_last_when_header_absent(): - responses = [httpx.Response(200), httpx.Response(200)] - assert wateruse._most_depleted(responses) is responses[1] +# (response aggregation now reuses ogc.planning._combine_chunk_responses; the +# integration test above pins the rate-limit-header behavior end-to-end.) # --- _resolve_locations unit tests (no HTTP) -------------------------------