Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 106 additions & 48 deletions tests/hypothesis_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def test_MultiPatch_roundtrips(
(31, "MultiPatch", multipatch),
]

def code_and_shape_strat_from_triple(t):
def code_and_shape_strategy_from_triple(t):
x, _name, shapes = t
return tuples(
just(x),
Expand All @@ -451,12 +451,40 @@ def code_and_shape_strat_from_triple(t):
max_size=MAX_NUM_SHAPES,
),
)
codes_and_shapes_strats = [
code_and_shape_strat_from_triple(t)
codes_and_shapes_strategies = [
code_and_shape_strategy_from_triple(t)
for t in shape_codes_names_and_strategies
]

codes_and_shapes = one_of(codes_and_shapes_strats)
codes_and_shapes = one_of(codes_and_shapes_strategies)


def _assert_reader_matches_expected_shapes(r, code_ex, expected_shapes):
assert r.shapeType == code_ex

for actual, expected in itertools.zip_longest(r.shapes(), expected_shapes):

assert isinstance(actual, (shp.SHAPE_CLASS_FROM_SHAPETYPE[code_ex], shp.NullShape))
assert actual.points_3D == expected.points_3D
# Don't assert actual.oid == expected.oid it's defined by
# actual.oid indicates the order actual was written in, expected.oid
# is not currently encoded (as we'd have to resort the entire Shapefile after each shape)
assert actual.parts == expected.parts, f"{type(actual.parts)=}, {type(expected.parts)=}"

if (m := getattr(actual, "m", None)):
assert m == expected.m, f"{type(m)=}, {type(expected.m)=}"
else:
assert not hasattr(expected, "m")

if (z := getattr(actual, "z", None)):
assert z == expected.z, f"{type(z)=}, {type(expected.z)=}"
else:
assert not hasattr(expected, "z")

if (partTypes := getattr(actual, "partTypes", None)):
assert actual.partTypes == expected.partTypes, f"{type(actual.partTypes)=}, {type(expected.partTypes)=}"
else:
assert not hasattr(expected, "partTypes")

@pytest.mark.hypothesis
@given(codes_and_shapes=codes_and_shapes)
Expand All @@ -470,31 +498,8 @@ def test_shp_reader_writer_roundtrip(codes_and_shapes)-> None:
w.shape(shape)

with shp.ShpReader(shp=stream) as r:
assert r.shapeType == code_ex
_assert_reader_matches_expected_shapes(r, code_ex, expected_shapes)

for actual, expected in itertools.zip_longest(r.shapes(), expected_shapes):

assert isinstance(actual, (shp.SHAPE_CLASS_FROM_SHAPETYPE[code_ex], shp.NullShape))
assert actual.points_3D == expected.points_3D
# Don't assert actual.oid == expected.oid it's defined by
# actual.oid indicates the order actual was written in, expected.oid
# is not currently encoded (as we'd have to resort the entire Shapefile after each shape)
assert actual.parts == expected.parts, f"{type(actual.parts)=}, {type(expected.parts)=}"

if (m := getattr(actual, "m", None)):
assert m == expected.m, f"{type(m)=}, {type(expected.m)=}"
else:
assert not hasattr(expected, "m")

if (z := getattr(actual, "z", None)):
assert z == expected.z, f"{type(z)=}, {type(expected.z)=}"
else:
assert not hasattr(expected, "z")

if (partTypes := getattr(actual, "partTypes", None)):
assert actual.partTypes == expected.partTypes, f"{type(actual.partTypes)=}, {type(expected.partTypes)=}"
else:
assert not hasattr(expected, "partTypes")


@pytest.mark.hypothesis
Expand Down Expand Up @@ -613,7 +618,6 @@ def record_value_for_field(name: str, field_type: str, size: int, decimal: int =
def _dbf_fields_and_record_strategy(
draw,
max_fields=10, # In DbfWriter.__init__, max_num_fields: int = 2046,
max_records=20,
):

fields = draw(lists(dbf_fields(), min_size=1, max_size=max_fields))
Expand All @@ -630,12 +634,30 @@ def dbf_fields_and_records(
max_records=20,
):

fields, record_strategy = _dbf_fields_and_record_strategy(draw, max_fields, max_records)
fields, record_strategy = _dbf_fields_and_record_strategy(draw, max_fields)

records = draw(lists(record_strategy, min_size=0, max_size=max_records))

return fields, records

def _assert_reader_matches_expected_records(r, fields, written_records):
for f_r, f_w in itertools.zip_longest(r.data_fields, fields):
actual_field_dict = f_r._asdict()
for k in ("field_type", "size", "decimal"):
assert actual_field_dict[k] == f_w[k], f"{k=}, {actual_field_dict[k]=}, {f_w[k]=}"
for exp_rec, actual_rec in itertools.zip_longest(written_records, r.records()):
for expected, actual, field in itertools.zip_longest(exp_rec, actual_rec, fields):
field_type = field["field_type"]
decimal = field["decimal"]
if field_type == "D":
if isinstance(expected, datetime.date):
expected = expected.strftime("%Y%m%d")
if isinstance(actual, datetime.date):
actual = actual.strftime("%Y%m%d")
elif field_type in ("N", "F") and decimal >= 1:
expected = float(format(expected, f".{decimal}f"))
assert actual == expected, f"{actual=}, {expected=}, {field_type=}, {type(actual)=}, {type(expected)=}"


@pytest.mark.hypothesis
@given(fields_and_records=dbf_fields_and_records())
Expand All @@ -656,21 +678,57 @@ def test_dbf_reader_writer_roundtrip(fields_and_records)-> None:


with shp.DbfReader(dbf=stream) as r:
actual_fields = iter(r.fields)
next(actual_fields) # skip deletion flag
for f_r, f_w in itertools.zip_longest(actual_fields, fields):
actual_field_dict = f_r._asdict()
for k in ("field_type", "size", "decimal"):
assert actual_field_dict[k] == f_w[k], f"{k=}, {actual_field_dict[k]=}, {f_w[k]=}"
for exp_rec, actual_rec in itertools.zip_longest(written_records, r.records()):
for expected, actual, field in itertools.zip_longest(exp_rec, actual_rec, fields):
field_type = field["field_type"]
decimal = field["decimal"]
if field_type == "D":
if isinstance(expected, datetime.date):
expected = expected.strftime("%Y%m%d")
if isinstance(actual, datetime.date):
actual = actual.strftime("%Y%m%d")
elif field_type in ("N", "F") and decimal >= 1:
expected = float(format(expected, f".{decimal}f"))
assert actual == expected, f"{actual=}, {expected=}, {field_type=}, {type(actual)=}, {type(expected)=}"
_assert_reader_matches_expected_records(r, fields, written_records)
# def code_and_shape_strategy_from_triple(t):
# x, _name, shapes = t
# return tuples(
# just(x),
# lists(
# one_of(shapes, null_shapes),
# min_size = 0, # Empty shp files are in the ESRI spec.
# max_size=MAX_NUM_SHAPES,
# ),
# )
# codes_and_shapes_strategies = [
# code_and_shape_strategy_from_triple(t)
# for t in shape_codes_names_and_strategies
# ]

# codes_and_shapes = one_of(codes_and_shapes_strategies)

@composite
def codes_fields_shapes_and_records(draw):
code, shapes = draw(codes_and_shapes)
fields, records_strategy = _dbf_fields_and_record_strategy(draw, max_fields=10)
N = len(shapes)
records = [draw(records_strategy) for _ in range(N)]

return code, fields, zip(shapes, records)

@pytest.mark.hypothesis
@settings(suppress_health_check=[HealthCheck.too_slow, HealthCheck.data_too_large])
@given(codes_fields_shapes_and_records=codes_fields_shapes_and_records())
def test_shapefile_reader_writer_roundtrip(codes_fields_shapes_and_records)-> None:

code_ex, fields_ex, shapes_and_records = codes_fields_shapes_and_records
streams = {"shp" : io.BytesIO(), "shx" : io.BytesIO(), "dbf" : io.BytesIO(),}

expected_shapes = []
expected_records = []

with shp.Writer(shapeType = code_ex, strict=True, **streams) as w:
for field in fields_ex:
w.field(**field)

for shape, record in shapes_and_records:
try:
w.record(*record)
except shp.DbfStringDataLoss:
continue
w.shape(shape)
expected_shapes.append(shape)
expected_records.append(record)

with shp.Reader(**streams) as r:
_assert_reader_matches_expected_records(r, fields_ex, expected_records)
_assert_reader_matches_expected_shapes(r, code_ex, expected_shapes)
Loading