Skip to content

Commit 810897a

Browse files
committed
Merge pull request #81 from maciejkula/default_max_sampled
Change default max_sampled to 10.
2 parents 7a9aa38 + a001ff1 commit 810897a

2 files changed

Lines changed: 9 additions & 9 deletions

File tree

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
### Changed
1515
- By default, an OpenMP-less version will be built on OSX. This allows much easier installation at the expense of
1616
performance.
17+
- The default value of the max_sampled argument is now 10. This represents a decent default value that allows fast training.
1718

1819
## [1.8][2016-01-14]
1920
### Changed

lightfm/lightfm.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ class LightFM(object):
4545
user_alpha: float, optional
4646
L2 penalty on user features.
4747
max_sampled: int, optional
48-
maximum number of negative samples used during WARP fitting. Defaults to
49-
the number of items divided by 10. Setting this to lower number may improve the speed of
50-
WARP fitting at the expense of some accuracy.
48+
maximum number of negative samples used during WARP fitting. It requires
49+
a lot of sampling to find negative triplets for users that are already
50+
well represented by the model; this can lead to very long training times
51+
and overfitting. Setting this to a higher number will generally lead
52+
to longer training times, but may in some cases improve accuracy.
5153
random_state: int seed, RandomState instance, or None
5254
The seed of the pseudo random number generator to use when shuffling the data and
5355
initializing the parameters.
@@ -116,7 +118,7 @@ def __init__(self, no_components=10, k=5, n=10,
116118
learning_schedule='adagrad',
117119
loss='logistic',
118120
learning_rate=0.05, rho=0.95, epsilon=1e-6,
119-
item_alpha=0.0, user_alpha=0.0, max_sampled=None,
121+
item_alpha=0.0, user_alpha=0.0, max_sampled=10,
120122
random_state=None):
121123

122124
assert item_alpha >= 0.0
@@ -129,7 +131,7 @@ def __init__(self, no_components=10, k=5, n=10,
129131
assert learning_schedule in ('adagrad', 'adadelta')
130132
assert loss in ('logistic', 'warp', 'bpr', 'warp-kos')
131133

132-
if max_sampled is not None and max_sampled < 1:
134+
if max_sampled < 1:
133135
raise ValueError('max_sampled must be a positive integer')
134136

135137
self.loss = loss
@@ -294,9 +296,6 @@ def _process_sample_weight(self, interactions, sample_weight):
294296

295297
def _get_lightfm_data(self):
296298

297-
max_sampled = (self.max_sampled if self.max_sampled is not None
298-
else self.item_embeddings.shape[0] / 10)
299-
300299
lightfm_data = FastLightFM(self.item_embeddings,
301300
self.item_embedding_gradients,
302301
self.item_embedding_momentum,
@@ -314,7 +313,7 @@ def _get_lightfm_data(self):
314313
self.learning_rate,
315314
self.rho,
316315
self.epsilon,
317-
max_sampled)
316+
self.max_sampled)
318317

319318
return lightfm_data
320319

0 commit comments

Comments
 (0)