@@ -167,6 +167,73 @@ def _get_column(data: pl.DataFrame, column: str) -> np.ndarray:
167167 return data [column ].to_numpy ()
168168
169169
170+ def _extract_needed_columns (
171+ terms : dict [str , dict [str , Any ]],
172+ response : str | None = None ,
173+ interactions : list [dict [str , Any ]] | None = None ,
174+ offset : str | np .ndarray | None = None ,
175+ weights : str | np .ndarray | None = None ,
176+ complement : str | np .ndarray | None = None ,
177+ ) -> set [str ]:
178+ """Extract all DataFrame column names needed to build this model.
179+
180+ Parameters
181+ ----------
182+ terms : dict
183+ Term specifications (same format as glm_dict).
184+ response : str, optional
185+ Response column name. Omit for prediction (no response needed).
186+ interactions, offset, weights, complement
187+ Same as glm_dict parameters.
188+ """
189+ import re
190+
191+ cols : set [str ] = set ()
192+ if response is not None :
193+ cols .add (response )
194+
195+ for var_name , spec in terms .items ():
196+ term_type = spec .get ("type" , "linear" )
197+ if term_type == "expression" :
198+ expr = spec ["expr" ]
199+ for token in re .findall (r"\b([A-Za-z_]\w*)\b" , expr ):
200+ cols .add (token )
201+ else :
202+ cols .add (var_name )
203+
204+ if interactions :
205+ for ix in interactions :
206+ for key in ix :
207+ if key in ("include_main" , "target_encoding" , "frequency_encoding" , "prior_weight" ):
208+ continue
209+ cols .add (key )
210+
211+ if isinstance (offset , str ):
212+ cols .add (offset )
213+ if isinstance (weights , str ):
214+ cols .add (weights )
215+ if isinstance (complement , str ):
216+ cols .add (complement )
217+
218+ return cols
219+
220+
221+ def _collect_lazyframe (
222+ data : pl .DataFrame | pl .LazyFrame ,
223+ needed_columns : set [str ],
224+ ) -> pl .DataFrame :
225+ """If data is a LazyFrame, select only needed columns and collect. Otherwise return as-is."""
226+ import polars as pl
227+
228+ if not isinstance (data , pl .LazyFrame ):
229+ return data
230+
231+ if needed_columns :
232+ return data .select (sorted (needed_columns )).collect ()
233+
234+ return data .collect ()
235+
236+
170237# Import from interactions module (the canonical implementation)
171238from rustystats .interactions import InteractionBuilder
172239
@@ -1400,7 +1467,7 @@ def diagnostics_json(
14001467
14011468 def predict (
14021469 self ,
1403- new_data : pl .DataFrame ,
1470+ new_data : pl .DataFrame | pl . LazyFrame ,
14041471 offset : str | np .ndarray | None = None ,
14051472 complement : str | np .ndarray | None = None ,
14061473 ) -> np .ndarray :
@@ -1409,8 +1476,9 @@ def predict(
14091476
14101477 Parameters
14111478 ----------
1412- new_data : pl.DataFrame
1479+ new_data : pl.DataFrame or pl.LazyFrame
14131480 New data to predict on. Must have the same columns as training data.
1481+ If a LazyFrame, only needed columns are collected.
14141482 offset : str or array-like, optional
14151483 Offset for new data. If None and the model was fit with an offset
14161484 column name, that column will be extracted from new_data.
@@ -1442,6 +1510,18 @@ def predict(
14421510 "Use fittedvalues for training data predictions."
14431511 )
14441512
1513+ # Resolve LazyFrame: select only columns needed for prediction
1514+ if self ._terms_dict is not None :
1515+ needed = _extract_needed_columns (
1516+ terms = self ._terms_dict ,
1517+ interactions = self ._interactions_spec ,
1518+ offset = offset if offset is not None else self ._offset_spec ,
1519+ complement = complement if complement is not None else self ._complement_spec ,
1520+ )
1521+ new_data = _collect_lazyframe (new_data , needed )
1522+ else :
1523+ new_data = _collect_lazyframe (new_data , set ())
1524+
14451525 # Build design matrix for new data using stored encoding state
14461526 X_new = self ._builder .transform_new_data (new_data )
14471527
@@ -2211,7 +2291,7 @@ def __init__(
22112291 self ._offset_spec = offset
22122292 self ._weights_spec = weights
22132293 self ._seed = seed
2214- self ._complement_spec = complement if isinstance (complement , ( str , GLMModel ) ) else None
2294+ self ._complement_spec = complement if isinstance (complement , str | GLMModel ) else None
22152295 self ._complement_values = None # Set by _process_complement
22162296
22172297 # Build formula string for compatibility (used in results/diagnostics)
@@ -2475,7 +2555,7 @@ def fit(
24752555def glm_dict (
24762556 response : str ,
24772557 terms : dict [str , dict [str , Any ]],
2478- data : pl .DataFrame ,
2558+ data : pl .DataFrame | pl . LazyFrame ,
24792559 interactions : list [dict [str , Any ]] | None = None ,
24802560 intercept : bool = True ,
24812561 family : str = "gaussian" ,
@@ -2511,8 +2591,10 @@ def glm_dict(
25112591 - ``{"type": "expression", "expr": "x**2"}`` - expression
25122592 - ``{"type": "linear", "monotonicity": "increasing"}`` - constrained
25132593
2514- data : pl.DataFrame
2515- Polars DataFrame containing the data.
2594+ data : pl.DataFrame or pl.LazyFrame
2595+ Polars DataFrame or LazyFrame containing the data. If a LazyFrame
2596+ is passed, only the columns needed by the model are collected,
2597+ enabling optimized reads from Parquet/CSV scans.
25162598 interactions : list of dict, optional
25172599 List of interaction specifications. Each is a dict with variable
25182600 names as keys and their specs as values, plus 'include_main'.
@@ -2561,6 +2643,16 @@ def glm_dict(
25612643 ... offset="Exposure",
25622644 ... ).fit()
25632645
2646+ >>> # LazyFrame: only needed columns are collected
2647+ >>> lf = pl.scan_parquet("insurance.parquet")
2648+ >>> result = rs.glm_dict(
2649+ ... response="ClaimCount",
2650+ ... terms={"VehAge": {"type": "linear"}, "Region": {"type": "categorical"}},
2651+ ... data=lf,
2652+ ... family="poisson",
2653+ ... offset="Exposure",
2654+ ... ).fit()
2655+
25642656 >>> # Lasso credibility: shrink state model toward countrywide rates
25652657 >>> state_result = rs.glm_dict(
25662658 ... response="ClaimCount",
@@ -2574,6 +2666,17 @@ def glm_dict(
25742666 ... complement="countrywide_rate",
25752667 ... ).fit(regularization="lasso")
25762668 """
2669+ # Resolve LazyFrame: select only needed columns, then collect
2670+ needed = _extract_needed_columns (
2671+ terms ,
2672+ response = response ,
2673+ interactions = interactions ,
2674+ offset = offset ,
2675+ weights = weights ,
2676+ complement = complement ,
2677+ )
2678+ data = _collect_lazyframe (data , needed )
2679+
25772680 return FormulaGLMDict (
25782681 response = response ,
25792682 terms = terms ,
0 commit comments