From a1999b21b08916d6d6fbab2330aec4cd8e1f0e40 Mon Sep 17 00:00:00 2001 From: James Parrott <80779630+JamesParrott@users.noreply.github.com> Date: Fri, 26 Jun 2026 15:39:25 +0100 Subject: [PATCH] Add Round trip Shapefile test --- tests/hypothesis_tests.py | 154 ++++++++++++++++++++++++++------------ 1 file changed, 106 insertions(+), 48 deletions(-) diff --git a/tests/hypothesis_tests.py b/tests/hypothesis_tests.py index 462dbfd..8a4b830 100644 --- a/tests/hypothesis_tests.py +++ b/tests/hypothesis_tests.py @@ -441,7 +441,7 @@ def test_MultiPatch_roundtrips( (31, "MultiPatch", multipatch), ] -def code_and_shape_strat_from_triple(t): +def code_and_shape_strategy_from_triple(t): x, _name, shapes = t return tuples( just(x), @@ -451,12 +451,40 @@ def code_and_shape_strat_from_triple(t): max_size=MAX_NUM_SHAPES, ), ) -codes_and_shapes_strats = [ - code_and_shape_strat_from_triple(t) +codes_and_shapes_strategies = [ + code_and_shape_strategy_from_triple(t) for t in shape_codes_names_and_strategies ] -codes_and_shapes = one_of(codes_and_shapes_strats) +codes_and_shapes = one_of(codes_and_shapes_strategies) + + +def _assert_reader_matches_expected_shapes(r, code_ex, expected_shapes): + assert r.shapeType == code_ex + + for actual, expected in itertools.zip_longest(r.shapes(), expected_shapes): + + assert isinstance(actual, (shp.SHAPE_CLASS_FROM_SHAPETYPE[code_ex], shp.NullShape)) + assert actual.points_3D == expected.points_3D + # Don't assert actual.oid == expected.oid it's defined by + # actual.oid indicates the order actual was written in, expected.oid + # is not currently encoded (as we'd have to resort the entire Shapefile after each shape) + assert actual.parts == expected.parts, f"{type(actual.parts)=}, {type(expected.parts)=}" + + if (m := getattr(actual, "m", None)): + assert m == expected.m, f"{type(m)=}, {type(expected.m)=}" + else: + assert not hasattr(expected, "m") + + if (z := getattr(actual, "z", None)): + assert z == expected.z, f"{type(z)=}, {type(expected.z)=}" + else: + assert not hasattr(expected, "z") + + if (partTypes := getattr(actual, "partTypes", None)): + assert actual.partTypes == expected.partTypes, f"{type(actual.partTypes)=}, {type(expected.partTypes)=}" + else: + assert not hasattr(expected, "partTypes") @pytest.mark.hypothesis @given(codes_and_shapes=codes_and_shapes) @@ -470,31 +498,8 @@ def test_shp_reader_writer_roundtrip(codes_and_shapes)-> None: w.shape(shape) with shp.ShpReader(shp=stream) as r: - assert r.shapeType == code_ex + _assert_reader_matches_expected_shapes(r, code_ex, expected_shapes) - for actual, expected in itertools.zip_longest(r.shapes(), expected_shapes): - - assert isinstance(actual, (shp.SHAPE_CLASS_FROM_SHAPETYPE[code_ex], shp.NullShape)) - assert actual.points_3D == expected.points_3D - # Don't assert actual.oid == expected.oid it's defined by - # actual.oid indicates the order actual was written in, expected.oid - # is not currently encoded (as we'd have to resort the entire Shapefile after each shape) - assert actual.parts == expected.parts, f"{type(actual.parts)=}, {type(expected.parts)=}" - - if (m := getattr(actual, "m", None)): - assert m == expected.m, f"{type(m)=}, {type(expected.m)=}" - else: - assert not hasattr(expected, "m") - - if (z := getattr(actual, "z", None)): - assert z == expected.z, f"{type(z)=}, {type(expected.z)=}" - else: - assert not hasattr(expected, "z") - - if (partTypes := getattr(actual, "partTypes", None)): - assert actual.partTypes == expected.partTypes, f"{type(actual.partTypes)=}, {type(expected.partTypes)=}" - else: - assert not hasattr(expected, "partTypes") @pytest.mark.hypothesis @@ -613,7 +618,6 @@ def record_value_for_field(name: str, field_type: str, size: int, decimal: int = def _dbf_fields_and_record_strategy( draw, max_fields=10, # In DbfWriter.__init__, max_num_fields: int = 2046, - max_records=20, ): fields = draw(lists(dbf_fields(), min_size=1, max_size=max_fields)) @@ -630,12 +634,30 @@ def dbf_fields_and_records( max_records=20, ): - fields, record_strategy = _dbf_fields_and_record_strategy(draw, max_fields, max_records) + fields, record_strategy = _dbf_fields_and_record_strategy(draw, max_fields) records = draw(lists(record_strategy, min_size=0, max_size=max_records)) return fields, records +def _assert_reader_matches_expected_records(r, fields, written_records): + for f_r, f_w in itertools.zip_longest(r.data_fields, fields): + actual_field_dict = f_r._asdict() + for k in ("field_type", "size", "decimal"): + assert actual_field_dict[k] == f_w[k], f"{k=}, {actual_field_dict[k]=}, {f_w[k]=}" + for exp_rec, actual_rec in itertools.zip_longest(written_records, r.records()): + for expected, actual, field in itertools.zip_longest(exp_rec, actual_rec, fields): + field_type = field["field_type"] + decimal = field["decimal"] + if field_type == "D": + if isinstance(expected, datetime.date): + expected = expected.strftime("%Y%m%d") + if isinstance(actual, datetime.date): + actual = actual.strftime("%Y%m%d") + elif field_type in ("N", "F") and decimal >= 1: + expected = float(format(expected, f".{decimal}f")) + assert actual == expected, f"{actual=}, {expected=}, {field_type=}, {type(actual)=}, {type(expected)=}" + @pytest.mark.hypothesis @given(fields_and_records=dbf_fields_and_records()) @@ -656,21 +678,57 @@ def test_dbf_reader_writer_roundtrip(fields_and_records)-> None: with shp.DbfReader(dbf=stream) as r: - actual_fields = iter(r.fields) - next(actual_fields) # skip deletion flag - for f_r, f_w in itertools.zip_longest(actual_fields, fields): - actual_field_dict = f_r._asdict() - for k in ("field_type", "size", "decimal"): - assert actual_field_dict[k] == f_w[k], f"{k=}, {actual_field_dict[k]=}, {f_w[k]=}" - for exp_rec, actual_rec in itertools.zip_longest(written_records, r.records()): - for expected, actual, field in itertools.zip_longest(exp_rec, actual_rec, fields): - field_type = field["field_type"] - decimal = field["decimal"] - if field_type == "D": - if isinstance(expected, datetime.date): - expected = expected.strftime("%Y%m%d") - if isinstance(actual, datetime.date): - actual = actual.strftime("%Y%m%d") - elif field_type in ("N", "F") and decimal >= 1: - expected = float(format(expected, f".{decimal}f")) - assert actual == expected, f"{actual=}, {expected=}, {field_type=}, {type(actual)=}, {type(expected)=}" + _assert_reader_matches_expected_records(r, fields, written_records) +# def code_and_shape_strategy_from_triple(t): +# x, _name, shapes = t +# return tuples( +# just(x), +# lists( +# one_of(shapes, null_shapes), +# min_size = 0, # Empty shp files are in the ESRI spec. +# max_size=MAX_NUM_SHAPES, +# ), +# ) +# codes_and_shapes_strategies = [ +# code_and_shape_strategy_from_triple(t) +# for t in shape_codes_names_and_strategies +# ] + +# codes_and_shapes = one_of(codes_and_shapes_strategies) + +@composite +def codes_fields_shapes_and_records(draw): + code, shapes = draw(codes_and_shapes) + fields, records_strategy = _dbf_fields_and_record_strategy(draw, max_fields=10) + N = len(shapes) + records = [draw(records_strategy) for _ in range(N)] + + return code, fields, zip(shapes, records) + +@pytest.mark.hypothesis +@settings(suppress_health_check=[HealthCheck.too_slow, HealthCheck.data_too_large]) +@given(codes_fields_shapes_and_records=codes_fields_shapes_and_records()) +def test_shapefile_reader_writer_roundtrip(codes_fields_shapes_and_records)-> None: + + code_ex, fields_ex, shapes_and_records = codes_fields_shapes_and_records + streams = {"shp" : io.BytesIO(), "shx" : io.BytesIO(), "dbf" : io.BytesIO(),} + + expected_shapes = [] + expected_records = [] + + with shp.Writer(shapeType = code_ex, strict=True, **streams) as w: + for field in fields_ex: + w.field(**field) + + for shape, record in shapes_and_records: + try: + w.record(*record) + except shp.DbfStringDataLoss: + continue + w.shape(shape) + expected_shapes.append(shape) + expected_records.append(record) + + with shp.Reader(**streams) as r: + _assert_reader_matches_expected_records(r, fields_ex, expected_records) + _assert_reader_matches_expected_shapes(r, code_ex, expected_shapes) \ No newline at end of file