Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions presidio-analyzer/presidio_analyzer/analyzer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,17 @@ def __init__(

self.context_aware_enhancer = context_aware_enhancer

def get_recognizers(self, language: Optional[str] = None) -> List[EntityRecognizer]:
def get_recognizers(
self,
language: Optional[str] = None,
countries: Optional[List[str]] = None,
) -> List[EntityRecognizer]:
"""
Return a list of PII recognizers currently loaded.

:param language: Return the recognizers supporting a given language.
:param countries: Optional country filter (case-insensitive ISO-3166
alpha-2 codes). Locale-agnostic recognizers are always returned.
:return: List of [Recognizer] as a RecognizersAllResponse
"""
if not language:
Expand All @@ -128,19 +134,27 @@ def get_recognizers(self, language: Optional[str] = None) -> List[EntityRecogniz
for language in languages:
logger.info(f"Fetching all recognizers for language {language}")
recognizers.extend(
self.registry.get_recognizers(language=language, all_fields=True)
self.registry.get_recognizers(
language=language, all_fields=True, countries=countries
)
)

return list(set(recognizers))

def get_supported_entities(self, language: Optional[str] = None) -> List[str]:
def get_supported_entities(
self,
language: Optional[str] = None,
countries: Optional[List[str]] = None,
) -> List[str]:
"""
Return a list of the entities that can be detected.

:param language: Return only entities supported in a specific language.
:param countries: Optional country filter to apply while collecting
supported entities.
:return: List of entity names
"""
recognizers = self.get_recognizers(language=language)
recognizers = self.get_recognizers(language=language, countries=countries)
supported_entities = []
for recognizer in recognizers:
supported_entities.extend(recognizer.get_supported_entities())
Expand All @@ -161,6 +175,7 @@ def analyze(
allow_list_match: Optional[str] = "exact",
regex_flags: Optional[int] = re.DOTALL | re.MULTILINE | re.IGNORECASE,
nlp_artifacts: Optional[NlpArtifacts] = None,
countries: Optional[List[str]] = None,
) -> List[RecognizerResult]:
"""
Find PII entities in text using different PII recognizers for a given language.
Expand All @@ -185,6 +200,9 @@ def analyze(
- if `exact`, results which exactly match any value in the allow_list would be allowed and not be returned as potential PII.
:param regex_flags: regex flags to be used for when allow_list_match is "regex"
:param nlp_artifacts: precomputed NlpArtifacts
:param countries: Optional country filter (case-insensitive ISO-3166
alpha-2 codes). When provided, country-specific recognizers run only
when their country code is included. Locale-agnostic recognizers always run.
:return: an array of the found entities in the text

:Example:
Expand All @@ -210,12 +228,15 @@ def analyze(
entities=entities,
all_fields=all_fields,
ad_hoc_recognizers=ad_hoc_recognizers,
countries=countries,
)

if all_fields:
# Since all_fields=True, list all entities by iterating
# over all recognizers
entities = self.get_supported_entities(language=language)
entities = self.get_supported_entities(
language=language, countries=countries
)

# run the nlp pipeline over the given text, store the results in
# a NlpArtifacts instance
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import logging
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Type, Union
Expand Down Expand Up @@ -164,6 +163,7 @@ def get_recognizers(
entities: Optional[List[str]] = None,
all_fields: bool = False,
ad_hoc_recognizers: Optional[List[EntityRecognizer]] = None,
countries: Optional[Iterable[str]] = None,
) -> List[EntityRecognizer]:
"""
Return a list of recognizers which supports the specified name and language.
Expand All @@ -173,6 +173,10 @@ def get_recognizers(
:param all_fields: a flag to return all fields of a requested language.
:param ad_hoc_recognizers: Additional recognizers provided by the user
as part of the request
:param countries: Optional country filter (case-insensitive ISO-3166
alpha-2 codes). When provided, country-specific recognizers are returned
only when their country code is included. Locale-agnostic recognizers
are always returned.
:return: A list of the recognizers which supports the supplied entities
and language
"""
Expand All @@ -182,9 +186,13 @@ def get_recognizers(
if entities is None and all_fields is False:
raise ValueError("No entities provided")

all_possible_recognizers = copy.copy(self.recognizers)
all_possible_recognizers = list(self.recognizers)
if ad_hoc_recognizers:
all_possible_recognizers.extend(ad_hoc_recognizers)
Comment on lines +189 to 191
if countries is not None:
all_possible_recognizers = RecognizerListLoader.filter_by_countries(
all_possible_recognizers, countries
)
Comment on lines +192 to +195

# filter out unwanted recognizers
to_return = set()
Expand Down Expand Up @@ -370,20 +378,26 @@ def _get_supported_languages(self) -> List[str]:
return list(set(languages))

def get_supported_entities(
self, languages: Optional[List[str]] = None
self,
languages: Optional[List[str]] = None,
countries: Optional[Iterable[str]] = None,
) -> List[str]:
"""
Return the supported entities by the set of recognizers loaded.

:param languages: The languages to get the supported entities for.
If languages=None, returns all entities for all languages.
:param countries: Optional country filter to apply while collecting
supported entities.
"""
if not languages:
languages = self._get_supported_languages()

supported_entities = []
for language in languages:
recognizers = self.get_recognizers(language=language, all_fields=True)
recognizers = self.get_recognizers(
language=language, all_fields=True, countries=countries
)

for recognizer in recognizers:
supported_entities.extend(recognizer.get_supported_entities())
Expand Down
59 changes: 59 additions & 0 deletions presidio-analyzer/tests/test_analyzer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,65 @@ def test_when_get_recognizers_then_returns_supported_language():
assert len(response) == 1


def test_when_get_recognizers_with_countries_then_filters_loaded_recognizers():
us_recognizer = PatternRecognizer(
"ID",
name="US ID",
patterns=[Pattern("us", regex="US-123", score=0.8)],
country_code="us",
)
uk_recognizer = PatternRecognizer(
"ID",
name="UK ID",
patterns=[Pattern("uk", regex="UK-123", score=0.8)],
country_code="uk",
)
generic_recognizer = PatternRecognizer(
"ID",
name="Generic ID",
patterns=[Pattern("generic", regex="ID-123", score=0.8)],
)
registry = RecognizerRegistry(
recognizers=[us_recognizer, uk_recognizer, generic_recognizer],
supported_languages=["en"],
)
analyzer = AnalyzerEngine(registry=registry, nlp_engine=NlpEngineMock())

response = analyzer.get_recognizers(language="en", countries=["us"])

assert {rec.name for rec in response} == {"US ID", "Generic ID"}


def test_when_analyze_with_countries_then_runs_matching_recognizers_only():
us_recognizer = PatternRecognizer(
"US_ID",
patterns=[Pattern("us", regex="US-123", score=0.8)],
country_code="us",
)
uk_recognizer = PatternRecognizer(
"UK_ID",
patterns=[Pattern("uk", regex="UK-123", score=0.8)],
country_code="uk",
)
generic_recognizer = PatternRecognizer(
"GENERIC_ID",
patterns=[Pattern("generic", regex="ID-123", score=0.8)],
)
registry = RecognizerRegistry(
recognizers=[us_recognizer, uk_recognizer, generic_recognizer],
supported_languages=["en"],
)
analyzer = AnalyzerEngine(registry=registry, nlp_engine=NlpEngineMock())

results = analyzer.analyze(
text="US-123 UK-123 ID-123",
language="en",
countries=["us"],
)

assert {result.entity_type for result in results} == {"US_ID", "GENERIC_ID"}


def test_when_add_recognizer_then_also_outputs_others(spacy_nlp_engine):
pattern = Pattern("rocket pattern", r"\W*(rocket)\W*", 0.8)
pattern_recognizer = PatternRecognizer(
Expand Down
50 changes: 50 additions & 0 deletions presidio-analyzer/tests/test_recognizer_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,56 @@ def test_get_country_codes_after_country_filter():
assert registry.get_country_codes() == ["us"]


def test_get_recognizers_filters_by_countries_at_request_time():
"""``get_recognizers(countries=...)`` filters an already-loaded registry."""
us_recognizer = PatternRecognizer(
supported_entity="ID",
name="US ID",
patterns=[Pattern("us", regex="US-123", score=0.8)],
country_code="us",
)
uk_recognizer = PatternRecognizer(
supported_entity="ID",
name="UK ID",
patterns=[Pattern("uk", regex="UK-123", score=0.8)],
country_code="uk",
)
generic_recognizer = PatternRecognizer(
supported_entity="ID",
name="Generic ID",
patterns=[Pattern("generic", regex="ID-123", score=0.8)],
)
registry = RecognizerRegistry(
recognizers=[us_recognizer, uk_recognizer, generic_recognizer]
)

unfiltered = registry.get_recognizers(language="en", entities=["ID"])
us_filtered = registry.get_recognizers(
language="en", entities=["ID"], countries=["US"]
)
country_agnostic_only = registry.get_recognizers(
language="en", entities=["ID"], countries=[]
)

assert {rec.name for rec in unfiltered} == {"US ID", "UK ID", "Generic ID"}
assert {rec.name for rec in us_filtered} == {"US ID", "Generic ID"}
assert {rec.name for rec in country_agnostic_only} == {"Generic ID"}


def test_get_recognizers_supports_non_list_iterable_with_ad_hoc_recognizers():
tuple_recognizer = create_mock_pattern_recognizer("en", "PERSON", "tuple")
ad_hoc_recognizer = create_mock_pattern_recognizer("en", "PERSON", "ad hoc")
registry = RecognizerRegistry(recognizers=(tuple_recognizer,))

recognizers = registry.get_recognizers(
language="en",
entities=["PERSON"],
ad_hoc_recognizers=[ad_hoc_recognizer],
)

assert {rec.name for rec in recognizers} == {"tuple", "ad hoc"}


def test_get_country_codes_excludes_locale_agnostic_recognizers():
"""Filter out locale-agnostic recognizers from the country code report.

Expand Down
Loading