Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions petab/v2/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ class Observable(BaseModel):
#: Observable name.
name: str | None = Field(alias=C.OBSERVABLE_NAME, default=None)
#: Observable formula.
formula: sp.Basic | None = Field(alias=C.OBSERVABLE_FORMULA, default=None)
formula: sp.Basic = Field(alias=C.OBSERVABLE_FORMULA)
#: Noise formula.
noise_formula: sp.Basic | None = Field(alias=C.NOISE_FORMULA, default=None)
noise_formula: sp.Basic = Field(alias=C.NOISE_FORMULA)
#: Noise distribution.
noise_distribution: NoiseDistribution = Field(
alias=C.NOISE_DISTRIBUTION, default=NoiseDistribution.NORMAL
Expand Down
117 changes: 95 additions & 22 deletions petab/v2/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"CheckPriorDistribution",
"CheckUndefinedExperiments",
"CheckInitialChangeSymbols",
"CheckMappingTable",
"lint_problem",
"default_validation_tasks",
]
Expand Down Expand Up @@ -445,31 +446,31 @@ def run(self, problem: Problem) -> ValidationIssue | None:

# check for uniqueness of all primary keys
counter = Counter(c.id for c in problem.conditions)
duplicates = {id_ for id_, count in counter.items() if count > 1}
duplicates = sorted(id_ for id_, count in counter.items() if count > 1)

if duplicates:
return ValidationError(
f"Condition table contains duplicate IDs: {duplicates}"
)

counter = Counter(o.id for o in problem.observables)
duplicates = {id_ for id_, count in counter.items() if count > 1}
duplicates = sorted(id_ for id_, count in counter.items() if count > 1)

if duplicates:
return ValidationError(
f"Observable table contains duplicate IDs: {duplicates}"
)

counter = Counter(e.id for e in problem.experiments)
duplicates = {id_ for id_, count in counter.items() if count > 1}
duplicates = sorted(id_ for id_, count in counter.items() if count > 1)

if duplicates:
return ValidationError(
f"Experiment table contains duplicate IDs: {duplicates}"
)

counter = Counter(p.id for p in problem.parameters)
duplicates = {id_ for id_, count in counter.items() if count > 1}
duplicates = sorted(id_ for id_, count in counter.items() if count > 1)

if duplicates:
return ValidationError(
Expand Down Expand Up @@ -508,7 +509,9 @@ def run(self, problem: Problem) -> ValidationIssue | None:
for experiment in problem.experiments:
# Check that there are no duplicate timepoints
counter = Counter(period.time for period in experiment.periods)
duplicates = {time for time, count in counter.items() if count > 1}
duplicates = sorted(
time for time, count in counter.items() if count > 1
)
if duplicates:
messages.append(
f"Experiment {experiment.id} contains duplicate "
Expand Down Expand Up @@ -551,7 +554,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:

class CheckAllParametersPresentInParameterTable(ValidationTask):
"""Ensure all required parameters are contained in the parameter table
with no additional ones."""
with no additional ones.
"""

def run(self, problem: Problem) -> ValidationIssue | None:
if problem.model is None:
Expand Down Expand Up @@ -825,17 +829,17 @@ def run(self, problem: Problem) -> ValidationIssue | None:

if parameter.prior_distribution not in PRIOR_DISTRIBUTIONS:
messages.append(
f"Prior distribution `{parameter.prior_distribution}' "
f"for parameter `{parameter.id}' is not valid."
f"Prior distribution `{parameter.prior_distribution}` "
f"for parameter `{parameter.id}` is not valid."
)
continue

if (
exp_num_par := self._num_pars[parameter.prior_distribution]
) != len(parameter.prior_parameters):
messages.append(
f"Prior distribution `{parameter.prior_distribution}' "
f"for parameter `{parameter.id}' requires "
f"Prior distribution `{parameter.prior_distribution}` "
f"for parameter `{parameter.id}` requires "
f"{exp_num_par} parameters, but got "
f"{len(parameter.prior_parameters)} "
f"({parameter.prior_parameters})."
Expand All @@ -848,8 +852,8 @@ def run(self, problem: Problem) -> ValidationIssue | None:
_ = parameter.prior_dist.sample(1)
except Exception as e:
messages.append(
f"Prior parameters `{parameter.prior_parameters}' "
f"for parameter `{parameter.id}' are invalid "
f"Prior parameters `{parameter.prior_parameters}` "
f"for parameter `{parameter.id}` are invalid "
f"(hint: {e})."
)

Expand All @@ -874,16 +878,16 @@ def run(self, problem: Problem) -> ValidationIssue | None:
continue

messages.append(
f"Measurement `{measurement}' does not have a model ID, "
f"Measurement `{measurement}` does not have a model ID, "
"but there are multiple models available. "
"Please specify the model ID in the measurement table."
)
continue

if measurement.model_id not in available_models:
messages.append(
f"Measurement `{measurement}' has model ID "
f"`{measurement.model_id}' which does not match "
f"Measurement `{measurement}` has model ID "
f"`{measurement.model_id}` which does not match "
"any of the available models: "
f"{available_models}."
)
Expand All @@ -894,6 +898,79 @@ def run(self, problem: Problem) -> ValidationIssue | None:
return None


class CheckMappingTable(ValidationTask):
"""Validate the mapping table."""

def run(self, problem: Problem) -> ValidationIssue | None:
# Mapping table is optional
if not problem.mappings:
return None

messages = []

# Check that each id, across both the petabEntityId and
# modelEntityId columns, occurs only once
must_be_unique_ids = []
for mapping in problem.mappings:
petab_id = mapping.petab_id
model_id = mapping.model_id

if petab_id:
must_be_unique_ids.append(petab_id)
# Identity mappings are permitted for annotation
if petab_id == model_id:
continue
if model_id:
must_be_unique_ids.append(model_id)

non_unique_ids = sorted(
id_
for id_, count in Counter(must_be_unique_ids).items()
if count > 1
)
if non_unique_ids:
return ValidationError(
f"Mapping table contains non-unique IDs: {non_unique_ids}."
)

# petabEntityId is not defined elsewhere in the PEtab problem
new_petab_ids = {
m.petab_id
for m in problem.mappings
# Ignore identity mappings used for annotation
if m.petab_id != m.model_id
}
old_petab_ids = (
{c.id for c in problem.conditions}
| {e.id for e in problem.experiments}
| {o.id for o in problem.observables}
)
if overdefined_ids := sorted(new_petab_ids & old_petab_ids):
messages.append(
f"PEtab IDs `{overdefined_ids}` are "
"defined in the mapping table but also defined through "
"other PEtab tables."
)

for mapping in problem.mappings:
# petabEntityId not referenced in any model, if alias
for model in problem.models:
if (
mapping.petab_id != mapping.model_id
and model.has_entity_with_id(mapping.petab_id)
):
messages.append(
f"`{mapping.petab_id}` is used in the mapping "
"table and referenced directly in the model "
f"`{model.model_id}`."
)

if messages:
return ValidationError("\n".join(messages))

return None


def get_valid_parameters_for_parameter_table(
problem: Problem,
) -> set[str]:
Expand Down Expand Up @@ -984,13 +1061,9 @@ def get_required_parameters_for_parameter_table(
for change in cond.changes
}

# Add parameters from measurement table, unless they are fixed parameters
def append_overrides(overrides):
parameter_ids.update(
str_p
for p in overrides
if isinstance(p, sp.Symbol)
and (str_p := str(p)) not in condition_targets

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary since this is collected as a set and the condition_targets are removed further down.

str(p) for p in overrides if isinstance(p, sp.Symbol)
)

for m in problem.measurements:
Expand Down Expand Up @@ -1033,7 +1106,7 @@ def append_overrides(overrides):
if not problem.model.has_entity_with_id(str(p))
)

# parameters that are overridden via the condition table are not allowed
# Parameters that are overridden via the condition table are not allowed
parameter_ids -= condition_targets

return parameter_ids
Expand Down Expand Up @@ -1090,5 +1163,5 @@ def get_placeholders(
CheckUnusedConditions(),
CheckPriorDistribution(),
CheckInitialChangeSymbols(),
# TODO validate mapping table
CheckMappingTable(),
]
18 changes: 13 additions & 5 deletions tests/v2/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_measurments():


def test_observable():
Observable(id="obs1", formula=x + y)
Observable(id="obs1", formula=x + y, noiseFormula=1)
Observable(id="obs1", formula="x + y", noise_formula="x + y")
Observable(id="obs1", formula=1, noise_formula=2)
Observable(
Expand All @@ -198,9 +198,17 @@ def test_observable():
observable_parameters=[sp.Symbol("p1")],
noise_parameters=[sp.Symbol("n1")],
)
assert Observable(id="obs1", formula="x + y", non_petab=1).non_petab == 1
assert (
Observable(
id="obs1",
formula="x + y",
noise_formula="x + y",
non_petab=1,
).non_petab
== 1
)

o = Observable(id="obs1", formula=x + y)
o = Observable(id="obs1", formula=x + y, noise_formula=1)
assert o.observable_placeholders == []
assert o.noise_placeholders == []

Expand Down Expand Up @@ -492,14 +500,14 @@ def test_modify_problem():
problem.condition_df, exp_condition_df, check_dtype=False
)

problem.add_observable("observable1", "1")
problem.add_observable("observable1", "1", noise_formula=1)
problem.add_observable("observable2", "2", noise_formula=2.2)

exp_observable_df = pd.DataFrame(
data={
OBSERVABLE_ID: ["observable1", "observable2"],
OBSERVABLE_FORMULA: [1, 2],
NOISE_FORMULA: [np.nan, 2.2],
NOISE_FORMULA: [1, 2.2],
}
).set_index([OBSERVABLE_ID])
assert_frame_equal(
Expand Down
56 changes: 54 additions & 2 deletions tests/v2/test_lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,23 @@

from copy import deepcopy

import pysb
import pytest

from petab.v2 import Problem
from petab.v2.lint import *
from petab.v2.models.pysb_model import PySBModel
from petab.v2.models.sbml_model import SbmlModel


@pytest.fixture
def uses_pysb():
"""Cleanup PySB auto-exported symbols before and after test"""
pysb.SelfExporter.cleanup()
yield ()
pysb.SelfExporter.cleanup()


def test_check_experiments():
"""Test ``CheckExperimentTable``."""
problem = Problem()
Expand Down Expand Up @@ -43,7 +55,7 @@ def test_invalid_model_id_in_measurements():
"""Test that measurements with an invalid model ID are caught."""
problem = Problem()
problem.models.append(SbmlModel.from_antimony("p1 = 1", model_id="model1"))
problem.add_observable("obs1", "A")
problem.add_observable("obs1", "A", 1)
problem.add_measurement("obs1", experiment_id="e1", time=0, measurement=1)

check = CheckMeasurementModelId()
Expand All @@ -70,7 +82,7 @@ def test_undefined_experiment_id_in_measurements():
"""Test that measurements with an undefined experiment ID are caught."""
problem = Problem()
problem.add_experiment("e1", 0, "c1")
problem.add_observable("obs1", "A")
problem.add_observable("obs1", "A", 1)
problem.add_measurement("obs1", experiment_id="e1", time=0, measurement=1)

check = CheckUndefinedExperiments()
Expand Down Expand Up @@ -107,3 +119,43 @@ def test_validate_initial_change_symbols():
problem.parameter_tables[0].parameters.remove(problem["p2"])
assert (error := check.run(problem)) is not None
assert "contains additional symbols: {'p2'}" in error.message


def test_check_mapping_table(uses_pysb):
"""Test checks related to the mapping table."""
problem = Problem()

# PySB model with a compartment and a monomer, and mapping of model entity
# to a valid PEtab id
pysb_model = pysb.Model("test_model")
pysb.Monomer("A_")
pysb.Initial(A_() ** pysb.Compartment("C"), pysb.Parameter("a0", 1))
problem.model = PySBModel(model=pysb_model, model_id="test_model")
problem.add_mapping("A", "A_() ** C")

check = CheckMappingTable()
assert check.run(problem) is None

check = CheckAllParametersPresentInParameterTable()
assert check.run(problem) is None

# add a petab id without model id but with name for annotation
problem.add_mapping(petab_id="p2", model_id=None, name="Parameter 2")
problem.add_parameter("p2", estimate=True, nominal_value=1, lb=0, ub=10)

check = CheckMappingTable()
assert check.run(problem) is None

# Invalid: petabEntityId is referenced in the model
pysb.SelfExporter.cleanup()
pysb_model_invalid = pysb.Model("test_model_invalid")
pysb.Monomer("A_")
pysb.Initial(A_() ** pysb.Compartment("C"), pysb.Parameter("A", 1))
problem.model = PySBModel(
model=pysb_model_invalid, model_id="test_model_invalid"
)
assert (error := check.run(problem)) is not None
assert (
"`A` is used in the mapping table and referenced directly"
in error.message
)
Loading