@@ -177,6 +177,25 @@ def _reset_state(self):
177177 self .user_bias_gradients = None
178178 self .user_bias_momentum = None
179179
180+ def _check_initialized (self ):
181+
182+ for var in (self .item_embeddings ,
183+ self .item_embedding_gradients ,
184+ self .item_embedding_momentum ,
185+ self .item_biases ,
186+ self .item_bias_gradients ,
187+ self .item_bias_momentum ,
188+ self .user_embeddings ,
189+ self .user_embedding_gradients ,
190+ self .user_embedding_momentum ,
191+ self .user_biases ,
192+ self .user_bias_gradients ,
193+ self .user_bias_momentum ):
194+
195+ if var is None :
196+ raise ValueError ('You must fit the model before '
197+ 'trying to obtain predictions.' )
198+
180199 def _initialize (self , no_components , no_item_features , no_user_features ):
181200 """
182201 Initialise internal latent representations.
@@ -596,6 +615,8 @@ def predict(self, user_ids, item_ids, item_features=None,
596615 by the inputs.
597616 """
598617
618+ self ._check_initialized ()
619+
599620 if not isinstance (user_ids , np .ndarray ):
600621 user_ids = np .repeat (np .int32 (user_ids ), len (item_ids ))
601622
@@ -670,6 +691,8 @@ def predict_rank(self, test_interactions, train_interactions=None,
670691 input interactions matrix.
671692 """
672693
694+ self ._check_initialized ()
695+
673696 n_users , n_items = test_interactions .shape
674697
675698 (user_features ,
@@ -731,6 +754,8 @@ def get_item_representations(self, features=None):
731754 Biases and latent representations for items.
732755 """
733756
757+ self ._check_initialized ()
758+
734759 if features is None :
735760 return self .item_biases , self .item_embeddings
736761
@@ -758,6 +783,8 @@ def get_user_representations(self, features=None):
758783 Biases and latent representations for users.
759784 """
760785
786+ self ._check_initialized ()
787+
761788 if features is None :
762789 return self .user_biases , self .user_embeddings
763790
0 commit comments