Skip to content

Commit 6855f43

Browse files
committed
fix: refactor dataset cleaning logic and add unit tests for consistency
1 parent 619b00e commit 6855f43

2 files changed

Lines changed: 59 additions & 24 deletions

File tree

src/pyjedai/datamodel.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -314,32 +314,37 @@ def clean_dataset(self,
314314
remove_unicodes: bool = True) -> None:
315315
"""Removes stopwords, punctuation, uni-codes, numbers from the dataset.
316316
"""
317-
nltk.download('stopwords')
317+
stop_words = None
318+
if remove_stopwords:
319+
nltk.download('stopwords')
320+
stop_words = set(stopwords.words('english'))
321+
322+
def _clean_dataframe(dataframe: DataFrame, columns: list) -> DataFrame:
323+
if not columns:
324+
return dataframe
325+
326+
cleaned_columns = dataframe.loc[:, columns].applymap(lambda x: x.lower())
327+
328+
if remove_numbers:
329+
cleaned_columns = cleaned_columns.applymap(lambda x: re.sub(r'\d+', '', x))
318330

319-
# Make self.dataset_1 and self.dataset_2 lowercase
320-
self.dataset_1 = self.dataset_1.applymap(lambda x: x.lower())
331+
if remove_unicodes:
332+
cleaned_columns = cleaned_columns.applymap(lambda x: re.sub(r'[^\x00-\x7F]+', '', x))
333+
334+
if remove_punctuation:
335+
cleaned_columns = cleaned_columns.applymap(lambda x: re.sub(r'[^\w\s]','',x))
336+
337+
if remove_stopwords:
338+
cleaned_columns = cleaned_columns.applymap(
339+
lambda x: ' '.join([word for word in x.split() if word not in stop_words])
340+
)
341+
342+
dataframe.loc[:, columns] = cleaned_columns
343+
return dataframe
344+
345+
self.dataset_1 = _clean_dataframe(self.dataset_1, self.attributes_1)
321346
if not self.is_dirty_er:
322-
self.dataset_2 = self.dataset_2.applymap(lambda x: x.lower())
323-
324-
if remove_numbers:
325-
self.dataset_1 = self.dataset_1.applymap(lambda x: re.sub(r'\d+', '', x))
326-
if not self.is_dirty_er:
327-
self.dataset_2 = self.dataset_2.applymap(lambda x: re.sub(r'\d+', '', x))
328-
329-
if remove_unicodes:
330-
self.dataset_1 = self.dataset_1.applymap(lambda x: re.sub(r'[^\x00-\x7F]+', '', x))
331-
if not self.is_dirty_er:
332-
self.dataset_2 = self.dataset_2.applymap(lambda x: re.sub(r'[^\x00-\x7F]+', '', x))
333-
334-
if remove_punctuation:
335-
self.dataset_1 = self.dataset_1.applymap(lambda x: re.sub(r'[^\w\s]','',x))
336-
if not self.is_dirty_er:
337-
self.dataset_2 = self.dataset_2.applymap(lambda x: re.sub(r'[^\w\s]','',x))
338-
339-
if remove_stopwords:
340-
self.dataset_1 = self.dataset_1.applymap(lambda x: ' '.join([word for word in x.split() if word not in (stopwords.words('english'))]))
341-
if not self.is_dirty_er:
342-
self.dataset_2 = self.dataset_2.applymap(lambda x: ' '.join([word for word in x.split() if word not in (stopwords.words('english'))]))
347+
self.dataset_2 = _clean_dataframe(self.dataset_2, self.attributes_2)
343348

344349
self.entities = self.dataset_1 = self.dataset_1.astype(str)
345350

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import os
2+
import sys
3+
4+
import pandas as pd
5+
6+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../src'))
7+
8+
from pyjedai.datamodel import Data
9+
10+
11+
def test_clean_dataset_keeps_identifier_columns_and_cached_mappings_consistent():
12+
data = Data(
13+
dataset_1=pd.DataFrame({"id": ["A-1"], "name": ["Alice-1"]}),
14+
id_column_name_1="id",
15+
ground_truth=pd.DataFrame([["A-1", "A-1"]]),
16+
)
17+
18+
data.clean_dataset(
19+
remove_numbers=True,
20+
remove_punctuation=True,
21+
remove_stopwords=False,
22+
remove_unicodes=False,
23+
)
24+
25+
assert data.dataset_1.loc[0, "id"] == "A-1"
26+
assert data.dataset_1.loc[0, "name"] == "alice"
27+
assert list(data._ids_mapping_1.keys()) == ["A-1"]
28+
assert list(data.duplicate_of.keys()) == ["A-1"]
29+
assert data._ids_mapping_1["A-1"] == 0
30+
assert data._are_true_positives("A-1", "A-1")

0 commit comments

Comments
 (0)