Skip to content

Commit 63f8037

Browse files
Merge pull request #3 from dataiku/feature/testing-release-0.1.0
Feature/testing release 0.1.0
2 parents e165e87 + 6e8b46b commit 63f8037

3 files changed

Lines changed: 186 additions & 0 deletions

File tree

tests/python/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pandas>=1.0,<1.1
2+
pytest==6.1.0

tests/python/unit/test_base.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import pandas as pd
2+
import os.path
3+
4+
from data_loader import DataLoader
5+
from tempfile import NamedTemporaryFile
6+
from nearest_neighbor.base import NearestNeighborSearch
7+
8+
9+
def test_build_save_index():
10+
11+
params = {'unique_id_column': 'images',
12+
'feature_columns': ['prediction'],
13+
'algorithm': 'annoy',
14+
'expert': True,
15+
'annoy_metric': 'angular',
16+
'annoy_num_trees': 10}
17+
18+
# Load data into vector format for indexing
19+
columns = [params["unique_id_column"]] + params["feature_columns"]
20+
input_df = pd.read_csv('./tests/resources/caltech_embeddings.csv')
21+
# Restrict to selected columns
22+
input_df = input_df[columns]
23+
data_loader = DataLoader(params["unique_id_column"], params["feature_columns"])
24+
(vector_ids, vectors) = data_loader.convert_df_to_vectors(input_df)
25+
nearest_neighbor = NearestNeighborSearch(num_dimensions=vectors.shape[1], **params)
26+
with NamedTemporaryFile() as tmp:
27+
nearest_neighbor.build_save_index(vectors=vectors, index_path=tmp.name)
28+
assert os.path.isfile(tmp.name)
29+
30+
31+
def test_find_neighbors_df():
32+
33+
params = {'unique_id_column': 'images',
34+
'feature_columns': ['prediction'],
35+
'algorithm': 'annoy',
36+
'expert': True,
37+
'annoy_metric': 'angular',
38+
'annoy_num_trees': 10}
39+
40+
index_config = {'algorithm': 'annoy',
41+
'num_dimensions': 2048,
42+
'annoy_metric': 'angular',
43+
'annoy_num_trees': 10,
44+
'feature_columns': ['prediction'],
45+
'expert': True}
46+
47+
# Load data into vector format for indexing
48+
columns = [params["unique_id_column"]] + params["feature_columns"]
49+
input_df = pd.read_csv('./tests/resources/caltech_embeddings.csv')
50+
input_df = input_df[columns]
51+
data_loader = DataLoader(params["unique_id_column"], params["feature_columns"])
52+
(vector_ids, vectors) = data_loader.convert_df_to_vectors(input_df)
53+
nearest_neighbor = NearestNeighborSearch(num_dimensions=vectors.shape[1], **params)
54+
with NamedTemporaryFile() as tmp:
55+
nearest_neighbor.build_save_index(vectors=vectors, index_path=tmp.name)
56+
params = {'unique_id_column': 'images', 'feature_columns': ['prediction'], 'num_neighbors': 5}
57+
nearest_neighbor = NearestNeighborSearch(**index_config)
58+
nearest_neighbor.load_index(tmp.name)
59+
# Find nearest neighbors in input dataset
60+
df = nearest_neighbor.find_neighbors_df(input_df, **params, index_vector_ids=vector_ids)
61+
actual = sorted(list(df[df['input_id'] == '34719_ostrich.jpg']['neighbor_id']))
62+
expected = ['107505_ostrich.jpg', '185189_ostrich.jpg', '213657_ostrich.jpg', '229350_ostrich.jpg', '34719_ostrich.jpg']
63+
assert len(actual) == len(expected)
64+
assert all([actual_item == expected_item for actual_item, expected_item in zip(actual, expected)])

0 commit comments

Comments
 (0)