Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions hamilton/function_modifiers/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,12 +803,11 @@ def __init__(
# We can decouple it so that on_input selects the target dataframe parameter that will inject into the next node
# pass_dataframe_as selects the original dataframe we want to extract columns from
# columns_to_pass is optinal helper that can be toggled on/off so no need to raise this error.
if (
int(pass_dataframe_as is None) + int(columns_to_pass is None) + int(on_input is None)
== 1
):
n_set = sum(arg is not None for arg in (pass_dataframe_as, columns_to_pass, on_input))
if n_set != 1:
raise ValueError(
"You must specify only one of ``columns_to_pass``, ``pass_dataframe_as``, and ``on_input``. "
"You must specify exactly one of ``columns_to_pass``, ``pass_dataframe_as``, "
"and ``on_input``. "
"This is because specifying ``pass_dataframe_as`` or ``on_input`` injects into "
"the set of columns, allowing you to perform your own extraction"
"from the dataframe. We then execute all columns in the subdag"
Expand Down
51 changes: 51 additions & 0 deletions tests/function_modifiers/test_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,54 @@ def test_columns_and_subdag_nodes_do_not_clash():

assert not with_columns_base.contains_duplicates([node_a, node_c])
assert with_columns_base.contains_duplicates([node_a, node_b, node_c])


class _BaseValidationSub(with_columns_base):
"""Minimal concrete subclass used to exercise `with_columns_base.__init__`-time
validation. The abstract methods are trivially stubbed because no test in this
block invokes them.
"""

def get_initial_nodes(self, fn, params):
return "", []

def get_subdag_nodes(self, fn, config):
return []

def chain_subdag_nodes(self, fn, inject_parameter, generated_nodes):
return None

def validate(self, fn):
pass


def _dummy_subdag_fn() -> int:
return 0


def test_with_columns_base_raises_when_no_mutex_arg_set():
with pytest.raises(ValueError, match="exactly one of"):
_BaseValidationSub(_dummy_subdag_fn, select=["x"])


def test_with_columns_base_raises_when_two_mutex_args_set():
with pytest.raises(ValueError, match="exactly one of"):
_BaseValidationSub(_dummy_subdag_fn, columns_to_pass=["a"], on_input="b", select=["x"])


def test_with_columns_base_raises_when_all_three_mutex_args_set():
with pytest.raises(ValueError, match="exactly one of"):
_BaseValidationSub(
_dummy_subdag_fn,
columns_to_pass=["a"],
pass_dataframe_as="b",
on_input="c",
select=["x"],
)


def test_with_columns_base_accepts_exactly_one_mutex_arg():
# Each of the three single-set cases must instantiate cleanly.
_BaseValidationSub(_dummy_subdag_fn, columns_to_pass=["a"], select=["x"], dataframe_type=object)
_BaseValidationSub(_dummy_subdag_fn, pass_dataframe_as="b", select=["x"], dataframe_type=object)
_BaseValidationSub(_dummy_subdag_fn, on_input="c", select=["x"], dataframe_type=object)
Loading