Skip to content

Commit a28a1f6

Browse files
test commit
1 parent 4cb00b1 commit a28a1f6

11 files changed

Lines changed: 2290 additions & 1437 deletions

python/rustystats/formula.py

Lines changed: 109 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,73 @@ def _get_column(data: pl.DataFrame, column: str) -> np.ndarray:
167167
return data[column].to_numpy()
168168

169169

170+
def _extract_needed_columns(
171+
terms: dict[str, dict[str, Any]],
172+
response: str | None = None,
173+
interactions: list[dict[str, Any]] | None = None,
174+
offset: str | np.ndarray | None = None,
175+
weights: str | np.ndarray | None = None,
176+
complement: str | np.ndarray | None = None,
177+
) -> set[str]:
178+
"""Extract all DataFrame column names needed to build this model.
179+
180+
Parameters
181+
----------
182+
terms : dict
183+
Term specifications (same format as glm_dict).
184+
response : str, optional
185+
Response column name. Omit for prediction (no response needed).
186+
interactions, offset, weights, complement
187+
Same as glm_dict parameters.
188+
"""
189+
import re
190+
191+
cols: set[str] = set()
192+
if response is not None:
193+
cols.add(response)
194+
195+
for var_name, spec in terms.items():
196+
term_type = spec.get("type", "linear")
197+
if term_type == "expression":
198+
expr = spec["expr"]
199+
for token in re.findall(r"\b([A-Za-z_]\w*)\b", expr):
200+
cols.add(token)
201+
else:
202+
cols.add(var_name)
203+
204+
if interactions:
205+
for ix in interactions:
206+
for key in ix:
207+
if key in ("include_main", "target_encoding", "frequency_encoding", "prior_weight"):
208+
continue
209+
cols.add(key)
210+
211+
if isinstance(offset, str):
212+
cols.add(offset)
213+
if isinstance(weights, str):
214+
cols.add(weights)
215+
if isinstance(complement, str):
216+
cols.add(complement)
217+
218+
return cols
219+
220+
221+
def _collect_lazyframe(
222+
data: pl.DataFrame | pl.LazyFrame,
223+
needed_columns: set[str],
224+
) -> pl.DataFrame:
225+
"""If data is a LazyFrame, select only needed columns and collect. Otherwise return as-is."""
226+
import polars as pl
227+
228+
if not isinstance(data, pl.LazyFrame):
229+
return data
230+
231+
if needed_columns:
232+
return data.select(sorted(needed_columns)).collect()
233+
234+
return data.collect()
235+
236+
170237
# Import from interactions module (the canonical implementation)
171238
from rustystats.interactions import InteractionBuilder
172239

@@ -1400,7 +1467,7 @@ def diagnostics_json(
14001467

14011468
def predict(
14021469
self,
1403-
new_data: pl.DataFrame,
1470+
new_data: pl.DataFrame | pl.LazyFrame,
14041471
offset: str | np.ndarray | None = None,
14051472
complement: str | np.ndarray | None = None,
14061473
) -> np.ndarray:
@@ -1409,8 +1476,9 @@ def predict(
14091476
14101477
Parameters
14111478
----------
1412-
new_data : pl.DataFrame
1479+
new_data : pl.DataFrame or pl.LazyFrame
14131480
New data to predict on. Must have the same columns as training data.
1481+
If a LazyFrame, only needed columns are collected.
14141482
offset : str or array-like, optional
14151483
Offset for new data. If None and the model was fit with an offset
14161484
column name, that column will be extracted from new_data.
@@ -1442,6 +1510,18 @@ def predict(
14421510
"Use fittedvalues for training data predictions."
14431511
)
14441512

1513+
# Resolve LazyFrame: select only columns needed for prediction
1514+
if self._terms_dict is not None:
1515+
needed = _extract_needed_columns(
1516+
terms=self._terms_dict,
1517+
interactions=self._interactions_spec,
1518+
offset=offset if offset is not None else self._offset_spec,
1519+
complement=complement if complement is not None else self._complement_spec,
1520+
)
1521+
new_data = _collect_lazyframe(new_data, needed)
1522+
else:
1523+
new_data = _collect_lazyframe(new_data, set())
1524+
14451525
# Build design matrix for new data using stored encoding state
14461526
X_new = self._builder.transform_new_data(new_data)
14471527

@@ -2211,7 +2291,7 @@ def __init__(
22112291
self._offset_spec = offset
22122292
self._weights_spec = weights
22132293
self._seed = seed
2214-
self._complement_spec = complement if isinstance(complement, (str, GLMModel)) else None
2294+
self._complement_spec = complement if isinstance(complement, str | GLMModel) else None
22152295
self._complement_values = None # Set by _process_complement
22162296

22172297
# Build formula string for compatibility (used in results/diagnostics)
@@ -2475,7 +2555,7 @@ def fit(
24752555
def glm_dict(
24762556
response: str,
24772557
terms: dict[str, dict[str, Any]],
2478-
data: pl.DataFrame,
2558+
data: pl.DataFrame | pl.LazyFrame,
24792559
interactions: list[dict[str, Any]] | None = None,
24802560
intercept: bool = True,
24812561
family: str = "gaussian",
@@ -2511,8 +2591,10 @@ def glm_dict(
25112591
- ``{"type": "expression", "expr": "x**2"}`` - expression
25122592
- ``{"type": "linear", "monotonicity": "increasing"}`` - constrained
25132593
2514-
data : pl.DataFrame
2515-
Polars DataFrame containing the data.
2594+
data : pl.DataFrame or pl.LazyFrame
2595+
Polars DataFrame or LazyFrame containing the data. If a LazyFrame
2596+
is passed, only the columns needed by the model are collected,
2597+
enabling optimized reads from Parquet/CSV scans.
25162598
interactions : list of dict, optional
25172599
List of interaction specifications. Each is a dict with variable
25182600
names as keys and their specs as values, plus 'include_main'.
@@ -2561,6 +2643,16 @@ def glm_dict(
25612643
... offset="Exposure",
25622644
... ).fit()
25632645
2646+
>>> # LazyFrame: only needed columns are collected
2647+
>>> lf = pl.scan_parquet("insurance.parquet")
2648+
>>> result = rs.glm_dict(
2649+
... response="ClaimCount",
2650+
... terms={"VehAge": {"type": "linear"}, "Region": {"type": "categorical"}},
2651+
... data=lf,
2652+
... family="poisson",
2653+
... offset="Exposure",
2654+
... ).fit()
2655+
25642656
>>> # Lasso credibility: shrink state model toward countrywide rates
25652657
>>> state_result = rs.glm_dict(
25662658
... response="ClaimCount",
@@ -2574,6 +2666,17 @@ def glm_dict(
25742666
... complement="countrywide_rate",
25752667
... ).fit(regularization="lasso")
25762668
"""
2669+
# Resolve LazyFrame: select only needed columns, then collect
2670+
needed = _extract_needed_columns(
2671+
terms,
2672+
response=response,
2673+
interactions=interactions,
2674+
offset=offset,
2675+
weights=weights,
2676+
complement=complement,
2677+
)
2678+
data = _collect_lazyframe(data, needed)
2679+
25772680
return FormulaGLMDict(
25782681
response=response,
25792682
terms=terms,

0 commit comments

Comments
 (0)