Skip to content

Commit ac58a47

Browse files
fyellinmhostetter
andauthored
Fix unpickling of FieldArray instances in new interpreter contexts (#639)
* Allow pickling of galois fields * Dynamically create and test pickled objects * Add to changelog * Make `FieldArray` subclasses hashable * Improve unit tests * Address @fyellin comments --------- Co-authored-by: mhostetter <matthostetter@gmail.com> Co-authored-by: mhostetter <mhostetter@users.noreply.github.com>
1 parent 79cc1fa commit ac58a47

6 files changed

Lines changed: 133 additions & 39 deletions

File tree

docs/release-notes/unreleased.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@ tocdepth: 2
1212

1313
### Features
1414

15+
- Make `FieldArray` subclasses hashable based on their properties. ([#639](https://github.com/mhostetter/galois/pull/639))
1516
- Improved FFT speed (~1.6x speedup) in `np.fft.fft()` with mixed-radix Cooley-Tukey algorithm. ([#620](https://github.com/mhostetter/galois/pull/620))
1617

1718
### Fixes
1819

19-
-
20+
- Fixed bug where `FieldArray` instances couldn't be unpickled if the `FieldArray` class had not yet be created. ([#639](https://github.com/mhostetter/galois/pull/639))
2021

2122
### Performance
2223

src/galois/_fields/_factory.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
from __future__ import annotations
66

7+
import copyreg
78
import sys
89
import types
910
import warnings
10-
from typing import Type, overload
11+
from typing import Any, Dict, Tuple, Type, overload
1112

1213
from typing_extensions import Literal
1314

@@ -18,6 +19,7 @@
1819
from ..typing import PolyLike
1920
from ._array import FieldArray
2021
from ._gf2 import GF2
22+
from ._meta import FieldArrayMeta
2123
from ._primitive_element import is_primitive_element, primitive_element
2224
from ._ufunc import UFuncMixin_2_m, UFuncMixin_p_1, UFuncMixin_p_m
2325

@@ -531,3 +533,46 @@ def _GF_extension(
531533

532534

533535
_GF_extension._classes = {}
536+
537+
538+
def _reconstruct_field_class(args: Tuple, kwargs: Dict[str, Any]):
539+
"""
540+
Reconstruct a field class via `galois.GF(...)`.
541+
542+
Pickle's reduce protocol passes positional args only, so we wrap keyword arguments
543+
in a dict and unpack them here.
544+
"""
545+
return GF(*args, **kwargs)
546+
547+
548+
def _reduce_field_class(field_cls) -> Tuple[object, Tuple[Dict[str, Any]]]:
549+
"""
550+
Pickle reducer for dynamically-created field classes (FieldArray subclasses).
551+
552+
We serialize the minimal set of constructor kwargs needed to reconstruct the same
553+
field class via the GF factory on unpickle.
554+
"""
555+
args = (
556+
int(field_cls.characteristic),
557+
int(field_cls.degree),
558+
)
559+
560+
kwargs: Dict[str, Any] = {
561+
"primitive_element": int(field_cls.primitive_element),
562+
"verify": False,
563+
"compile": field_cls.ufunc_mode, # Restore the field's current ufunc mode on reconstruction
564+
"repr": field_cls.element_repr,
565+
}
566+
567+
# Only extension fields have an irreducible polynomial. Encode as a string to avoid
568+
# formatting / parsing issues.
569+
if field_cls.degree > 1:
570+
kwargs["irreducible_poly"] = str(field_cls.irreducible_poly)
571+
572+
# Return (callable, args) where args is a tuple of positional args; we pass kwargs as one arg.
573+
return (_reconstruct_field_class, (args, kwargs))
574+
575+
576+
# Register pickling for the metaclass used by field classes.
577+
# FieldArrayMeta is your metaclass (import it appropriately here).
578+
copyreg.pickle(FieldArrayMeta, _reduce_field_class)

src/galois/_fields/_meta.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ def __init__(cls, name, bases, namespace, **kwargs):
5151
cls._name = f"GF({cls._order_str})"
5252
cls._long_name = f"GF({cls._order_str}, primitive_element={cls._primitive_element_str!r}, irreducible_poly={cls._irreducible_poly_str!r})"
5353

54+
def __hash__(cls):
55+
t = (cls.characteristic, cls.degree, cls._irreducible_poly_int, int(cls._primitive_element))
56+
return hash(t)
57+
5458
def __repr__(cls) -> str:
5559
if cls.order == 0:
5660
# This is not a runtime-created subclass, so return the base class name.

src/galois/_polys/_poly.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,7 @@ def __int__(self) -> int:
10141014
return self.__index__()
10151015

10161016
def __hash__(self):
1017-
t = (self.field.order, *self.nonzero_degrees.tolist(), *self.nonzero_coeffs.tolist())
1017+
t = (hash(self.field), hash(tuple(self.nonzero_degrees.tolist())), hash(tuple(self.nonzero_coeffs.tolist())))
10181018
return hash(t)
10191019

10201020
@overload

tests/fields/test_classes.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
A pytest module to test the class attributes of FieldArray subclasses.
33
"""
44

5-
import pickle
6-
75
import numpy as np
86
import pytest
97

@@ -117,37 +115,3 @@ def test_is_primitive_poly():
117115
poly = galois.conway_poly(3, 101)
118116
GF = galois.GF(3**101, irreducible_poly=poly, primitive_element="x", verify=False)
119117
assert GF.is_primitive_poly
120-
121-
122-
def test_pickle_class(tmp_path):
123-
GF = galois.GF(13)
124-
with open(tmp_path / "class.pkl", "wb") as f:
125-
pickle.dump(GF, f)
126-
with open(tmp_path / "class.pkl", "rb") as f:
127-
GF_loaded = pickle.load(f)
128-
assert GF is GF_loaded
129-
130-
GF = galois.GF(3**5)
131-
with open(tmp_path / "class.pkl", "wb") as f:
132-
pickle.dump(GF, f)
133-
with open(tmp_path / "class.pkl", "rb") as f:
134-
GF_loaded = pickle.load(f)
135-
assert GF is GF_loaded
136-
137-
138-
def test_pickle_array(tmp_path):
139-
GF = galois.GF(13)
140-
x = GF.Random(10)
141-
with open(tmp_path / "array.pkl", "wb") as f:
142-
pickle.dump(x, f)
143-
with open(tmp_path / "array.pkl", "rb") as f:
144-
x_loaded = pickle.load(f)
145-
assert np.array_equal(x, x_loaded)
146-
147-
GF = galois.GF(3**5)
148-
x = GF.Random(10)
149-
with open(tmp_path / "array.pkl", "wb") as f:
150-
pickle.dump(x, f)
151-
with open(tmp_path / "array.pkl", "rb") as f:
152-
x_loaded = pickle.load(f)
153-
assert np.array_equal(x, x_loaded)

tests/fields/test_pickle.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""
2+
A pytest module to test pickling and unpickling of FieldArray subclasses and their instances.
3+
"""
4+
5+
import os
6+
import pickle
7+
import subprocess
8+
import sys
9+
10+
import numpy as np
11+
import pytest
12+
13+
import galois
14+
15+
16+
@pytest.mark.parametrize("order", [3, 3**2])
17+
def test_pickle_field_array_in_same_interpreter(order, tmp_path):
18+
GF = galois.GF(order)
19+
x = GF.Random(10)
20+
21+
# Write the pickle artifact
22+
pkl_path = tmp_path / "field_array.pkl"
23+
with pkl_path.open("wb") as f:
24+
pickle.dump(x, f)
25+
26+
# Read the pickle artifact
27+
with pkl_path.open("rb") as f:
28+
x2 = pickle.load(f)
29+
30+
assert type(x) is type(x2)
31+
assert np.array_equal(x, x2)
32+
33+
34+
@pytest.mark.parametrize("order", [3, 3**2])
35+
def test_pickle_field_array_in_new_interpreter(order, tmp_path):
36+
GF = galois.GF(order)
37+
x = GF.Random(10)
38+
39+
# Write the pickle artifact
40+
pkl_path = tmp_path / "field_array.pkl"
41+
with pkl_path.open("wb") as f:
42+
pickle.dump(x, f)
43+
44+
# Capture "expected" in a representation that is easy to embed as Python literals.
45+
# Use integers of the underlying representation, not repr(x).
46+
expected_hash = hash(type(x))
47+
expected_properties = type(x).properties
48+
expected_values = x.tolist()
49+
50+
# Run a fresh interpreter that ONLY imports galois and unpickles.
51+
code = f"""
52+
import pickle
53+
import numpy as np
54+
import galois
55+
56+
with open(r"{pkl_path}", "rb") as f:
57+
x2 = pickle.load(f)
58+
59+
assert hash(type(x2)) == {expected_hash}
60+
assert type(x2).properties == {expected_properties!r}
61+
assert x2.tolist() == {expected_values!r}
62+
"""
63+
64+
env = os.environ.copy()
65+
66+
# Ensure subprocess imports your working tree version (editable checkouts, etc.)
67+
repo_root = os.getcwd()
68+
env["PYTHONPATH"] = repo_root + os.pathsep + env.get("PYTHONPATH", "")
69+
70+
result = subprocess.run(
71+
[sys.executable, "-c", code],
72+
env=env,
73+
capture_output=True,
74+
text=True,
75+
check=False,
76+
)
77+
78+
assert result.returncode == 0, (
79+
f"Unpickling in a fresh interpreter failed.\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}\n"
80+
)

0 commit comments

Comments
 (0)