diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-prompt/src/main/java/org/apache/shenyu/plugin/ai/prompt/AiPromptPlugin.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-prompt/src/main/java/org/apache/shenyu/plugin/ai/prompt/AiPromptPlugin.java index b46b56b83dd2..9e380027489c 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-prompt/src/main/java/org/apache/shenyu/plugin/ai/prompt/AiPromptPlugin.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-prompt/src/main/java/org/apache/shenyu/plugin/ai/prompt/AiPromptPlugin.java @@ -111,7 +111,7 @@ private String decorateBody(final String originalBody, final AiPromptConfig aiPr prependMap.put(Constants.ROLE, aiPromptConfig.getPreRole()); decoratedMessages.add(prependMap); } - decoratedMessages.add(messages.get(0)); + decoratedMessages.addAll(messages); // If append in aiPromptConfig is not empty, add append to the end of message body if (Objects.nonNull(aiPromptConfig.getAppend()) && Objects.nonNull(aiPromptConfig.getPostRole())) { // Assemble append content role diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-prompt/src/test/java/org/apache/shenyu/plugin/ai/prompt/AiPromptPluginTest.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-prompt/src/test/java/org/apache/shenyu/plugin/ai/prompt/AiPromptPluginTest.java index b2aeb29b32dc..f4b57e69c469 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-prompt/src/test/java/org/apache/shenyu/plugin/ai/prompt/AiPromptPluginTest.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-prompt/src/test/java/org/apache/shenyu/plugin/ai/prompt/AiPromptPluginTest.java @@ -17,12 +17,16 @@ package org.apache.shenyu.plugin.ai.prompt; +import org.apache.shenyu.common.dto.convert.plugin.AiPromptConfig; import org.apache.shenyu.common.enums.PluginEnum; +import org.apache.shenyu.common.utils.GsonUtils; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.http.codec.HttpMessageReader; +import java.lang.reflect.Method; import java.util.List; +import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.mock; @@ -50,4 +54,146 @@ void testGetOrder() { assertEquals(PluginEnum.AI_PROMPT.getCode(), plugin.getOrder()); } + + @Test + void testDecorateBodyReturnsOriginalWhenNoMessages() throws Exception { + AiPromptConfig config = new AiPromptConfig(); + String body = "{\"model\":\"gpt-4\"}"; + String result = invokeDecorateBody(body, config); + assertEquals(body, result); + } + + @Test + void testDecorateBodyReturnsOriginalWhenMessagesEmpty() throws Exception { + AiPromptConfig config = new AiPromptConfig(); + String body = "{\"messages\":[]}"; + String result = invokeDecorateBody(body, config); + assertEquals(body, result); + } + + @SuppressWarnings("unchecked") + @Test + void testDecorateBodyWithPrependAndAppend() throws Exception { + AiPromptConfig config = new AiPromptConfig(); + config.setPrepend("system instruction"); + config.setPreRole("system"); + config.setAppend("post instruction"); + config.setPostRole("assistant"); + + String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}]}"; + String result = invokeDecorateBody(body, config); + + Map resultMap = GsonUtils.getInstance().convertToMap(result); + List> messages = (List>) resultMap.get("messages"); + + assertEquals(3, messages.size()); + assertEquals("system instruction", messages.get(0).get("content")); + assertEquals("system", messages.get(0).get("role")); + assertEquals("hello", messages.get(1).get("content")); + assertEquals("user", messages.get(1).get("role")); + assertEquals("post instruction", messages.get(2).get("content")); + assertEquals("assistant", messages.get(2).get("role")); + } + + @SuppressWarnings("unchecked") + @Test + void testDecorateBodyWithOnlyPrepend() throws Exception { + AiPromptConfig config = new AiPromptConfig(); + config.setPrepend("system instruction"); + config.setPreRole("system"); + + String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}]}"; + String result = invokeDecorateBody(body, config); + + Map resultMap = GsonUtils.getInstance().convertToMap(result); + List> messages = (List>) resultMap.get("messages"); + + assertEquals(2, messages.size()); + assertEquals("system instruction", messages.get(0).get("content")); + assertEquals("hello", messages.get(1).get("content")); + } + + @SuppressWarnings("unchecked") + @Test + void testDecorateBodyPreservesAllOriginalMessages() throws Exception { + AiPromptConfig config = new AiPromptConfig(); + config.setPrepend("system prompt"); + config.setPreRole("system"); + config.setAppend("footer"); + config.setPostRole("assistant"); + + String body = "{\"messages\":[" + + "{\"role\":\"system\",\"content\":\"sys\"}," + + "{\"role\":\"user\",\"content\":\"hi\"}," + + "{\"role\":\"assistant\",\"content\":\"hello\"}," + + "{\"role\":\"user\",\"content\":\"how are you\"}" + + "]}"; + String result = invokeDecorateBody(body, config); + + Map resultMap = GsonUtils.getInstance().convertToMap(result); + List> messages = (List>) resultMap.get("messages"); + + assertEquals(6, messages.size()); + assertEquals("system prompt", messages.get(0).get("content")); + assertEquals("sys", messages.get(1).get("content")); + assertEquals("hi", messages.get(2).get("content")); + assertEquals("hello", messages.get(3).get("content")); + assertEquals("how are you", messages.get(4).get("content")); + assertEquals("footer", messages.get(5).get("content")); + } + + @SuppressWarnings("unchecked") + @Test + void testDecorateBodyNoPrependOrAppend() throws Exception { + AiPromptConfig config = new AiPromptConfig(); + + String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}]}"; + String result = invokeDecorateBody(body, config); + + Map resultMap = GsonUtils.getInstance().convertToMap(result); + List> messages = (List>) resultMap.get("messages"); + + assertEquals(1, messages.size()); + assertEquals("hello", messages.get(0).get("content")); + } + + @SuppressWarnings("unchecked") + @Test + void testDecorateBodyPrependWithoutRoleIgnored() throws Exception { + AiPromptConfig config = new AiPromptConfig(); + config.setPrepend("system instruction"); + + String body = "{\"messages\":[{\"role\":\"user\",\"content\":\"hello\"}]}"; + String result = invokeDecorateBody(body, config); + + Map resultMap = GsonUtils.getInstance().convertToMap(result); + List> messages = (List>) resultMap.get("messages"); + + assertEquals(1, messages.size()); + assertEquals("hello", messages.get(0).get("content")); + } + + @SuppressWarnings("unchecked") + @Test + void testDecorateBodyPreservesOtherFields() throws Exception { + AiPromptConfig config = new AiPromptConfig(); + config.setPrepend("prefix"); + config.setPreRole("system"); + + String body = "{\"model\":\"gpt-4\",\"temperature\":0.7,\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}"; + String result = invokeDecorateBody(body, config); + + Map resultMap = GsonUtils.getInstance().convertToMap(result); + assertEquals("gpt-4", resultMap.get("model")); + assertEquals(0.7, ((Number) resultMap.get("temperature")).doubleValue(), 0.001); + + List> messages = (List>) resultMap.get("messages"); + assertEquals(2, messages.size()); + } + + private String invokeDecorateBody(final String body, final AiPromptConfig config) throws Exception { + Method method = AiPromptPlugin.class.getDeclaredMethod("decorateBody", String.class, AiPromptConfig.class); + method.setAccessible(true); + return (String) method.invoke(plugin, body, config); + } }