Skip to content

Commit 78563a4

Browse files
committed
Add test coverage
1 parent 088053b commit 78563a4

2 files changed

Lines changed: 128 additions & 6 deletions

File tree

tests/python/unit/python-lib/ner/test_flair.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
# -*- coding: utf-8 -*-
12
import json
23

34
import pandas as pd
5+
import pytest
6+
from flair.models import SequenceTagger
47

58
from ner.constants import (
69
COLUMN_PER_ENTITY_FORMAT,
710
JSON_KEY_PER_ENTITY_FORMAT,
811
JSON_LABELING_FORMAT
912
)
10-
from ner.flair import extract_entities
13+
from ner.flair import extract_entities, get_model
1114

1215
TEST_SENTENCE = "Mark Zuckerberg is one of the founders of Facebook, a company from the United States"
1316

@@ -42,6 +45,58 @@ def test_extract_entities():
4245
'ORG': ['Facebook'],
4346
'LOC': ['United States'],
4447
})
45-
4648
})
47-
)
49+
)
50+
51+
def test_extract_entities_empty_text():
52+
df = pd.DataFrame({'text': ['']})
53+
result = extract_entities(df['text'], JSON_LABELING_FORMAT, "en")
54+
assert len(result) == 1
55+
assert result['sentence'].iloc[0] == ''
56+
assert json.loads(result['entities'].iloc[0]) == []
57+
58+
def test_extract_entities_no_entities():
59+
df = pd.DataFrame({'text': ['Hello world, this is a simple test.']})
60+
result = extract_entities(df['text'], JSON_LABELING_FORMAT, "en")
61+
assert len(result) == 1
62+
assert result['sentence'].iloc[0] == 'Hello world, this is a simple test.'
63+
assert json.loads(result['entities'].iloc[0]) == []
64+
65+
def test_extract_entities_unicode():
66+
unicode_text = 'Müller works at Nestlé in Zürich.'
67+
df = pd.DataFrame({'text': [unicode_text]})
68+
result = extract_entities(df['text'], JSON_LABELING_FORMAT, "en")
69+
assert len(result) == 1
70+
# Verify unicode text is preserved correctly
71+
assert result['sentence'].iloc[0] == unicode_text
72+
# Verify valid JSON output (no encoding errors)
73+
entities = json.loads(result['entities'].iloc[0])
74+
assert isinstance(entities, list)
75+
76+
def test_extract_entities_multiple_same_type():
77+
df = pd.DataFrame({'text': ['John and Mary went to Paris and London.']})
78+
result = extract_entities(df['text'], COLUMN_PER_ENTITY_FORMAT, "en")
79+
assert len(result) == 1
80+
# Should have multiple entities, check that PER or LOC columns exist with arrays
81+
if 'PER' in result.columns:
82+
per_entities = json.loads(result['PER'].iloc[0])
83+
assert isinstance(per_entities, list)
84+
if 'LOC' in result.columns:
85+
loc_entities = json.loads(result['LOC'].iloc[0])
86+
assert isinstance(loc_entities, list)
87+
88+
def test_extract_entities_multiple_rows():
89+
df = pd.DataFrame({'text': [
90+
'Apple is based in California.',
91+
'Microsoft was founded by Bill Gates.'
92+
]})
93+
result = extract_entities(df['text'], JSON_LABELING_FORMAT, "en")
94+
assert len(result) == 2
95+
96+
def test_get_model_legacy_mapping():
97+
model = get_model("en")
98+
assert isinstance(model, SequenceTagger)
99+
100+
def test_get_model_invalid_id():
101+
with pytest.raises(KeyError):
102+
get_model("invalid_language_code")

tests/python/unit/python-lib/ner/test_spacy.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
# -*- coding: utf-8 -*-
12
import json
3+
24
import pandas as pd
5+
import pytest
6+
import spacy
37

48
from ner.constants import (
59
COLUMN_PER_ENTITY_FORMAT,
610
JSON_KEY_PER_ENTITY_FORMAT,
711
JSON_LABELING_FORMAT
812
)
9-
from ner.spacy import extract_entities
10-
13+
from ner.spacy import extract_entities, get_model, SPACY_LANGUAGE_MODELS_LEGACY_MAPPING
1114

1215
TEST_SENTENCE = "Mark Zuckerberg is one of the founders of Facebook, a company from the United States"
1316

@@ -43,4 +46,68 @@ def test_extract_entities():
4346
'GPE': ['the United States']
4447
})
4548
})
46-
)
49+
)
50+
51+
def test_extract_entities_empty_text():
52+
df = pd.DataFrame({'text': ['']})
53+
result = extract_entities(df['text'], JSON_LABELING_FORMAT, "en")
54+
assert len(result) == 1
55+
assert result['sentence'].iloc[0] == ''
56+
assert json.loads(result['entities'].iloc[0]) == []
57+
58+
def test_extract_entities_no_entities():
59+
df = pd.DataFrame({'text': ['Hello world, this is a simple test.']})
60+
result = extract_entities(df['text'], JSON_LABELING_FORMAT, "en")
61+
assert len(result) == 1
62+
assert result['sentence'].iloc[0] == 'Hello world, this is a simple test.'
63+
assert json.loads(result['entities'].iloc[0]) == []
64+
65+
def test_extract_entities_unicode():
66+
unicode_text = 'Müller works at Nestlé in Zürich.'
67+
df = pd.DataFrame({'text': [unicode_text]})
68+
result = extract_entities(df['text'], JSON_LABELING_FORMAT, "en")
69+
assert len(result) == 1
70+
# Verify unicode text is preserved correctly
71+
assert result['sentence'].iloc[0] == unicode_text
72+
# Verify valid JSON output (no encoding errors)
73+
entities = json.loads(result['entities'].iloc[0])
74+
assert isinstance(entities, list)
75+
76+
def test_extract_entities_multiple_same_type():
77+
df = pd.DataFrame({'text': ['John and Mary went to Paris and London.']})
78+
result = extract_entities(df['text'], COLUMN_PER_ENTITY_FORMAT, "en")
79+
assert len(result) == 1
80+
# Should have multiple entities, check that PERSON or GPE columns exist with arrays
81+
if 'PERSON' in result.columns:
82+
person_entities = json.loads(result['PERSON'].iloc[0])
83+
assert isinstance(person_entities, list)
84+
if 'GPE' in result.columns:
85+
gpe_entities = json.loads(result['GPE'].iloc[0])
86+
assert isinstance(gpe_entities, list)
87+
88+
def test_extract_entities_multiple_rows():
89+
df = pd.DataFrame({'text': [
90+
'Apple is based in California.',
91+
'Microsoft was founded by Bill Gates.'
92+
]})
93+
result = extract_entities(df['text'], JSON_LABELING_FORMAT, "en")
94+
assert len(result) == 2
95+
96+
def test_get_model_english():
97+
model = get_model("en")
98+
assert model is not None
99+
assert hasattr(model, 'pipe')
100+
101+
def test_get_model_french():
102+
model = get_model("fr")
103+
assert model is not None
104+
assert hasattr(model, 'pipe')
105+
106+
def test_get_model_invalid_id():
107+
with pytest.raises(KeyError):
108+
get_model("invalid_language_code")
109+
110+
def test_language_models_mapping_completeness():
111+
expected_languages = ["en", "es", "zh", "pl", "fr", "de", "ja", "nb"]
112+
for lang in expected_languages:
113+
assert lang in SPACY_LANGUAGE_MODELS_LEGACY_MAPPING

0 commit comments

Comments
 (0)