Skip to content

Commit 79c9c3a

Browse files
Fix issue with not computing numberofpairs if should is false
1 parent 2d99cc5 commit 79c9c3a

10 files changed

Lines changed: 91 additions & 24 deletions

File tree

algo-common/src/main/java/org/neo4j/gds/result/SimilarityStatistics.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,6 @@ public record Histogram(Optional<DoubleHistogram> histogram, boolean success)
9393

9494
public record SimilarityStats(Optional<DoubleHistogram> histogram, long computeMilliseconds, boolean success) {}
9595

96-
public record SimilarityDistributionResults(long numberOfEntries, Map<String,Object> distribution, long computeMilliseconds) {}
96+
public record SimilarityDistributionResults(long numberOfSimilarityPairs, Map<String,Object> distribution, long computeMilliseconds) {}
9797

9898
}

procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/similarity/FilteredKnnResultBuilderForStatsMode.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.neo4j.gds.termination.TerminationFlag;
2828

2929
import java.util.Optional;
30+
import java.util.OptionalLong;
3031
import java.util.stream.Stream;
3132

3233
class FilteredKnnResultBuilderForStatsMode implements StatsResultBuilder<FilteredKnnResult, Stream<KnnStatsResult>> {
@@ -64,7 +65,8 @@ public Stream<KnnStatsResult> build(
6465
configuration.concurrency(),
6566
filteredKnnResult.similarityResultStream(),
6667
shouldComputeSimilarityDistribution,
67-
terminationFlag
68+
terminationFlag,
69+
OptionalLong.of(filteredKnnResult.numberOfSimilarityPairs())
6870
);
6971

7072
return Stream.of(
@@ -73,7 +75,7 @@ public Stream<KnnStatsResult> build(
7375
timings.computeMillis,
7476
similarityStats.computeMilliseconds(),
7577
filteredKnnResult.nodesCompared(),
76-
filteredKnnResult.numberOfSimilarityPairs(),
78+
similarityStats.numberOfSimilarityPairs(),
7779
similarityStats.distribution(),
7880
filteredKnnResult.didConverge(),
7981
filteredKnnResult.ranIterations(),

procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/similarity/GenericNodeSimilarityResultBuilderForStatsMode.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import java.util.Map;
2828
import java.util.Optional;
29+
import java.util.OptionalLong;
2930
import java.util.stream.Stream;
3031

3132
class GenericNodeSimilarityResultBuilderForStatsMode {
@@ -51,7 +52,8 @@ Stream<SimilarityStatsResult> build(
5152
nodeSimilarityResult.isStream() ? concurrency : new Concurrency(1),
5253
nodeSimilarityResult.isStream()? nodeSimilarityResult.streamResult() : nodeSimilarityResult.topKMap().stream(),
5354
shouldComputeSimilarityDistribution,
54-
terminationFlag
55+
terminationFlag,
56+
OptionalLong.empty()
5557
);
5658

5759
return Stream.of(
@@ -60,7 +62,7 @@ Stream<SimilarityStatsResult> build(
6062
timings.computeMillis,
6163
similarityStats.computeMilliseconds(),
6264
nodeSimilarityResult.comparedNodes(),
63-
similarityStats.numberOfEntries(),
65+
similarityStats.numberOfSimilarityPairs(),
6466
similarityStats.distribution(),
6567
configurationMap
6668
)

procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/similarity/KnnResultBuilderForStatsMode.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.neo4j.gds.termination.TerminationFlag;
2828

2929
import java.util.Optional;
30+
import java.util.OptionalLong;
3031
import java.util.stream.Stream;
3132

3233
class KnnResultBuilderForStatsMode implements StatsResultBuilder<KnnResult, Stream<KnnStatsResult>> {
@@ -65,7 +66,8 @@ public Stream<KnnStatsResult> build(
6566
configuration.concurrency(),
6667
knnResult.streamSimilarityResult(),
6768
shouldComputeSimilarityDistribution,
68-
terminationFlag
69+
terminationFlag,
70+
OptionalLong.of(knnResult.totalSimilarityPairs())
6971
);
7072

7173
return Stream.of(
@@ -74,7 +76,7 @@ public Stream<KnnStatsResult> build(
7476
timings.computeMillis,
7577
similarityStats.computeMilliseconds(),
7678
knnResult.nodesCompared(),
77-
knnResult.totalSimilarityPairs(),
79+
similarityStats.numberOfSimilarityPairs(),
7880
similarityStats.distribution(),
7981
knnResult.didConverge(),
8082
knnResult.ranIterations(),

procedures/algorithms-facade/src/main/java/org/neo4j/gds/procedures/algorithms/similarity/SimilarityStatsProcessor.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.neo4j.gds.termination.TerminationFlag;
3030

3131
import java.util.Map;
32+
import java.util.OptionalLong;
3233
import java.util.concurrent.atomic.AtomicLong;
3334
import java.util.concurrent.atomic.LongAdder;
3435
import java.util.stream.Stream;
@@ -46,9 +47,16 @@ SimilarityStatistics.SimilarityDistributionResults computeSimilarityDistribution
4647
Concurrency concurrency,
4748
Stream<SimilarityResult> similarityResultStream,
4849
boolean shouldComputeSimilarityDistribution,
49-
TerminationFlag terminationFlag
50+
TerminationFlag terminationFlag,
51+
OptionalLong numberOfEntries
5052
) {
51-
if (!shouldComputeSimilarityDistribution) return EMPTY;
53+
if (!shouldComputeSimilarityDistribution && numberOfEntries.isPresent()){
54+
return new SimilarityStatistics.SimilarityDistributionResults(
55+
numberOfEntries.getAsLong(),
56+
Map.of(),
57+
0
58+
);
59+
}
5260

5361
var statsMillis = new AtomicLong();
5462
Map<String,Object> distribution;
@@ -71,7 +79,7 @@ SimilarityStatistics.SimilarityDistributionResults computeSimilarityDistribution
7179
}
7280

7381
return new SimilarityStatistics.SimilarityDistributionResults(
74-
adder.longValue(),
82+
numberOfEntries.orElse(adder.longValue()),
7583
distribution,
7684
statsMillis.get()
7785
);

procedures/pushback-procedures-facade/src/main/java/org/neo4j/gds/procedures/algorithms/similarity/stats/FilteredKnnStatsResultTransformer.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.neo4j.gds.termination.TerminationFlag;
2929

3030
import java.util.Map;
31+
import java.util.OptionalLong;
3132
import java.util.stream.Stream;
3233

3334
public class FilteredKnnStatsResultTransformer implements ResultTransformer<TimedAlgorithmResult<FilteredKnnResult>, Stream<KnnStatsResult>> {
@@ -68,7 +69,8 @@ public Stream<KnnStatsResult> apply(TimedAlgorithmResult<FilteredKnnResult> time
6869
concurrency,
6970
knnResult.similarityResultStream(),
7071
shouldComputeSimilarityDistribution,
71-
terminationFlag
72+
terminationFlag,
73+
OptionalLong.of(knnResult.numberOfSimilarityPairs())
7274
);
7375

7476

@@ -78,7 +80,7 @@ public Stream<KnnStatsResult> apply(TimedAlgorithmResult<FilteredKnnResult> time
7880
timedAlgorithmResult.computeMillis(),
7981
similarityStats.computeMilliseconds(),
8082
knnResult.nodesCompared(),
81-
knnResult.numberOfSimilarityPairs(),
83+
similarityStats.numberOfSimilarityPairs(),
8284
similarityStats.distribution(),
8385
knnResult.didConverge(),
8486
knnResult.ranIterations(),

procedures/pushback-procedures-facade/src/main/java/org/neo4j/gds/procedures/algorithms/similarity/stats/KnnStatsResultTransformer.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.neo4j.gds.termination.TerminationFlag;
2929

3030
import java.util.Map;
31+
import java.util.OptionalLong;
3132
import java.util.stream.Stream;
3233

3334
public class KnnStatsResultTransformer implements ResultTransformer<TimedAlgorithmResult<KnnResult>, Stream<KnnStatsResult>> {
@@ -68,7 +69,8 @@ public Stream<KnnStatsResult> apply(TimedAlgorithmResult<KnnResult> timedAlgorit
6869
concurrency,
6970
knnResult.streamSimilarityResult(),
7071
shouldComputeSimilarityDistribution,
71-
terminationFlag
72+
terminationFlag,
73+
OptionalLong.of(knnResult.totalSimilarityPairs())
7274
);
7375

7476

@@ -78,7 +80,7 @@ public Stream<KnnStatsResult> apply(TimedAlgorithmResult<KnnResult> timedAlgorit
7880
timedAlgorithmResult.computeMillis(),
7981
similarityStats.computeMilliseconds(),
8082
knnResult.nodesCompared(),
81-
knnResult.totalSimilarityPairs(),
83+
similarityStats.numberOfSimilarityPairs(),
8284
similarityStats.distribution(),
8385
knnResult.didConverge(),
8486
knnResult.ranIterations(),

procedures/pushback-procedures-facade/src/main/java/org/neo4j/gds/procedures/algorithms/similarity/stats/NodeSimilarityStatsResultTransformer.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.neo4j.gds.termination.TerminationFlag;
2828

2929
import java.util.Map;
30+
import java.util.OptionalLong;
3031
import java.util.stream.Stream;
3132

3233
public class NodeSimilarityStatsResultTransformer implements ResultTransformer<TimedAlgorithmResult<NodeSimilarityResult>, Stream<SimilarityStatsResult>> {
@@ -65,7 +66,8 @@ public Stream<SimilarityStatsResult> apply(TimedAlgorithmResult<NodeSimilarityRe
6566
result.isStream() ? concurrency : new Concurrency(1),
6667
result.isStream()? result.streamResult() : result.topKMap().stream(),
6768
shouldComputeSimilarityDistribution,
68-
terminationFlag
69+
terminationFlag,
70+
OptionalLong.empty()
6971
);
7072

7173
return Stream.of(
@@ -74,7 +76,7 @@ public Stream<SimilarityStatsResult> apply(TimedAlgorithmResult<NodeSimilarityRe
7476
timedAlgorithmResult.computeMillis(),
7577
similarityStats.computeMilliseconds(),
7678
result.comparedNodes(),
77-
similarityStats.numberOfEntries(),
79+
similarityStats.numberOfSimilarityPairs(),
7880
similarityStats.distribution(),
7981
configuration
8082
)

procedures/pushback-procedures-facade/src/main/java/org/neo4j/gds/procedures/algorithms/similarity/stats/SimilarityStatsTools.java

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
*/
2020
package org.neo4j.gds.procedures.algorithms.similarity.stats;
2121

22-
import org.neo4j.gds.algorithms.similarity.ActualSimilaritySummaryBuilder;
2322
import org.neo4j.gds.algorithms.similarity.SimilarityResultStreamDelegate;
23+
import org.neo4j.gds.algorithms.similarity.SimilaritySummaryBuilderFactory;
2424
import org.neo4j.gds.api.properties.relationships.RelationshipWithPropertyConsumer;
2525
import org.neo4j.gds.core.concurrency.Concurrency;
2626
import org.neo4j.gds.core.utils.ProgressTimer;
@@ -29,6 +29,7 @@
2929
import org.neo4j.gds.termination.TerminationFlag;
3030

3131
import java.util.Map;
32+
import java.util.OptionalLong;
3233
import java.util.concurrent.atomic.AtomicLong;
3334
import java.util.concurrent.atomic.LongAdder;
3435
import java.util.stream.Stream;
@@ -47,20 +48,28 @@ static SimilarityStatistics.SimilarityDistributionResults computeSimilarityDistr
4748
Concurrency concurrency,
4849
Stream<SimilarityResult> similarityResultStream,
4950
boolean shouldComputeSimilarityDistribution,
50-
TerminationFlag terminationFlag
51+
TerminationFlag terminationFlag,
52+
OptionalLong numberOfEntries
5153
) {
52-
if (!shouldComputeSimilarityDistribution) return EMPTY;
5354

55+
if (!shouldComputeSimilarityDistribution && numberOfEntries.isPresent()){
56+
return new SimilarityStatistics.SimilarityDistributionResults(
57+
numberOfEntries.getAsLong(),
58+
Map.of(),
59+
0
60+
);
61+
}
5462
var statsMillis = new AtomicLong();
5563
Map<String,Object> distribution;
56-
var similaritySummaryBuilder = ActualSimilaritySummaryBuilder.create(concurrency);
64+
var similaritySummaryBuilder = SimilaritySummaryBuilderFactory.create(concurrency,shouldComputeSimilarityDistribution);
5765
LongAdder adder = new LongAdder();
5866

5967
RelationshipWithPropertyConsumer relationshipWithPropertyConsumer= (s,t,w)->{
6068
adder.increment();
6169
similaritySummaryBuilder.accept(s,t,w);
6270
return true;
6371
};
72+
6473
var similarityResultStreamDelegate = new SimilarityResultStreamDelegate();
6574
try (var ignored = ProgressTimer.start(statsMillis::set)) {
6675
similarityResultStreamDelegate.consumeStream(
@@ -74,7 +83,7 @@ static SimilarityStatistics.SimilarityDistributionResults computeSimilarityDistr
7483
}
7584

7685
return new SimilarityStatistics.SimilarityDistributionResults(
77-
adder.longValue(),
86+
numberOfEntries.orElse(adder.longValue()),
7887
distribution,
7988
statsMillis.get()
8089
);

procedures/pushback-procedures-facade/src/test/java/org/neo4j/gds/procedures/algorithms/similarity/stats/SimilarityStatsToolsTest.java

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.neo4j.gds.similarity.SimilarityResult;
2626
import org.neo4j.gds.termination.TerminationFlag;
2727

28+
import java.util.OptionalLong;
2829
import java.util.stream.Stream;
2930

3031
import static org.assertj.core.api.Assertions.assertThat;
@@ -39,9 +40,25 @@ void shouldReturnEmptyDistributionIfNotSpecified() {
3940
new Concurrency(1),
4041
Stream.of(new SimilarityResult(0, 1, 10)),
4142
false,
42-
TerminationFlag.RUNNING_TRUE
43+
TerminationFlag.RUNNING_TRUE,
44+
OptionalLong.empty()
4345
);
4446
assertThat(stats.distribution()).isEmpty();
47+
assertThat(stats.numberOfSimilarityPairs()).isEqualTo(1);
48+
}
49+
50+
@Test
51+
void shouldReturnEmptyDistributionIfNotSpecifiedButPairsProvided() {
52+
53+
var stats = SimilarityStatsTools.computeSimilarityDistribution(
54+
new Concurrency(1),
55+
Stream.of(new SimilarityResult(0, 1, 10)),
56+
false,
57+
TerminationFlag.RUNNING_TRUE,
58+
OptionalLong.of(100)
59+
);
60+
assertThat(stats.distribution()).isEmpty();
61+
assertThat(stats.numberOfSimilarityPairs()).isEqualTo(100);
4562
}
4663

4764
@Test
@@ -55,11 +72,32 @@ void shouldReturnValidDistributionIfTrue(){
5572
new SimilarityResult(1,3,9)
5673
),
5774
true,
58-
TerminationFlag.RUNNING_TRUE
75+
TerminationFlag.RUNNING_TRUE,
76+
OptionalLong.empty()
77+
);
78+
assertThat(stats.distribution().get("mean")).asInstanceOf(DOUBLE).isCloseTo(8.0, Offset.offset(1e-3));
79+
assertThat(stats.computeMilliseconds()).isGreaterThanOrEqualTo(0L);
80+
assertThat(stats.numberOfSimilarityPairs()).isEqualTo(3L);
81+
82+
}
83+
84+
@Test
85+
void shouldReturnValueBasedOnProvided(){
86+
87+
var stats = SimilarityStatsTools.computeSimilarityDistribution(
88+
new Concurrency(1),
89+
Stream.of(
90+
new SimilarityResult(0, 1, 10),
91+
new SimilarityResult(1,2,5),
92+
new SimilarityResult(1,3,9)
93+
),
94+
true,
95+
TerminationFlag.RUNNING_TRUE,
96+
OptionalLong.of(100L)
5997
);
6098
assertThat(stats.distribution().get("mean")).asInstanceOf(DOUBLE).isCloseTo(8.0, Offset.offset(1e-3));
6199
assertThat(stats.computeMilliseconds()).isGreaterThanOrEqualTo(0L);
62-
assertThat(stats.numberOfEntries()).isEqualTo(3L);
100+
assertThat(stats.numberOfSimilarityPairs()).isEqualTo(100L);
63101

64102
}
65103

0 commit comments

Comments
 (0)