-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
feat: add LiteLLM as AI gateway provider #514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
RheagalFire
wants to merge
6
commits into
AsyncFuncAI:main
Choose a base branch
from
RheagalFire:feat/add-litellm-provider
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
2461d81
feat: add LiteLLM as AI gateway provider
RheagalFire fcf1fa8
fix: wire LiteLLM into serving paths, fix retry logic, add async/stre…
RheagalFire 0e3789d
fix: address review feedback
RheagalFire d543ac6
fix: guard backoff import, fix duplicate kwargs, use specific retry e…
RheagalFire 68d97fd
fix: fix 2 failing tests - sync_client assertion and retryable predicate
RheagalFire f942064
fix: redact API key in to_dict, align parse_chat_completion with Open…
RheagalFire File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,234 @@ | ||
| """LiteLLM ModelClient integration. | ||
|
|
||
| Routes to 100+ LLM providers via litellm.completion(). | ||
| Provider API keys are read from environment variables automatically | ||
| (OPENAI_API_KEY, ANTHROPIC_API_KEY, AWS_ACCESS_KEY_ID, GEMINI_API_KEY, etc.). | ||
|
|
||
| Model names use LiteLLM format: "provider/model-name", e.g.: | ||
| anthropic/claude-sonnet-4-20250514, openai/gpt-4o, | ||
| bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0 | ||
|
|
||
| See https://docs.litellm.ai/docs/providers for the full list. | ||
| """ | ||
|
|
||
| import logging | ||
| from typing import ( | ||
| Any, | ||
| Callable, | ||
| Dict, | ||
| List, | ||
| Optional, | ||
| Sequence, | ||
| TypeVar, | ||
| Union, | ||
| ) | ||
|
|
||
| import backoff | ||
|
|
||
| from adalflow.core.model_client import ModelClient | ||
| from adalflow.core.types import ( | ||
| CompletionUsage, | ||
| EmbedderOutput, | ||
| GeneratorOutput, | ||
| ModelType, | ||
| ) | ||
| from adalflow.components.model_client.utils import parse_embedding_response | ||
|
|
||
| log = logging.getLogger(__name__) | ||
| T = TypeVar("T") | ||
|
|
||
|
|
||
| def get_first_message_content(completion) -> str: | ||
| return completion.choices[0].message.content | ||
|
|
||
|
|
||
| class LiteLLMClient(ModelClient): | ||
| __doc__ = r"""A component wrapper for the LiteLLM AI gateway. | ||
|
|
||
| LiteLLM routes to 100+ LLM providers (OpenAI, Anthropic, Google, AWS Bedrock, | ||
| Azure, Ollama, etc.) through a single unified interface. Provider API keys are | ||
| read from environment variables automatically. | ||
|
|
||
| Model names use LiteLLM format: ``provider/model-name``. | ||
|
|
||
| Example: | ||
| ```python | ||
| from api.litellm_client import LiteLLMClient | ||
| import adalflow as adal | ||
|
|
||
| client = LiteLLMClient() | ||
| generator = adal.Generator( | ||
| model_client=client, | ||
| model_kwargs={"model": "anthropic/claude-sonnet-4-20250514"} | ||
| ) | ||
| response = generator(prompt_kwargs={"input_str": "What is LLM?"}) | ||
| ``` | ||
|
|
||
| Args: | ||
| api_key (Optional[str]): API key for the provider. If not provided, | ||
| LiteLLM reads from the provider's standard env var (e.g. ANTHROPIC_API_KEY). | ||
| base_url (Optional[str]): Custom API base URL (e.g. for LiteLLM proxy server). | ||
| chat_completion_parser: A function to parse the chat completion response. | ||
| Defaults to extracting the first message content. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| api_key: Optional[str] = None, | ||
| base_url: Optional[str] = None, | ||
| chat_completion_parser: Callable = None, | ||
| ): | ||
| super().__init__() | ||
| self._api_key = api_key | ||
| self._base_url = base_url | ||
| self.chat_completion_parser = ( | ||
| chat_completion_parser or get_first_message_content | ||
| ) | ||
| self.sync_client = self.init_sync_client() | ||
| self.async_client = None | ||
|
|
||
| def init_sync_client(self): | ||
| return {"api_key": self._api_key, "base_url": self._base_url} | ||
|
|
||
| def init_async_client(self): | ||
| return {"api_key": self._api_key, "base_url": self._base_url} | ||
|
|
||
| def convert_inputs_to_api_kwargs( | ||
| self, | ||
| input: Optional[Any] = None, | ||
| model_kwargs: Dict = {}, | ||
| model_type: ModelType = ModelType.UNDEFINED, | ||
| ) -> Dict: | ||
| final_model_kwargs = model_kwargs.copy() | ||
|
|
||
| if model_type == ModelType.EMBEDDER: | ||
| if isinstance(input, str): | ||
| input = [input] | ||
| if not isinstance(input, Sequence): | ||
| raise TypeError("input must be a sequence of text") | ||
| final_model_kwargs["input"] = input | ||
| elif model_type == ModelType.LLM: | ||
| messages: List[Dict[str, str]] = [] | ||
| if isinstance(input, str): | ||
| messages.append({"role": "user", "content": input}) | ||
| elif isinstance(input, list) and all(isinstance(m, dict) for m in input): | ||
| messages = input | ||
| else: | ||
| messages.append({"role": "user", "content": str(input)}) | ||
| final_model_kwargs["messages"] = messages | ||
| else: | ||
| raise ValueError(f"model_type {model_type} is not supported") | ||
|
|
||
| return final_model_kwargs | ||
|
|
||
| def parse_chat_completion(self, completion) -> GeneratorOutput: | ||
| try: | ||
| data = self.chat_completion_parser(completion) | ||
| except Exception as e: | ||
| log.error(f"Error parsing the completion: {e}") | ||
| return GeneratorOutput(data=None, error=str(e), raw_response=completion) | ||
|
|
||
| try: | ||
| usage = self.track_completion_usage(completion) | ||
| return GeneratorOutput( | ||
| data=None, error=None, raw_response=data, usage=usage | ||
| ) | ||
| except Exception as e: | ||
| log.error(f"Error tracking the completion usage: {e}") | ||
| return GeneratorOutput(data=None, error=str(e), raw_response=data) | ||
|
RheagalFire marked this conversation as resolved.
Outdated
|
||
|
|
||
| def track_completion_usage(self, completion) -> CompletionUsage: | ||
| try: | ||
| return CompletionUsage( | ||
| completion_tokens=completion.usage.completion_tokens, | ||
| prompt_tokens=completion.usage.prompt_tokens, | ||
| total_tokens=completion.usage.total_tokens, | ||
| ) | ||
| except Exception as e: | ||
| log.error(f"Error tracking the completion usage: {e}") | ||
| return CompletionUsage( | ||
| completion_tokens=None, prompt_tokens=None, total_tokens=None | ||
| ) | ||
|
|
||
| def parse_embedding_response(self, response) -> EmbedderOutput: | ||
| try: | ||
| return parse_embedding_response(response) | ||
| except Exception as e: | ||
| log.error(f"Error parsing the embedding response: {e}") | ||
| return EmbedderOutput(data=[], error=str(e), raw_response=response) | ||
|
|
||
| @backoff.on_exception(backoff.expo, Exception, max_time=5) | ||
|
RheagalFire marked this conversation as resolved.
Outdated
|
||
| def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): | ||
| import litellm | ||
|
|
||
| log.info(f"api_kwargs: {api_kwargs}") | ||
|
RheagalFire marked this conversation as resolved.
Outdated
|
||
|
|
||
| extra = {} | ||
| if self._api_key: | ||
| extra["api_key"] = self._api_key | ||
| if self._base_url: | ||
| extra["api_base"] = self._base_url | ||
|
|
||
| if model_type == ModelType.EMBEDDER: | ||
| return litellm.embedding( | ||
| drop_params=True, | ||
| **api_kwargs, | ||
| **extra, | ||
| ) | ||
| elif model_type == ModelType.LLM: | ||
| if api_kwargs.get("stream", False): | ||
| self.chat_completion_parser = _handle_streaming_response | ||
| return litellm.completion( | ||
| drop_params=True, | ||
| **api_kwargs, | ||
| **extra, | ||
| ) | ||
| else: | ||
| return litellm.completion( | ||
| drop_params=True, | ||
| **api_kwargs, | ||
| **extra, | ||
| ) | ||
|
RheagalFire marked this conversation as resolved.
Outdated
|
||
| else: | ||
| raise ValueError(f"model_type {model_type} is not supported") | ||
|
|
||
| async def acall( | ||
| self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED | ||
| ): | ||
| import litellm | ||
|
|
||
| extra = {} | ||
| if self._api_key: | ||
| extra["api_key"] = self._api_key | ||
| if self._base_url: | ||
| extra["api_base"] = self._base_url | ||
|
|
||
| if model_type == ModelType.EMBEDDER: | ||
| return await litellm.aembedding( | ||
| drop_params=True, | ||
| **api_kwargs, | ||
| **extra, | ||
| ) | ||
| elif model_type == ModelType.LLM: | ||
| return await litellm.acompletion( | ||
| drop_params=True, | ||
| **api_kwargs, | ||
| **extra, | ||
| ) | ||
| else: | ||
| raise ValueError(f"model_type {model_type} is not supported") | ||
|
|
||
| @classmethod | ||
| def from_dict(cls, data: Dict[str, Any]): | ||
| return cls(**data) | ||
|
|
||
| def to_dict(self) -> Dict[str, Any]: | ||
| exclude = ["sync_client", "async_client"] | ||
| output = super().to_dict(exclude=exclude) | ||
| return output | ||
|
|
||
|
|
||
| def _handle_streaming_response(generator): | ||
| for completion in generator: | ||
| if completion.choices and completion.choices[0].delta.content: | ||
| yield completion.choices[0].delta.content | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.