Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 25 additions & 29 deletions Lib/test/test_zipfile/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5840,59 +5840,55 @@ class StripExtraTests(unittest.TestCase):

ZIP64_EXTRA = 1

strip_extra = staticmethod(zipfile.ZipFile._strip_extra_fields)

def test_no_data(self):
s = struct.Struct("<HH")
a = s.pack(self.ZIP64_EXTRA, 0)
b = s.pack(2, 0)
c = s.pack(3, 0)

self.assertEqual(b'', zipfile._Extra.strip(a, (self.ZIP64_EXTRA,)))
self.assertEqual(b, zipfile._Extra.strip(b, (self.ZIP64_EXTRA,)))
self.assertEqual(
b+b"z", zipfile._Extra.strip(b+b"z", (self.ZIP64_EXTRA,)))
self.assertEqual(b'', self.strip_extra(a, (self.ZIP64_EXTRA,)))
self.assertEqual(b, self.strip_extra(b, (self.ZIP64_EXTRA,)))
self.assertEqual(b+b"z", self.strip_extra(b+b"z", (self.ZIP64_EXTRA,)))

self.assertEqual(b+c, zipfile._Extra.strip(a+b+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, zipfile._Extra.strip(b+a+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, zipfile._Extra.strip(b+c+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, self.strip_extra(a+b+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, self.strip_extra(b+a+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, self.strip_extra(b+c+a, (self.ZIP64_EXTRA,)))

def test_with_data(self):
s = struct.Struct("<HH")
a = s.pack(self.ZIP64_EXTRA, 1) + b"a"
b = s.pack(2, 2) + b"bb"
c = s.pack(3, 3) + b"ccc"

self.assertEqual(b"", zipfile._Extra.strip(a, (self.ZIP64_EXTRA,)))
self.assertEqual(b, zipfile._Extra.strip(b, (self.ZIP64_EXTRA,)))
self.assertEqual(
b+b"z", zipfile._Extra.strip(b+b"z", (self.ZIP64_EXTRA,)))
self.assertEqual(b"", self.strip_extra(a, (self.ZIP64_EXTRA,)))
self.assertEqual(b, self.strip_extra(b, (self.ZIP64_EXTRA,)))
self.assertEqual(b+b"z", self.strip_extra(b+b"z", (self.ZIP64_EXTRA,)))

self.assertEqual(b+c, zipfile._Extra.strip(a+b+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, zipfile._Extra.strip(b+a+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, zipfile._Extra.strip(b+c+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, self.strip_extra(a+b+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, self.strip_extra(b+a+c, (self.ZIP64_EXTRA,)))
self.assertEqual(b+c, self.strip_extra(b+c+a, (self.ZIP64_EXTRA,)))

def test_multiples(self):
s = struct.Struct("<HH")
a = s.pack(self.ZIP64_EXTRA, 1) + b"a"
b = s.pack(2, 2) + b"bb"

self.assertEqual(b"", zipfile._Extra.strip(a+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b"", zipfile._Extra.strip(a+a+a, (self.ZIP64_EXTRA,)))
self.assertEqual(
b"z", zipfile._Extra.strip(a+a+b"z", (self.ZIP64_EXTRA,)))
self.assertEqual(
b+b"z", zipfile._Extra.strip(a+a+b+b"z", (self.ZIP64_EXTRA,)))
self.assertEqual(b"", self.strip_extra(a+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b"", self.strip_extra(a+a+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b"z", self.strip_extra(a+a+b"z", (self.ZIP64_EXTRA,)))
self.assertEqual(b+b"z", self.strip_extra(a+a+b+b"z", (self.ZIP64_EXTRA,)))

self.assertEqual(b, zipfile._Extra.strip(a+a+b, (self.ZIP64_EXTRA,)))
self.assertEqual(b, zipfile._Extra.strip(a+b+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b, zipfile._Extra.strip(b+a+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b, self.strip_extra(a+a+b, (self.ZIP64_EXTRA,)))
self.assertEqual(b, self.strip_extra(a+b+a, (self.ZIP64_EXTRA,)))
self.assertEqual(b, self.strip_extra(b+a+a, (self.ZIP64_EXTRA,)))

def test_too_short(self):
self.assertEqual(b"", zipfile._Extra.strip(b"", (self.ZIP64_EXTRA,)))
self.assertEqual(b"z", zipfile._Extra.strip(b"z", (self.ZIP64_EXTRA,)))
self.assertEqual(
b"zz", zipfile._Extra.strip(b"zz", (self.ZIP64_EXTRA,)))
self.assertEqual(
b"zzz", zipfile._Extra.strip(b"zzz", (self.ZIP64_EXTRA,)))
self.assertEqual(b"", self.strip_extra(b"", (self.ZIP64_EXTRA,)))
self.assertEqual(b"z", self.strip_extra(b"z", (self.ZIP64_EXTRA,)))
self.assertEqual(b"zz", self.strip_extra(b"zz", (self.ZIP64_EXTRA,)))
self.assertEqual(b"zzz", self.strip_extra(b"zzz", (self.ZIP64_EXTRA,)))


class StatIO(_pyio.BytesIO):
Expand Down
72 changes: 32 additions & 40 deletions Lib/zipfile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,43 +194,6 @@ class LargeZipFile(Exception):

_DD_SIGNATURE = 0x08074b50


class _Extra(bytes):
FIELD_STRUCT = struct.Struct('<HH')

def __new__(cls, val, id=None):
return super().__new__(cls, val)

def __init__(self, val, id=None):
self.id = id

@classmethod
def read_one(cls, raw):
try:
xid, xlen = cls.FIELD_STRUCT.unpack(raw[:4])
except struct.error:
xid = None
xlen = 0
return cls(raw[:4+xlen], xid), raw[4+xlen:]

@classmethod
def split(cls, data):
# use memoryview for zero-copy slices
rest = memoryview(data)
while rest:
extra, rest = _Extra.read_one(rest)
yield extra

@classmethod
def strip(cls, data, xids):
"""Remove Extra fields with specified IDs."""
return b''.join(
ex
for ex in cls.split(data)
if ex.id not in xids
)


def _check_zipfile(fp):
try:
endrec = _EndRecData(fp)
Expand Down Expand Up @@ -2670,11 +2633,11 @@ def _write_end_record(self):
extra_data = zinfo.extra
min_version = 0
if extra:
# Append a ZIP64 field to the extra's
extra_data = _Extra.strip(extra_data, (1,))
# Prepend a ZIP64 field to the extra's
extra_data = struct.pack(
'<HH' + 'Q'*len(extra),
1, 8*len(extra), *extra) + extra_data
1, 8*len(extra), *extra
) + self._strip_extra_fields(extra_data, (1,))

min_version = ZIP64_VERSION

Expand Down Expand Up @@ -2741,6 +2704,35 @@ def _write_end_record(self):
self.fp.truncate()
self.fp.flush()

@staticmethod
def _strip_extra_fields(data, field_ids):
"""Remove Extra fields with specified IDs and return a bytearray.

data should be bytes or bytearray.
"""
result = bytearray()

# early return for empty extra data
if not data:
return result

# use memoryview for zero-copy slices
data_len = len(data)
pos = 0
while pos + 4 <= data_len:
xid, xlen = struct.unpack_from('<HH', data, pos)
if pos + 4 + xlen > data_len:
break
if xid not in field_ids:
result.extend(data[pos:pos + 4 + xlen])
pos += 4 + xlen

# keep remaining trailing bytes (e.g. truncated or malformed data)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this added? In the old read_one no data was returned for trailing data with < 4 bytes.

@danny0838 danny0838 Jun 25, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the original behavior since read_one returns wrapped original bytes with id=None (and as a result will never be stripped) if the input is less than 4 bytes. The current code just add a comment to make it more explicit.

You can check the tests (especially the test_too_short) to confirm that the behavior is not changed.

if pos < data_len:
result.extend(data[pos:])

return result

def _fpclose(self, fp):
assert self._fileRefCnt > 0
self._fileRefCnt -= 1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Replace :class:`!_Extra` with :meth:`!ZipFile._strip_extra_fields` in the :mod:`zipfile` module.
Loading