@@ -23,12 +23,9 @@ def test_build_save_index():
2323 data_loader = DataLoader (params ["unique_id_column" ], params ["feature_columns" ])
2424 (vector_ids , vectors ) = data_loader .convert_df_to_vectors (input_df )
2525 nearest_neighbor = NearestNeighborSearch (num_dimensions = vectors .shape [1 ], ** params )
26- tmp = NamedTemporaryFile ()
27- nearest_neighbor .build_save_index (vectors = vectors , index_path = tmp .name )
28- file_exist = os .path .isfile (tmp .name )
29- tmp .close ()
30- # Test if file is created as a result
31- assert file_exist
26+ with NamedTemporaryFile () as tmp :
27+ nearest_neighbor .build_save_index (vectors = vectors , index_path = tmp .name )
28+ assert os .path .isfile (tmp .name )
3229
3330
3431def test_find_neighbors_df ():
@@ -54,15 +51,14 @@ def test_find_neighbors_df():
5451 data_loader = DataLoader (params ["unique_id_column" ], params ["feature_columns" ])
5552 (vector_ids , vectors ) = data_loader .convert_df_to_vectors (input_df )
5653 nearest_neighbor = NearestNeighborSearch (num_dimensions = vectors .shape [1 ], ** params )
57- tmp = NamedTemporaryFile ()
58- nearest_neighbor .build_save_index (vectors = vectors , index_path = tmp .name )
59- params = {'unique_id_column' : 'images' , 'feature_columns' : ['prediction' ], 'num_neighbors' : 5 }
60- nearest_neighbor = NearestNeighborSearch (** index_config )
61- nearest_neighbor .load_index (tmp .name )
62- # Find nearest neighbors in input dataset
63- df = nearest_neighbor .find_neighbors_df (input_df , ** params , index_vector_ids = vector_ids )
64- actual = sorted (list (df [df ['input_id' ] == '34719_ostrich.jpg' ]['neighbor_id' ]))
65- expected = ['107505_ostrich.jpg' , '185189_ostrich.jpg' , '213657_ostrich.jpg' , '229350_ostrich.jpg' , '34719_ostrich.jpg' ]
66- tmp .close ()
54+ with NamedTemporaryFile () as tmp :
55+ nearest_neighbor .build_save_index (vectors = vectors , index_path = tmp .name )
56+ params = {'unique_id_column' : 'images' , 'feature_columns' : ['prediction' ], 'num_neighbors' : 5 }
57+ nearest_neighbor = NearestNeighborSearch (** index_config )
58+ nearest_neighbor .load_index (tmp .name )
59+ # Find nearest neighbors in input dataset
60+ df = nearest_neighbor .find_neighbors_df (input_df , ** params , index_vector_ids = vector_ids )
61+ actual = sorted (list (df [df ['input_id' ] == '34719_ostrich.jpg' ]['neighbor_id' ]))
62+ expected = ['107505_ostrich.jpg' , '185189_ostrich.jpg' , '213657_ostrich.jpg' , '229350_ostrich.jpg' , '34719_ostrich.jpg' ]
6763 assert len (actual ) == len (expected )
6864 assert all ([actual_item == expected_item for actual_item , expected_item in zip (actual , expected )])
0 commit comments