diff --git a/src/main/java/cn/van/business/service/SocialMediaService.java b/src/main/java/cn/van/business/service/SocialMediaService.java index b6f59ef..a685a3f 100644 --- a/src/main/java/cn/van/business/service/SocialMediaService.java +++ b/src/main/java/cn/van/business/service/SocialMediaService.java @@ -1,7 +1,7 @@ package cn.van.business.service; import cn.hutool.core.util.StrUtil; -import cn.van.business.util.ds.OllamaClientUtil; +import cn.van.business.util.ds.SocialMediaLlmClient; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.redis.core.StringRedisTemplate; @@ -23,7 +23,7 @@ import java.util.Map; public class SocialMediaService { @Autowired - private OllamaClientUtil ollamaClientUtil; + private SocialMediaLlmClient socialMediaLlmClient; @Autowired private MarketingImageService marketingImageService; @@ -102,7 +102,7 @@ public class SocialMediaService { String promptTemplate = getPromptTemplate("keywords", DEFAULT_KEYWORDS_PROMPT); String prompt = String.format(promptTemplate, productName); - String response = ollamaClientUtil.getResponse(prompt); + String response = socialMediaLlmClient.getResponse(prompt); if (StrUtil.isNotBlank(response)) { // 解析关键词 @@ -191,7 +191,7 @@ public class SocialMediaService { String prompt = String.format(promptTemplate, productName, priceInfo.toString(), keywordsInfo); - String content = ollamaClientUtil.getResponse(prompt.toString()); + String content = socialMediaLlmClient.getResponse(prompt.toString()); if (StrUtil.isNotBlank(content)) { result.put("success", true); diff --git a/src/main/java/cn/van/business/util/ds/OllamaClientUtil.java b/src/main/java/cn/van/business/util/ds/OllamaClientUtil.java index b0963c8..bb57c38 100644 --- a/src/main/java/cn/van/business/util/ds/OllamaClientUtil.java +++ b/src/main/java/cn/van/business/util/ds/OllamaClientUtil.java @@ -1,5 +1,6 @@ package cn.van.business.util.ds; +import cn.hutool.core.util.StrUtil; import cn.hutool.http.HttpRequest; import cn.hutool.http.HttpResponse; import cn.hutool.http.Method; @@ -28,12 +29,17 @@ public class OllamaClientUtil { private String ollamaModel; /** - * 调用 Ollama /api/chat 并返回助手回复的文本内容 - * - * @param inputText 用户输入的文本 - * @return 模型回复的文本内容 + * 调用 Ollama /api/chat 并返回助手回复的文本内容(使用 application.yml 默认地址与模型) */ public String getResponse(String inputText) throws IOException { + return getResponse(inputText, null, null); + } + + /** + * @param overrideBaseUrl 非空时覆盖配置的 Ollama 根地址(不含路径,如 http://127.0.0.1:11434) + * @param overrideModel 非空时覆盖配置的模型名 + */ + public String getResponse(String inputText, String overrideBaseUrl, String overrideModel) throws IOException { if (inputText == null || inputText.trim().isEmpty()) { throw new IllegalArgumentException("输入文本不能为空"); } @@ -41,9 +47,12 @@ public class OllamaClientUtil { throw new IllegalArgumentException("输入文本过长"); } - String url = ollamaBaseUrl.replaceAll("/$", "") + "/api/chat"; + String base = StrUtil.isNotBlank(overrideBaseUrl) ? overrideBaseUrl.trim() : ollamaBaseUrl; + String model = StrUtil.isNotBlank(overrideModel) ? overrideModel.trim() : ollamaModel; + + String url = base.replaceAll("/$", "") + "/api/chat"; Map requestBody = new HashMap<>(); - requestBody.put("model", ollamaModel); + requestBody.put("model", model); requestBody.put("messages", new Object[]{ Map.of("role", "user", "content", inputText) }); @@ -58,7 +67,7 @@ public class OllamaClientUtil { .body(jsonBody) .timeout(60000); - logger.info("请求 Ollama API: URL={}, model={}", url, ollamaModel); + logger.info("请求 Ollama API: URL={}, model={}", url, model); HttpResponse response = request.execute(); diff --git a/src/main/java/cn/van/business/util/ds/SocialMediaLlmClient.java b/src/main/java/cn/van/business/util/ds/SocialMediaLlmClient.java new file mode 100644 index 0000000..0aa225f --- /dev/null +++ b/src/main/java/cn/van/business/util/ds/SocialMediaLlmClient.java @@ -0,0 +1,115 @@ +package cn.van.business.util.ds; + +import cn.hutool.core.util.StrUtil; +import cn.hutool.http.HttpRequest; +import cn.hutool.http.HttpResponse; +import cn.hutool.http.Method; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.redis.core.StringRedisTemplate; +import org.springframework.stereotype.Component; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * 社媒文案/关键词所用大模型客户端:支持本地 Ollama 或与 OpenAI 兼容的 HTTP 接口(含远程 API、Ollama /v1 等)。 + * 配置存 Redis,与若依后台「大模型接入」一致;未配置时走 {@link OllamaClientUtil} 的 yml 默认。 + */ +@Component +public class SocialMediaLlmClient { + + private static final org.slf4j.Logger log = org.slf4j.LoggerFactory.getLogger(SocialMediaLlmClient.class); + private static final ObjectMapper objectMapper = new ObjectMapper(); + + private static final String KEY_MODE = "social_media:llm:mode"; + private static final String KEY_BASE_URL = "social_media:llm:base_url"; + private static final String KEY_API_KEY = "social_media:llm:api_key"; + private static final String KEY_MODEL = "social_media:llm:model"; + + private static final String MODE_OLLAMA = "ollama"; + private static final String MODE_OPENAI = "openai"; + + @Autowired(required = false) + private StringRedisTemplate redisTemplate; + + @Autowired + private OllamaClientUtil ollamaClientUtil; + + public String getResponse(String inputText) throws IOException { + if (redisTemplate == null) { + return ollamaClientUtil.getResponse(inputText); + } + + String mode = trimOrNull(redisTemplate.opsForValue().get(KEY_MODE)); + String baseUrl = trimOrNull(redisTemplate.opsForValue().get(KEY_BASE_URL)); + String apiKey = trimOrNull(redisTemplate.opsForValue().get(KEY_API_KEY)); + String model = trimOrNull(redisTemplate.opsForValue().get(KEY_MODEL)); + + if (mode == null && baseUrl == null && model == null && apiKey == null) { + return ollamaClientUtil.getResponse(inputText); + } + + String effectiveMode = mode != null ? mode.toLowerCase() : MODE_OLLAMA; + + if (MODE_OPENAI.equals(effectiveMode)) { + if (StrUtil.isBlank(baseUrl)) { + throw new IOException("OpenAI 兼容模式需在后台配置完整的 Chat Completions 地址"); + } + if (StrUtil.isBlank(model)) { + throw new IOException("OpenAI 兼容模式需在后台配置模型名称"); + } + return callOpenAiCompatible(baseUrl, apiKey, model, inputText); + } + + return ollamaClientUtil.getResponse(inputText, baseUrl, model); + } + + private static String trimOrNull(String s) { + if (s == null) { + return null; + } + String t = s.trim(); + return t.isEmpty() ? null : t; + } + + private String callOpenAiCompatible(String chatCompletionsUrl, String apiKey, String model, String userText) + throws IOException { + Map requestBody = new HashMap<>(); + requestBody.put("model", model); + requestBody.put("messages", new Map[]{ + Map.of("role", "user", "content", userText) + }); + requestBody.put("temperature", 0.7); + + String jsonBody = objectMapper.writeValueAsString(requestBody); + + HttpRequest request = HttpRequest.of(chatCompletionsUrl.trim()) + .method(Method.POST) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .body(jsonBody) + .timeout(120000); + + if (StrUtil.isNotBlank(apiKey)) { + request.header("Authorization", "Bearer " + apiKey.trim()); + } + + log.info("请求 OpenAI 兼容 API: URL={}, model={}", chatCompletionsUrl, model); + + HttpResponse response = request.execute(); + if (response.getStatus() != 200) { + log.error("OpenAI 兼容 API 失败: status={}, body={}", response.getStatus(), response.body()); + throw new IOException("API 调用失败,HTTP 状态码: " + response.getStatus()); + } + + JsonNode root = objectMapper.readTree(response.body()); + JsonNode choices = root.path("choices"); + if (choices.isEmpty() || !choices.get(0).path("message").path("content").isTextual()) { + throw new IOException("API 返回数据格式异常,未找到回复内容"); + } + return choices.get(0).path("message").path("content").asText(); + } +}