import logging
from collections import defaultdict
from functools import lru_cache
from itertools import product
from typing import TYPE_CHECKING
from warnings import warn
import pandas as pd
import xarray as xr
from iam_units import registry
from sdmx.model.v21 import Code
from message_ix_models import Context
from message_ix_models.model.structure import get_codes
from message_ix_models.util import load_package_data, package_data_path
log = logging.getLogger(__name__)
if TYPE_CHECKING:
from message_ix_models import ScenarioInfo
#: Configuration files.
METADATA = [
# Information about MESSAGE-water
("water", "config"),
("water", "set"),
("water", "technology"),
]
# Conversion factors used in the water module
MONTHLY_CONVERSION = (
(30 * registry.day / registry.month).to_base_units().magnitude
) # MCM/day to MCM/month
# Convert USD/(m³/day) to USD/MCM: m³/day * 365 days/year / 1e6 m³/MCM
USD_M3DAY_TO_USD_MCM = (registry("m^3/day").to("m^3/year").magnitude) / 1e6
USD_KM3_TO_USD_MCM = registry("USD/km^3").to("USD/m^3").magnitude * 1e6
GWa_KM3_TO_GWa_MCM = registry("GWa/km^3").to("GWa/m^3").magnitude * 1e6
ANNUAL_CAPACITY_FACTOR = 5 # Convert 5-year capacity to annual
# Convert km³ to MCM: 1 km³ = 1e9 m³, 1 MCM = 1e6 m³, so factor = 1000
KM3_TO_MCM = registry("1 km^3").to("meter^3").magnitude / 1e6 # km³ to MCM conversion
kWh_m3_TO_GWa_MCM = registry("kWh/m^3").to("GWa/m^3").magnitude * 1e6
# Convert m3/GJ to MCM/GWa
m3_GJ_TO_MCM_GWa = registry("m^3/GJ").to("m^3/GWa").magnitude / 1e6
# MCM not standard so have to remember to divide by 1e6 each time.
[docs]
def read_config(context: Context | None = None):
"""Read the water model configuration / metadata from file.
Numerical values are converted to computation-ready data structures.
Returns
-------
.Context
The current Context, with the loaded configuration.
"""
context = context or Context.get_instance(-1)
# if context.nexus_set == 'nexus':
if "water set" in context:
# Already loaded
return context
# Load water configuration
for parts in METADATA:
# Key for storing in the context
key = " ".join(parts)
# Actual filename parts; ends with YAML
_parts = list(parts)
_parts[-1] += ".yaml"
context[key] = load_package_data(*_parts)
return context
[docs]
def filter_basins_by_region(
df_basins: pd.DataFrame,
context: Context | None = None,
n_per_region: int = 3,
) -> pd.DataFrame:
"""Filter basins based on context configuration.
Selection is two-step:
1. **Automatic selection** — either ``"first_k"`` (head *n* per region) or
``"stress"`` (diversity-sampled across the demand/supply ratio spectrum),
controlled by ``context.basin_selection``.
2. **filter_list augmentation** — if ``context.filter_list`` is set, those
basins are *added* to the automatic set (union, not replacement).
Parameters
----------
df_basins : pd.DataFrame
DataFrame with basin data including 'REGION' and 'BCU_name' columns.
context : Context, optional
Context object that may contain:
- ``reduced_basin`` (bool): enable filtering (default False).
- ``basin_selection`` (str): ``"first_k"`` or ``"stress"``
(default ``"first_k"``).
- ``num_basins`` (int): override *n_per_region*.
- ``filter_list`` (list[str]): additional BCU_name values to include
on top of the automatic selection.
n_per_region : int, default 3
Default number of basins to keep per region (used as fallback).
Returns
-------
pd.DataFrame
Filtered DataFrame based on configuration.
"""
if not context:
context = Context.get_instance(-1)
# Check if reduced basin filtering is enabled
reduced_basin = getattr(context, "reduced_basin", False)
if not reduced_basin:
# No filtering, return original dataframe
log.info("Basin filtering disabled, returning all basins")
return df_basins
# Basin filtering is enabled — run automatic selection, then augment with
# filter_list if provided.
filter_list = getattr(context, "filter_list", None)
num_basins = getattr(context, "num_basins", None)
basin_selection = getattr(context, "basin_selection", "first_k")
if num_basins is None:
log.info(f"num_basins not set, using default n_per_region={n_per_region}")
elif num_basins < 3:
log.warning(
f"num_basins={num_basins} is below 3; results may not capture "
f"sufficient basin diversity per region"
)
# Step 1: automatic selection (stress or first_k)
if "REGION" not in df_basins.columns:
log.info("REGION column not found, cannot filter by region")
return df_basins
if basin_selection == "stress":
n = num_basins if num_basins is not None else n_per_region
ssp = getattr(context, "ssp", "SSP2")
stress_df = compute_basin_demand_ratio(context.regions, ssp=ssp)
selected = _select_by_stress(stress_df, n_per_region=n)
filtered = df_basins[df_basins["BCU_name"].isin(selected)]
log.info(
f"Stress-based selection: {len(df_basins)} -> {len(filtered)} basins "
f"(n_per_region={n})"
)
else:
if num_basins is not None:
n_per_region = num_basins
mask = df_basins.groupby("REGION").cumcount() < n_per_region
filtered = df_basins[mask]
log.info(
f"first_k selection: {len(df_basins)} -> {len(filtered)} basins "
f"(n_per_region={n_per_region})"
)
# Step 2: augment with filter_list (additive — union with automatic selection)
if filter_list:
extra = df_basins[
df_basins["BCU_name"].isin(filter_list)
& ~df_basins["BCU_name"].isin(filtered["BCU_name"])
]
if len(extra):
log.info(
f"filter_list adds {len(extra)} basins on top of automatic selection"
)
filtered = pd.concat([filtered, extra], ignore_index=True)
return filtered.reset_index(drop=True)
[docs]
def compute_basin_demand_ratio(
regions: str = "R12",
ssp: str = "SSP2",
demand_year: int = 2050,
) -> pd.DataFrame:
"""Compute basin-level demand/supply ratio from pre-build CSV data.
Demand = urban + rural + manufacturing withdrawals (MCM/year).
Supply = (surface water + groundwater recharge) mean across years (km3 -> MCM).
Parameters
----------
regions : str
Region codelist (e.g. "R12").
ssp : str
SSP scenario for demand file naming.
demand_year : int
Year to use for demand values (later years show higher stress).
Returns
-------
pd.DataFrame
Columns: BCU_name, REGION, supply_mcm, demand_mcm, demand_ratio.
"""
ssp_label = ssp.lower().replace("ssp", "ssp") # SSP2 -> ssp2
basins = pd.read_csv(
package_data_path(
"water", "delineation", f"basins_by_region_simpl_{regions}.csv"
)
)
# Supply: surface + groundwater, mean across year columns, km3 -> MCM
qtot = pd.read_csv(
package_data_path(
"water", "availability", f"qtot_5y_no_climate_low_{regions}.csv"
)
).drop(columns=["Unnamed: 0"], errors="ignore")
qr = pd.read_csv(
package_data_path(
"water", "availability", f"qr_5y_no_climate_low_{regions}.csv"
)
).drop(columns=["Unnamed: 0"], errors="ignore")
supply_mcm = (qtot.mean(axis=1) + qr.mean(axis=1)) * KM3_TO_MCM
# Demand: urban + rural + manufacturing withdrawals at demand_year
demand_path = package_data_path("water", "demands", "harmonized", regions)
demand_files = [
f"{ssp_label}_regional_urban_withdrawal2_baseline.csv",
f"{ssp_label}_regional_rural_withdrawal_baseline.csv",
f"{ssp_label}_regional_manufacturing_withdrawal_baseline.csv",
]
total_demand = pd.Series(0.0, index=basins["BCU_name"].astype(str))
for fname in demand_files:
df = pd.read_csv(demand_path / fname)
row = df[df.iloc[:, 0] == demand_year]
if row.empty:
log.warning(f"Year {demand_year} not found in {fname}")
continue
vals = row.iloc[0, 1:].astype(float)
# Align by basin name
for bcu in total_demand.index:
if bcu in vals.index:
total_demand[bcu] += vals[bcu]
result = pd.DataFrame(
{
"BCU_name": basins["BCU_name"],
"REGION": basins["REGION"],
"supply_mcm": supply_mcm.values,
}
)
result["demand_mcm"] = result["BCU_name"].astype(str).map(total_demand).fillna(0.0)
safe_supply = result["supply_mcm"].replace(0, float("inf"))
result["demand_ratio"] = result["demand_mcm"] / safe_supply
return result
def _diversity_select(group_sorted: pd.DataFrame, n_per_region: int) -> set[str]:
"""Select basins spanning a range via evenly spaced quantile positions.
Parameters
----------
group_sorted : pd.DataFrame
Single-region subset, pre-sorted by the target metric.
n_per_region : int
Target number of basins.
Returns
-------
set[str]
Selected BCU_name values.
"""
n = len(group_sorted)
if n <= n_per_region:
return set(group_sorted["BCU_name"])
if n_per_region == 1:
return {group_sorted.iloc[n // 2]["BCU_name"]}
if n_per_region == 2:
return {
group_sorted.iloc[0]["BCU_name"],
group_sorted.iloc[-1]["BCU_name"],
}
positions = [i / (n_per_region - 1) for i in range(n_per_region)]
indices = {int(round(p * (n - 1))) for p in positions}
return {group_sorted.iloc[i]["BCU_name"] for i in indices}
def _select_by_stress(
stress_df: pd.DataFrame,
n_per_region: int = 3,
) -> set[str]:
"""Select basins spanning the demand/supply ratio range per region.
Ensures the reduced model includes basins across the stress spectrum:
low-stress (demand << supply) through high-stress (demand ~ supply).
Parameters
----------
stress_df : pd.DataFrame
Output of compute_basin_demand_ratio().
n_per_region : int
Target number of basins per region.
"""
selected: set[str] = set()
for region, group in stress_df.groupby("REGION"):
group_sorted = group.sort_values("demand_ratio").reset_index(drop=True)
basins = _diversity_select(group_sorted, n_per_region)
selected.update(basins)
log.info(f"{region}: {len(basins)} basins selected")
return selected
[docs]
@lru_cache()
def map_add_on(rtype=Code):
"""Map addon & type_addon in ``sets.yaml``."""
dims = ["add_on", "type_addon"]
# Retrieve configuration
context = read_config()
# Assemble group information
result = defaultdict(list)
for indices in product(*[context["water set"][d]["add"] for d in dims]):
# Create a new code by combining two
result["code"].append(
Code(
id="".join(str(c.id) for c in indices),
name=", ".join(str(c.name) for c in indices),
)
)
# Tuple of the values along each dimension
result["index"].append(tuple(c.id for c in indices))
if rtype == "indexers":
# Three tuples of members along each dimension
indexers = zip(*result["index"])
indexers = {
d: xr.DataArray(list(i), dims="consumer_group")
for d, i in zip(dims, indexers)
}
indexers["consumer_group"] = xr.DataArray(
[c.id for c in result["code"]],
dims="consumer_group",
)
return indexers
elif rtype is Code:
return sorted(result["code"], key=str)
else:
raise ValueError(rtype)
def add_commodity_and_level(df: pd.DataFrame, default_level=None):
# Add input commodity and level
t_info: list = Context.get_instance()["water set"]["technology"]["add"]
c_info: list = get_codes("commodity")
@lru_cache()
def t_cl(t):
input = t_info[t_info.index(t)].annotations["input"]
# Commodity must be specified
commodity = input["commodity"]
# Use the default level for the commodity in the RES (per
# commodity.yaml)
level = (
input.get("level", "water_supply")
or c_info[c_info.index(commodity)].annotations.get("level", None)
or default_level
)
return commodity, level
def func(row: pd.Series):
row[["commodity", "level"]] = t_cl(row["technology"])
return row
return df.apply(func, axis=1)
[docs]
def get_vintage_and_active_years(
info: "ScenarioInfo",
technical_lifetime: int | None = None,
same_year_only: bool = False,
) -> pd.DataFrame:
"""Get valid vintage and active year combinations.
This implements similar logic as scenario.vintage_and_active_years() but
uses the technical lifetime data directly instead of requiring it to be in
the scenario first.
Parameters
----------
info : ScenarioInfo
Contains the base yv_ya combinations and duration_period data
technical_lifetime : int, optional
Technical lifetime in years. If None, returns all combinations.
same_year_only : bool, optional
If True, returns only combinations where year_vtg == year_act.
Useful for dummy technologies where vintage doesn't matter.
Returns
-------
pd.DataFrame
DataFrame with columns ['year_vtg', 'year_act'] containing valid combinations
"""
# Get base yv_ya from ScenarioInfo property
yv_ya = info.yv_ya
# If same_year_only is requested, return only year_vtg == year_act combinations
if same_year_only:
same_year_mask = yv_ya["year_vtg"] == yv_ya["year_act"]
return yv_ya[same_year_mask].reset_index(drop=True)
# If no technical lifetime specified or is nan, default to same year
if technical_lifetime is None or pd.isna(technical_lifetime):
warn(
"no technical_lifetime provided, defaulting to same year",
UserWarning,
)
same_year_mask = yv_ya["year_vtg"] == yv_ya["year_act"]
return yv_ya[same_year_mask].reset_index(drop=True)
# Memory optimization: use same-year logic for short-lived technologies
# to reduce unused equations. Time steps are 5-year intervals pre-2060,
# 10-year intervals post-2060. Short lifetimes don't benefit from
# advance construction.
kink_year = 2060
has_post_kink = (yv_ya["year_act"] >= kink_year).any()
short_lived = technical_lifetime <= 5
medium_lived = technical_lifetime <= 10 and has_post_kink
if short_lived or medium_lived:
# Pre-2060: use same-year if lifetime <= 5
# Post-2060: use same-year if lifetime <= 10
if short_lived:
# Same-year for entire horizon
same_year_mask = yv_ya["year_vtg"] == yv_ya["year_act"]
return yv_ya[same_year_mask].reset_index(drop=True)
else:
# Same-year only for post-2060, normal logic for pre-2060
pre_kink = yv_ya[yv_ya["year_act"] < kink_year]
post_kink = yv_ya[yv_ya["year_act"] >= kink_year]
# Pre-2060: normal lifetime filtering
age = pre_kink["year_act"] - pre_kink["year_vtg"]
pre_kink_filtered = pre_kink[age <= technical_lifetime]
# Post-2060: same-year only
same_yr = post_kink["year_vtg"] == post_kink["year_act"]
post_kink_same_year = post_kink[same_yr]
result = pd.concat(
[pre_kink_filtered, post_kink_same_year], ignore_index=True
)
return result.reset_index(drop=True)
# Apply simple lifetime logic: year_act - year_vtg <= technical_lifetime
condition_values = yv_ya["year_act"] - yv_ya["year_vtg"]
valid_mask = condition_values <= technical_lifetime
result = yv_ya[valid_mask].reset_index(drop=True)
return result