diff --git a/.gitignore b/.gitignore index bbd11b6e..70e259c2 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ SNT Process/ .vscode/ venv/ uv.lock +__pycache__/ # Jupyter stuff ------------------------- diff --git a/pipelines/snt_map_extracts/reporting/snt_map_extracts_report.ipynb b/pipelines/snt_map_extracts/reporting/snt_map_extracts_report.ipynb index 2115b5a6..5d536c29 100644 --- a/pipelines/snt_map_extracts/reporting/snt_map_extracts_report.ipynb +++ b/pipelines/snt_map_extracts/reporting/snt_map_extracts_report.ipynb @@ -1,216 +1,161 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "5777b72c-d87e-47c5-87b1-2698a6510b2f", - "metadata": {}, - "source": [ - "# **Cartes extraites du Malaria Atlas Project (MAP)**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6397ab91-1ae4-4db7-b6c3-061c453a7b03", - "metadata": { - "vscode": { - "languageId": "r" - } - }, - "outputs": [], - "source": [ - "# Set SNT Paths\n", - "SNT_ROOT_PATH <- \"~/workspace\"\n", - "CODE_PATH <- file.path(SNT_ROOT_PATH, \"code\")\n", - "CONFIG_PATH <- file.path(SNT_ROOT_PATH, \"configuration\")\n", - "PIPELINE_PATH <- file.path(SNT_ROOT_PATH, \"pipelines\", \"snt_map_extracts\")\n", - "\n", - "# load util functions\n", - "source(file.path(CODE_PATH, \"snt_utils.r\"))\n", - "source(file.path(PIPELINE_PATH, \"utils\", \"snt_map_extracts_report.r\"))\n", - "\n", - "# List required packages\n", - "required_packages <- c(\"dplyr\", \"tidyr\", \"terra\", \"ggplot2\", \"stringr\", \"lubridate\", \"viridis\", \"patchwork\", \"zoo\", \"purrr\", \"arrow\", \"sf\", \"reticulate\")\n", - "\n", - "# Execute function\n", - "install_and_load(required_packages)\n", - "\n", - "# Set environment to load openhexa.sdk from the right environment\n", - "Sys.setenv(RETICULATE_PYTHON = \"/opt/conda/bin/python\")\n", - "reticulate::py_config()$python\n", - "openhexa <- import(\"openhexa.sdk\")\n", - "\n", - "# Load SNT config\n", - "config_json <- tryCatch({ jsonlite::fromJSON(file.path(CONFIG_PATH, \"SNT_config.json\"))},\n", - " error = function(e) {\n", - " msg <- paste0(\"Error while loading configuration\", conditionMessage(e))\n", - " cat(msg)\n", - " stop(msg)\n", - " })\n", - "\n", - "# Required environment for the sf packages\n", - "Sys.setenv(PROJ_LIB = \"/opt/conda/share/proj\")\n", - "Sys.setenv(GDAL_DATA = \"/opt/conda/share/gdal\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "93e04996-13ba-4855-a1d1-46e70ba4640e", - "metadata": { - "vscode": { - "languageId": "r" - } - }, - "outputs": [], - "source": [ - "# Configuration variables\n", - "DATASET_NAME <- config_json$SNT_DATASET_IDENTIFIERS$SNT_MAP_EXTRACT\n", - "COUNTRY_CODE <- config_json$SNT_CONFIG$COUNTRY_CODE\n", - "ADM_2 <- toupper(config_json$SNT_CONFIG$DHIS2_ADMINISTRATION_2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a1ee21d1-c7d1-4893-ac56-91abb92926ea", - "metadata": { - "vscode": { - "languageId": "r" - } - }, - "outputs": [], - "source": [ - "# printdim() loaded from utils/snt_map_extracts_report.r" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f7de799e-896c-4237-a9f6-9dafc0f30bde", - "metadata": { - "vscode": { - "languageId": "r" - } - }, - "outputs": [], - "source": [ - "# import seasonality data\n", - "map_data <- load_map_report_input(\n", - " dataset_name = DATASET_NAME,\n", - " filename = paste0(COUNTRY_CODE, \"_map_data.parquet\"),\n", - " label = \"MAP extracted data\"\n", - ")\n", - "\n", - "# import DHIS2 shapes data\n", - "DATASET_DHIS2 <- config_json$SNT_DATASET_IDENTIFIERS$DHIS2_DATASET_FORMATTED\n", - "shapes_data <- load_map_report_input(\n", - " dataset_name = DATASET_DHIS2,\n", - " filename = paste0(COUNTRY_CODE, \"_shapes.geojson\"),\n", - " label = \"DHIS2 shapes data\"\n", - ")\n", - "\n", - "printdim(map_data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "84031c7e-e9c6-4496-896f-7f7f3403d951", - "metadata": { - "vscode": { - "languageId": "r" - } - }, - "outputs": [], - "source": [ - "names(map_data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "66b05b53-3f65-424a-af22-0686238a06c9", - "metadata": { - "vscode": { - "languageId": "r" - } - }, - "outputs": [], - "source": [ - "unique(map_data$METRIC_NAME)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9d0515ab-4dc2-4671-8c3c-236578a840d8", - "metadata": { - "vscode": { - "languageId": "r" - } - }, - "outputs": [], - "source": [ - "# Merge geometry with map data\n", - "map_data_joined <- dplyr::left_join(shapes_data, map_data, by = c(\"ADM2_ID\" = \"ADM2_ID\"))\n", - "\n", - "# Get list of metrics\n", - "metrics <- unique(map_data$METRIC_NAME)\n", - "\n", - "# Create one map per metric\n", - "plots <- build_metric_plots(map_data_joined = map_data_joined, metrics = metrics)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0f1dc1df-211d-4174-83ae-e8ae974fa790", - "metadata": { - "vscode": { - "languageId": "r" - } - }, - "outputs": [], - "source": [ - "# Set plot size for individual display\n", - "options(repr.plot.width = 10, repr.plot.height = 8)\n", - "\n", - "# Loop through plots and print one by one\n", - "for (p in plots) {\n", - " print(p)\n", - " Sys.sleep(1) # Optional: short pause between plots\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "50c843e4-9157-480d-acde-80887410d156", - "metadata": { - "vscode": { - "languageId": "r" - } - }, - "outputs": [], - "source": [] + "cells": [ + { + "cell_type": "markdown", + "id": "5777b72c-d87e-47c5-87b1-2698a6510b2f", + "metadata": {}, + "source": [ + "# **Cartes extraites du Malaria Atlas Project (MAP)**" + ] + }, + { + "cell_type": "markdown", + "id": "f0860675-9819-4a0a-b0ce-23f1b16d40c3", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfb0dee9-0da8-4afd-9f3a-7e8f0a830f85", + "metadata": {}, + "outputs": [], + "source": [ + "source(file.path(\"~/workspace/pipelines/snt_map_extracts/utils/snt_map_extracts_report.r\"))\n", + "setup_var <- get_setup_variables(packages= c(\"arrow\", \"dplyr\", \"tidyr\", \"stringr\", \"stringi\", \"jsonlite\", \"httr\", \"reticulate\", \"glue\"))\n", + "config_json <- load_snt_config(file.path(setup_var$CONFIG_PATH, \"SNT_config.json\"))\n", + "\n", + "# Save config variables\n", + "DATASET_MAP <- config_json$SNT_DATASET_IDENTIFIERS$SNT_MAP_EXTRACTS\n", + "DATASET_FORMATTED <- config_json$SNT_DATASET_IDENTIFIERS$DHIS2_DATASET_FORMATTED\n", + "COUNTRY_CODE <- config_json$SNT_CONFIG$COUNTRY_CODE" + ] + }, + { + "cell_type": "markdown", + "id": "6b3456da-ddfd-4b9b-9c9b-506f583aab3e", + "metadata": {}, + "source": [ + "## Load Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7de799e-896c-4237-a9f6-9dafc0f30bde", + "metadata": { + "vscode": { + "languageId": "r" } - ], - "metadata": { - "kernelspec": { - "display_name": "R", - "language": "R", - "name": "ir" - }, - "language_info": { - "codemirror_mode": "r", - "file_extension": ".r", - "mimetype": "text/x-r-source", - "name": "R", - "pygments_lexer": "r", - "version": "4.4.3" + }, + "outputs": [], + "source": [ + "# Load parameters\n", + "map_parameters <- load_dataset_file(dataset_id = DATASET_MAP, filename = paste0(COUNTRY_CODE, \"_parameters.json\"))\n", + "cat(jsonlite::toJSON(map_parameters, pretty = TRUE, auto_unbox = TRUE), \"\\n\")\n", + "\n", + "# Load latest year file available\n", + "map_data <- load_dataset_file(dataset_id = DATASET_MAP, filename = glue(\"{COUNTRY_CODE}_map_data_{map_parameters$YEAR_END}.parquet\"))\n", + "\n", + "# import DHIS2 shapes data\n", + "shapes_data <- load_dataset_file(dataset_id = DATASET_FORMATTED, filename = paste0(COUNTRY_CODE, \"_shapes.geojson\"))\n", + "\n", + "printdim(map_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66b05b53-3f65-424a-af22-0686238a06c9", + "metadata": { + "vscode": { + "languageId": "r" } + }, + "outputs": [], + "source": [ + "print(glue(\"Year file selection: {map_parameters$YEAR_END}\"))\n", + "print(glue(\"Indicators: {unique(map_data$METRIC_NAME)}\"))" + ] + }, + { + "cell_type": "markdown", + "id": "0ea5587e-b7d3-47c1-a4b3-c69061d24bca", + "metadata": {}, + "source": [ + "## Create plots \n", + "\n", + "We select the latest year of the selected years based on the parameters run." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d0515ab-4dc2-4671-8c3c-236578a840d8", + "metadata": { + "vscode": { + "languageId": "r" + } + }, + "outputs": [], + "source": [ + "# Merge geometry with map data\n", + "map_data_joined <- dplyr::left_join(shapes_data, map_data, by = c(\"ADM2_ID\" = \"ADM2_ID\"))\n", + "\n", + "# Get list of metrics\n", + "metrics <- unique(map_data$METRIC_NAME)\n", + "\n", + "# Create one map per metric\n", + "plots <- build_metric_plots(map_data_joined = map_data_joined, metrics = metrics, year=map_parameters$YEAR_END)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f1dc1df-211d-4174-83ae-e8ae974fa790", + "metadata": { + "vscode": { + "languageId": "r" + } + }, + "outputs": [], + "source": [ + "# Set plot size for individual display\n", + "options(repr.plot.width = 10, repr.plot.height = 8)\n", + "\n", + "# Loop through plots and print one by one\n", + "for (p in plots) {\n", + " print(p)\n", + " Sys.sleep(1) # Optional: short pause between plots\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8eb9cc86-f3ed-45f2-84f3-9e395ef85106", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "R", + "language": "R", + "name": "ir" }, - "nbformat": 4, - "nbformat_minor": 5 + "language_info": { + "codemirror_mode": "r", + "file_extension": ".r", + "mimetype": "text/x-r-source", + "name": "R", + "pygments_lexer": "r", + "version": "4.5.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/pipelines/snt_map_extracts/utils/snt_map_extracts_report.r b/pipelines/snt_map_extracts/utils/snt_map_extracts_report.r index 359134ff..f57cc780 100644 --- a/pipelines/snt_map_extracts/utils/snt_map_extracts_report.r +++ b/pipelines/snt_map_extracts/utils/snt_map_extracts_report.r @@ -1,3 +1,50 @@ + +# Load base utils +source(file.path("~/workspace/code", "snt_utils.r")) + + +#' Get Setup Variables for SNT Workspace +#' Initializes workspace paths, loads R packages, and imports OpenHEXA SDK. +#' +#' @param SNT_ROOT_PATH Character. Root path of the SNT workspace. Default: '~/workspace' +#' @param packages Character vector. R packages to install and load. +#' @return List with SNT paths. +#' +#' @export +get_setup_variables <- function( + SNT_ROOT_PATH='~/workspace', + packages=c("arrow", "dplyr", "tidyr", "stringr", "stringi", "jsonlite", "httr", "glue") +) { + + # List required pcks + required_packages <- unique(c(packages, "reticulate")) + install_and_load(required_packages) + + # Set environment to load openhexa.sdk from the right environment + Sys.setenv(RETICULATE_PYTHON = "/opt/conda/bin/python") + + # Attempt to import the SDK + tryCatch({ + sdk <- reticulate::import("openhexa.sdk") + assign("openhexa", sdk, envir = .GlobalEnv) + }, error = function(e) { + log_msg("Could not import openhexa.sdk. Ensure it is installed in /opt/conda/bin/python", "warning") + }) + + # Set paths (add paths here) + paths_to_check = list( + CONFIG_PATH = file.path(SNT_ROOT_PATH, "configuration"), + UPLOADS_PATH = file.path(SNT_ROOT_PATH, "uploads"), + DATA_PATH = file.path(SNT_ROOT_PATH, "data") + ) + + # create if they do not exist + lapply(paths_to_check, dir.create, recursive = TRUE, showWarnings = FALSE) + + return(paths_to_check) +} + + #' Print dataframe dimensions with a readable label. #' #' @param df Data frame-like object. @@ -7,27 +54,47 @@ printdim <- function(df, name = deparse(substitute(df))) { cat("Dimensions of", name, ":", nrow(df), "rows x", ncol(df), "columns\n\n") } -#' Load a MAP report input file from dataset with logging/error handling. -#' -#' @param dataset_name Dataset identifier/name. -#' @param filename File name to download from latest dataset version. -#' @param label Human-readable label for logs/errors. -#' @return Loaded dataset object (data frame / sf depending on source file). -load_map_report_input <- function(dataset_name, filename, label = "dataset file") { - data <- tryCatch( - { - get_latest_dataset_file_in_memory(dataset_name, filename) - }, - error = function(e) { - msg <- paste("Error while loading", label, "for file:", filename, conditionMessage(e)) - cat(msg) - stop(msg) - } - ) - log_msg(paste0(label, " loaded from dataset: ", dataset_name, " dataframe dimensions: ", paste(dim(data), collapse = ", "))) - data +#' Load SNT Configuration File +#' Reads and parses a JSON configuration file. +#' @param snt_config_path Character. Path to the configuration JSON file. +#' @return List containing parsed configuration. +#' +#' @export +load_snt_config <- function(snt_config_path) { + + # config file path + config_json <- tryCatch({ fromJSON(snt_config_path) }, + error = function(e) { + stop(glue::glue("[ERROR] Error while loading configuration: {snt_config_path}")) + }) + + log_msg(paste0("SNT configuration loaded from : ", snt_config_path)) + return(config_json) } +#' Load Dataset File from OpenHEXA +#' Retrieves the latest version of a file from an OpenHEXA dataset. +#' +#' @param dataset_id Character. OpenHEXA dataset identifier. +#' @param filename Character. Name of file to load. +#' @param verbose Bool. Log messages +#' @return Dataframe containing the loaded data. +#' +#' @export +load_dataset_file <- function (dataset_id, filename, verbose=TRUE) { + data <- tryCatch({ + get_latest_dataset_file_in_memory(dataset_id, filename) + }, error = function(e) { + stop(glue::glue("[ERROR] Error while loading {filename} file from dataset: {dataset_id}")) + }) + + if (verbose) { + log_msg(glue::glue("{filename} data loaded from dataset : {dataset_id} dataframe dimensions: [{paste(dim(data), collapse=', ')}]")) + } + return(data) +} + + #' Build one choropleth plot per MAP metric. #' @@ -37,13 +104,13 @@ load_map_report_input <- function(dataset_name, filename, label = "dataset file" #' @param map_data_joined Spatial table containing `METRIC_NAME` and `VALUE`. #' @param metrics Character vector of metric names to plot. #' @return List of `ggplot` objects, one per metric. -build_metric_plots <- function(map_data_joined, metrics) { +build_metric_plots <- function(map_data_joined, metrics, year) { purrr::map(metrics, function(metric) { ggplot2::ggplot(map_data_joined %>% dplyr::filter(METRIC_NAME == metric)) + ggplot2::geom_sf(ggplot2::aes(fill = VALUE), color = "white") + ggplot2::scale_fill_viridis_c(option = "C", na.value = "lightgrey") + ggplot2::labs( - title = paste0(metric), + title = paste0(metric , " - ", year), fill = "Valeur" ) + ggplot2::theme_minimal(base_size = 16) + diff --git a/snt_map_extracts/malariaAtlasProject/map.py b/snt_map_extracts/malariaAtlasProject/map.py index 673cf694..eb7bbf50 100644 --- a/snt_map_extracts/malariaAtlasProject/map.py +++ b/snt_map_extracts/malariaAtlasProject/map.py @@ -27,7 +27,7 @@ def __init__( ): """Initialize the MAPRasterExtractor.""" if category not in self.SUPPORTED_CATEGORIES: - raise ValueError(f"Supported categories: {self.SUPPORTED_CATEGORIES}.") + raise MAPExtractorError(f"Supported categories: {self.SUPPORTED_CATEGORIES}.") self.logger = logger self.base_url = base_url self.category = category @@ -79,7 +79,7 @@ def _log_message(self, message: str, level: str = "info", exc: Exception | None } if level not in logger_methods: - raise ValueError(f"Unsupported logging level: {level}") + raise MAPExtractorError(f"Unsupported logging level: {level}") # Log to standard logger if self.logger: @@ -178,16 +178,16 @@ def _get_band_names(self, coverage_id: str | None, category: str | None) -> list """Retrieve the band names (mean, mask, LCI, UCI) of a WCS coverage via DescribeCoverage. Args: - coverage_id: Coverage ID to query. If None, raises ValueError. + coverage_id: Coverage ID to query. If None, raises MAPExtractorError. category: Category name ('Malaria' or 'Interventions'). Returns: List of band names/layers available for the coverage id. """ if coverage_id is None: - raise ValueError("coverage_id must be provided.") + raise MAPExtractorError("coverage_id must be provided.") if category not in self.SUPPORTED_CATEGORIES: - raise ValueError(f"Supported categories: {self.SUPPORTED_CATEGORIES}.") + raise MAPExtractorError(f"Supported categories: {self.SUPPORTED_CATEGORIES}.") if category: url = f"{self.base_url}/{category}/ows" else: @@ -233,9 +233,9 @@ def _download_raster( Path to the downloaded raster file. """ if output_fname is None: - raise ValueError("output_fname must be provided.") + raise MAPExtractorError("output_fname must be provided.") if not output_fname.parent.exists(): - raise ValueError("Provided output_fname's parent directory does not exist.") + raise MAPExtractorError("Provided output_fname's parent directory does not exist.") year = time_position.split("-", maxsplit=1)[0] params = self._build_raster_query(coverage_id, bbox, time_position) @@ -262,14 +262,14 @@ def get_band_names(self, coverage_id: str | None, category: str | None = None) - """Public: Retrieve the band names (mean, mask, LCI, UCI) of a WCS coverage via DescribeCoverage. Args: - coverage_id: Coverage ID to query. If None, raises ValueError. + coverage_id: Coverage ID to query. If None, raises MAPExtractorError. category: Category name ('Malaria' or 'Interventions'). Returns: List of band names/layers available for the coverage id. """ if coverage_id is None: - raise ValueError("coverage_id must be provided.") + raise MAPExtractorError("coverage_id must be provided.") if category is None: category = self.category return self._get_band_names(coverage_id=coverage_id, category=category) @@ -300,25 +300,28 @@ def download_indicator_raster( Path to the downloaded raster file. """ if output_path is None: - raise ValueError("output_path must be provided.") + raise MAPExtractorError("output_path must be provided.") if category and category != self.category: if category not in self.SUPPORTED_CATEGORIES: - raise ValueError(f"Supported categories: {self.SUPPORTED_CATEGORIES}.") + raise MAPExtractorError(f"Supported categories: {self.SUPPORTED_CATEGORIES}.") self._log_message(f"Switching category from '{self.category}' to '{category}'") self.category = category self.coverage_ids = self._list_coverage_ids_for_category() latest_coverage_id = self._latest_version_for_indicator(indicator) if latest_coverage_id is None: - raise ValueError(f"No coverage found for indicator '{indicator}' in category '{category}'.") + raise MAPExtractorError( + f"No coverage found for indicator '{indicator}' in category '{category}'." + ) self._log_message(f"Latest coverage ID for indicator '{indicator}': {latest_coverage_id}") available_times: dict = self._get_time_positions_for_coverage(latest_coverage_id) if target_year not in available_times: - raise ValueError( - f"Year {target_year} is not available for indicator '{indicator}'." - f" Available years: {available_times if available_times else 'No years available!'}" + years_present = sorted(available_times.keys(), reverse=True) if available_times else [] + raise MAPExtractorError( + f"Year {target_year} not available for indicator '{indicator}'. " + f"Available years: {years_present or 'No years available!'}" ) if shapes is not None: @@ -326,21 +329,20 @@ def download_indicator_raster( minx, miny, maxx, maxy = shapes.total_bounds bbox = [minx, miny, maxx, maxy] if bbox is None: - raise ValueError("Either bbox or shapes must be provided to define the area of interest.") - + raise MAPExtractorError("Either bbox or shapes must be provided to define the area of interest.") + + # Avoid downloading the same file in the provided folder + raster_fname = output_path / f"{latest_coverage_id}_{target_year}.tif" + if raster_fname.exists(): + if replace_file: + raster_fname.unlink() # delete existing file + self._log_message( + f"Raster exists, deleting and re-downloading: {raster_fname.name}", level="warning" + ) + else: + self._log_message(f"Raster already exists: {raster_fname.name}, skipping download.") + return raster_fname try: - # Avoid downloading the same file in the provided folder - raster_fname = output_path / f"{latest_coverage_id}_{target_year}.tif" - if raster_fname.exists(): - if replace_file: - raster_fname.unlink() # delete existing file - self._log_message( - f"Raster exists, deleting and re-downloading: {raster_fname.name}", level="warning" - ) - else: - self._log_message(f"Raster already exists: {raster_fname.name}, skipping download.") - return raster_fname - raster_path = self._download_raster( coverage_id=latest_coverage_id, bbox=bbox, diff --git a/snt_map_extracts/malariaAtlasProject/map_utils.py b/snt_map_extracts/malariaAtlasProject/map_utils.py index 35419bc2..5f54c13f 100644 --- a/snt_map_extracts/malariaAtlasProject/map_utils.py +++ b/snt_map_extracts/malariaAtlasProject/map_utils.py @@ -41,21 +41,20 @@ def load_tiff_bands( bands = {} try: - with rasterio.open(tif_path) as src: - count = src.count - transform = src.transform - crs = src.crs - nodata = src.nodata - - for idx in range(1, count + 1): - if band_names and idx <= len(band_names): - name = band_names[idx - 1] - else: - name = f"band_{idx}" - bands[name] = src.read(idx) - + src = rasterio.open(tif_path) except rasterio.errors.RasterioIOError as e: - raise RuntimeError(f"Failed to read TIFF file: {tif_file}") from e + raise RuntimeError(f"Failed to read TIFF file: {tif_file}. Details: {e}") from e + + with src: + count = src.count + transform = src.transform + crs = src.crs + nodata = src.nodata + data = src.read() + + for idx in range(1, count + 1): + name = band_names[idx - 1] if band_names and idx <= len(band_names) else f"band_{idx}" + bands[name] = data[idx - 1] return bands, transform, crs, nodata diff --git a/snt_map_extracts/pipeline.py b/snt_map_extracts/pipeline.py index aecb447b..8b3bb03b 100644 --- a/snt_map_extracts/pipeline.py +++ b/snt_map_extracts/pipeline.py @@ -1,54 +1,53 @@ from pathlib import Path + import geopandas as gpd import numpy as np -import pandas as pd -from datetime import datetime - -import logging -from openhexa.sdk import current_run, parameter, pipeline, workspace, File -import rasterio -from rasterio.warp import reproject, Resampling -from affine import Affine +import polars as pl +from malariaAtlasProject.map import MAPExtractorError, MAPRasterExtractor +from malariaAtlasProject.map_utils import ( + load_tiff_bands, + parse_raster_filename_vars, +) +from openhexa.sdk import current_run, parameter, pipeline, workspace from rasterstats import zonal_stats - -# from owslib.wcs import WebCoverageService from snt_lib.snt_pipeline_utils import ( - pull_scripts_from_repository, add_files_to_dataset, + get_file_from_dataset, load_configuration_snt, + pull_scripts_from_repository, run_report_notebook, - get_file_from_dataset, - validate_config, save_pipeline_parameters, + validate_config, ) -from malariaAtlasProject.map import MAPRasterExtractor, MAPExtractorError -from malariaAtlasProject.map_utils import ( - load_tiff_bands, - parse_raster_filename_vars, +from utils import ( + compute_population_weighted_metric, + generate_population_table_from_raster, + get_extract_periods, + load_raw_population_raster, ) +from worlpopclient import WorldPopClient # Ticket: # https://bluesquare.atlassian.net/browse/SNT25-143 (old pipeline) # https://bluesquare.atlassian.net/browse/SNT25-259 (old pipeline) # https://bluesquare.atlassian.net/browse/SNT25-284 +# https://bluesquare.atlassian.net/browse/SNT25-518 (include periods) @pipeline("snt_map_extracts") @parameter( - code="pop_raster_selection", - name="Population raster selection (.tif)", - type=File, - help="Select the population raster (.tif) used for population-weighted calculations.", - required=False, + code="year_start", + name="Year start", + help="Start year of indicators selection (e.g. 2022).", + type=int, default=None, + required=True, ) @parameter( - code="target_year", - name="Target Year", - help=( - "Target year for indicator selection (e.g. 2022). Defaults to latest if unavailable or not specified." - ), - type=str, + code="year_end", + name="Year end", + help="End year of indicators selection (e.g. 2023).", + type=int, default=None, required=True, ) @@ -67,14 +66,21 @@ default=False, required=False, ) -def snt_map_extracts( - pop_raster_selection: File, target_year: str, run_report_only: bool, pull_scripts: bool -) -> None: +def snt_map_extracts(year_start: int, year_end: int, run_report_only: bool, pull_scripts: bool) -> None: """Main function to get raster data for a dhis2 country.""" root_path = Path(workspace.files_path) pipeline_path = root_path / "pipelines" / "snt_map_extracts" pipeline_path.mkdir(parents=True, exist_ok=True) - logger = create_file_logger(log_path=pipeline_path / "logs") + + if year_start > year_end: + msg = f"Start period ({year_start}) must be less than or equal to end period ({year_end})." + current_run.log_warning(msg) + raise ValueError(msg) + + try: + validate_worldpop_periods(year_start, year_end) + except ValueError as e: + current_run.log_warning(f"Invalid period configuration: {e}") # pop is optional # Define indicators to download snt_indicators = { @@ -92,160 +98,302 @@ def snt_map_extracts( } if pull_scripts: - log_message(logger, "Pulling pipeline scripts from repository.") + current_run.log_info("Pulling pipeline scripts from repository.") pull_scripts_from_repository( pipeline_name="snt_map_extracts", report_scripts=["snt_map_extracts_report.ipynb"], code_scripts=[], ) - try: - # Load configuration - snt_config = load_configuration_snt(config_path=root_path / "configuration" / "SNT_config.json") - validate_config(snt_config) - country_code = snt_config["SNT_CONFIG"].get("COUNTRY_CODE") - dataset_id = snt_config["SNT_DATASET_IDENTIFIERS"].get("SNT_MAP_EXTRACTS") - - if not run_report_only: - output_path = root_path / "data" / "map" - output_path.mkdir(parents=True, exist_ok=True) - - # Validate population raster (optional) - if pop_raster_selection: - log_message(logger, f"Population raster selected: {pop_raster_selection.path}") - if not Path(pop_raster_selection.path).exists(): - raise FileNotFoundError(f"Population raster file not found: {pop_raster_selection.path}") - if Path(pop_raster_selection.path).suffix.lower() != ".tif": - raise ValueError("Population raster must be a '.tif' file.") - - make_table( - coverage_categories=snt_indicators, - snt_config=snt_config, - pop_raster_path=Path(pop_raster_selection.path) if pop_raster_selection else None, - target_year=target_year, - output_path=output_path, - logger=logger, - ) + # Load configuration + snt_config = load_configuration_snt(config_path=root_path / "configuration" / "SNT_config.json") + validate_config(snt_config) + country_code = snt_config["SNT_CONFIG"].get("COUNTRY_CODE") - parameters_file = save_pipeline_parameters( - pipeline_name="snt_map_extracts", - parameters={ - "pop_raster_selection": pop_raster_selection.path if pop_raster_selection else None, - "target_year": target_year, - "run_report_only": run_report_only, - "pull_scripts": pull_scripts, - }, - output_path=output_path, + shapes = retrieve_shapes(snt_config=snt_config) + if shapes is None: + current_run.log_error("No valid shapes available. Processing stopped.") + raise ValueError + + if not run_report_only: + output_path = root_path / "data" / "map" + output_path.mkdir(parents=True, exist_ok=True) + + parameters_file = save_pipeline_parameters( + pipeline_name="snt_map_extracts", + parameters={ + "year_start": year_start, + "year_end": year_end, + "run_report_only": run_report_only, + "pull_scripts": pull_scripts, + }, + output_path=output_path, + country_code=country_code, + ) + + periods = get_extract_periods(start=str(year_start), end=str(year_end)) + files_to_dataset = [] + for year in periods: + pop_table, pop_raster_path = get_or_download_population_table( + year=year, country_code=country_code, + shapes=shapes, + wpop_repo_path=root_path / "data" / "worldpop", + output_path=root_path / "data" / "map" / "aggregated_populations", ) - add_files_to_dataset( - dataset_id=dataset_id, + files_to_dataset += build_map_statistics_table( + coverage_categories=snt_indicators, + population_totals=pop_table, + pop_raster_path=pop_raster_path, + shapes=shapes, + target_year=year, country_code=country_code, - file_paths=[ - output_path / "formatted" / country_code / f"{country_code}_map_data.parquet", - output_path / "formatted" / country_code / f"{country_code}_map_data.csv", - parameters_file, - ], + output_path=output_path, ) - else: - log_message(logger, "Skipping calculations, running only the reporting.") + add_files_to_dataset( + dataset_id=snt_config["SNT_DATASET_IDENTIFIERS"].get("SNT_MAP_EXTRACTS"), + country_code=country_code, + file_paths=files_to_dataset + [parameters_file], + ) + + else: + current_run.log_info("Skipping calculations, running reporting.") + + run_report_notebook( + nb_file=pipeline_path / "reporting" / "snt_map_extracts_report.ipynb", + nb_output_path=pipeline_path / "reporting" / "outputs", + country_code=country_code, + ) + + current_run.log_info("Pipeline completed successfully!") + - run_report_notebook( - nb_file=pipeline_path / "reporting" / "snt_map_extracts_report.ipynb", - nb_output_path=pipeline_path / "reporting" / "outputs", +def get_or_download_population_table( + year: str, country_code: str, shapes: gpd.GeoDataFrame, wpop_repo_path: Path, output_path: Path +) -> tuple[pl.DataFrame | None, Path | None]: + """Check if population raster exists for the given year and country, if not, download it. + + Parameters + ---------- + year : str + The year for which to retrieve the population data. (e.g.: "2020"). + country_code : str + The 3-letter ISO code of the country (e.g.: "COD", "BFA"). + shapes : gpd.GeoDataFrame + GeoDataFrame containing the shapes for zonal statistics. + wpop_repo_path : Path + Path to the worldpop pipeline directory where rasters and population tables are stored. + output_path : Path + Path to save the generated population table if it needs to be created. + + Returns + ------- + Tuple[pl.DataFrame | None, Path | None] + A tuple containing: + - The population table as a Polars DataFrame if it was generated successfully, otherwise None. + - The path to the population raster used for generating the table, or None if not retrieved. + """ + pop_raster_path = list((wpop_repo_path / "rasters").glob(f"{country_code.lower()}_pop_{year}_*.tif")) + if pop_raster_path: + current_run.log_info(f"Population raster found for {year}: {pop_raster_path[0]}.") + pop_raster_path = pop_raster_path[0] + + else: + current_run.log_info(f"No population raster found for {year}. Attempting to download.") + pop_raster_path = retrieve_population_data( country_code=country_code, + year=year, + output_path=wpop_repo_path / "rasters", + overwrite=False, ) - log_message(logger, "Pipeline completed successfully!") + if not pop_raster_path: + current_run.log_warning( + f"Population raster could not be retrieved for {year}. Skipping population table generation." + ) + return None, None + current_run.log_info(f"Generating population table for {year} from raster: {pop_raster_path}.") + pop_table = generate_population_table_from_raster(raster_path=pop_raster_path, shapes=shapes) + + if pop_table is None: + current_run.log_warning(f"Population table could not be generated for {year}.") + return None, None + + output_path.mkdir(parents=True, exist_ok=True) + pop_table.write_parquet(output_path / f"{country_code}_worldpop_population_{year}.parquet") + current_run.log_info( + f"Population table generated and saved for {year} at " + f"{output_path / f'{country_code}_worldpop_population_{year}.parquet'}." + ) + return pop_table, pop_raster_path + + +def retrieve_population_data( + country_code: str, year: str, output_path: Path, overwrite: bool = False +) -> Path: + """Retrieve raster population data from worldpop. + + Parameters + ---------- + country_code : str + The 3-letter ISO code of the country (e.g.: "COD", "BFA"). + year : str, optional + The year for which to retrieve the population data. (e.g.: "2020"). + overwrite : bool, optional + Whether to overwrite existing files. Defaults to False. + output_path : Path + The directory where the population data will be saved. + + Returns + ------- + Path + The full path to the saved population raster file. + + """ + current_run.log_info("Retrieving population data grid from WorldPop.") + + wpop_client = WorldPopClient() + output_path.mkdir(parents=True, exist_ok=True) + country = country_code.upper() + + try: + pop_file_path = wpop_client.download_data_for_country( + country_iso3=country, + year=year, + output_dir=output_path, + overwrite=overwrite, + ) + current_run.log_info(f"Population raster successfully downloaded under: {pop_file_path}.") + return pop_file_path except Exception as e: - log_message(logger, f"Pipeline error: {e}", level="error") - raise e + raise Exception(f"Error retrieving WorldPop data for {country} {year}: {e}") from e -def make_table( - coverage_categories: dict, - snt_config: str, - pop_raster_path: Path, - target_year: str, - output_path: Path, - logger: logging.Logger, -) -> None: - """Generate a table of zonal statistics for given coverage indicators and save the results. +def retrieve_shapes(snt_config: dict) -> gpd.GeoDataFrame | None: + """Retrieve and validate shapes for the specified country. Parameters ---------- - coverage_categories : dict - Dictionary mapping categories to indicator layer names. - snt_config : str + snt_config : dict SNT configuration file. - pop_raster_path : Path - Path to the selected raster directory. - target_year : str - Target year for selecting indicator versions. - output_path : Path - Path to save the output files. - logger : logging.Logger - Logger for logging messages. + + Returns + ------- + gpd.GeoDataFrame + GeoDataFrame containing the valid shapes for the specified country. + """ country_code = snt_config["SNT_CONFIG"].get("COUNTRY_CODE") dataset_shapes_id = snt_config.get("SNT_DATASET_IDENTIFIERS", {}).get("DHIS2_DATASET_FORMATTED") shapes = get_file_from_dataset(dataset_shapes_id, f"{country_code}_shapes.geojson") - log_message(logger, f"Shapes loaded from dataset: {dataset_shapes_id}.") - # Check shapes: drop rows with null or None geometry (zonal_stats fails on None). - # Dropped ADM2 will not appear in map_data.parquet; in assemble_results they get NA for map - # indicators (left join on ADM2_ID). The shapes file in the dataset is not modified. - invalid_shapes = shapes[shapes.geometry.isna() | shapes.geometry.apply(lambda g: g is None)] + if shapes is None or shapes.shape[0] == 0: + current_run.log_warning("No shapes found in dataset.") + return None + + current_run.log_info(f"Shapes loaded from dataset: {dataset_shapes_id}.") + + # Drop None geometries — zonal_stats fails on None. + invalid_shapes = shapes[shapes.geometry.isna()] if len(invalid_shapes) > 0: - log_message( - logger, f"Dropping {len(invalid_shapes)} organisation units without geometry.", level="warning" - ) + current_run.log_warning(f"Dropping {len(invalid_shapes)} organisation units without geometry.") shapes = shapes[shapes.geometry.notna() & shapes.geometry.apply(lambda g: g is not None)] - # Drop empty geometries so rasterstats/shapely don't get invalid geometries + # Drop empty geometries (geometry not None but empty) empty_shapes = shapes[shapes.geometry.is_empty] if len(empty_shapes) > 0: - log_message( - logger, f"Dropping {len(empty_shapes)} organisation units with empty geometry.", level="warning" - ) + current_run.log_warning(f"Dropping {len(empty_shapes)} organisation units with empty geometry.") shapes = shapes[~shapes.geometry.is_empty] if len(shapes) == 0: - return + return None - rasters_path = output_path / "raster_files" / country_code + return shapes + + +def build_map_statistics_table( + coverage_categories: dict, + population_totals: pl.DataFrame | None, + pop_raster_path: Path | None, + shapes: gpd.GeoDataFrame, + target_year: str, + country_code: str, + output_path: Path, +) -> list[Path]: + """Generate a table of zonal statistics for given coverage indicators and save the results. + + Parameters + ---------- + coverage_categories : dict + Dictionary mapping categories to indicator layer names. + population_totals : pl.DataFrame | None + DataFrame containing total population values for each shape, or None if not available. + pop_raster_path : Path + Path to the selected raster directory. + shapes : gpd.GeoDataFrame + GeoDataFrame containing the shapes for zonal statistics. + target_year : str + Target year for selecting indicator versions. + country_code : str + The 3-letter ISO code of the country (e.g.: "COD", "BFA"). + output_path : Path + Path to save the output files. + + Returns + ------- + list[Path] + List of paths to the generated output files (parquet and csv). + """ + rasters_path = output_path / "raster_files" rasters_path.mkdir(parents=True, exist_ok=True) raster_files = retrieve_rasters( coverage_categories=coverage_categories, target_year=target_year, shapes=shapes, - logger=logger, rasters_path=rasters_path, ) if len(raster_files) == 0: - log_message(logger, "No raster files were downloaded. Exiting table generation.", level="warning") - return + current_run.log_warning( + f"No raster files were downloaded for year {target_year}. Exiting table generation." + ) + return [] - run_aggregations( - raster_files=raster_files, - shapes=shapes, - pop_raster_path=pop_raster_path, - snt_config=snt_config, - output_path=output_path / "formatted" / country_code, - logger=logger, - ) + try: + map_indicators = compute_zonal_statistics( + raster_files=raster_files, + shapes=shapes, + population_totals=population_totals, + pop_raster_path=pop_raster_path, + ) + except Exception as e: + current_run.log_error(f"Error during aggregation: {e}") + return [] + + if map_indicators.is_empty(): + current_run.log_warning(f"No valid statistics were computed for year {target_year}.") + return [] + + # Save file + out_dir = output_path / "formatted" + out_dir.mkdir(parents=True, exist_ok=True) + file_parquet = out_dir / f"{country_code}_map_data_{target_year}.parquet" + file_csv = out_dir / f"{country_code}_map_data_{target_year}.csv" + map_indicators.write_parquet(file_parquet) + map_indicators.write_csv(file_csv) + current_run.log_info(f"Output file saved under : {file_csv}") + + return [file_parquet, file_csv] def retrieve_rasters( coverage_categories: dict, target_year: str, shapes: gpd.GeoDataFrame, - logger: logging.Logger, rasters_path: Path, ) -> list[Path]: """Retrieve raster files for specified coverage categories and indicators. @@ -255,11 +403,11 @@ def retrieve_rasters( """ downloaded_rasters = [] for category, indicators in coverage_categories.items(): - log_message(logger, f"Processing category: {category}.") - map_extractor = MAPRasterExtractor(category=category, logger=logger) + current_run.log_info(f"Processing category: {category}.") + map_extractor = MAPRasterExtractor(category=category) for indicator in indicators: try: - log_message(logger, f"Downloading raster for indicator: {indicator}.") + current_run.log_info(f"Downloading raster for indicator: {indicator}.") raster_path = map_extractor.download_indicator_raster( indicator=indicator, target_year=target_year, @@ -269,58 +417,50 @@ def retrieve_rasters( ) downloaded_rasters.append(raster_path) except MAPExtractorError as e: - log_message(logger, f"Error downloading raster for {indicator}.", level="error", exc=e) + current_run.log_error(f"Error downloading raster for {indicator}. Details: {e}") continue return downloaded_rasters -def run_aggregations( +def compute_zonal_statistics( raster_files: list[Path], shapes: gpd.GeoDataFrame, + population_totals: pl.DataFrame | None, pop_raster_path: Path | None, - snt_config: str, - output_path: Path, - logger: logging.Logger, -): - """Run zonal statistics aggregations on the downloaded rasters.""" - country_code = snt_config["SNT_CONFIG"].get("COUNTRY_CODE") +) -> pl.DataFrame: + """Run zonal statistics aggregations on the downloaded rasters. + Returns: + A Polars DataFrame containing the aggregated statistics for each indicator and shape. + """ # 1. Load population raster (if available) - if not pop_raster_path: - log_message(logger, "Population raster file not provided.", level="warning") - pop_data = None + pop_data = pop_transform = pop_crs = pop_nodata = None + if pop_raster_path is None: + current_run.log_warning("Population raster file not provided.") else: pop_data, pop_transform, pop_crs, pop_nodata = load_raw_population_raster( - file_pattern=pop_raster_path.name, - raster_path=pop_raster_path.parent, - logger=logger, + raster_path=pop_raster_path, ) - pop_total = compute_total_populations( - shapes, data=pop_data, transform=pop_transform, crs=pop_crs, nodata=pop_nodata, logger=logger - ) - # Set nodata to np.nan if pop_data is not None: pop_data = pop_data.astype(float) pop_data[pop_data == pop_nodata] = np.nan # 2. Process each raster file - final_df = pd.DataFrame() + final_df = pl.DataFrame() for raster_file in raster_files: file_vars = parse_raster_filename_vars(raster_file) coverage_id = ( f"{file_vars['category']}__{file_vars['version']}_{file_vars['region']}_{file_vars['indicator']}" ) - bands = MAPRasterExtractor(category=file_vars["category"], logger=logger).get_band_names( - coverage_id=coverage_id - ) + bands = MAPRasterExtractor(category=file_vars["category"]).get_band_names(coverage_id=coverage_id) raster_data, raster_transform, raster_crs, raster_nodata = load_tiff_bands( raster_file, band_names=bands ) - log_message(logger, f"Computing {raster_file.name} statistics...") + current_run.log_info(f"Computing {raster_file.name} statistics...") ref_columns = ["ADM1_NAME", "ADM1_ID", "ADM2_NAME", "ADM2_ID"] bands_for_statistics = ["Data", "LCI", "UCI", "GRAY_INDEX"] stats_results = [] @@ -328,7 +468,7 @@ def run_aggregations( # Compute Zonal Statistics per layer for band in bands: if band in bands_for_statistics: - log_message(logger, f"Processing {file_vars['indicator']} band: {band}.") + current_run.log_info(f"Processing {file_vars['indicator']} band: {band}.") zstats = zonal_stats( vectors=shapes, raster=raster_data[band], @@ -357,415 +497,90 @@ def run_aggregations( pop_data=pop_data, pop_transform=pop_transform, pop_crs=pop_crs, - total_population=pop_total, + population_totals=population_totals, shapes=shapes, indicator=file_vars["indicator"], - logger=logger, ) if weighted_metric is not None: - # We can add population if we need it 'total_population' - melt_df = melt_df.merge( - weighted_metric[["ADM2_ID", "population_weighted"]], - on="ADM2_ID", - how="left", + melt_df = ( + pl.from_pandas(melt_df) + .join( + weighted_metric.select(["ADM2_ID", "population_weighted"]), + on="ADM2_ID", + how="left", + ) + .with_columns( + pl.col("population_weighted").cast(pl.Float64, strict=False), + pl.col("value").cast(pl.Float64, strict=False), + ) ) else: - melt_df["population_weighted"] = None # default + melt_df = pl.from_pandas(melt_df).with_columns( + pl.lit(None).cast(pl.Float64).alias("population_weighted") + ) + else: - melt_df["population_weighted"] = None # default + melt_df = pl.from_pandas(melt_df).with_columns( + pl.lit(None).cast(pl.Float64).alias("population_weighted") + ) stats_results.append(melt_df) - # Log missing bands + # Log missing bands for raster_file # NOTE: This is not an error, just info about the available layers per coverage, # some of them only have a GRAY_INDEX band missing = [s for s in ["Data", "LCI", "UCI"] if s not in bands] if bands == ["GRAY_INDEX"]: - log_message( - logger, + current_run.log_warning( f"{file_vars['indicator']} contains only the 'GRAY_INDEX' band; " f"no main indicator bands found.", - level="warning", ) elif missing: - log_message( - logger, + current_run.log_warning( f"{file_vars['indicator']} is missing bands: {missing}. Using available band(s): {bands}.", - level="warning", ) if len(stats_results) > 0: # Format results, add metadata - stats = pd.concat(stats_results, ignore_index=True) - stats["metric_category"] = file_vars["category"] - stats["metric_name"] = file_vars["indicator"] - stats["version"] = file_vars["version"] - stats["year"] = int(file_vars["year"]) - stats["value"] = pd.to_numeric(stats["value"], errors="coerce") + stats = pl.concat(stats_results) + stats = stats.with_columns( + [ + pl.lit(file_vars["category"]).alias("metric_category"), + pl.lit(file_vars["indicator"]).alias("metric_name"), + pl.lit(file_vars["version"]).alias("version"), + pl.lit(int(file_vars["year"])).alias("year"), + pl.col("value").cast(pl.Float64, strict=False), + ] + ) + final_df = pl.concat([final_df, stats]) # concat final table - # concat final table - final_df = pd.concat([final_df, stats], ignore_index=True) + if final_df.shape[0] == 0: + return pl.DataFrame() # No valid statistics computed, return empty DataFrame # SNT format - final_df.columns = [col.strip().upper() for col in final_df.columns] - final_df["METRIC_NAME"] = final_df["METRIC_NAME"].str.strip() - - # Save Output - output_path.mkdir(parents=True, exist_ok=True) - - # Save file - final_df.to_parquet(output_path / f"{country_code}_map_data.parquet", index=False) - final_df.to_csv(output_path / f"{country_code}_map_data.csv", index=False) - log_message(logger, f"Output file saved under : {output_path / f'{country_code}_map_data.csv'}") - - -def align_raster_to_reference( - data: np.ndarray, - crs: str, - transform: Affine, - reference_data: np.ndarray, - reference_crs: str, - reference_transform: Affine, - resampling: Resampling = Resampling.bilinear, -) -> np.ndarray: - """Align a metric raster to match a reference raster (CRS and shape). - - Parameters - ---------- - data : np.ndarray - 2D array of the metric raster. - crs : rasterio.crs.CRS or str - CRS of the metric raster. - transform : Affine - Affine transform of the metric raster. - reference_data : np.ndarray - 2D array of the reference raster. - reference_crs : rasterio.crs.CRS or str - CRS of the reference raster. - reference_transform : Affine - Affine transform of the reference raster. - resampling : rasterio.enums.Resampling - Resampling method (default: bilinear). - - Returns - ------- - np.ndarray - Metric raster reprojected and resampled to reference grid. - """ - reference_shape = reference_data.shape - aligned = np.empty(reference_shape, dtype=data.dtype) - - # Only reproject if CRS or shape/transform differ - if (crs != reference_crs) or (data.shape != reference_shape): - reproject( - source=data, - destination=aligned, - src_transform=transform, - src_crs=crs, - dst_transform=reference_transform, - dst_crs=reference_crs, - resampling=resampling, - ) - else: - # Already aligned - aligned[:] = data - - return aligned - - -def compute_population_weighted_metric( - metric_data: np.ndarray, - metric_transform: Affine, - metric_crs: str, - metric_nodata: float, - pop_data: np.ndarray, - pop_transform: Affine, - pop_crs: str, - total_population: pd.DataFrame, - shapes: gpd.GeoDataFrame, - indicator: str, - logger: logging.Logger, -) -> pd.Series: - """Compute weighted metric values for given shapes using population data. - - Parameters - ---------- - metric_data : np.ndarray - 2D array of the metric raster, nodata values set to np.nan. - metric_transform : Affine - Affine transform of the metric raster. - metric_crs : str - CRS of the metric raster. - metric_nodata : float - NoData value of the metric raster. - pop_data: - 2D array of the population raster, nodata values set to np.nan. - pop_transform: - Affine transform of the population raster. - pop_crs: - CRS of the population raster. - total_population: - DataFrame containing total populations for each shape. - shapes : gpd.GeoDataFrame - GeoDataFrame containing the shapes for zonal statistics. - indicator : str - Name of the indicator being processed. - logger : logging.Logger - Logger for logging messages. - - Returns - ------- - pd.Series - Series containing the weighted metric values for each shape or None if population data is unavailable. - """ - if any( - x is None - for x in (shapes, metric_data, metric_transform, metric_crs, pop_data, pop_transform, pop_crs) - ): - log_message( - logger, f"Population-weighted computation skipped for metric: {indicator}.", level="warning" - ) - return None - - log_message(logger, f"Computing population-weighted for metric: {indicator}.") - # Align metric raster to population raster (resolution and CRS) - metric_aligned = align_raster_to_reference( - data=metric_data, - crs=metric_crs, - transform=metric_transform, - reference_data=pop_data, - reference_crs=pop_crs, - reference_transform=pop_transform, - resampling=Resampling.nearest, # nearest repeats metric values - ) - - metric_aligned = metric_aligned.astype(float) - metric_aligned[metric_aligned == metric_nodata] = np.nan - - # Multiply - weighted_raster = pop_data * metric_aligned - zstats_w = zonal_stats( - vectors=shapes, - raster=weighted_raster, - affine=pop_transform, - stats=["sum"], - geojson_out=True, - nodata=np.nan, - ) - result_w = pd.DataFrame( - [ - { - "ADM2_ID": f["properties"].get("ADM2_ID"), - "weighted_sum": f["properties"]["sum"], - } - for f in zstats_w - ] - ) - result_w["ADM2_ID"] = result_w["ADM2_ID"].astype(str) - result = result_w.merge(total_population, on="ADM2_ID", how="left") - result["population_weighted"] = result["weighted_sum"] / result["total_population"] - return result - + final_df = final_df.rename({col: col.strip().upper() for col in final_df.columns}) + return final_df.with_columns(pl.col("METRIC_NAME").str.strip_chars()) -def load_raw_population_raster(file_pattern: str, raster_path: Path, logger: logging.Logger) -> tuple: - """Load raw population raster from the specified path. - Parameters - ---------- - file_pattern : str - Pattern to match the population raster file. - raster_path : Path - Path to the population raster file. - logger : logging.Logger - Logger for logging messages. +def validate_worldpop_periods(start: int, end: int) -> None: + """Validate that start and end periods are in the correct format and logical. - Returns - ------- - tuple | None - The loaded raster dataset or None if loading fails. + Raises + ------ + ValueError + If start or end are not valid integers or if start is greater than end. """ - raster_file = list(raster_path.glob(file_pattern)) - if not raster_file: - log_message(logger, f"Population raster not found: {raster_path}.", level="warning") - return None, None, None, None - - if len(raster_file) > 1: - log_message( - logger, - f"Expected 1 file but found {len(raster_file)}: {raster_file}. Using first match.", - level="warning", - ) - - try: - with rasterio.open(raster_file[0]) as src: - raster = src.read(1) - transform = src.transform # affine - crs = src.crs - nodata = src.nodata - log_message(logger, f"Population raster loaded: {raster_file[0]}.") - return raster, transform, crs, nodata - except Exception as e: - log_message(logger, f"Could not load population raster {raster_file[0]}", level="error", exc=e) - return None, None, None, None - - -def compute_total_populations( - shapes: gpd.GeoDataFrame, - data: np.ndarray, - transform: Affine, - crs: str, - nodata: float, - logger: logging.Logger, -) -> pd.DataFrame: - """Compute total populations for given shapes using population data. - - Parameters - ---------- - shapes : gpd.GeoDataFrame - GeoDataFrame containing the shapes for zonal statistics. - data : np.ndarray - 2D array of the population raster. - transform : Affine - Affine transform of the population raster. - crs : str - CRS of the population raster. - nodata : float - NoData value of the population raster. - logger : logging.Logger - Logger for logging messages. - - Returns - ------- - pd.Series - Series containing the total populations for each shape or None if data is unavailable. - """ - if any(x is None for x in (shapes, data, crs)): - return None - - # Ensure CRS matches the raster & reproject if necessary - if shapes.crs is None: - raise ValueError("Shapes GeoDataFrame must have a defined CRS.") - # Reproject shapes if CRS is different (consistent to wpop pipeline calculation check) - if shapes.crs.to_string() != crs: - log_message( - logger, - f"The CRS data differs from the provided shapes file. Reprojecting shapes with {crs}", - level="warning", - ) - shapes = shapes.to_crs(crs) - - # get statistics - log_message(logger, f"Computing ADM2 spatial aggregation for {len(shapes)} shapes.") - pop_total = zonal_stats( - vectors=shapes, - raster=data, - affine=transform, - stats=["sum"], - geojson_out=True, - nodata=nodata, - ) - result = pd.DataFrame( - [ - {"ADM2_ID": f["properties"].get("ADM2_ID"), "total_population": f["properties"]["sum"]} - for f in pop_total - ] - ) - - try: - result["total_population"] = result["total_population"].round(0).astype(int) - except Exception: - log_message( - logger, - "Could not convert total_population to int, is possible that all results are None.", - level="warning", + if not (2015 <= start <= 2030): + raise ValueError( + f"Start year {start} is out of range for population rasters available in repository (2015-2030)." + " (see: https://data.worldpop.org/GIS/Population/Global_2015_2030/R2025A/)" ) - result["ADM2_ID"] = result["ADM2_ID"].astype(str) - - return result - - -def create_file_logger(log_path: Path, level: int = logging.INFO) -> logging.Logger: - """Create a logger that writes messages to a file. - - Args: - log_path: Path to the log file. - level: Logging level (default INFO). - - Returns: - Configured logger. - """ - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - log_file = log_path / f"map_extractor_{timestamp}.log" - logger = logging.getLogger(str(log_file)) # unique name per file - logger.setLevel(level) - - # Avoid adding multiple handlers if logger already exists - if not logger.handlers: - # Ensure parent folder exists - log_file.parent.mkdir(parents=True, exist_ok=True) - - # Create file handler - fh = logging.FileHandler(log_file, mode="a", encoding="utf-8") - fh.setLevel(level) - - # Optional: also log to console - ch = logging.StreamHandler() - ch.setLevel(level) - - # Formatter - formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") - fh.setFormatter(formatter) - ch.setFormatter(formatter) - - # Add handlers - logger.addHandler(fh) - logger.addHandler(ch) - - return logger - - -def log_message( - logger: logging.Logger, message: str, level: str = "info", exc: Exception | None = None -) -> None: - """Log a message to both a standard logger and OpenHexa current_run logger.""" - if not message: - return - - level = level.lower() - logger_methods = { - "info": "info", - "warning": "warning", - "error": "error", - "debug": "debug", - } - run_methods = { - "info": "log_info", - "warning": "log_warning", - "error": "log_error", - "debug": "log_debug", - } - - if level not in logger_methods: + if not (2015 <= end <= 2030): raise ValueError( - f"Unsupported logging level: {level}. Supported levels: {list(logger_methods.keys())}" + f"End year {end} is out of range for population rasters available in repository (2015-2030)." + " (see: https://data.worldpop.org/GIS/Population/Global_2015_2030/R2025A/)" ) - # File logger - if logger: - if exc: - logger.error(message, exc_info=exc) - else: - getattr(logger, logger_methods[level])(message) - - # ---- OpenHexa UI logger (NO exception details) - try: - run_fn = getattr(current_run, run_methods[level], None) - if run_fn: - run_fn(message) - except Exception: - # Never let UI logging break the pipeline - pass - if __name__ == "__main__": snt_map_extracts() diff --git a/snt_map_extracts/utils.py b/snt_map_extracts/utils.py new file mode 100644 index 00000000..a28268a6 --- /dev/null +++ b/snt_map_extracts/utils.py @@ -0,0 +1,314 @@ +from openhexa.toolbox.dhis2.periods import period_from_string +from openhexa.sdk import current_run +import geopandas as gpd +import numpy as np +from pathlib import Path +import polars as pl +from rasterstats import zonal_stats +from pyproj import CRS +from affine import Affine +import rasterio +from rasterio.warp import Resampling, reproject + + +def get_extract_periods(start: str, end: str) -> list[str]: + """Generates a list of periods between start and end. + + Returns + ------- + list[str] + List of periods as strings (e.g. "2020", "202501"). + """ + try: + # Get periods + p1 = period_from_string(start) + p2 = period_from_string(end) + periods = [p1] if p1 == p2 else p1.get_range(p2) + return [str(p) for p in periods] + except Exception as e: + raise Exception(f"Error in start/end date configuration: {e!s}") from e + + +def compute_total_populations( + shapes: gpd.GeoDataFrame, + data: np.ndarray, + transform: Affine, + crs: rasterio.crs.CRS, + nodata: float, +) -> pl.DataFrame | None: + """Compute total populations for given shapes using population data. + + Parameters + ---------- + shapes : gpd.GeoDataFrame + GeoDataFrame containing the shapes for zonal statistics. + data : np.ndarray + 2D array of the population raster. + transform : Affine + Affine transform of the population raster. + crs : rasterio.crs.CRS + CRS of the population raster. + nodata : float + NoData value of the population raster. + + Returns + ------- + pl.DataFrame + DataFrame with ADM2_ID (Utf8) and total_population (Int64, nullable) columns. + """ + if any(x is None for x in (shapes, data, crs)): + current_run.log_warning("Total population computation skipped due to missing data or shapes.") + return None + + # Ensure CRS matches the raster & reproject if necessary + if shapes.crs is None: + raise ValueError("Shapes GeoDataFrame must have a defined CRS.") + + # Reproject shapes if CRS is different (consistent to wpop pipeline calculation check) + if shapes.crs != CRS.from_user_input(crs): + current_run.log_warning( + f"The CRS data differs from the provided shapes file. Reprojecting shapes with {crs}", + ) + shapes = shapes.to_crs(crs) + + # get statistics + current_run.log_info(f"Computing ADM2 spatial aggregation for {len(shapes)} shapes.") + pop_total = zonal_stats( + vectors=shapes, + raster=data, + affine=transform, + stats=["sum"], + geojson_out=True, + nodata=nodata, + ) + result = pl.DataFrame( + [ + {"ADM2_ID": f["properties"].get("ADM2_ID"), "total_population": f["properties"].get("sum")} + for f in pop_total + ], + schema={"ADM2_ID": pl.Utf8, "total_population": pl.Float64}, + ) + + return result.with_columns(pl.col("total_population").round(0).cast(pl.Int64, strict=False)) + + +def align_raster_to_reference( + data: np.ndarray, + crs: str, + transform: Affine, + reference_data: np.ndarray, + reference_crs: str, + reference_transform: Affine, + resampling: Resampling = Resampling.bilinear, +) -> np.ndarray: + """Align a metric raster to match a reference raster (CRS and shape). + + Parameters + ---------- + data : np.ndarray + 2D array of the metric raster. + crs : rasterio.crs.CRS or str + CRS of the metric raster. + transform : Affine + Affine transform of the metric raster. + reference_data : np.ndarray + 2D array of the reference raster. + reference_crs : rasterio.crs.CRS or str + CRS of the reference raster. + reference_transform : Affine + Affine transform of the reference raster. + resampling : rasterio.enums.Resampling + Resampling method (default: bilinear). + + Returns + ------- + np.ndarray + Metric raster reprojected and resampled to reference grid. + """ + reference_shape = reference_data.shape + aligned = np.empty(reference_shape, dtype=data.dtype) + + # Only reproject if CRS or shape/transform differ + if (crs != reference_crs) or (data.shape != reference_shape) or (transform != reference_transform): + reproject( + source=data, + destination=aligned, + src_transform=transform, + src_crs=crs, + dst_transform=reference_transform, + dst_crs=reference_crs, + resampling=resampling, + ) + else: + # Already aligned + aligned[:] = data + + return aligned + + +def compute_population_weighted_metric( + metric_data: np.ndarray, + metric_transform: Affine, + metric_crs: str, + metric_nodata: float, + pop_data: np.ndarray, + pop_transform: Affine, + pop_crs: str, + population_totals: pl.DataFrame | None, + shapes: gpd.GeoDataFrame, + indicator: str, +) -> pl.DataFrame | None: + """Compute weighted metric values for given shapes using population data. + + Parameters + ---------- + metric_data : np.ndarray + 2D array of the metric raster, nodata values set to np.nan. + metric_transform : Affine + Affine transform of the metric raster. + metric_crs : str + CRS of the metric raster. + metric_nodata : float + NoData value of the metric raster. + pop_data: + 2D array of the population raster, nodata values set to np.nan. + pop_transform: + Affine transform of the population raster. + pop_crs: + CRS of the population raster. + population_totals: + pl.DataFrame containing total populations for each shape. + If None, an empty column will be added to the final table with null values. + shapes : gpd.GeoDataFrame + GeoDataFrame containing the shapes for zonal statistics. + indicator : str + Name of the indicator being processed. + + Returns + ------- + pl.DataFrame | None + DataFrame with ADM2_ID, weighted_sum, total_population, and population_weighted columns. + """ + if any( + x is None + for x in (shapes, metric_data, metric_transform, metric_crs, pop_data, pop_transform, pop_crs) + ): + current_run.log_warning(f"Population-weighted computation skipped for metric: {indicator}.") + return None + + current_run.log_info(f"Computing population-weighted for metric: {indicator}.") + # Align metric raster to population raster (resolution and CRS) + metric_aligned = align_raster_to_reference( + data=metric_data, + crs=metric_crs, + transform=metric_transform, + reference_data=pop_data, + reference_crs=pop_crs, + reference_transform=pop_transform, + resampling=Resampling.nearest, # nearest repeats metric values + ) + + metric_aligned = metric_aligned.astype(float) + metric_aligned[metric_aligned == metric_nodata] = np.nan + + # Multiply + weighted_raster = pop_data * metric_aligned + zstats_w = zonal_stats( + vectors=shapes, + raster=weighted_raster, + affine=pop_transform, + stats=["sum"], + geojson_out=True, + nodata=np.nan, + ) + result_w = pl.DataFrame( + [ + { + "ADM2_ID": f["properties"].get("ADM2_ID"), + "weighted_sum": f["properties"].get("sum"), + } + for f in zstats_w + ] + ).with_columns( + pl.col("ADM2_ID").cast(pl.Utf8), + pl.col("weighted_sum").cast(pl.Float64), + ) + + if population_totals is None or population_totals.shape[0] == 0: + current_run.log_warning( + "Population totals not available. Population-weighted metric will set to null values." + ) + return result_w.with_columns( + pl.lit(None).cast(pl.Float64).alias("total_population"), + pl.lit(None).cast(pl.Float64).alias("population_weighted"), + ) + + return result_w.join(population_totals, on="ADM2_ID", how="left").with_columns( + (pl.col("weighted_sum") / pl.col("total_population")).alias("population_weighted") + ) + + +def load_raw_population_raster(raster_path: Path) -> tuple: + """Load raw population raster from the specified path. + + Parameters + ---------- + raster_path : Path + Path to the population raster file. + + Returns + ------- + tuple | None + The loaded raster dataset or None if loading fails. + """ + if not (raster_path).exists(): + current_run.log_warning(f"Population raster not found: {raster_path}.") + return None, None, None, None + + try: + with rasterio.open(raster_path) as src: + raster = src.read(1) + transform = src.transform # affine + crs = src.crs + nodata = src.nodata + except Exception as e: + current_run.log_warning(f"Could not load population raster {raster_path}. Error: {e}") + return None, None, None, None + + return raster, transform, crs, nodata + + +def generate_population_table_from_raster(raster_path: Path, shapes: gpd.GeoDataFrame) -> pl.DataFrame | None: + """Generate a population table from the given raster and shapes. + + Parameters + ---------- + raster_path : Path + Path to the population raster file. + shapes : gpd.GeoDataFrame + GeoDataFrame containing the shapes for zonal statistics. + + Returns + ------- + pl.DataFrame + Polars DataFrame containing the population data for each shape. + """ + # Load raster data (this is a placeholder, implement as needed) + with rasterio.open(raster_path) as src: + pop_data = src.read(1) + pop_transform = src.transform + pop_crs = src.crs + pop_nodata = src.nodata + + if pop_crs is None: + current_run.log_warning(f"Raster {raster_path} has no CRS defined, skipping population computation.") + return None + + # Compute total populations for each shape using zonal statistics + return compute_total_populations( + shapes=shapes, + data=pop_data, + transform=pop_transform, + crs=pop_crs, + nodata=pop_nodata, + ) diff --git a/snt_map_extracts/worlpopclient.py b/snt_map_extracts/worlpopclient.py new file mode 100644 index 00000000..98e56cde --- /dev/null +++ b/snt_map_extracts/worlpopclient.py @@ -0,0 +1,204 @@ +from pathlib import Path +import re +import requests +from openhexa.sdk import current_run +import logging + + +class WorldPopClient: + """Mini client for the WorldPop REST API. + + Source: https://data.worldpop.org/GIS/Population + """ + + def __init__( + self, url: str = "https://data.worldpop.org/GIS/Population", logger: logging.Logger | None = None + ) -> None: + """Initialize the client. + + Parameters + ---------- + url : str + The base URL for the WorldPop data download. + logger : logging.Logger, optional + A logger instance to use for logging messages. If None, a default logger will be created + """ + self.base_url = url + self.logger = logger or logging.getLogger(__name__) + + def download_data_for_country( + self, + country_iso3: str, + year: str, + output_dir: Path, + overwrite: bool = False, + filename: str | None = None, + ) -> Path: + """Download and save the WorldPop raster dataset for a given country and year. + + This operation is atomic. A partial download will not result in a corrupt + final file. + + Parameters + ---------- + country_iso3 : str + 3-letter ISO code of the country (e.g., "COD", "BFA"). + year : str + Year to filter the dataset (e.g., "2020"). + output_dir : Path + Directory to save the GeoTIFF file. + overwrite : bool, optional + Whether to overwrite the file if it already exists. Defaults to False. + filename : str, optional + Filename to save the raster data. If None, defaults to + "{country_iso3}_worldpop_population_{year}.tif". + + Returns + ------- + Path + Full path to the saved GeoTIFF file. + + Raises + ------ + ValueError + If the country_iso3 code is invalid. + IOError + If the file download or disk write fails. + """ + if not (isinstance(country_iso3, str) and len(country_iso3) == 3): + raise ValueError("country_iso3 must be a 3-letter string.") + + year_int = int(year) + if year_int < 2015 or year_int > 2030: # NOTE: We might want to change the url repo in the future. + raise ValueError( + f"WorldPop data not available for {year} " + "(see: https://data.worldpop.org/GIS/Population/Global_2015_2030/R2025A/)" + ) + + country_iso3 = country_iso3.upper() + candidate_url = self._build_url(country_iso3, year) + + # Determine the filename to save as + if filename: + fname = filename + else: + fname = Path(candidate_url).name + + destination_path = output_dir / fname + + if not overwrite and destination_path.exists(): + self._log(f"File {destination_path.name} already exists. Skipping download.", level="info") + return destination_path + + self._download_file(candidate_url, destination_path) + return destination_path + + def _build_url(self, country_iso3: str, year: str) -> str: + """Build download URL candidates. + + Parameters + ---------- + country_iso3 : str + Country ISO A3 code. + year : str, optional + Year of interest. + + Returns + ------- + Path + download URL candidate. + """ + # select latest release available + releases = self._list_remote_directories(url=f"{self.base_url}/Global_2015_2030/") + if not releases: + raise ValueError(f"No releases found at {self.base_url}/Global_2015_2030/") + latest_release = releases[0] + return ( + f"{self.base_url}/Global_2015_2030/{latest_release}/{year}/{country_iso3.upper()}/" + f"v1/100m/constrained/{country_iso3.lower()}_pop_{year}_CN_100m_{latest_release}_v1.tif" + ) + + def _list_remote_directories(self, url: str) -> list[str]: + """List folder names available at an HTTP directory listing URL. + + Returns + ------- + list[str] + Directory names found at the URL, sorted in reverse alphabetical order. + """ + response = requests.get(url, timeout=10) + response.raise_for_status() + return sorted(set(re.findall(r'href="([^/"]+)/"', response.text)), reverse=True) + + def _download_file(self, url: str, destination_path: Path) -> None: + """Download a WorldPop raster from URL.""" + try: + self._log(f"Download WorldPop raster data from URL: {url}") + self._atomic_download(url, destination_path) + return + except OSError as err: + raise OSError(f"WorldPop URL '{url}' failed to download. Details: {err}") from err + + def _atomic_download( + self, url: str, destination_path: Path, session: requests.Session | None = None + ) -> None: + """Downloads a file from a URL to a destination path atomically. + + It downloads to a temporary file first and renames it upon success, + preventing partial/corrupt files. + + Parameters + ---------- + url : str + The URL of the file to download. + destination_path : Path + The final path to save the file. + session : requests.Session, optional + An existing requests session to use for the download (in the case of reusing connections). + + Raises + ------ + requests.HTTPError + If the download fails with a non-200 status code. + OSError + If the file cannot be written to disk. + """ + # Download to a temporary path + temp_path = destination_path.with_suffix(destination_path.suffix + ".part") + destination_path.parent.mkdir(parents=True, exist_ok=True) + http_client = session or requests + + try: + with http_client.get(url, stream=True, timeout=30) as response: + response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx) + with Path.open(temp_path, "wb") as f: + for chunk in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks + f.write(chunk) + # If download is successful, rename the temp file to the final destination + temp_path.rename(destination_path) + + except (requests.RequestException, OSError) as e: + raise OSError(f"Failed to download or write file from {url}: {e}") from e + finally: + if temp_path.exists(): # Clean up the partial file + try: + temp_path.unlink() + except OSError as e: + self._log(f"Failed to remove partial file {temp_path}: {e}", level="warning") + + def _log(self, message: str, level: str = "info") -> None: + """Log a message using the Python logger and/or the OpenHEXA current_run, if available.""" + if self.logger: + log_method = getattr(self.logger, level, self.logger.info) + log_method(message) + if current_run is not None: + if level == "info": + current_run.log_info(message) + elif level == "warning": + current_run.log_warning(message) + elif level == "error": + current_run.log_error(message) + elif level == "debug": + current_run.log_debug(message) + elif level == "critical": + current_run.log_critical(message)