Skip to content

Commit 94a3f33

Browse files
committed
more
1 parent 78b7df1 commit 94a3f33

6 files changed

Lines changed: 227 additions & 197 deletions

File tree

mbodied/agents/backends/anthropic_backend.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
from typing import Any, List
1617

1718
import anthropic
@@ -80,7 +81,7 @@ def __init__(self, api_key: str | None, client: anthropic.Anthropic | None = Non
8081
client: An optional client for the Anthropic service.
8182
kwargs: Additional keyword arguments.
8283
"""
83-
self.api_key = api_key
84+
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
8485
self.client = client
8586

8687
self.model = kwargs.pop("model", self.DEFAULT_MODEL)
@@ -117,13 +118,6 @@ def predict(
117118
)
118119
return completion.content[0].text
119120

120-
async def async_predict(
121-
self, message: Message, context: List[Message] | None = None, model: Any | None = None
122-
) -> str:
123-
"""Asynchronously predict the next message in the conversation."""
124-
# For now, we'll use the synchronous method since Anthropic doesn't provide an async API
125-
return self.predict(message, context, model)
126-
127121
def stream(
128122
self, message: Message, context: List[Message] = None, model: str = "claude-3-5-sonnet-20240620", **kwargs
129123
):

mbodied/agents/backends/gemini_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(
105105
client: An optional client for the Gemini service.
106106
**kwargs: Additional keyword arguments.
107107
"""
108-
self.api_key = api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("MBODI_API_KEY")
108+
self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
109109
self.client = client
110110

111111
self.model = kwargs.pop("model", self.DEFAULT_MODEL)

mbodied/agents/backends/openai_backend.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(
100100
aclient: Whether to use the asynchronous client.
101101
**kwargs: Additional keyword arguments.
102102
"""
103-
self.api_key = api_key or os.getenv("OPENAI_API_KEY") or os.getenv("MBODI_API_KEY")
103+
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
104104
self.client = client
105105
if self.client is None:
106106
from openai import AsyncOpenAI, OpenAI
@@ -137,7 +137,7 @@ def predict(
137137
**kwargs: Additional keyword arguments.
138138
139139
Returns:
140-
str | tuple[str, List[ToolCall]]:
140+
str | tuple[str, List[ToolCall]]:
141141
When tools are not provided: Just the text response
142142
When tools are provided: A tuple of (text_response, tool_calls)
143143
"""
@@ -163,12 +163,7 @@ def predict(
163163
return completion.choices[0].message.content
164164

165165
def stream(
166-
self,
167-
message: Message,
168-
context: List[Message] = None,
169-
model: str = "gpt-4o",
170-
tools: List[Tool] = None,
171-
**kwargs
166+
self, message: Message, context: List[Message] = None, model: str = "gpt-4o", tools: List[Tool] = None, **kwargs
172167
):
173168
"""Streams a completion for the given messages using the OpenAI API standard.
174169
@@ -178,7 +173,7 @@ def stream(
178173
model: The model to be used for the completion.
179174
tools: Optional list of tools (function calls) available to the model.
180175
**kwargs: Additional keyword arguments.
181-
176+
182177
Yields:
183178
When tools is None:
184179
str: Content delta chunks
@@ -196,7 +191,7 @@ def stream(
196191
tools=tools,
197192
**kwargs,
198193
)
199-
194+
200195
if not tools:
201196
for chunk in stream:
202197
yield chunk.choices[0].delta.content or ""
@@ -208,12 +203,7 @@ def stream(
208203
yield content, tool_calls
209204

210205
async def astream(
211-
self,
212-
message: Message,
213-
context: List[Message] = None,
214-
model: str = "gpt-4o",
215-
tools: List[Tool] = None,
216-
**kwargs
206+
self, message: Message, context: List[Message] = None, model: str = "gpt-4o", tools: List[Tool] = None, **kwargs
217207
):
218208
"""Streams a completion asynchronously for the given messages using the OpenAI API standard.
219209
@@ -223,7 +213,7 @@ async def astream(
223213
model: The model to be used for the completion.
224214
tools: Optional list of tools (function calls) available to the model.
225215
**kwargs: Additional keyword arguments.
226-
216+
227217
Yields:
228218
When tools is None:
229219
str: Content delta chunks
@@ -243,7 +233,7 @@ async def astream(
243233
tools=tools,
244234
**kwargs,
245235
)
246-
236+
247237
if not tools:
248238
async for chunk in stream:
249239
yield chunk.choices[0].delta.content or ""

mbodied/agents/language/language_agent.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,19 @@ def __getitem__(self, key):
8686
raise IndexError("Invalid index")
8787

8888

89-
def make_context_list(context: list[str | Image | Message] | Image | str | Message | None) -> List[Message]:
89+
def make_context_list(
90+
context: list[str | Image | Message] | Image | str | Message | None, model_src: str
91+
) -> List[Message]:
9092
"""Convert the context to a list of messages."""
9193
if isinstance(context, list):
9294
return [Message(content=c) if not isinstance(c, Message) else c for c in context]
9395
if isinstance(context, Message):
9496
return [context]
9597
if isinstance(context, str | Image):
96-
return [Message(role="user", content=[context]), Message(role="assistant", content="Understood.")]
98+
if model_src == "openai":
99+
return [Message(role="system", content=[context])]
100+
else:
101+
return [Message(role="user", content=[context]), Message(role="assistant", content="Understood.")]
97102
return []
98103

99104

@@ -150,7 +155,7 @@ def __init__(
150155
| DirectoryPath
151156
| NewPath = "openai",
152157
context: list | Image | str | Message = None,
153-
api_key: str | None = os.getenv("OPENAI_API_KEY"),
158+
api_key: str | None = None,
154159
model_kwargs: dict = None,
155160
recorder: Literal["default", "omit"] | str = "omit",
156161
recorder_kwargs: dict = None,
@@ -199,7 +204,7 @@ def __init__(
199204
api_key=api_key,
200205
)
201206

202-
self.context = make_context_list(context)
207+
self.context = make_context_list(context, model_src)
203208

204209
def forget_last(self) -> Message:
205210
"""Forget the last message in the context."""

mbodied/agents/language/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ def function_to_tool(func: Callable) -> Tool:
104104
return Tool.model_validate(tool_data)
105105

106106

107-
def main():
108-
"""Example usage of function_to_tool utility"""
107+
def main() -> None:
108+
"""Example usage of function_to_tool utility."""
109109
from mbodied.agents.language import LanguageAgent
110110

111111
def move_by(x: float, y: float, z: float, speed: float = 1.0) -> bool:

0 commit comments

Comments
 (0)