diff --git a/packages/graphrag/graphrag/index/operations/summarize_communities/community_reports_extractor.py b/packages/graphrag/graphrag/index/operations/summarize_communities/community_reports_extractor.py index 4513abe95b..efae3cb82e 100644 --- a/packages/graphrag/graphrag/index/operations/summarize_communities/community_reports_extractor.py +++ b/packages/graphrag/graphrag/index/operations/summarize_communities/community_reports_extractor.py @@ -49,15 +49,35 @@ class CommunityReportsResult: structured_output: CommunityReportResponse | None -class CommunityReportsExtractor: - """Community reports extractor class definition.""" +def _is_unsupported_response_format_error(e: Exception) -> bool: + """Check if the error is due to unsupported response_format.""" + msg = str(e).lower() + return any( + phrase in msg + for phrase in [ + "response_format", + "json_schema", + "unsupported", + "unavailable", + ] + ) + - _model: "LLMCompletion" - _extraction_prompt: str - _output_formatter_prompt: str - _on_error: ErrorHandlerFn - _max_report_length: int +def _parse_json_from_text( + text: str, model: type[CommunityReportResponse] +) -> CommunityReportResponse: + """Extract and parse JSON from LLM text output.""" + text = text or "" + # Try code block first + m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL) + if not m: + m = re.search(r"(\{.*\})", text, re.DOTALL) + json_text = m.group(1).strip() if m else text.strip() + data = json.loads(json_text) + return model(**data) + +class CommunityReportsExtractor: def __init__( self, model: "LLMCompletion", @@ -65,26 +85,61 @@ def __init__( max_report_length: int, on_error: ErrorHandlerFn | None = None, ): - """Init method definition.""" self._model = model self._extraction_prompt = extraction_prompt self._on_error = on_error or (lambda _e, _s, _d: None) self._max_report_length = max_report_length async def __call__(self, input_text: str): - """Call method definition.""" output = None try: prompt = self._extraction_prompt.format(**{ INPUT_TEXT_KEY: input_text, MAX_LENGTH_KEY: str(self._max_report_length), }) - response = await self._model.completion_async( - messages=prompt, - response_format=CommunityReportResponse, # A model is required when using json mode - ) - output = response.formatted_response # type: ignore + # Strategy 1: Try structured output with Pydantic model (json_schema) + try: + response = await self._model.completion_async( + messages=prompt, + response_format=CommunityReportResponse, + ) + output = response.formatted_response + except Exception as schema_error: + if not _is_unsupported_response_format_error(schema_error): + raise + + logger.warning( + "json_schema not supported by provider, " + "falling back to json_object mode" + ) + + # Strategy 2: Fallback to json_object + try: + response = await self._model.completion_async( + messages=prompt, + response_format={"type": "json_object"}, + ) + output = _parse_json_from_text( + response.content, CommunityReportResponse + ) + except Exception as json_error: + if not _is_unsupported_response_format_error(json_error): + raise + + logger.warning( + "json_object not supported by provider, " + "falling back to plain text parsing" + ) + + # Strategy 3: Final fallback to plain text + response = await self._model.completion_async( + messages=prompt, + ) + output = _parse_json_from_text( + response.content, CommunityReportResponse + ) + except Exception as e: logger.exception("error generating community report") self._on_error(e, traceback.format_exc(), None)