Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from api.google_embedder_client import GoogleEmbedderClient
from api.azureai_client import AzureAIClient
from api.dashscope_client import DashscopeClient
from api.litellm_client import LiteLLMClient
from adalflow import GoogleGenAIClient, OllamaClient

# Get API keys from environment variables
Expand Down Expand Up @@ -63,7 +64,8 @@
"OllamaClient": OllamaClient,
"BedrockClient": BedrockClient,
"AzureAIClient": AzureAIClient,
"DashscopeClient": DashscopeClient
"DashscopeClient": DashscopeClient,
"LiteLLMClient": LiteLLMClient,
}

def replace_env_placeholders(config: Union[Dict[str, Any], List[Any], str, Any]) -> Union[Dict[str, Any], List[Any], str, Any]:
Expand Down Expand Up @@ -131,15 +133,16 @@ def load_generator_config():
if provider_config.get("client_class") in CLIENT_CLASSES:
provider_config["model_client"] = CLIENT_CLASSES[provider_config["client_class"]]
# Fall back to default mapping based on provider_id
elif provider_id in ["google", "openai", "openrouter", "ollama", "bedrock", "azure", "dashscope"]:
elif provider_id in ["google", "openai", "openrouter", "ollama", "bedrock", "azure", "dashscope", "litellm"]:
default_map = {
"google": GoogleGenAIClient,
"openai": OpenAIClient,
"openrouter": OpenRouterClient,
"ollama": OllamaClient,
"bedrock": BedrockClient,
"azure": AzureAIClient,
"dashscope": DashscopeClient
"dashscope": DashscopeClient,
"litellm": LiteLLMClient,
}
Comment thread
RheagalFire marked this conversation as resolved.
Outdated
provider_config["model_client"] = default_map[provider_id]
else:
Expand Down
15 changes: 15 additions & 0 deletions api/config/generator.json
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,21 @@
"top_p": 0.8
}
}
},
"litellm": {
"client_class": "LiteLLMClient",
"default_model": "openai/gpt-4o",
"supportsCustomModel": true,
"models": {
"openai/gpt-4o": {
"temperature": 0.7,
"top_p": 0.8
},
"anthropic/claude-sonnet-4-20250514": {
"temperature": 0.7,
"top_p": 0.8
}
}
}
}
}
Expand Down
234 changes: 234 additions & 0 deletions api/litellm_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
"""LiteLLM ModelClient integration.

Routes to 100+ LLM providers via litellm.completion().
Provider API keys are read from environment variables automatically
(OPENAI_API_KEY, ANTHROPIC_API_KEY, AWS_ACCESS_KEY_ID, GEMINI_API_KEY, etc.).

Model names use LiteLLM format: "provider/model-name", e.g.:
anthropic/claude-sonnet-4-20250514, openai/gpt-4o,
bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0

See https://docs.litellm.ai/docs/providers for the full list.
"""

import logging
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
TypeVar,
Union,
)

import backoff

from adalflow.core.model_client import ModelClient
from adalflow.core.types import (
CompletionUsage,
EmbedderOutput,
GeneratorOutput,
ModelType,
)
from adalflow.components.model_client.utils import parse_embedding_response

log = logging.getLogger(__name__)
T = TypeVar("T")


def get_first_message_content(completion) -> str:
return completion.choices[0].message.content


class LiteLLMClient(ModelClient):
__doc__ = r"""A component wrapper for the LiteLLM AI gateway.

LiteLLM routes to 100+ LLM providers (OpenAI, Anthropic, Google, AWS Bedrock,
Azure, Ollama, etc.) through a single unified interface. Provider API keys are
read from environment variables automatically.

Model names use LiteLLM format: ``provider/model-name``.

Example:
```python
from api.litellm_client import LiteLLMClient
import adalflow as adal

client = LiteLLMClient()
generator = adal.Generator(
model_client=client,
model_kwargs={"model": "anthropic/claude-sonnet-4-20250514"}
)
response = generator(prompt_kwargs={"input_str": "What is LLM?"})
```

Args:
api_key (Optional[str]): API key for the provider. If not provided,
LiteLLM reads from the provider's standard env var (e.g. ANTHROPIC_API_KEY).
base_url (Optional[str]): Custom API base URL (e.g. for LiteLLM proxy server).
chat_completion_parser: A function to parse the chat completion response.
Defaults to extracting the first message content.
"""

def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
chat_completion_parser: Callable = None,
):
super().__init__()
self._api_key = api_key
self._base_url = base_url
self.chat_completion_parser = (
chat_completion_parser or get_first_message_content
)
self.sync_client = self.init_sync_client()
self.async_client = None

def init_sync_client(self):
return {"api_key": self._api_key, "base_url": self._base_url}

def init_async_client(self):
return {"api_key": self._api_key, "base_url": self._base_url}

def convert_inputs_to_api_kwargs(
self,
input: Optional[Any] = None,
model_kwargs: Dict = {},
model_type: ModelType = ModelType.UNDEFINED,
) -> Dict:
final_model_kwargs = model_kwargs.copy()

if model_type == ModelType.EMBEDDER:
if isinstance(input, str):
input = [input]
if not isinstance(input, Sequence):
raise TypeError("input must be a sequence of text")
final_model_kwargs["input"] = input
elif model_type == ModelType.LLM:
messages: List[Dict[str, str]] = []
if isinstance(input, str):
messages.append({"role": "user", "content": input})
elif isinstance(input, list) and all(isinstance(m, dict) for m in input):
messages = input
else:
messages.append({"role": "user", "content": str(input)})
final_model_kwargs["messages"] = messages
else:
raise ValueError(f"model_type {model_type} is not supported")

return final_model_kwargs

def parse_chat_completion(self, completion) -> GeneratorOutput:
try:
data = self.chat_completion_parser(completion)
except Exception as e:
log.error(f"Error parsing the completion: {e}")
return GeneratorOutput(data=None, error=str(e), raw_response=completion)

try:
usage = self.track_completion_usage(completion)
return GeneratorOutput(
data=None, error=None, raw_response=data, usage=usage
)
except Exception as e:
log.error(f"Error tracking the completion usage: {e}")
return GeneratorOutput(data=None, error=str(e), raw_response=data)
Comment thread
RheagalFire marked this conversation as resolved.
Outdated

def track_completion_usage(self, completion) -> CompletionUsage:
try:
return CompletionUsage(
completion_tokens=completion.usage.completion_tokens,
prompt_tokens=completion.usage.prompt_tokens,
total_tokens=completion.usage.total_tokens,
)
except Exception as e:
log.error(f"Error tracking the completion usage: {e}")
return CompletionUsage(
completion_tokens=None, prompt_tokens=None, total_tokens=None
)

def parse_embedding_response(self, response) -> EmbedderOutput:
try:
return parse_embedding_response(response)
except Exception as e:
log.error(f"Error parsing the embedding response: {e}")
return EmbedderOutput(data=[], error=str(e), raw_response=response)

@backoff.on_exception(backoff.expo, Exception, max_time=5)
Comment thread
RheagalFire marked this conversation as resolved.
Outdated
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
import litellm

log.info(f"api_kwargs: {api_kwargs}")
Comment thread
RheagalFire marked this conversation as resolved.
Outdated

extra = {}
if self._api_key:
extra["api_key"] = self._api_key
if self._base_url:
extra["api_base"] = self._base_url

if model_type == ModelType.EMBEDDER:
return litellm.embedding(
drop_params=True,
**api_kwargs,
**extra,
)
elif model_type == ModelType.LLM:
if api_kwargs.get("stream", False):
self.chat_completion_parser = _handle_streaming_response
return litellm.completion(
drop_params=True,
**api_kwargs,
**extra,
)
else:
return litellm.completion(
drop_params=True,
**api_kwargs,
**extra,
)
Comment thread
RheagalFire marked this conversation as resolved.
Outdated
else:
raise ValueError(f"model_type {model_type} is not supported")

async def acall(
self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED
):
import litellm

extra = {}
if self._api_key:
extra["api_key"] = self._api_key
if self._base_url:
extra["api_base"] = self._base_url

if model_type == ModelType.EMBEDDER:
return await litellm.aembedding(
drop_params=True,
**api_kwargs,
**extra,
)
elif model_type == ModelType.LLM:
return await litellm.acompletion(
drop_params=True,
**api_kwargs,
**extra,
)
else:
raise ValueError(f"model_type {model_type} is not supported")

@classmethod
def from_dict(cls, data: Dict[str, Any]):
return cls(**data)

def to_dict(self) -> Dict[str, Any]:
exclude = ["sync_client", "async_client"]
output = super().to_dict(exclude=exclude)
return output


def _handle_streaming_response(generator):
for completion in generator:
if completion.choices and completion.choices[0].delta.content:
yield completion.choices[0].delta.content
1 change: 1 addition & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ boto3 = ">=1.34.0"
websockets = ">=11.0.3"
azure-identity = ">=1.12.0"
azure-core = ">=1.24.0"
litellm = {version = ">=1.60.0,<2.0", optional = true}


[build-system]
Expand Down
Loading