Skip to content

Commit 7de1c03

Browse files
authored
Merge pull request #177 from maciejkula/fix_segfault_not_fitted
Fix segfault when predicting on a model that hasn't been fitted.
2 parents 436b1de + 4be2761 commit 7de1c03

3 files changed

Lines changed: 46 additions & 0 deletions

File tree

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
### Fixed
77
- recall_at_k and precision_at_k now work correctly at k=1 (thanks to Zank Bennett).
88
- Moved Movielens data to data release to prevent grouplens server flakiness from affecting users.
9+
- Fix segfault when trying to predict from a model that has not been fitted.
910

1011
## [1.12][2017-01-26]
1112
### Changed

lightfm/lightfm.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,25 @@ def _reset_state(self):
177177
self.user_bias_gradients = None
178178
self.user_bias_momentum = None
179179

180+
def _check_initialized(self):
181+
182+
for var in (self.item_embeddings,
183+
self.item_embedding_gradients,
184+
self.item_embedding_momentum,
185+
self.item_biases,
186+
self.item_bias_gradients,
187+
self.item_bias_momentum,
188+
self.user_embeddings,
189+
self.user_embedding_gradients,
190+
self.user_embedding_momentum,
191+
self.user_biases,
192+
self.user_bias_gradients,
193+
self.user_bias_momentum):
194+
195+
if var is None:
196+
raise ValueError('You must fit the model before '
197+
'trying to obtain predictions.')
198+
180199
def _initialize(self, no_components, no_item_features, no_user_features):
181200
"""
182201
Initialise internal latent representations.
@@ -596,6 +615,8 @@ def predict(self, user_ids, item_ids, item_features=None,
596615
by the inputs.
597616
"""
598617

618+
self._check_initialized()
619+
599620
if not isinstance(user_ids, np.ndarray):
600621
user_ids = np.repeat(np.int32(user_ids), len(item_ids))
601622

@@ -670,6 +691,8 @@ def predict_rank(self, test_interactions, train_interactions=None,
670691
input interactions matrix.
671692
"""
672693

694+
self._check_initialized()
695+
673696
n_users, n_items = test_interactions.shape
674697

675698
(user_features,
@@ -731,6 +754,8 @@ def get_item_representations(self, features=None):
731754
Biases and latent representations for items.
732755
"""
733756

757+
self._check_initialized()
758+
734759
if features is None:
735760
return self.item_biases, self.item_embeddings
736761

@@ -758,6 +783,8 @@ def get_user_representations(self, features=None):
758783
Biases and latent representations for users.
759784
"""
760785

786+
self._check_initialized()
787+
761788
if features is None:
762789
return self.user_biases, self.user_embeddings
763790

tests/test_api.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,21 @@ def test_sklearn_api():
332332
params['invalid_param'] = 666
333333
with pytest.raises(ValueError):
334334
model.set_params(**params)
335+
336+
337+
def test_predict_not_fitte():
338+
339+
model = LightFM()
340+
341+
with pytest.raises(ValueError):
342+
model.predict(np.arange(10),
343+
np.arange(10))
344+
345+
with pytest.raises(ValueError):
346+
model.predict_rank(1)
347+
348+
with pytest.raises(ValueError):
349+
model.get_user_representations()
350+
351+
with pytest.raises(ValueError):
352+
model.get_item_representations()

0 commit comments

Comments
 (0)