Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Doc/library/stdtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5757,6 +5757,13 @@ Frozen dictionaries
Like dictionaries, frozendicts are :ref:`generic <generics>` over two types,
signifying (respectively) the types of the frozendict's keys and values.

.. classmethod:: fromkeys(iterable, value=None, /)

Similar to :meth:`dict.fromkeys`, but call again the type constructor
with an initialized :class:`frozendict` if the type is a
:class:`frozendict` subclass or if the constructor returned a
:class:`frozendict`.

.. versionadded:: 3.15


Expand Down
31 changes: 19 additions & 12 deletions Lib/test/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1939,8 +1939,11 @@ def test_fromkeys(self):
# Subclass which overrides the constructor
created = frozendict(x=1)
class FrozenDictSubclass(frozendict):
def __new__(self):
return created
def __new__(cls, *args, **kwargs):
if args or kwargs:
return super().__new__(cls, *args, **kwargs)
else:
return created

fd = FrozenDictSubclass.fromkeys("abc")
self.assertEqual(fd, frozendict(x=1, a=None, b=None, c=None))
Expand All @@ -1952,6 +1955,20 @@ def __new__(self):
self.assertEqual(type(fd), FrozenDictSubclass)
self.assertEqual(created, frozendict(x=1))

# Dict subclass with a constructor which returns a frozendict
# by default
class DictSubclass(dict):
def __new__(cls, *args, **kwargs):
if args or kwargs:
return super().__new__(cls, *args, **kwargs)
else:
return created

fd = DictSubclass.fromkeys("abc")
self.assertEqual(fd, frozendict(x=1, a=None, b=None, c=None))
self.assertEqual(type(fd), DictSubclass)
self.assertEqual(created, frozendict(x=1))

# Subclass which doesn't override the constructor
class FrozenDictSubclass2(frozendict):
pass
Expand All @@ -1960,16 +1977,6 @@ class FrozenDictSubclass2(frozendict):
self.assertEqual(fd, frozendict(a=None, b=None, c=None))
self.assertEqual(type(fd), FrozenDictSubclass2)

# Dict subclass which overrides the constructor
class DictSubclass(dict):
def __new__(self):
return created

fd = DictSubclass.fromkeys("abc")
self.assertEqual(fd, frozendict(x=1, a=None, b=None, c=None))
self.assertEqual(type(fd), DictSubclass)
self.assertEqual(created, frozendict(x=1))

def test_pickle(self):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
for fd in (
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
:meth:`!frozendict.fromkeys` now only tracks the :class:`frozendict` in the
garbage collector once the dictionary is fully initialized. Patch by Donghee Na
and Victor Stinner.
173 changes: 120 additions & 53 deletions Objects/dictobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ static PyObject* frozendict_new(PyTypeObject *type, PyObject *args,
PyObject *kwds);
static PyObject* frozendict_new_untracked(PyTypeObject *type);
static PyObject* dict_new(PyTypeObject *type, PyObject *args, PyObject *kwds);
static PyObject* dict_new_untracked(PyTypeObject *type);
static int dict_merge(PyObject *a, PyObject *b, int override, PyObject **dupkey);
static int dict_contains(PyObject *op, PyObject *key);
static int dict_merge_from_seq2(PyObject *d, PyObject *seq2, int override);
Expand Down Expand Up @@ -3414,40 +3415,47 @@ dict_set_fromkeys(PyDictObject *mp, PyObject *iterable, PyObject *value)
PyObject *
_PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value)
{
PyObject *it; /* iter(iterable) */
PyObject *key;
PyObject *it = NULL; /* iter(iterable) */
PyObject *d;
int status;
int need_copy = 0;

d = _PyObject_CallNoArgs(cls);
if (cls == (PyObject*)&PyFrozenDict_Type) {
// gh-151722: Create a frozendict which is not tracked by the GC.
d = frozendict_new_untracked(&PyFrozenDict_Type);
}
else {
// Dict subclass, or frozendict subclass which overrides
// the constructor.
d = _PyObject_CallNoArgs(cls);
}
if (d == NULL) {
return NULL;
}

// If cls is a dict or frozendict subclass with overridden constructor,
// copy the frozendict.
PyTypeObject *cls_type = _PyType_CAST(cls);
if (PyFrozenDict_Check(d) && cls_type->tp_new != frozendict_new) {
// Subclass-friendly copy
PyObject *copy;
if (PyObject_IsSubclass(cls, (PyObject*)&PyFrozenDict_Type)) {
copy = frozendict_new(cls_type, NULL, NULL);
}
else {
copy = dict_new(cls_type, NULL, NULL);
}
// gh-151722: If cls constructor returns a frozendict which is tracked by
// the GC, create a frozendict copy which is not tracked by the GC.
//
// At the function exit, return cls(fd) where fd is a frozendict.
//
// Untracking the frozendict requires tracking again the frozendict on
// error which is more complicated. It's easier to work on a copy.
if (PyFrozenDict_Check(d) && _PyObject_GC_IS_TRACKED(d)) {
need_copy = 1;

PyObject *copy = frozendict_new_untracked(&PyFrozenDict_Type);
if (copy == NULL) {
Py_DECREF(d);
return NULL;
goto Fail;
}
if (dict_merge(copy, d, 1, NULL) < 0) {
Py_DECREF(d);
Py_DECREF(copy);
return NULL;
goto Fail;
}
Py_SETREF(d, copy);
}
assert(!PyFrozenDict_Check(d) || can_modify_dict((PyDictObject*)d));
if (PyFrozenDict_Check(d)) {
assert(can_modify_dict((PyDictObject*)d));
assert(!_PyObject_GC_IS_TRACKED(d));
}

if (PyDict_CheckExact(d)) {
if (PyDict_CheckExact(iterable)) {
Expand All @@ -3456,23 +3464,23 @@ _PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value)
Py_BEGIN_CRITICAL_SECTION2(d, iterable);
d = (PyObject *)dict_dict_fromkeys(mp, iterable, value);
Py_END_CRITICAL_SECTION2();
return d;
goto Done;
}
else if (PyFrozenDict_CheckExact(iterable)) {
PyDictObject *mp = (PyDictObject *)d;

Py_BEGIN_CRITICAL_SECTION(d);
d = (PyObject *)dict_dict_fromkeys(mp, iterable, value);
Py_END_CRITICAL_SECTION();
return d;
goto Done;
}
else if (PyAnySet_CheckExact(iterable)) {
PyDictObject *mp = (PyDictObject *)d;

Py_BEGIN_CRITICAL_SECTION2(d, iterable);
d = (PyObject *)dict_set_fromkeys(mp, iterable, value);
Py_END_CRITICAL_SECTION2();
return d;
goto Done;
}
}
else if (PyFrozenDict_CheckExact(d)) {
Expand All @@ -3482,71 +3490,113 @@ _PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value)
Py_BEGIN_CRITICAL_SECTION(iterable);
d = (PyObject *)dict_dict_fromkeys(mp, iterable, value);
Py_END_CRITICAL_SECTION();
return d;
goto Done;
}
else if (PyFrozenDict_CheckExact(iterable)) {
PyDictObject *mp = (PyDictObject *)d;
d = (PyObject *)dict_dict_fromkeys(mp, iterable, value);
return d;
goto Done;
}
else if (PyAnySet_CheckExact(iterable)) {
PyDictObject *mp = (PyDictObject *)d;

Py_BEGIN_CRITICAL_SECTION(iterable);
d = (PyObject *)dict_set_fromkeys(mp, iterable, value);
Py_END_CRITICAL_SECTION();
return d;
goto Done;
}
}

it = PyObject_GetIter(iterable);
if (it == NULL){
Py_DECREF(d);
return NULL;
goto Fail;
}

if (PyDict_CheckExact(d)) {
int status = 0;

Py_BEGIN_CRITICAL_SECTION(d);
while ((key = PyIter_Next(it)) != NULL) {
while (1) {
PyObject *key;
status = PyIter_NextItem(it, &key);
if (status <= 0) {
break;
}

status = setitem_lock_held((PyDictObject *)d, key, value);
Py_DECREF(key);
if (status < 0) {
assert(PyErr_Occurred());
goto dict_iter_exit;
break;
}
}
dict_iter_exit:;
Py_END_CRITICAL_SECTION();

if (status < 0) {
goto Fail;
}
}
else if (PyFrozenDict_Check(d)) {
while ((key = PyIter_Next(it)) != NULL) {
while (1) {
PyObject *key;
int status = PyIter_NextItem(it, &key);
if (status < 0) {
goto Fail;
}
if (status == 0) {
break;
}

// setitem_take2_lock_held consumes a reference to key
status = setitem_take2_lock_held((PyDictObject *)d,
key, Py_NewRef(value));
if (status < 0) {
assert(PyErr_Occurred());
goto Fail;
}
}
}
else {
while ((key = PyIter_Next(it)) != NULL) {
while (1) {
PyObject *key;
int status = PyIter_NextItem(it, &key);
if (status < 0) {
goto Fail;
}
if (status == 0) {
break;
}

status = PyObject_SetItem(d, key, value);
Py_DECREF(key);
if (status < 0)
if (status < 0) {
goto Fail;
}
}

}

if (PyErr_Occurred())
goto Fail;
assert(!PyErr_Occurred());
Py_DECREF(it);
return d;
goto Done;

Fail:
Py_DECREF(it);
assert(PyErr_Occurred());
Py_XDECREF(it);
Py_DECREF(d);
return NULL;

Done:
if (d == NULL) {
return NULL;
}

if (need_copy) {
PyObject *copy = _PyObject_CallOneArg(cls, d);
Py_SETREF(d, copy);
}
else if (!_PyObject_GC_IS_TRACKED(d)) {
_PyObject_GC_TRACK(d);
}
return d;
}

/* Methods */
Expand Down Expand Up @@ -4147,9 +4197,6 @@ dict_dict_merge(PyDictObject *mp, PyDictObject *other, int override, PyObject **
set_keys(mp, keys);
STORE_USED(mp, other->ma_used);
ASSERT_CONSISTENT(mp);
if (PyDict_Check(mp)) {
assert(_PyObject_GC_IS_TRACKED(mp));
}
return 0;
}
}
Expand Down Expand Up @@ -4316,7 +4363,12 @@ dict_merge_api(PyObject *a, PyObject *b, int override, PyObject **dupkey)
}
return -1;
}
return dict_merge(a, b, override, dupkey);

int res = dict_merge(a, b, override, dupkey);
if (PyDict_Check(a)) {
assert(_PyObject_GC_IS_TRACKED(a));
}
return res;
}

int
Expand Down Expand Up @@ -4475,10 +4527,15 @@ copy_lock_held(PyObject *o, int as_frozendict)
}
if (copy == NULL)
return NULL;
if (dict_merge(copy, o, 1, NULL) == 0)
return copy;
Py_DECREF(copy);
return NULL;
if (dict_merge(copy, o, 1, NULL) < 0) {
Py_DECREF(copy);
return NULL;
}

if (PyDict_Check(copy)) {
assert(_PyObject_GC_IS_TRACKED(copy));
}
return copy;
}

PyObject *
Expand Down Expand Up @@ -5239,11 +5296,11 @@ static PyNumberMethods dict_as_number = {
.nb_inplace_or = _PyDict_IOr,
};

static PyObject *
dict_new_untracked(PyTypeObject *type)
static PyObject*
anydict_new_untracked(PyTypeObject *type)
{
assert(type != NULL);
// dict subclasses must implement the GC protocol
// dict and frozendict subclasses must implement the GC protocol
assert(_PyType_IS_GC(type));

PyObject *self = _PyType_AllocNoTrack(type, 0);
Expand All @@ -5262,6 +5319,14 @@ dict_new_untracked(PyTypeObject *type)
return self;
}

static PyObject*
dict_new_untracked(PyTypeObject *type)
{
assert(PyObject_IsSubclass((PyObject*)type, (PyObject*)&PyDict_Type));

return anydict_new_untracked(type);
}

static PyObject *
dict_new(PyTypeObject *type, PyObject *Py_UNUSED(args), PyObject *Py_UNUSED(kwds))
{
Expand Down Expand Up @@ -8377,7 +8442,9 @@ frozendict_hash(PyObject *op)
static PyObject *
frozendict_new_untracked(PyTypeObject *type)
{
PyObject *d = dict_new_untracked(type);
assert(PyObject_IsSubclass((PyObject*)type, (PyObject*)&PyFrozenDict_Type));

PyObject *d = anydict_new_untracked(type);
if (d == NULL) {
return NULL;
}
Expand Down
Loading