Source code for message_ix_models.tests.test_util

"""Tests of :mod:`message_ix_models.util`."""

import logging
import re
from importlib.metadata import version
from pathlib import Path

import numpy as np
import pandas as pd
import pytest
from iam_units import registry
from ixmp.testing import assert_logs
from message_ix import Scenario, make_df
from message_ix.testing import make_dantzig
from packaging.version import parse
from pandas.testing import assert_series_equal

from message_ix_models import ScenarioInfo
from message_ix_models.util import (
    MESSAGE_DATA_PATH,
    MESSAGE_MODELS_PATH,
    as_codes,
    broadcast,
    check_support,
    convert_units,
    copy_column,
    ffill,
    iter_parameters,
    load_package_data,
    load_private_data,
    local_data_path,
    make_source_tech,
    maybe_query,
    package_data_path,
    path_fallback,
    private_data_path,
    replace_par_data,
    same_node,
    same_time,
    series_of_pint_quantity,
    strip_par_data,
)

_actual_package_data = Path(__file__).parents[1].joinpath("data")


[docs] def test_as_codes(): """Forward reference to a child is silently dropped.""" data = dict( foo=dict(child=["bar"]), bar=dict(name="Bar!"), ) result = as_codes(data) assert result[1] not in result[0].child # With Codes already, the function is a pass-through assert result == as_codes(result)
[docs] def test_broadcast(caplog): # Base data frame to be broadcast, with 2 rows and dimensions: # - a: length 2 # - b, c, d: missing N_a = 2 base = pd.DataFrame([["a0", 1.2], ["a1", 3.4]], columns=["a", "value"]).assign( b=None, c=None, d=None ) # broadcast works with DataFrame.pipe(), using keyword arguments result = base.pipe( broadcast, b="b0 b1 b2".split(), c="c0 c1 c2 c3".split(), d=["d0"] ) # Results have the expected length: original × cartesian product of 3, 4, and 1 assert N_a * 3 * 4 * 1 == len(result) # Resulting array is completely full, no missing labels assert not result.isna().any(axis=None) # Length zero labels for one dimension—debug message is logged with caplog.at_level(logging.DEBUG, logger="message_ix_models"): result = base.pipe(broadcast, b="b0 b1".split(), c="c0 c1".split(), d=[]) # Debug message is logged assert "Don't broadcast over 'd'; labels [] have length 0" in caplog.messages caplog.clear() assert N_a * 2 * 2 * 1 == len(result) # Expected length assert result["d"].isna().all() # Dimension d remains empty assert not result.drop("d", axis=1).isna().any(axis=None) # Others completely full # Using a DataFrame as the first/only positional argument, plus keyword arguments labels = pd.DataFrame(dict(b="b0 b1 b2".split(), c="c0 c1 c2".split())) result = base.pipe(broadcast, labels, d="d0 d1".split()) assert N_a * 3 * 2 == len(result) # (b, c) dimensions linked with 3 pairs of labels assert not result.isna().any(axis=None) # Completely full # Using a positional argument with only 1 column result = base.pipe(broadcast, labels[["b"]], c="c0 c1 c2 c3".split(), d=["d0"]) assert N_a * 3 * 4 * 1 == len(result) # Expected length assert not result.isna().any(axis=None) # Completely full # Overlap between columns in the positional argument and keywords with pytest.raises(ValueError): result = base.pipe(broadcast, labels, c="c0 c1 c2 c3".split(), d=["d0"]) # Extra, invalid dimensions result in ValueError with pytest.raises(ValueError): base.pipe(broadcast, b="b0 b1 b2".split(), c="c0 c1 c2 c3".split(), e=["e0"]) labels["e"] = "e0 e1 e2".split() with pytest.raises(ValueError): base.pipe(broadcast, labels, d=["d0"])
[docs] @pytest.mark.parametrize( "data", ( set(), # dict() with a value that is not a str or a further dict() dict(foo="foo", bar=[1, 2, 3]), ), ) def test_as_codes_invalid(data): """as_codes() rejects invalid data.""" with pytest.raises(TypeError): as_codes(data)
[docs] def test_check_support(test_context): """:func:`.check_support` raises an exception for missing/non-matching values.""" args = [test_context, dict(regions=["R11", "R14"]), "Test data available"] # Setting not set → KeyError with pytest.raises(KeyError, match="baz"): check_support(test_context, dict(baz=["baz"]), "Baz is not set") # Accepted value test_context.regions = "R11" check_support(*args) # Wrong setting test_context.regions = "FOO" with pytest.raises( NotImplementedError, match=re.escape("Test data available for ['R11', 'R14']; got 'FOO'"), ): check_support(*args)
[docs] def test_convert_units(recwarn): """:func:`.convert_units` and :func:`.series_of_pint_quantity` work.""" # Common arguments args = [pd.Series([1.1, 10.2, 100.3], name="bar"), dict(bar=(10.0, "lb", "kg"))] exp = series_of_pint_quantity( [registry("4.9895 kg"), registry("46.2664 kg"), registry("454.9531 kg")], ) # With store="quantity", a series of pint.Quantity is returned result = convert_units(*args, store="quantity") # Will raise a DimensionalityError if units are not equal ratios = [(a / b) for a, b in zip(exp.values, result.values)] # Assert equal units and sufficiently close values for ratio in ratios: assert ratio.dimensionless and np.isclose(ratio, 1, atol=1e-4) # With store="magnitude", a series of floats exp = pd.Series([q.magnitude for q in exp.values], name="bar") assert_series_equal(exp, convert_units(*args, store="magnitude"), check_dtype=False) # Other values for store= are errors with pytest.raises(ValueError, match="store = 'foo'"): convert_units(*args, store="foo") # series_of_pint_quantity() successfully caught warnings assert 0 == len(recwarn)
[docs] def test_copy_column(): df = pd.DataFrame([[0, 1], [2, 3]], columns=["a", "b"]) df = df.assign(c=copy_column("a"), d=4) assert all(df["c"] == [0, 2]) assert all(df["d"] == 4)
[docs] def test_ffill(): years = list(range(6)) df = ( make_df( "fix_cost", year_act=[0, 2, 4], year_vtg=[0, 2, 4], technology=["foo", "bar", "baz"], unit="USD", ) .pipe(broadcast, node_loc=["A", "B", "C"]) .assign(value=list(map(float, range(9)))) ) # Function completes result = ffill(df, "year_vtg", years, "year_act = year_vtg") assert 2 * len(df) == len(result) assert years == sorted(result["year_vtg"].unique()) # Cannot ffill on "value" and "unit" dimensions with pytest.raises(ValueError, match="value"): ffill(df, "value", [])
# TODO test some specific values
[docs] @pytest.mark.skipif( parse(version("ixmp")) > parse("3.7.0"), reason="Not used with ixmp > 3.7.0" ) def test_iter_parameters(test_context): """Parameters indexed by set 'node' can be retrieved.""" result = list(iter_parameters("node")) assert result[0] == "abs_cost_activity_soft_lo" assert result[-1] == "var_cost" # The length of this list depends on message_ix. Changes in message_ix may increase # the number of parameters, so use <= to future-proof. See the method comments. assert 99 <= len(result)
[docs] @pytest.mark.parametrize("path", _actual_package_data.rglob("*.yaml")) def test_load_package_data(path): """Existing package data can be loaded.""" load_package_data(*path.relative_to(_actual_package_data).parts)
[docs] def test_load_package_data_twice(caplog): """Loading the same data twice logs a message.""" caplog.set_level(logging.DEBUG, logger="message_ix_models") load_package_data("node", "R11") load_package_data("node", "R11") assert "'node R11' already loaded; skip" in caplog.messages
[docs] def test_load_package_data_invalid(): """load_package_data() raises an exception for an unsupported file type.""" with pytest.raises(ValueError): load_package_data("test.xml")
[docs] @pytest.mark.xfail( condition=MESSAGE_DATA_PATH is None, reason="Requires message_data to be installed." ) def test_load_private_data(*parts, suffix=None): load_private_data("sources.yaml")
_MST_COMMON = dict( commodity="commodity", level="level", mode="mode", technology="technology", time="time", time_dest="time", unit="unit", ) _MST_VALUES = dict( capacity_factor=1.0, output=2.0, var_cost=3.0, technical_lifetime=4.0, )
[docs] def test_make_source_tech0(): info = ScenarioInfo() info.set["node"] = ["World", "node0", "node1"] info.set["year"] = [1, 2, 3] values = _MST_VALUES.copy() # Code runs result = make_source_tech(info, _MST_COMMON, **values) # Result is dictionary with the expected keys assert isinstance(result, dict) assert set(result.keys()) == set(values.keys()) # "World" node does not appear in results assert set(result["output"]["node_loc"].unique()) == set(info.N[1:]) for df in result.values(): # Results have 2 nodes × 3 years assert len(df) == 2 * 3 # No empty values assert not df.isna().any(axis=None) del values["var_cost"] with pytest.raises(ValueError, match=re.escape("needs values for {'var_cost'}")): make_source_tech(info, _MST_COMMON, **values)
[docs] def test_make_source_tech1(test_mp): """Test make_source_tech() with a Scenario object as input.""" s = Scenario(test_mp, model="model", scenario="scenario", version="new") s.add_set("node", ["World", "node0", "node1"]) s.add_set("technology", ["t"]) s.add_horizon([1, 2, 3]) s.commit("") make_source_tech(s, _MST_COMMON, **_MST_VALUES)
[docs] def test_maybe_query(): """:func:`.maybe_query` works as intended.""" s = pd.Series( [0, 1, 2, 3], index=pd.MultiIndex.from_product( [["a", "b"], ["c", "d"]], names=["foo", "bar"] ), ) # No-op assert_series_equal(s, maybe_query(s, None)) # Select a few rows assert 2 == len(maybe_query(s, "bar == 'c'"))
[docs] def test_local_data_path(tmp_path_factory, session_context): assert tmp_path_factory.getbasetemp().joinpath( "data0", "foo", "bar" ) == local_data_path("foo", "bar")
[docs] def test_package_data_path(): assert MESSAGE_MODELS_PATH.joinpath("data", "foo", "bar") == package_data_path( "foo", "bar" )
[docs] def test_path_fallback(caplog): # Can be called with where=list() including both strings and paths result = path_fallback( "test", "macro", "kgdp.csv", where=["private", package_data_path()] ) assert package_data_path("test", "macro", "kgdp.csv") == result assert 1 <= len(caplog.messages) assert caplog.messages[-1].startswith("Not found: ") caplog.clear() # "package", "private", and "test" each expanded to a path with pytest.raises(ValueError, match=r"'foo.bar' not found in any of \["): path_fallback("foo", "bar", where="cache local package private test") assert 4 <= len(caplog.messages) assert caplog.messages[-1].startswith("Not found: ") caplog.clear() # Empty argument raises an exception with pytest.raises(ValueError, match="No directories identified among ''"): path_fallback("foo", "bar") assert 0 == len(caplog.messages)
[docs] @pytest.mark.xfail( condition=MESSAGE_DATA_PATH is None, reason="Requires message_data to be installed." ) def test_private_data_path(): assert MESSAGE_DATA_PATH.joinpath("data", "foo", "bar") == private_data_path( "foo", "bar" )
[docs] @pytest.mark.parametrize( "name, func, col", [("node", same_node, "node_loc"), ("time", same_time, "time")] ) def test_same(name, func, col): """Test both :func:`.same_node` and :func:`.same_time`.""" df_in = pd.DataFrame( { col: ["foo", "bar", "baz"], f"{name}_dest": None, f"{name}_origin": None, "value": [1.1, 2.2, 3.3], } ) df_out = func(df_in) assert not df_out.isna().any(axis=None) assert_series_equal(df_out[f"{name}_dest"], df_in[col], check_names=False) assert_series_equal(df_out[f"{name}_origin"], df_in[col], check_names=False)
[docs] def test_replace_par_data(caplog, test_context): """Test :func:`.replace_par_data`.""" # Generate a scenario. This scenario has 3 data points in each of "input" and # "output" with technology="transport_from_seattle". s = make_dantzig(test_context.get_platform()) # Arguments to replace_par_data() parameters = ["input", "output"] filters = dict(mode=["to_chicago", "to_topeka"]) to_replace = dict(technology={"transport_from_seattle": "tfs"}) with s.transact("Add a new set element, to which values will be renamed"): s.add_set("technology", "tfs") # Function runs replace_par_data(s, parameters, filters=filters, to_replace=to_replace) for data in map(lambda n: s.par(n, filters=dict(node_loc="seattle")), parameters): # Data points selected by `filters` have been relabeled assert 2 == len(data.query("technology == 'tfs'")) # Data points not selected by `filters` are not affected assert 1 == len(data.query("technology == 'transport_from_seattle'"))
[docs] def test_strip_par_data(caplog, test_context): """Test the "dry run" feature of :func:`.strip_par_data`.""" s = make_dantzig(test_context.get_platform()) N = len(s.par("output")) strip_par_data(s, "technology", "canning_plant", dry_run=True, dump=dict()) assert_logs( caplog, [ "Remove data with technology='canning_plant' (DRY RUN)", "2 rows in 'output'", "with commodity=['cases']", "with level=['supply']", ], ) # Nothing was actually removed assert N == len(s.par("output"))