@@ -45,9 +45,11 @@ class LightFM(object):
4545 user_alpha: float, optional
4646 L2 penalty on user features.
4747 max_sampled: int, optional
48- maximum number of negative samples used during WARP fitting. Defaults to
49- the number of items divided by 10. Setting this to lower number may improve the speed of
50- WARP fitting at the expense of some accuracy.
48+ maximum number of negative samples used during WARP fitting. It requires
49+ a lot of sampling to find negative triplets for users that are already
50+ well represented by the model; this can lead to very long training times
51+ and overfitting. Setting this to a higher number will generally lead
52+ to longer training times, but may in some cases improve accuracy.
5153 random_state: int seed, RandomState instance, or None
5254 The seed of the pseudo random number generator to use when shuffling the data and
5355 initializing the parameters.
@@ -116,7 +118,7 @@ def __init__(self, no_components=10, k=5, n=10,
116118 learning_schedule = 'adagrad' ,
117119 loss = 'logistic' ,
118120 learning_rate = 0.05 , rho = 0.95 , epsilon = 1e-6 ,
119- item_alpha = 0.0 , user_alpha = 0.0 , max_sampled = None ,
121+ item_alpha = 0.0 , user_alpha = 0.0 , max_sampled = 10 ,
120122 random_state = None ):
121123
122124 assert item_alpha >= 0.0
@@ -129,7 +131,7 @@ def __init__(self, no_components=10, k=5, n=10,
129131 assert learning_schedule in ('adagrad' , 'adadelta' )
130132 assert loss in ('logistic' , 'warp' , 'bpr' , 'warp-kos' )
131133
132- if max_sampled is not None and max_sampled < 1 :
134+ if max_sampled < 1 :
133135 raise ValueError ('max_sampled must be a positive integer' )
134136
135137 self .loss = loss
@@ -294,9 +296,6 @@ def _process_sample_weight(self, interactions, sample_weight):
294296
295297 def _get_lightfm_data (self ):
296298
297- max_sampled = (self .max_sampled if self .max_sampled is not None
298- else self .item_embeddings .shape [0 ] / 10 )
299-
300299 lightfm_data = FastLightFM (self .item_embeddings ,
301300 self .item_embedding_gradients ,
302301 self .item_embedding_momentum ,
@@ -314,7 +313,7 @@ def _get_lightfm_data(self):
314313 self .learning_rate ,
315314 self .rho ,
316315 self .epsilon ,
317- max_sampled )
316+ self . max_sampled )
318317
319318 return lightfm_data
320319
0 commit comments