ChatClient 解读
原文链接:SpringAI(GA):ChatClient调用链路解读
教程说明
说明:本教程将采用2025年5月20日正式的GA版,给出如下内容
- 核心功能模块的快速上手教程
- 核心功能模块的源码级解读
- Spring ai alibaba增强的快速上手教程 + 源码级解读
版本:JDK21 + SpringBoot3.4.5 + SpringAI 1.0.0 + SpringAI Alibaba最新
将陆续完成如下章节教程。本章是第一章(chat初体验)下的其中一部分源码解读—ChatClient解读
- 目前已完成第一、二章源码解读部分
微信推文往届解读可参考:
获取更好的观赏体验,可付费获取飞书云文档Spring AI最新教程权限,目前39.9,随着内容不断完善,会逐步涨价。
注:M6版快速上手教程+源码解读飞书云文档已免费提供
[!TIP]
ChatClient 端设置 advisors、ChatOptions、用户提示信息、系统提示信息、工具等信息,构建 DefaultChatClient.DefaultChatClientRequestSpec,再利用 DefaultChatClientUtils 将其转换为 ChatClientRequest
AdvisorChain 链调用一系列的增强器Advisor 基础,每个增强器输入是 ChatClientRequest,输出 ChatClientResponse(其中必定会用到的是 ChatModelCallAdvisor 或 ChatModelStreamAdvisor)
ChatClient
类的说明:面向对话式 AI 模型的客户端接口,提供了系列的 API 与 AI 会话模型交互,该接口封装了请求构建、调用、响应处理等流畅,支持同步、流式调用
方法说明
内部接口类说明
(全局的ChatClient配置)
(当前的ChatClient配置)
(用户提示信息的构建规范)
(系统提示信息的构建规范)
(设置增强器)
public interface ChatClient { static ChatClient create(ChatModel chatModel) { return create(chatModel, ObservationRegistry.NOOP); } static ChatClient create(ChatModel chatModel, ObservationRegistry observationRegistry) { return create(chatModel, observationRegistry, (ChatClientObservationConvention)null); } static ChatClient create(ChatModel chatModel, ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention observationConvention) { Assert.notNull(chatModel, \"chatModel cannot be null\"); Assert.notNull(observationRegistry, \"observationRegistry cannot be null\"); return builder(chatModel, observationRegistry, observationConvention).build(); } static Builder builder(ChatModel chatModel) { return builder(chatModel, ObservationRegistry.NOOP, (ChatClientObservationConvention)null); } static Builder builder(ChatModel chatModel, ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention customObservationConvention) { Assert.notNull(chatModel, \"chatModel cannot be null\"); Assert.notNull(observationRegistry, \"observationRegistry cannot be null\"); return new DefaultChatClientBuilder(chatModel, observationRegistry, customObservationConvention); } ChatClientRequestSpec prompt(); ChatClientRequestSpec prompt(String content); ChatClientRequestSpec prompt(Prompt prompt); Builder mutate(); public interface AdvisorSpec { AdvisorSpec param(String k, Object v); AdvisorSpec params(Map<String, Object> p); AdvisorSpec advisors(Advisor... advisors); AdvisorSpec advisors(List<Advisor> advisors); } public interface Builder { Builder defaultAdvisors(Advisor... advisor); Builder defaultAdvisors(Consumer<AdvisorSpec> advisorSpecConsumer); Builder defaultAdvisors(List<Advisor> advisors); Builder defaultOptions(ChatOptions chatOptions); Builder defaultUser(String text); Builder defaultUser(Resource text, Charset charset); Builder defaultUser(Resource text); Builder defaultUser(Consumer<PromptUserSpec> userSpecConsumer); Builder defaultSystem(String text); Builder defaultSystem(Resource text, Charset charset); Builder defaultSystem(Resource text); Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer); Builder defaultTemplateRenderer(TemplateRenderer templateRenderer); Builder defaultToolNames(String... toolNames); Builder defaultTools(Object... toolObjects); Builder defaultToolCallbacks(ToolCallback... toolCallbacks); Builder defaultToolCallbacks(List<ToolCallback> toolCallbacks); Builder defaultToolCallbacks(ToolCallbackProvider... toolCallbackProviders); Builder defaultToolContext(Map<String, Object> toolContext); Builder clone(); ChatClient build(); } public interface CallPromptResponseSpec { String content(); List<String> contents(); ChatResponse chatResponse(); } public interface CallResponseSpec { @Nullable <T> T entity(ParameterizedTypeReference<T> type); @Nullable <T> T entity(StructuredOutputConverter<T> structuredOutputConverter); @Nullable <T> T entity(Class<T> type); ChatClientResponse chatClientResponse(); @Nullable ChatResponse chatResponse(); @Nullable String content(); <T> ResponseEntity<ChatResponse, T> responseEntity(Class<T> type); <T> ResponseEntity<ChatResponse, T> responseEntity(ParameterizedTypeReference<T> type); <T> ResponseEntity<ChatResponse, T> responseEntity(StructuredOutputConverter<T> structuredOutputConverter); } public interface ChatClientRequestSpec { Builder mutate(); ChatClientRequestSpec advisors(Consumer<AdvisorSpec> consumer); ChatClientRequestSpec advisors(Advisor... advisors); ChatClientRequestSpec advisors(List<Advisor> advisors); ChatClientRequestSpec messages(Message... messages); ChatClientRequestSpec messages(List<Message> messages); <T extends ChatOptions> ChatClientRequestSpec options(T options); ChatClientRequestSpec toolNames(String... toolNames); ChatClientRequestSpec tools(Object... toolObjects); ChatClientRequestSpec toolCallbacks(ToolCallback... toolCallbacks); ChatClientRequestSpec toolCallbacks(List<ToolCallback> toolCallbacks); ChatClientRequestSpec toolCallbacks(ToolCallbackProvider... toolCallbackProviders); ChatClientRequestSpec toolContext(Map<String, Object> toolContext); ChatClientRequestSpec system(String text); ChatClientRequestSpec system(Resource textResource, Charset charset); ChatClientRequestSpec system(Resource text); ChatClientRequestSpec system(Consumer<PromptSystemSpec> consumer); ChatClientRequestSpec user(String text); ChatClientRequestSpec user(Resource text, Charset charset); ChatClientRequestSpec user(Resource text); ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer); ChatClientRequestSpec templateRenderer(TemplateRenderer templateRenderer); CallResponseSpec call(); StreamResponseSpec stream(); } public interface PromptSystemSpec { PromptSystemSpec text(String text); PromptSystemSpec text(Resource text, Charset charset); PromptSystemSpec text(Resource text); PromptSystemSpec params(Map<String, Object> p); PromptSystemSpec param(String k, Object v); } public interface PromptUserSpec { PromptUserSpec text(String text); PromptUserSpec text(Resource text, Charset charset); PromptUserSpec text(Resource text); PromptUserSpec params(Map<String, Object> p); PromptUserSpec param(String k, Object v); PromptUserSpec media(Media... media); PromptUserSpec media(MimeType mimeType, URL url); PromptUserSpec media(MimeType mimeType, Resource resource); } public interface StreamPromptResponseSpec { Flux<ChatResponse> chatResponse(); Flux<String> content(); } public interface StreamResponseSpec { Flux<ChatClientResponse> chatClientResponse(); Flux<ChatResponse> chatResponse(); Flux<String> content(); }}
DefaultChatClient
ChatClient 接口的默认实现类,用于构建和执行与 AI 聊天模型交互的请求
- 内部类 DefaultChatClientRequestSpec 实现了 ChatClient.ChatClientRequestSpec:新增 ChatModelCallAdvisor
public static class DefaultChatClientRequestSpec implements ChatClient.ChatClientRequestSpec { private BaseAdvisorChain buildAdvisorChain() { this.advisors.add(ChatModelCallAdvisor.builder().chatModel(this.chatModel).build()); this.advisors.add(ChatModelStreamAdvisor.builder().chatModel(this.chatModel).build()); return DefaultAroundAdvisorChain.builder(this.observationRegistry).pushAll(this.advisors).templateRenderer(this.templateRenderer).build(); } }
- 内部类 DefaultPromptSystemSpec 实现 ChatClient.PromptSystemSpec:设置用户文本内容、参数
- 内部类 DefaultPromptSystemSpec 实现 ChatClient.PromptSystemSpec:设置系统文本内容、参数
- 内部类 DefaultAdvisorSpec 实现 ChatClient.AdvisorSpec:设置 Advisor,及其 advisor 中用到的参数
- 内部类 DefaultCallResponseSpec 实现 ChatClient.CallResponseSpec:通过 doGetObservableChatClientResponse 方法发起请求,调用一系列的 BaseAdvisorChain
public static class DefaultCallResponseSpec implements ChatClient.CallResponseSpec { private final ChatClientRequest request; private final BaseAdvisorChain advisorChain; private final ObservationRegistry observationRegistry; private final ChatClientObservationConvention observationConvention; private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest, @Nullable String outputFormat) { if (outputFormat != null) { chatClientRequest.context().put(ChatClientAttributes.OUTPUTFORMAT.getKey(), outputFormat); } ChatClientObservationContext observationContext = ChatClientObservationContext.builder().request(chatClientRequest).advisors(this.advisorChain.getCallAdvisors()).stream(false).format(outputFormat).build(); Observation observation = ChatClientObservationDocumentation.AICHATCLIENT.observation(this.observationConvention, DefaultChatClient.DEFAULTCHATCLIENTOBSERVATIONCONVENTION, () -> observationContext, this.observationRegistry); ChatClientResponse chatClientResponse = (ChatClientResponse)observation.observe(() -> this.advisorChain.nextCall(chatClientRequest)); return chatClientResponse != null ? chatClientResponse : ChatClientResponse.builder().build(); }}
- 内部类 DefaultStreamResponseSpec 实现 ChatClient.StreamResponseSpec:通过 doGetObservableFluxChatResponse 方法发起请求,调用一系列的 BaseAdvisorChain
public static class DefaultStreamResponseSpec implements ChatClient.StreamResponseSpec { private final ChatClientRequest request; private final BaseAdvisorChain advisorChain; private final ObservationRegistry observationRegistry; private final ChatClientObservationConvention observationConvention; private Flux<ChatClientResponse> doGetObservableFluxChatResponse(ChatClientRequest chatClientRequest) { return Flux.deferContextual((contextView) -> { ChatClientObservationContext observationContext = ChatClientObservationContext.builder().request(chatClientRequest).advisors(this.advisorChain.getStreamAdvisors()).stream(true).build(); Observation observation = ChatClientObservationDocumentation.AICHATCLIENT.observation(this.observationConvention, DefaultChatClient.DEFAULTCHATCLIENTOBSERVATIONCONVENTION, () -> observationContext, this.observationRegistry); observation.parentObservation((Observation)contextView.getOrDefault(\"micrometer.observation\", (Object)null)).start(); Flux var10000 = this.advisorChain.nextStream(chatClientRequest); Objects.requireNonNull(observation); return var10000.doOnError(observation::error).doFinally((s) -> observation.stop()).contextWrite((ctx) -> ctx.put(\"micrometer.observation\", observation)); }); }}
完整代码如下
package org.springframework.ai.chat.client;public class DefaultChatClient implements ChatClient { private static final ChatClientObservationConvention DEFAULTCHATCLIENTOBSERVATIONCONVENTION = new DefaultChatClientObservationConvention(); private static final TemplateRenderer DEFAULTTEMPLATERENDERER = StTemplateRenderer.builder().build(); private final DefaultChatClientRequestSpec defaultChatClientRequest; public DefaultChatClient(DefaultChatClientRequestSpec defaultChatClientRequest) { Assert.notNull(defaultChatClientRequest, \"defaultChatClientRequest cannot be null\"); this.defaultChatClientRequest = defaultChatClientRequest; } public ChatClient.ChatClientRequestSpec prompt() { return new DefaultChatClientRequestSpec(this.defaultChatClientRequest); } public ChatClient.ChatClientRequestSpec prompt(String content) { Assert.hasText(content, \"content cannot be null or empty\"); return this.prompt(new Prompt(content)); } public ChatClient.ChatClientRequestSpec prompt(Prompt prompt) { Assert.notNull(prompt, \"prompt cannot be null\"); DefaultChatClientRequestSpec spec = new DefaultChatClientRequestSpec(this.defaultChatClientRequest); if (prompt.getOptions() != null) { spec.options(prompt.getOptions()); } if (prompt.getInstructions() != null) { spec.messages(prompt.getInstructions()); } return spec; } public ChatClient.Builder mutate() { return this.defaultChatClientRequest.mutate(); } public static class DefaultPromptUserSpec implements ChatClient.PromptUserSpec { private final Map<String, Object> params = new HashMap(); private final List<Media> media = new ArrayList(); @Nullable private String text; public ChatClient.PromptUserSpec media(Media... media) { Assert.notNull(media, \"media cannot be null\"); Assert.noNullElements(media, \"media cannot contain null elements\"); this.media.addAll(Arrays.asList(media)); return this; } public ChatClient.PromptUserSpec media(MimeType mimeType, URL url) { Assert.notNull(mimeType, \"mimeType cannot be null\"); Assert.notNull(url, \"url cannot be null\"); try { this.media.add(Media.builder().mimeType(mimeType).data(url.toURI()).build()); return this; } catch (URISyntaxException e) { throw new RuntimeException(e); } } public ChatClient.PromptUserSpec media(MimeType mimeType, Resource resource) { Assert.notNull(mimeType, \"mimeType cannot be null\"); Assert.notNull(resource, \"resource cannot be null\"); this.media.add(Media.builder().mimeType(mimeType).data(resource).build()); return this; } public ChatClient.PromptUserSpec text(String text) { Assert.hasText(text, \"text cannot be null or empty\"); this.text = text; return this; } public ChatClient.PromptUserSpec text(Resource text, Charset charset) { Assert.notNull(text, \"text cannot be null\"); Assert.notNull(charset, \"charset cannot be null\"); try { this.text(text.getContentAsString(charset)); return this; } catch (IOException e) { throw new RuntimeException(e); } } public ChatClient.PromptUserSpec text(Resource text) { Assert.notNull(text, \"text cannot be null\"); this.text(text, Charset.defaultCharset()); return this; } public ChatClient.PromptUserSpec param(String key, Object value) { Assert.hasText(key, \"key cannot be null or empty\"); Assert.notNull(value, \"value cannot be null\"); this.params.put(key, value); return this; } public ChatClient.PromptUserSpec params(Map<String, Object> params) { Assert.notNull(params, \"params cannot be null\"); Assert.noNullElements(params.keySet(), \"param keys cannot contain null elements\"); Assert.noNullElements(params.values(), \"param values cannot contain null elements\"); this.params.putAll(params); return this; } @Nullable protected String text() { return this.text; } protected Map<String, Object> params() { return this.params; } protected List<Media> media() { return this.media; } } public static class DefaultPromptSystemSpec implements ChatClient.PromptSystemSpec { private final Map<String, Object> params = new HashMap(); @Nullable private String text; public ChatClient.PromptSystemSpec text(String text) { Assert.hasText(text, \"text cannot be null or empty\"); this.text = text; return this; } public ChatClient.PromptSystemSpec text(Resource text, Charset charset) { Assert.notNull(text, \"text cannot be null\"); Assert.notNull(charset, \"charset cannot be null\"); try { this.text(text.getContentAsString(charset)); return this; } catch (IOException e) { throw new RuntimeException(e); } } public ChatClient.PromptSystemSpec text(Resource text) { Assert.notNull(text, \"text cannot be null\"); this.text(text, Charset.defaultCharset()); return this; } public ChatClient.PromptSystemSpec param(String key, Object value) { Assert.hasText(key, \"key cannot be null or empty\"); Assert.notNull(value, \"value cannot be null\"); this.params.put(key, value); return this; } public ChatClient.PromptSystemSpec params(Map<String, Object> params) { Assert.notNull(params, \"params cannot be null\"); Assert.noNullElements(params.keySet(), \"param keys cannot contain null elements\"); Assert.noNullElements(params.values(), \"param values cannot contain null elements\"); this.params.putAll(params); return this; } @Nullable protected String text() { return this.text; } protected Map<String, Object> params() { return this.params; } } public static class DefaultAdvisorSpec implements ChatClient.AdvisorSpec { private final List<Advisor> advisors = new ArrayList(); private final Map<String, Object> params = new HashMap(); public ChatClient.AdvisorSpec param(String key, Object value) { Assert.hasText(key, \"key cannot be null or empty\"); Assert.notNull(value, \"value cannot be null\"); this.params.put(key, value); return this; } public ChatClient.AdvisorSpec params(Map<String, Object> params) { Assert.notNull(params, \"params cannot be null\"); Assert.noNullElements(params.keySet(), \"param keys cannot contain null elements\"); Assert.noNullElements(params.values(), \"param values cannot contain null elements\"); this.params.putAll(params); return this; } public ChatClient.AdvisorSpec advisors(Advisor... advisors) { Assert.notNull(advisors, \"advisors cannot be null\"); Assert.noNullElements(advisors, \"advisors cannot contain null elements\"); this.advisors.addAll(List.of(advisors)); return this; } public ChatClient.AdvisorSpec advisors(List<Advisor> advisors) { Assert.notNull(advisors, \"advisors cannot be null\"); Assert.noNullElements(advisors, \"advisors cannot contain null elements\"); this.advisors.addAll(advisors); return this; } public List<Advisor> getAdvisors() { return this.advisors; } public Map<String, Object> getParams() { return this.params; } } public static class DefaultCallResponseSpec implements ChatClient.CallResponseSpec { private final ChatClientRequest request; private final BaseAdvisorChain advisorChain; private final ObservationRegistry observationRegistry; private final ChatClientObservationConvention observationConvention; public DefaultCallResponseSpec(ChatClientRequest chatClientRequest, BaseAdvisorChain advisorChain, ObservationRegistry observationRegistry, ChatClientObservationConvention observationConvention) { Assert.notNull(chatClientRequest, \"chatClientRequest cannot be null\"); Assert.notNull(advisorChain, \"advisorChain cannot be null\"); Assert.notNull(observationRegistry, \"observationRegistry cannot be null\"); Assert.notNull(observationConvention, \"observationConvention cannot be null\"); this.request = chatClientRequest; this.advisorChain = advisorChain; this.observationRegistry = observationRegistry; this.observationConvention = observationConvention; } public <T> ResponseEntity<ChatResponse, T> responseEntity(Class<T> type) { Assert.notNull(type, \"type cannot be null\"); return this.doResponseEntity(new BeanOutputConverter(type)); } public <T> ResponseEntity<ChatResponse, T> responseEntity(ParameterizedTypeReference<T> type) { Assert.notNull(type, \"type cannot be null\"); return this.doResponseEntity(new BeanOutputConverter(type)); } public <T> ResponseEntity<ChatResponse, T> responseEntity(StructuredOutputConverter<T> structuredOutputConverter) { Assert.notNull(structuredOutputConverter, \"structuredOutputConverter cannot be null\"); return this.doResponseEntity(structuredOutputConverter); } protected <T> ResponseEntity<ChatResponse, T> doResponseEntity(StructuredOutputConverter<T> outputConverter) { Assert.notNull(outputConverter, \"structuredOutputConverter cannot be null\"); ChatResponse chatResponse = this.doGetObservableChatClientResponse(this.request, outputConverter.getFormat()).chatResponse(); String responseContent = getContentFromChatResponse(chatResponse); if (responseContent == null) { return new ResponseEntity(chatResponse, (Object)null); } else { T entity = (T)outputConverter.convert(responseContent); return new ResponseEntity(chatResponse, entity); } } @Nullable public <T> T entity(ParameterizedTypeReference<T> type) { Assert.notNull(type, \"type cannot be null\"); return (T)this.doSingleWithBeanOutputConverter(new BeanOutputConverter(type)); } @Nullable public <T> T entity(StructuredOutputConverter<T> structuredOutputConverter) { Assert.notNull(structuredOutputConverter, \"structuredOutputConverter cannot be null\"); return (T)this.doSingleWithBeanOutputConverter(structuredOutputConverter); } @Nullable public <T> T entity(Class<T> type) { Assert.notNull(type, \"type cannot be null\"); BeanOutputConverter<T> outputConverter = new BeanOutputConverter(type); return (T)this.doSingleWithBeanOutputConverter(outputConverter); } @Nullable private <T> T doSingleWithBeanOutputConverter(StructuredOutputConverter<T> outputConverter) { ChatResponse chatResponse = this.doGetObservableChatClientResponse(this.request, outputConverter.getFormat()).chatResponse(); String stringResponse = getContentFromChatResponse(chatResponse); return (T)(stringResponse == null ? null : outputConverter.convert(stringResponse)); } public ChatClientResponse chatClientResponse() { return this.doGetObservableChatClientResponse(this.request); } @Nullable public ChatResponse chatResponse() { return this.doGetObservableChatClientResponse(this.request).chatResponse(); } @Nullable public String content() { ChatResponse chatResponse = this.doGetObservableChatClientResponse(this.request).chatResponse(); return getContentFromChatResponse(chatResponse); } private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest) { return this.doGetObservableChatClientResponse(chatClientRequest, (String)null); } private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest, @Nullable String outputFormat) { if (outputFormat != null) { chatClientRequest.context().put(ChatClientAttributes.OUTPUTFORMAT.getKey(), outputFormat); } ChatClientObservationContext observationContext = ChatClientObservationContext.builder().request(chatClientRequest).advisors(this.advisorChain.getCallAdvisors()).stream(false).format(outputFormat).build(); Observation observation = ChatClientObservationDocumentation.AICHATCLIENT.observation(this.observationConvention, DefaultChatClient.DEFAULTCHATCLIENTOBSERVATIONCONVENTION, () -> observationContext, this.observationRegistry); ChatClientResponse chatClientResponse = (ChatClientResponse)observation.observe(() -> this.advisorChain.nextCall(chatClientRequest)); return chatClientResponse != null ? chatClientResponse : ChatClientResponse.builder().build(); } @Nullable private static String getContentFromChatResponse(@Nullable ChatResponse chatResponse) { return (String)Optional.ofNullable(chatResponse).map(ChatResponse::getResult).map(Generation::getOutput).map(AbstractMessage::getText).orElse((Object)null); } } public static class DefaultStreamResponseSpec implements ChatClient.StreamResponseSpec { private final ChatClientRequest request; private final BaseAdvisorChain advisorChain; private final ObservationRegistry observationRegistry; private final ChatClientObservationConvention observationConvention; public DefaultStreamResponseSpec(ChatClientRequest chatClientRequest, BaseAdvisorChain advisorChain, ObservationRegistry observationRegistry, ChatClientObservationConvention observationConvention) { Assert.notNull(chatClientRequest, \"chatClientRequest cannot be null\"); Assert.notNull(advisorChain, \"advisorChain cannot be null\"); Assert.notNull(observationRegistry, \"observationRegistry cannot be null\"); Assert.notNull(observationConvention, \"observationConvention cannot be null\"); this.request = chatClientRequest; this.advisorChain = advisorChain; this.observationRegistry = observationRegistry; this.observationConvention = observationConvention; } private Flux<ChatClientResponse> doGetObservableFluxChatResponse(ChatClientRequest chatClientRequest) { return Flux.deferContextual((contextView) -> { ChatClientObservationContext observationContext = ChatClientObservationContext.builder().request(chatClientRequest).advisors(this.advisorChain.getStreamAdvisors()).stream(true).build(); Observation observation = ChatClientObservationDocumentation.AICHATCLIENT.observation(this.observationConvention, DefaultChatClient.DEFAULTCHATCLIENTOBSERVATIONCONVENTION, () -> observationContext, this.observationRegistry); observation.parentObservation((Observation)contextView.getOrDefault(\"micrometer.observation\", (Object)null)).start(); Flux var10000 = this.advisorChain.nextStream(chatClientRequest); Objects.requireNonNull(observation); return var10000.doOnError(observation::error).doFinally((s) -> observation.stop()).contextWrite((ctx) -> ctx.put(\"micrometer.observation\", observation)); }); } public Flux<ChatClientResponse> chatClientResponse() { return this.doGetObservableFluxChatResponse(this.request); } public Flux<ChatResponse> chatResponse() { return this.doGetObservableFluxChatResponse(this.request).mapNotNull(ChatClientResponse::chatResponse); } public Flux<String> content() { return this.doGetObservableFluxChatResponse(this.request).mapNotNull(ChatClientResponse::chatResponse).map((r) -> r.getResult() != null && r.getResult().getOutput() != null && r.getResult().getOutput().getText() != null ? r.getResult().getOutput().getText() : \"\").filter(StringUtils::hasLength); } } public static class DefaultChatClientRequestSpec implements ChatClient.ChatClientRequestSpec { private final ObservationRegistry observationRegistry; private final ChatClientObservationConvention observationConvention; private final ChatModel chatModel; private final List<Media> media; private final List<String> toolNames; private final List<ToolCallback> toolCallbacks; private final List<Message> messages; private final Map<String, Object> userParams; private final Map<String, Object> systemParams; private final List<Advisor> advisors; private final Map<String, Object> advisorParams; private final Map<String, Object> toolContext; private TemplateRenderer templateRenderer; @Nullable private String userText; @Nullable private String systemText; @Nullable private ChatOptions chatOptions; DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) { this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.toolCallbacks, ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, ccr.observationRegistry, ccr.observationConvention, ccr.toolContext, ccr.templateRenderer); } public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText, Map<String, Object> userParams, @Nullable String systemText, Map<String, Object> systemParams, List<ToolCallback> toolCallbacks, List<Message> messages, List<String> toolNames, List<Media> media, @Nullable ChatOptions chatOptions, List<Advisor> advisors, Map<String, Object> advisorParams, ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention observationConvention, Map<String, Object> toolContext, @Nullable TemplateRenderer templateRenderer) { this.media = new ArrayList(); this.toolNames = new ArrayList(); this.toolCallbacks = new ArrayList(); this.messages = new ArrayList(); this.userParams = new HashMap(); this.systemParams = new HashMap(); this.advisors = new ArrayList(); this.advisorParams = new HashMap(); this.toolContext = new HashMap(); Assert.notNull(chatModel, \"chatModel cannot be null\"); Assert.notNull(userParams, \"userParams cannot be null\"); Assert.notNull(systemParams, \"systemParams cannot be null\"); Assert.notNull(toolCallbacks, \"toolCallbacks cannot be null\"); Assert.notNull(messages, \"messages cannot be null\"); Assert.notNull(toolNames, \"toolNames cannot be null\"); Assert.notNull(media, \"media cannot be null\"); Assert.notNull(advisors, \"advisors cannot be null\"); Assert.notNull(advisorParams, \"advisorParams cannot be null\"); Assert.notNull(observationRegistry, \"observationRegistry cannot be null\"); Assert.notNull(toolContext, \"toolContext cannot be null\"); this.chatModel = chatModel; this.chatOptions = chatOptions != null ? chatOptions.copy() : (chatModel.getDefaultOptions() != null ? chatModel.getDefaultOptions().copy() : null); this.userText = userText; this.userParams.putAll(userParams); this.systemText = systemText; this.systemParams.putAll(systemParams); this.toolNames.addAll(toolNames); this.toolCallbacks.addAll(toolCallbacks); this.messages.addAll(messages); this.media.addAll(media); this.advisors.addAll(advisors); this.advisorParams.putAll(advisorParams); this.observationRegistry = observationRegistry; this.observationConvention = observationConvention != null ? observationConvention : DefaultChatClient.DEFAULTCHATCLIENTOBSERVATIONCONVENTION; this.toolContext.putAll(toolContext); this.templateRenderer = templateRenderer != null ? templateRenderer : DefaultChatClient.DEFAULTTEMPLATERENDERER; } @Nullable public String getUserText() { return this.userText; } public Map<String, Object> getUserParams() { return this.userParams; } @Nullable public String getSystemText() { return this.systemText; } public Map<String, Object> getSystemParams() { return this.systemParams; } @Nullable public ChatOptions getChatOptions() { return this.chatOptions; } public List<Advisor> getAdvisors() { return this.advisors; } public Map<String, Object> getAdvisorParams() { return this.advisorParams; } public List<Message> getMessages() { return this.messages; } public List<Media> getMedia() { return this.media; } public List<String> getToolNames() { return this.toolNames; } public List<ToolCallback> getToolCallbacks() { return this.toolCallbacks; } public Map<String, Object> getToolContext() { return this.toolContext; } public TemplateRenderer getTemplateRenderer() { return this.templateRenderer; } public ChatClient.Builder mutate() { DefaultChatClientBuilder builder = (DefaultChatClientBuilder)ChatClient.builder(this.chatModel, this.observationRegistry, this.observationConvention).defaultTemplateRenderer(this.templateRenderer).defaultToolCallbacks(this.toolCallbacks).defaultToolContext(this.toolContext).defaultToolNames(StringUtils.toStringArray(this.toolNames)); if (StringUtils.hasText(this.userText)) { builder.defaultUser((u) -> u.text(this.userText).params(this.userParams).media((Media[])this.media.toArray(new Media[0]))); } if (StringUtils.hasText(this.systemText)) { builder.defaultSystem((s) -> s.text(this.systemText).params(this.systemParams)); } if (this.chatOptions != null) { builder.defaultOptions(this.chatOptions); } builder.addMessages(this.messages); return builder; } public ChatClient.ChatClientRequestSpec advisors(Consumer<ChatClient.AdvisorSpec> consumer) { Assert.notNull(consumer, \"consumer cannot be null\"); DefaultAdvisorSpec advisorSpec = new DefaultAdvisorSpec(); consumer.accept(advisorSpec); this.advisorParams.putAll(advisorSpec.getParams()); this.advisors.addAll(advisorSpec.getAdvisors()); return this; } public ChatClient.ChatClientRequestSpec advisors(Advisor... advisors) { Assert.notNull(advisors, \"advisors cannot be null\"); Assert.noNullElements(advisors, \"advisors cannot contain null elements\"); this.advisors.addAll(Arrays.asList(advisors)); return this; } public ChatClient.ChatClientRequestSpec advisors(List<Advisor> advisors) { Assert.notNull(advisors, \"advisors cannot be null\"); Assert.noNullElements(advisors, \"advisors cannot contain null elements\"); this.advisors.addAll(advisors); return this; } public ChatClient.ChatClientRequestSpec messages(Message... messages) { Assert.notNull(messages, \"messages cannot be null\"); Assert.noNullElements(messages, \"messages cannot contain null elements\"); this.messages.addAll(List.of(messages)); return this; } public ChatClient.ChatClientRequestSpec messages(List<Message> messages) { Assert.notNull(messages, \"messages cannot be null\"); Assert.noNullElements(messages, \"messages cannot contain null elements\"); this.messages.addAll(messages); return this; } public <T extends ChatOptions> ChatClient.ChatClientRequestSpec options(T options) { Assert.notNull(options, \"options cannot be null\"); this.chatOptions = options; return this; } public ChatClient.ChatClientRequestSpec toolNames(String... toolNames) { Assert.notNull(toolNames, \"toolNames cannot be null\"); Assert.noNullElements(toolNames, \"toolNames cannot contain null elements\"); this.toolNames.addAll(List.of(toolNames)); return this; } public ChatClient.ChatClientRequestSpec toolCallbacks(ToolCallback... toolCallbacks) { Assert.notNull(toolCallbacks, \"toolCallbacks cannot be null\"); Assert.noNullElements(toolCallbacks, \"toolCallbacks cannot contain null elements\"); this.toolCallbacks.addAll(List.of(toolCallbacks)); return this; } public ChatClient.ChatClientRequestSpec toolCallbacks(List<ToolCallback> toolCallbacks) { Assert.notNull(toolCallbacks, \"toolCallbacks cannot be null\"); Assert.noNullElements(toolCallbacks, \"toolCallbacks cannot contain null elements\"); this.toolCallbacks.addAll(toolCallbacks); return this; } public ChatClient.ChatClientRequestSpec tools(Object... toolObjects) { Assert.notNull(toolObjects, \"toolObjects cannot be null\"); Assert.noNullElements(toolObjects, \"toolObjects cannot contain null elements\"); this.toolCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects))); return this; } public ChatClient.ChatClientRequestSpec toolCallbacks(ToolCallbackProvider... toolCallbackProviders) { Assert.notNull(toolCallbackProviders, \"toolCallbackProviders cannot be null\"); Assert.noNullElements(toolCallbackProviders, \"toolCallbackProviders cannot contain null elements\"); for(ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) { this.toolCallbacks.addAll(List.of(toolCallbackProvider.getToolCallbacks())); } return this; } public ChatClient.ChatClientRequestSpec toolContext(Map<String, Object> toolContext) { Assert.notNull(toolContext, \"toolContext cannot be null\"); Assert.noNullElements(toolContext.keySet(), \"toolContext keys cannot contain null elements\"); Assert.noNullElements(toolContext.values(), \"toolContext values cannot contain null elements\"); this.toolContext.putAll(toolContext); return this; } public ChatClient.ChatClientRequestSpec system(String text) { Assert.hasText(text, \"text cannot be null or empty\"); this.systemText = text; return this; } public ChatClient.ChatClientRequestSpec system(Resource text, Charset charset) { Assert.notNull(text, \"text cannot be null\"); Assert.notNull(charset, \"charset cannot be null\"); try { this.systemText = text.getContentAsString(charset); return this; } catch (IOException e) { throw new RuntimeException(e); } } public ChatClient.ChatClientRequestSpec system(Resource text) { Assert.notNull(text, \"text cannot be null\"); return this.system(text, Charset.defaultCharset()); } public ChatClient.ChatClientRequestSpec system(Consumer<ChatClient.PromptSystemSpec> consumer) { Assert.notNull(consumer, \"consumer cannot be null\"); DefaultPromptSystemSpec systemSpec = new DefaultPromptSystemSpec(); consumer.accept(systemSpec); this.systemText = StringUtils.hasText(systemSpec.text()) ? systemSpec.text() : this.systemText; this.systemParams.putAll(systemSpec.params()); return this; } public ChatClient.ChatClientRequestSpec user(String text) { Assert.hasText(text, \"text cannot be null or empty\"); this.userText = text; return this; } public ChatClient.ChatClientRequestSpec user(Resource text, Charset charset) { Assert.notNull(text, \"text cannot be null\"); Assert.notNull(charset, \"charset cannot be null\"); try { this.userText = text.getContentAsString(charset); return this; } catch (IOException e) { throw new RuntimeException(e); } } public ChatClient.ChatClientRequestSpec user(Resource text) { Assert.notNull(text, \"text cannot be null\"); return this.user(text, Charset.defaultCharset()); } public ChatClient.ChatClientRequestSpec user(Consumer<ChatClient.PromptUserSpec> consumer) { Assert.notNull(consumer, \"consumer cannot be null\"); DefaultPromptUserSpec us = new DefaultPromptUserSpec(); consumer.accept(us); this.userText = StringUtils.hasText(us.text()) ? us.text() : this.userText; this.userParams.putAll(us.params()); this.media.addAll(us.media()); return this; } public ChatClient.ChatClientRequestSpec templateRenderer(TemplateRenderer templateRenderer) { Assert.notNull(templateRenderer, \"templateRenderer cannot be null\"); this.templateRenderer = templateRenderer; return this; } public ChatClient.CallResponseSpec call() { BaseAdvisorChain advisorChain = this.buildAdvisorChain(); return new DefaultCallResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain, this.observationRegistry, this.observationConvention); } public ChatClient.StreamResponseSpec stream() { BaseAdvisorChain advisorChain = this.buildAdvisorChain(); return new DefaultStreamResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain, this.observationRegistry, this.observationConvention); } private BaseAdvisorChain buildAdvisorChain() { this.advisors.add(ChatModelCallAdvisor.builder().chatModel(this.chatModel).build()); this.advisors.add(ChatModelStreamAdvisor.builder().chatModel(this.chatModel).build()); return DefaultAroundAdvisorChain.builder(this.observationRegistry).pushAll(this.advisors).templateRenderer(this.templateRenderer).build(); } }}
DefaultChatClientUtils
类作用:用来将 DefaultChatClient.DefaultChatClientRequestSpec 转换为 ChatClientRequest
- 处理系统提示
- 处理用户提示
- 处理工具调用选项
package org.springframework.ai.chat.client;import java.util.ArrayList;import java.util.HashSet;import java.util.List;import java.util.Map;import java.util.Set;import java.util.concurrent.ConcurrentHashMap;import org.springframework.ai.chat.messages.Message;import org.springframework.ai.chat.messages.SystemMessage;import org.springframework.ai.chat.messages.UserMessage;import org.springframework.ai.chat.prompt.ChatOptions;import org.springframework.ai.chat.prompt.Prompt;import org.springframework.ai.chat.prompt.PromptTemplate;import org.springframework.ai.model.tool.ToolCallingChatOptions;import org.springframework.ai.tool.ToolCallback;import org.springframework.util.Assert;import org.springframework.util.CollectionUtils;import org.springframework.util.StringUtils;final class DefaultChatClientUtils { private DefaultChatClientUtils() { } static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClientRequestSpec inputRequest) { Assert.notNull(inputRequest, \"inputRequest cannot be null\"); List<Message> processedMessages = new ArrayList(); String processedSystemText = inputRequest.getSystemText(); if (StringUtils.hasText(processedSystemText)) { if (!CollectionUtils.isEmpty(inputRequest.getSystemParams())) { processedSystemText = PromptTemplate.builder().template(processedSystemText).variables(inputRequest.getSystemParams()).renderer(inputRequest.getTemplateRenderer()).build().render(); } processedMessages.add(new SystemMessage(processedSystemText)); } if (!CollectionUtils.isEmpty(inputRequest.getMessages())) { processedMessages.addAll(inputRequest.getMessages()); } String processedUserText = inputRequest.getUserText(); if (StringUtils.hasText(processedUserText)) { if (!CollectionUtils.isEmpty(inputRequest.getUserParams())) { processedUserText = PromptTemplate.builder().template(processedUserText).variables(inputRequest.getUserParams()).renderer(inputRequest.getTemplateRenderer()).build().render(); } processedMessages.add(UserMessage.builder().text(processedUserText).media(inputRequest.getMedia()).build()); } ChatOptions processedChatOptions = inputRequest.getChatOptions(); if (processedChatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) { if (!inputRequest.getToolNames().isEmpty()) { Set<String> toolNames = ToolCallingChatOptions.mergeToolNames(new HashSet(inputRequest.getToolNames()), toolCallingChatOptions.getToolNames()); toolCallingChatOptions.setToolNames(toolNames); } if (!inputRequest.getToolCallbacks().isEmpty()) { List<ToolCallback> toolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(inputRequest.getToolCallbacks(), toolCallingChatOptions.getToolCallbacks()); ToolCallingChatOptions.validateToolCallbacks(toolCallbacks); toolCallingChatOptions.setToolCallbacks(toolCallbacks); } if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) { Map<String, Object> toolContext = ToolCallingChatOptions.mergeToolContext(inputRequest.getToolContext(), toolCallingChatOptions.getToolContext()); toolCallingChatOptions.setToolContext(toolContext); } } return ChatClientRequest.builder().prompt(Prompt.builder().messages(processedMessages).chatOptions(processedChatOptions).build()).context(new ConcurrentHashMap(inputRequest.getAdvisorParams())).build(); }}
AdvisorChain
AdvisorChain 链调用一系列的增强器 Advisor 基础,每个增强器输入是 ChatClientRequest,输出 ChatClientResponse(其中必定会用到的是 ChatModelCallAdvisor 或 ChatModelStreamAdvisor)
- ChatModelCallAdvisor 触发 ChatModel 的 call 方法
- ChatModelStreamAdvisor 触发 ChatModel 的 stream 方法
ChatModel
package org.springframework.ai.chat.model;import java.util.Arrays;import org.springframework.ai.chat.messages.Message;import org.springframework.ai.chat.messages.UserMessage;import org.springframework.ai.chat.prompt.ChatOptions;import org.springframework.ai.chat.prompt.Prompt;import org.springframework.ai.model.Model;import reactor.core.publisher.Flux;public interface ChatModel extends Model<Prompt, ChatResponse>, StreamingChatModel { default String call(String message) { Prompt prompt = new Prompt(new UserMessage(message)); Generation generation = this.call(prompt).getResult(); return generation != null ? generation.getOutput().getText() : \"\"; } default String call(Message... messages) { Prompt prompt = new Prompt(Arrays.asList(messages)); Generation generation = this.call(prompt).getResult(); return generation != null ? generation.getOutput().getText() : \"\"; } ChatResponse call(Prompt prompt); default ChatOptions getDefaultOptions() { return ChatOptions.builder().build(); } default Flux<ChatResponse> stream(Prompt prompt) { throw new UnsupportedOperationException(\"streaming is not supported\"); }}
不同厂商实现各种的 ChaModel,但实现逻辑基本以 OpenAI 作为官方实现
pom 引入对应依赖
<dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-autoconfigure-model-openai</artifactId></dependency>
OpenAiChatModel
各字段说明
对外暴露的方法
2. 创建观测上下文
3. 执行带观测的模型调用
4. 执行OpenAI接口调用
5. 解析模型返回的choices
6. 将每个choice转换为Generation对象,构建完整的
7. 提取限流信息(RateLimit)
8. 计算token使用量
9. 构建最终的ChatResponse并设置上下文
10. 工具调用处理
2. 构建OpenAI流式请求对象
3. 发起流式 API 调用,获取 chunk 数据
4. 创建角色映射表,解决 chunk 中 role 缺失问题
5. 创建观测上下文
6. 启动观测操作
7. 将 chunk 转换为 ChatCompletion 标准格式
8. 转换为 ChatResponse 并构建生成内
9. 处理 usage 字段(仅最终 chunk 包含完整 usage)
10. 工具调用处理
11. 聚合消息流并设置响应
package org.springframework.ai.openai;import io.micrometer.observation.Observation;import io.micrometer.observation.ObservationRegistry;import java.util.ArrayList;import java.util.Base64;import java.util.Collection;import java.util.HashMap;import java.util.List;import java.util.Map;import java.util.Objects;import java.util.concurrent.ConcurrentHashMap;import java.util.stream.Collectors;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.ai.chat.messages.AssistantMessage;import org.springframework.ai.chat.messages.MessageType;import org.springframework.ai.chat.messages.ToolResponseMessage;import org.springframework.ai.chat.messages.UserMessage;import org.springframework.ai.chat.metadata.ChatGenerationMetadata;import org.springframework.ai.chat.metadata.ChatResponseMetadata;import org.springframework.ai.chat.metadata.DefaultUsage;import org.springframework.ai.chat.metadata.EmptyUsage;import org.springframework.ai.chat.metadata.RateLimit;import org.springframework.ai.chat.metadata.Usage;import org.springframework.ai.chat.model.ChatModel;import org.springframework.ai.chat.model.ChatResponse;import org.springframework.ai.chat.model.Generation;import org.springframework.ai.chat.model.MessageAggregator;import org.springframework.ai.chat.observation.ChatModelObservationContext;import org.springframework.ai.chat.observation.ChatModelObservationConvention;import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;import org.springframework.ai.chat.prompt.ChatOptions;import org.springframework.ai.chat.prompt.Prompt;import org.springframework.ai.content.Media;import org.springframework.ai.model.ModelOptionsUtils;import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;import org.springframework.ai.model.tool.ToolCallingChatOptions;import org.springframework.ai.model.tool.ToolCallingManager;import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;import org.springframework.ai.model.tool.ToolExecutionResult;import org.springframework.ai.openai.api.OpenAiApi;import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role;import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.MediaContent.InputAudio.Format;import org.springframework.ai.openai.api.common.OpenAiApiConstants;import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;import org.springframework.ai.retry.RetryUtils;import org.springframework.ai.support.UsageCalculator;import org.springframework.ai.tool.definition.ToolDefinition;import org.springframework.core.io.ByteArrayResource;import org.springframework.core.io.Resource;import org.springframework.http.ResponseEntity;import org.springframework.retry.support.RetryTemplate;import org.springframework.util.Assert;import org.springframework.util.CollectionUtils;import org.springframework.util.MimeType;import org.springframework.util.MimeTypeUtils;import org.springframework.util.MultiValueMap;import org.springframework.util.StringUtils;import reactor.core.publisher.Flux;import reactor.core.publisher.Mono;import reactor.core.scheduler.Schedulers;public class OpenAiChatModel implements ChatModel { private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModel.class); private static final ChatModelObservationConvention DEFAULTOBSERVATIONCONVENTION = new DefaultChatModelObservationConvention(); private static final ToolCallingManager DEFAULTTOOLCALLINGMANAGER = ToolCallingManager.builder().build(); private final OpenAiChatOptions defaultOptions; private final RetryTemplate retryTemplate; private final OpenAiApi openAiApi; private final ObservationRegistry observationRegistry; private final ToolCallingManager toolCallingManager; private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; private ChatModelObservationConvention observationConvention; public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { this(openAiApi, defaultOptions, toolCallingManager, retryTemplate, observationRegistry, new DefaultToolExecutionEligibilityPredicate()); } public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry, ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { this.observationConvention = DEFAULTOBSERVATIONCONVENTION; Assert.notNull(openAiApi, \"openAiApi cannot be null\"); Assert.notNull(defaultOptions, \"defaultOptions cannot be null\"); Assert.notNull(toolCallingManager, \"toolCallingManager cannot be null\"); Assert.notNull(retryTemplate, \"retryTemplate cannot be null\"); Assert.notNull(observationRegistry, \"observationRegistry cannot be null\"); Assert.notNull(toolExecutionEligibilityPredicate, \"toolExecutionEligibilityPredicate cannot be null\"); this.openAiApi = openAiApi; this.defaultOptions = defaultOptions; this.toolCallingManager = toolCallingManager; this.retryTemplate = retryTemplate; this.observationRegistry = observationRegistry; this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; } public ChatResponse call(Prompt prompt) { Prompt requestPrompt = this.buildRequestPrompt(prompt); return this.internalCall(requestPrompt, (ChatResponse)null); } public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { OpenAiApi.ChatCompletionRequest request = this.createRequest(prompt, false); ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(OpenAiApiConstants.PROVIDERNAME).build(); ChatResponse response = (ChatResponse)ChatModelObservationDocumentation.CHATMODELOPERATION.observation(this.observationConvention, DEFAULTOBSERVATIONCONVENTION, () -> observationContext, this.observationRegistry).observe(() -> { ResponseEntity<OpenAiApi.ChatCompletion> completionEntity = (ResponseEntity)this.retryTemplate.execute((ctx) -> this.openAiApi.chatCompletionEntity(request, this.getAdditionalHttpHeaders(prompt))); OpenAiApi.ChatCompletion chatCompletion = (OpenAiApi.ChatCompletion)completionEntity.getBody(); if (chatCompletion == null) { logger.warn(\"No chat completion returned for prompt: {}\", prompt); return new ChatResponse(List.of()); } else { List<OpenAiApi.ChatCompletion.Choice> choices = chatCompletion.choices(); if (choices == null) { logger.warn(\"No choices returned for prompt: {}\", prompt); return new ChatResponse(List.of()); } else { List<Generation> generations = choices.stream().map((choice) -> { Map<String, Object> metadata = Map.of(\"id\", chatCompletion.id() != null ? chatCompletion.id() : \"\", \"role\", choice.message().role() != null ? choice.message().role().name() : \"\", \"index\", choice.index(), \"finishReason\", choice.finishReason() != null ? choice.finishReason().name() : \"\", \"refusal\", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : \"\", \"annotations\", choice.message().annotations() != null ? choice.message().annotations() : List.of()); return this.buildGeneration(choice, metadata, request); }).toList(); RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); OpenAiApi.Usage usage = chatCompletion.usage(); Usage currentChatResponseUsage = (Usage)(usage != null ? this.getDefaultUsage(usage) : new EmptyUsage()); Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); ChatResponse chatResponse = new ChatResponse(generations, this.from(chatCompletion, rateLimit, accumulatedUsage)); observationContext.setResponse(chatResponse); return chatResponse; } } }); if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { ToolExecutionResult toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); return toolExecutionResult.returnDirect() ? ChatResponse.builder().from(response).generations(ToolExecutionResult.buildGenerations(toolExecutionResult)).build() : this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response); } else { return response; } } public Flux<ChatResponse> stream(Prompt prompt) { Prompt requestPrompt = this.buildRequestPrompt(prompt); return this.internalStream(requestPrompt, (ChatResponse)null); } public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) { return Flux.deferContextual((contextView) -> { OpenAiApi.ChatCompletionRequest request = this.createRequest(prompt, true); if (request.outputModalities() != null && request.outputModalities().stream().anyMatch((m) -> m.equals(\"audio\"))) { logger.warn(\"Audio output is not supported for streaming requests. Removing audio output.\"); throw new IllegalArgumentException(\"Audio output is not supported for streaming requests.\"); } else if (request.audioParameters() != null) { logger.warn(\"Audio parameters are not supported for streaming requests. Removing audio parameters.\"); throw new IllegalArgumentException(\"Audio parameters are not supported for streaming requests.\"); } else { Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request, this.getAdditionalHttpHeaders(prompt)); ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap(); ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(OpenAiApiConstants.PROVIDERNAME).build(); Observation observation = ChatModelObservationDocumentation.CHATMODELOPERATION.observation(this.observationConvention, DEFAULTOBSERVATIONCONVENTION, () -> observationContext, this.observationRegistry); observation.parentObservation((Observation)contextView.getOrDefault(\"micrometer.observation\", (Object)null)).start(); Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion).switchMap((chatCompletion) -> Mono.just(chatCompletion).map((chatCompletion2) -> { try { String id = chatCompletion2.id() == null ? \"NOID\" : chatCompletion2.id(); List<Generation> generations = chatCompletion2.choices().stream().map((choice) -> { if (choice.message().role() != null) { roleMap.putIfAbsent(id, choice.message().role().name()); } Map<String, Object> metadata = Map.of(\"id\", id, \"role\", roleMap.getOrDefault(id, \"\"), \"index\", choice.index(), \"finishReason\", choice.finishReason() != null ? choice.finishReason().name() : \"\", \"refusal\", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : \"\", \"annotations\", choice.message().annotations() != null ? choice.message().annotations() : List.of()); return this.buildGeneration(choice, metadata, request); }).toList(); OpenAiApi.Usage usage = chatCompletion2.usage(); Usage currentChatResponseUsage = (Usage)(usage != null ? this.getDefaultUsage(usage) : new EmptyUsage()); Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); return new ChatResponse(generations, this.from(chatCompletion2, (RateLimit)null, accumulatedUsage)); } catch (Exception e) { logger.error(\"Error processing chat completion\", e); return new ChatResponse(List.of()); } })).buffer(2, 1).map((bufferList) -> { ChatResponse firstResponse = (ChatResponse)bufferList.get(0); if (request.streamOptions() != null && request.streamOptions().includeUsage() && bufferList.size() == 2) { ChatResponse secondResponse = (ChatResponse)bufferList.get(1); if (secondResponse != null && secondResponse.getMetadata() != null) { Usage usage = secondResponse.getMetadata().getUsage(); if (!UsageCalculator.isEmpty(usage)) { return new ChatResponse(firstResponse.getResults(), this.from(firstResponse.getMetadata(), usage)); } } } return firstResponse; }); Flux var10000 = chatResponse.flatMap((response) -> this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response) ? Flux.defer(() -> { ToolExecutionResult toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); return toolExecutionResult.returnDirect() ? Flux.just(ChatResponse.builder().from(response).generations(ToolExecutionResult.buildGenerations(toolExecutionResult)).build()) : this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response); }).subscribeOn(Schedulers.boundedElastic()) : Flux.just(response)); Objects.requireNonNull(observation); Flux<ChatResponse> flux = var10000.doOnError(observation::error).doFinally((s) -> observation.stop()).contextWrite((ctx) -> ctx.put(\"micrometer.observation\", observation)); MessageAggregator var11 = new MessageAggregator(); Objects.requireNonNull(observationContext); return var11.aggregate(flux, observationContext::setResponse); } }); } private MultiValueMap<String, String> getAdditionalHttpHeaders(Prompt prompt) { Map<String, String> headers = new HashMap(this.defaultOptions.getHttpHeaders()); if (prompt.getOptions() != null) { ChatOptions var4 = prompt.getOptions(); if (var4 instanceof OpenAiChatOptions) { OpenAiChatOptions chatOptions = (OpenAiChatOptions)var4; headers.putAll(chatOptions.getHttpHeaders()); } } return CollectionUtils.toMultiValueMap((Map)headers.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, (e) -> List.of((String)e.getValue())))); } private Generation buildGeneration(OpenAiApi.ChatCompletion.Choice choice, Map<String, Object> metadata, OpenAiApi.ChatCompletionRequest request) { List<AssistantMessage.ToolCall> toolCalls = choice.message().toolCalls() == null ? List.of() : choice.message().toolCalls().stream().map((toolCall) -> new AssistantMessage.ToolCall(toolCall.id(), \"function\", toolCall.function().name(), toolCall.function().arguments())).toList(); String finishReason = choice.finishReason() != null ? choice.finishReason().name() : \"\"; ChatGenerationMetadata.Builder generationMetadataBuilder = ChatGenerationMetadata.builder().finishReason(finishReason); List<Media> media = new ArrayList(); String textContent = choice.message().content(); OpenAiApi.ChatCompletionMessage.AudioOutput audioOutput = choice.message().audioOutput(); if (audioOutput != null) { String mimeType = String.format(\"audio/%s\", request.audioParameters().format().name().toLowerCase()); byte[] audioData = Base64.getDecoder().decode(audioOutput.data()); Resource resource = new ByteArrayResource(audioData); Media.builder().mimeType(MimeTypeUtils.parseMimeType(mimeType)).data(resource).id(audioOutput.id()).build(); media.add(Media.builder().mimeType(MimeTypeUtils.parseMimeType(mimeType)).data(resource).id(audioOutput.id()).build()); if (!StringUtils.hasText(textContent)) { textContent = audioOutput.transcript(); } generationMetadataBuilder.metadata(\"audioId\", audioOutput.id()); generationMetadataBuilder.metadata(\"audioExpiresAt\", audioOutput.expiresAt()); } if (Boolean.TRUE.equals(request.logprobs())) { generationMetadataBuilder.metadata(\"logprobs\", choice.logprobs()); } AssistantMessage assistantMessage = new AssistantMessage(textContent, metadata, toolCalls, media); return new Generation(assistantMessage, generationMetadataBuilder.build()); } private ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rateLimit, Usage usage) { Assert.notNull(result, \"OpenAI ChatCompletionResult must not be null\"); ChatResponseMetadata.Builder builder = ChatResponseMetadata.builder().id(result.id() != null ? result.id() : \"\").usage(usage).model(result.model() != null ? result.model() : \"\").keyValue(\"created\", result.created() != null ? result.created() : 0L).keyValue(\"system-fingerprint\", result.systemFingerprint() != null ? result.systemFingerprint() : \"\"); if (rateLimit != null) { builder.rateLimit(rateLimit); } return builder.build(); } private ChatResponseMetadata from(ChatResponseMetadata chatResponseMetadata, Usage usage) { Assert.notNull(chatResponseMetadata, \"OpenAI ChatResponseMetadata must not be null\"); ChatResponseMetadata.Builder builder = ChatResponseMetadata.builder().id(chatResponseMetadata.getId() != null ? chatResponseMetadata.getId() : \"\").usage(usage).model(chatResponseMetadata.getModel() != null ? chatResponseMetadata.getModel() : \"\"); if (chatResponseMetadata.getRateLimit() != null) { builder.rateLimit(chatResponseMetadata.getRateLimit()); } return builder.build(); } private OpenAiApi.ChatCompletion chunkToChatCompletion(OpenAiApi.ChatCompletionChunk chunk) { List<OpenAiApi.ChatCompletion.Choice> choices = chunk.choices().stream().map((chunkChoice) -> new OpenAiApi.ChatCompletion.Choice(chunkChoice.finishReason(), chunkChoice.index(), chunkChoice.delta(), chunkChoice.logprobs())).toList(); return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), chunk.serviceTier(), chunk.systemFingerprint(), \"chat.completion\", chunk.usage()); } private DefaultUsage getDefaultUsage(OpenAiApi.Usage usage) { return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); } Prompt buildRequestPrompt(Prompt prompt) { OpenAiChatOptions runtimeOptions = null; if (prompt.getOptions() != null) { ChatOptions var4 = prompt.getOptions(); if (var4 instanceof ToolCallingChatOptions) { ToolCallingChatOptions toolCallingChatOptions = (ToolCallingChatOptions)var4; runtimeOptions = (OpenAiChatOptions)ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, OpenAiChatOptions.class); } else { runtimeOptions = (OpenAiChatOptions)ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, OpenAiChatOptions.class); } } OpenAiChatOptions requestOptions = (OpenAiChatOptions)ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OpenAiChatOptions.class); if (runtimeOptions != null) { if (runtimeOptions.getTopK() != null) { logger.warn(\"The topK option is not supported by OpenAI chat models. Ignoring.\"); } requestOptions.setHttpHeaders(this.mergeHttpHeaders(runtimeOptions.getHttpHeaders(), this.defaultOptions.getHttpHeaders())); requestOptions.setInternalToolExecutionEnabled((Boolean)ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), this.defaultOptions.getToolCallbacks())); requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), this.defaultOptions.getToolContext())); } else { requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders()); requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); } ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); return new Prompt(prompt.getInstructions(), requestOptions); } private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHeaders, Map<String, String> defaultHttpHeaders) { HashMap<String, String> mergedHttpHeaders = new HashMap(defaultHttpHeaders); mergedHttpHeaders.putAll(runtimeHttpHeaders); return mergedHttpHeaders; } OpenAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { List<OpenAiApi.ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map((message) -> { if (message.getMessageType() != MessageType.USER && message.getMessageType() != MessageType.SYSTEM) { if (message.getMessageType() == MessageType.ASSISTANT) { AssistantMessage assistantMessage = (AssistantMessage)message; List<OpenAiApi.ChatCompletionMessage.ToolCall> toolCalls = null; if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { toolCalls = assistantMessage.getToolCalls().stream().map((toolCall) -> { OpenAiApi.ChatCompletionMessage.ChatCompletionFunction function = new OpenAiApi.ChatCompletionMessage.ChatCompletionFunction(toolCall.name(), toolCall.arguments()); return new OpenAiApi.ChatCompletionMessage.ToolCall(toolCall.id(), toolCall.type(), function); }).toList(); } OpenAiApi.ChatCompletionMessage.AudioOutput audioOutput = null; if (!CollectionUtils.isEmpty(assistantMessage.getMedia())) { Assert.isTrue(assistantMessage.getMedia().size() == 1, \"Only one media content is supported for assistant messages\"); audioOutput = new OpenAiApi.ChatCompletionMessage.AudioOutput(((Media)assistantMessage.getMedia().get(0)).getId(), (String)null, (Long)null, (String)null); } return List.of(new OpenAiApi.ChatCompletionMessage(assistantMessage.getText(), Role.ASSISTANT, (String)null, (String)null, toolCalls, (String)null, audioOutput, (List)null)); } else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage)message; toolMessage.getResponses().forEach((response) -> Assert.isTrue(response.id() != null, \"ToolResponseMessage must have an id\")); return toolMessage.getResponses().stream().map((tr) -> new OpenAiApi.ChatCompletionMessage(tr.responseData(), Role.TOOL, tr.name(), tr.id(), (List)null, (String)null, (OpenAiApi.ChatCompletionMessage.AudioOutput)null, (List)null)).toList(); } else { throw new IllegalArgumentException(\"Unsupported message type: \" + String.valueOf(message.getMessageType())); } } else { Object content = message.getText(); if (message instanceof UserMessage) { UserMessage userMessage = (UserMessage)message; if (!CollectionUtils.isEmpty(userMessage.getMedia())) { List<OpenAiApi.ChatCompletionMessage.MediaContent> contentList = new ArrayList(List.of(new OpenAiApi.ChatCompletionMessage.MediaContent(message.getText()))); contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList()); content = contentList; } } return List.of(new OpenAiApi.ChatCompletionMessage(content, Role.valueOf(message.getMessageType().name()))); } }).flatMap(Collection::stream).toList(); OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, stream); OpenAiChatOptions requestOptions = (OpenAiChatOptions)prompt.getOptions(); request = (OpenAiApi.ChatCompletionRequest)ModelOptionsUtils.merge(requestOptions, request, OpenAiApi.ChatCompletionRequest.class); List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); if (!CollectionUtils.isEmpty(toolDefinitions)) { request = (OpenAiApi.ChatCompletionRequest)ModelOptionsUtils.merge(OpenAiChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), request, OpenAiApi.ChatCompletionRequest.class); } if (request.streamOptions() != null && !stream) { logger.warn(\"Removing streamOptions from the request as it is not a streaming request!\"); request = request.streamOptions((OpenAiApi.ChatCompletionRequest.StreamOptions)null); } return request; } private OpenAiApi.ChatCompletionMessage.MediaContent mapToMediaContent(Media media) { MimeType mimeType = media.getMimeType(); if (MimeTypeUtils.parseMimeType(\"audio/mp3\").equals(mimeType)) { return new OpenAiApi.ChatCompletionMessage.MediaContent(new OpenAiApi.ChatCompletionMessage.MediaContent.InputAudio(this.fromAudioData(media.getData()), Format.MP3)); } else { return MimeTypeUtils.parseMimeType(\"audio/wav\").equals(mimeType) ? new OpenAiApi.ChatCompletionMessage.MediaContent(new OpenAiApi.ChatCompletionMessage.MediaContent.InputAudio(this.fromAudioData(media.getData()), Format.WAV)) : new OpenAiApi.ChatCompletionMessage.MediaContent(new OpenAiApi.ChatCompletionMessage.MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData()))); } } private String fromAudioData(Object audioData) { if (audioData instanceof byte[] bytes) { return Base64.getEncoder().encodeToString(bytes); } else { throw new IllegalArgumentException(\"Unsupported audio data type: \" + audioData.getClass().getSimpleName()); } } private String fromMediaData(MimeType mimeType, Object mediaContentData) { if (mediaContentData instanceof byte[] bytes) { return String.format(\"data:%s;base64,%s\", mimeType.toString(), Base64.getEncoder().encodeToString(bytes)); } else if (mediaContentData instanceof String text) { return text; } else { throw new IllegalArgumentException(\"Unsupported media data type: \" + mediaContentData.getClass().getSimpleName()); } } private List<OpenAiApi.FunctionTool> getFunctionTools(List<ToolDefinition> toolDefinitions) { return toolDefinitions.stream().map((toolDefinition) -> { OpenAiApi.FunctionTool.Function function = new OpenAiApi.FunctionTool.Function(toolDefinition.description(), toolDefinition.name(), toolDefinition.inputSchema()); return new OpenAiApi.FunctionTool(function); }).toList(); } public ChatOptions getDefaultOptions() { return OpenAiChatOptions.fromOptions(this.defaultOptions); } public String toString() { return \"OpenAiChatModel [defaultOptions=\" + String.valueOf(this.defaultOptions) + \"]\"; } public void setObservationConvention(ChatModelObservationConvention observationConvention) { Assert.notNull(observationConvention, \"observationConvention cannot be null\"); this.observationConvention = observationConvention; } public static Builder builder() { return new Builder(); } public Builder mutate() { return new Builder(this); } public OpenAiChatModel clone() { return this.mutate().build(); } public static final class Builder { private OpenAiApi openAiApi; private OpenAiChatOptions defaultOptions; private ToolCallingManager toolCallingManager; private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; private RetryTemplate retryTemplate; private ObservationRegistry observationRegistry; public Builder(OpenAiChatModel model) { this.defaultOptions = OpenAiChatOptions.builder().model(OpenAiApi.DEFAULTCHATMODEL).temperature(0.7).build(); this.toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); this.retryTemplate = RetryUtils.DEFAULTRETRYTEMPLATE; this.observationRegistry = ObservationRegistry.NOOP; this.openAiApi = model.openAiApi; this.defaultOptions = model.defaultOptions; this.toolCallingManager = model.toolCallingManager; this.toolExecutionEligibilityPredicate = model.toolExecutionEligibilityPredicate; this.retryTemplate = model.retryTemplate; this.observationRegistry = model.observationRegistry; } private Builder() { this.defaultOptions = OpenAiChatOptions.builder().model(OpenAiApi.DEFAULTCHATMODEL).temperature(0.7).build(); this.toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); this.retryTemplate = RetryUtils.DEFAULTRETRYTEMPLATE; this.observationRegistry = ObservationRegistry.NOOP; } public Builder openAiApi(OpenAiApi openAiApi) { this.openAiApi = openAiApi; return this; } public Builder defaultOptions(OpenAiChatOptions defaultOptions) { this.defaultOptions = defaultOptions; return this; } public Builder toolCallingManager(ToolCallingManager toolCallingManager) { this.toolCallingManager = toolCallingManager; return this; } public Builder toolExecutionEligibilityPredicate(ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; return this; } public Builder retryTemplate(RetryTemplate retryTemplate) { this.retryTemplate = retryTemplate; return this; } public Builder observationRegistry(ObservationRegistry observationRegistry) { this.observationRegistry = observationRegistry; return this; } public OpenAiChatModel build() { return this.toolCallingManager != null ? new OpenAiChatModel(this.openAiApi, this.defaultOptions, this.toolCallingManager, this.retryTemplate, this.observationRegistry, this.toolExecutionEligibilityPredicate) : new OpenAiChatModel(this.openAiApi, this.defaultOptions, OpenAiChatModel.DEFAULTTOOLCALLINGMANAGER, this.retryTemplate, this.observationRegistry, this.toolExecutionEligibilityPredicate); } }}
OpenAiApi
各字段说明
对外暴露的方法
内部枚举类说明
完整代码如下
package org.springframework.ai.openai.api;import com.fasterxml.jackson.annotation.JsonFormat;import com.fasterxml.jackson.annotation.JsonIgnore;import com.fasterxml.jackson.annotation.JsonIgnoreProperties;import com.fasterxml.jackson.annotation.JsonInclude;import com.fasterxml.jackson.annotation.JsonProperty;import com.fasterxml.jackson.annotation.JsonFormat.Feature;import com.fasterxml.jackson.annotation.JsonInclude.Include;import java.util.List;import java.util.Map;import java.util.concurrent.atomic.AtomicBoolean;import java.util.function.Consumer;import java.util.function.Predicate;import org.springframework.ai.model.ApiKey;import org.springframework.ai.model.ChatModelDescription;import org.springframework.ai.model.ModelOptionsUtils;import org.springframework.ai.model.NoopApiKey;import org.springframework.ai.model.SimpleApiKey;import org.springframework.ai.retry.RetryUtils;import org.springframework.core.ParameterizedTypeReference;import org.springframework.http.HttpHeaders;import org.springframework.http.MediaType;import org.springframework.http.ResponseEntity;import org.springframework.util.Assert;import org.springframework.util.CollectionUtils;import org.springframework.util.LinkedMultiValueMap;import org.springframework.util.MultiValueMap;import org.springframework.web.client.ResponseErrorHandler;import org.springframework.web.client.RestClient;import org.springframework.web.reactive.function.client.WebClient;import reactor.core.publisher.Flux;import reactor.core.publisher.Mono;public class OpenAiApi { public static final ChatModel DEFAULTCHATMODEL; public static final String DEFAULTEMBEDDINGMODEL; private static final Predicate<String> SSEDONEPREDICATE; private final String baseUrl; private final ApiKey apiKey; private final MultiValueMap<String, String> headers; private final String completionsPath; private final String embeddingsPath; private final ResponseErrorHandler responseErrorHandler; private final RestClient restClient; private final WebClient webClient; private OpenAiStreamFunctionCallingHelper chunkMerger = new OpenAiStreamFunctionCallingHelper(); public Builder mutate() { return new Builder(this); } public static Builder builder() { return new Builder(); } public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers, String completionsPath, String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { this.baseUrl = baseUrl; this.apiKey = apiKey; this.headers = headers; this.completionsPath = completionsPath; this.embeddingsPath = embeddingsPath; this.responseErrorHandler = responseErrorHandler; Assert.hasText(completionsPath, \"Completions Path must not be null\"); Assert.hasText(embeddingsPath, \"Embeddings Path must not be null\"); Assert.notNull(headers, \"Headers must not be null\"); Consumer<HttpHeaders> finalHeaders = (h) -> { if (!(apiKey instanceof NoopApiKey)) { h.setBearerAuth(apiKey.getValue()); } h.setContentType(MediaType.APPLICATIONJSON); h.addAll(headers); }; this.restClient = restClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(finalHeaders).defaultStatusHandler(responseErrorHandler).build(); this.webClient = webClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(finalHeaders).build(); } public static String getTextContent(List<ChatCompletionMessage.MediaContent> content) { return (String)content.stream().filter((c) -> \"text\".equals(c.type())).map(ChatCompletionMessage.MediaContent::text).reduce(\"\", (a, b) -> a + b); } public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest chatRequest) { return this.chatCompletionEntity(chatRequest, new LinkedMultiValueMap()); } public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest chatRequest, MultiValueMap<String, String> additionalHttpHeader) { Assert.notNull(chatRequest, \"The request body can not be null.\"); Assert.isTrue(!chatRequest.stream(), \"Request must set the stream property to false.\"); Assert.notNull(additionalHttpHeader, \"The additional HTTP headers can not be null.\"); return ((RestClient.RequestBodySpec)((RestClient.RequestBodySpec)this.restClient.post().uri(this.completionsPath, new Object[0])).headers((headers) -> headers.addAll(additionalHttpHeader))).body(chatRequest).retrieve().toEntity(ChatCompletion.class); } public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chatRequest) { return this.chatCompletionStream(chatRequest, new LinkedMultiValueMap()); } public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chatRequest, MultiValueMap<String, String> additionalHttpHeader) { Assert.notNull(chatRequest, \"The request body can not be null.\"); Assert.isTrue(chatRequest.stream(), \"Request must set the stream property to true.\"); AtomicBoolean isInsideTool = new AtomicBoolean(false); return ((WebClient.RequestBodySpec)((WebClient.RequestBodySpec)this.webClient.post().uri(this.completionsPath, new Object[0])).headers((headers) -> headers.addAll(additionalHttpHeader))).body(Mono.just(chatRequest), ChatCompletionRequest.class).retrieve().bodyToFlux(String.class).takeUntil(SSEDONEPREDICATE).filter(SSEDONEPREDICATE.negate()).map((content) -> (ChatCompletionChunk)ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)).map((chunk) -> { if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { isInsideTool.set(true); } return chunk; }).windowUntil((chunk) -> { if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { isInsideTool.set(false); return true; } else { return !isInsideTool.get(); } }).concatMapIterable((window) -> { Mono<ChatCompletionChunk> monoChunk = window.reduce(new ChatCompletionChunk((String)null, (List)null, (Long)null, (String)null, (String)null, (String)null, (String)null, (Usage)null), (previous, current) -> this.chunkMerger.merge(previous, current)); return List.of(monoChunk); }).flatMap((mono) -> mono); } public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<T> embeddingRequest) { Assert.notNull(embeddingRequest, \"The request body can not be null.\"); Assert.notNull(embeddingRequest.input(), \"The input can not be null.\"); Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, \"The input must be either a String, or a List of Strings or List of List of integers.\"); Object var3 = embeddingRequest.input(); if (var3 instanceof List list) { Assert.isTrue(!CollectionUtils.isEmpty(list), \"The input list can not be empty.\"); Assert.isTrue(list.size() <= 2048, \"The list must be 2048 dimensions or less\"); Assert.isTrue(list.get(0) instanceof String || list.get(0) instanceof Integer || list.get(0) instanceof List, \"The input must be either a String, or a List of Strings or list of list of integers.\"); } return ((RestClient.RequestBodySpec)this.restClient.post().uri(this.embeddingsPath, new Object[0])).body(embeddingRequest).retrieve().toEntity(new ParameterizedTypeReference<EmbeddingList<Embedding>>() { }); } String getBaseUrl() { return this.baseUrl; } ApiKey getApiKey() { return this.apiKey; } MultiValueMap<String, String> getHeaders() { return this.headers; } String getCompletionsPath() { return this.completionsPath; } String getEmbeddingsPath() { return this.embeddingsPath; } ResponseErrorHandler getResponseErrorHandler() { return this.responseErrorHandler; } static { DEFAULTCHATMODEL = OpenAiApi.ChatModel.GPT4O; DEFAULTEMBEDDINGMODEL = OpenAiApi.EmbeddingModel.TEXTEMBEDDINGADA002.getValue(); SSEDONEPREDICATE = \"[DONE]\"::equals; } public static enum ChatModel implements ChatModelDescription { O4MINI(\"o4-mini\"), O3(\"o3\"), O3MINI(\"o3-mini\"), O1(\"o1\"), O1MINI(\"o1-mini\"), O1PRO(\"o1-pro\"), GPT41(\"gpt-4.1\"), GPT4O(\"gpt-4o\"), CHATGPT4OLATEST(\"chatgpt-4o-latest\"), GPT4OAUDIOPREVIEW(\"gpt-4o-audio-preview\"), GPT41MINI(\"gpt-4.1-mini\"), GPT41NANO(\"gpt-4.1-nano\"), GPT4OMINI(\"gpt-4o-mini\"), GPT4OMINIAUDIOPREVIEW(\"gpt-4o-mini-audio-preview\"), GPT4OREALTIMEPREVIEW(\"gpt-4o-realtime-preview\"), GPT4OMINIREALTIMEPREVIEW(\"gpt-4o-mini-realtime-preview\\n\"), GPT4TURBO(\"gpt-4-turbo\"), GPT4(\"gpt-4\"), GPT35TURBO(\"gpt-3.5-turbo\"), GPT35TURBOINSTRUCT(\"gpt-3.5-turbo-instruct\"), GPT4OSEARCHPREVIEW(\"gpt-4o-search-preview\"), GPT4OMINISEARCHPREVIEW(\"gpt-4o-mini-search-preview\"); public final String value; private ChatModel(String value) { this.value = value; } public String getValue() { return this.value; } public String getName() { return this.value; } } public static enum ChatCompletionFinishReason { @JsonProperty(\"stop\") STOP, @JsonProperty(\"length\") LENGTH, @JsonProperty(\"contentfilter\") CONTENTFILTER, @JsonProperty(\"toolcalls\") TOOLCALLS, @JsonProperty(\"toolcall\") TOOLCALL; } public static enum EmbeddingModel { TEXTEMBEDDING3LARGE(\"text-embedding-3-large\"), TEXTEMBEDDING3SMALL(\"text-embedding-3-small\"), TEXTEMBEDDINGADA002(\"text-embedding-ada-002\"); public final String value; private EmbeddingModel(String value) { this.value = value; } public String getValue() { return this.value; } } @JsonInclude(Include.NONNULL) public static class FunctionTool { @JsonProperty(\"type\") private Type type; @JsonProperty(\"function\") private Function function; public FunctionTool() { this.type = OpenAiApi.FunctionTool.Type.FUNCTION; } public FunctionTool(Type type, Function function) { this.type = OpenAiApi.FunctionTool.Type.FUNCTION; this.type = type; this.function = function; } public FunctionTool(Function function) { this(OpenAiApi.FunctionTool.Type.FUNCTION, function); } public Type getType() { return this.type; } public Function getFunction() { return this.function; } public void setType(Type type) { this.type = type; } public void setFunction(Function function) { this.function = function; } public static enum Type { @JsonProperty(\"function\") FUNCTION; } @JsonInclude(Include.NONNULL) public static class Function { @JsonProperty(\"description\") private String description; @JsonProperty(\"name\") private String name; @JsonProperty(\"parameters\") private Map<String, Object> parameters; @JsonProperty(\"strict\") Boolean strict; @JsonIgnore private String jsonSchema; private Function() { } public Function(String description, String name, Map<String, Object> parameters, Boolean strict) { this.description = description; this.name = name; this.parameters = parameters; this.strict = strict; } public Function(String description, String name, String jsonSchema) { this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema), (Boolean)null); } public String getDescription() { return this.description; } public String getName() { return this.name; } public Map<String, Object> getParameters() { return this.parameters; } public void setDescription(String description) { this.description = description; } public void setName(String name) { this.name = name; } public void setParameters(Map<String, Object> parameters) { this.parameters = parameters; } public Boolean getStrict() { return this.strict; } public void setStrict(Boolean strict) { this.strict = strict; } public String getJsonSchema() { return this.jsonSchema; } public void setJsonSchema(String jsonSchema) { this.jsonSchema = jsonSchema; if (jsonSchema != null) { this.parameters = ModelOptionsUtils.jsonToMap(jsonSchema); } } } } public static enum OutputModality { @JsonProperty(\"audio\") AUDIO, @JsonProperty(\"text\") TEXT; } @JsonInclude(Include.NONNULL) public static record ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Boolean store, Map<String, String> metadata, Double frequencyPenalty, Map<String, Integer> logitBias, Boolean logprobs, Integer topLogprobs, Integer maxTokens, Integer maxCompletionTokens, Integer n, List<OutputModality> outputModalities, AudioParameters audioParameters, Double presencePenalty, ResponseFormat responseFormat, Integer seed, String serviceTier, List<String> stop, Boolean stream, StreamOptions streamOptions, Double temperature, Double topP, List<FunctionTool> tools, Object toolChoice, Boolean parallelToolCalls, String user, String reasoningEffort, WebSearchOptions webSearchOptions) { public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) { this(messages, model, (Boolean)null, (Map)null, (Double)null, (Map)null, (Boolean)null, (Integer)null, (Integer)null, (Integer)null, (Integer)null, (List)null, (AudioParameters)null, (Double)null, (ResponseFormat)null, (Integer)null, (String)null, (List)null, false, (StreamOptions)null, temperature, (Double)null, (List)null, (Object)null, (Boolean)null, (String)null, (String)null, (WebSearchOptions)null); } public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, AudioParameters audio, boolean stream) { this(messages, model, (Boolean)null, (Map)null, (Double)null, (Map)null, (Boolean)null, (Integer)null, (Integer)null, (Integer)null, (Integer)null, List.of(OpenAiApi.OutputModality.AUDIO, OpenAiApi.OutputModality.TEXT), audio, (Double)null, (ResponseFormat)null, (Integer)null, (String)null, (List)null, stream, (StreamOptions)null, (Double)null, (Double)null, (List)null, (Object)null, (Boolean)null, (String)null, (String)null, (WebSearchOptions)null); } public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature, boolean stream) { this(messages, model, (Boolean)null, (Map)null, (Double)null, (Map)null, (Boolean)null, (Integer)null, (Integer)null, (Integer)null, (Integer)null, (List)null, (AudioParameters)null, (Double)null, (ResponseFormat)null, (Integer)null, (String)null, (List)null, stream, (StreamOptions)null, temperature, (Double)null, (List)null, (Object)null, (Boolean)null, (String)null, (String)null, (WebSearchOptions)null); } public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, List<FunctionTool> tools, Object toolChoice) { this(messages, model, (Boolean)null, (Map)null, (Double)null, (Map)null, (Boolean)null, (Integer)null, (Integer)null, (Integer)null, (Integer)null, (List)null, (AudioParameters)null, (Double)null, (ResponseFormat)null, (Integer)null, (String)null, (List)null, false, (StreamOptions)null, 0.8, (Double)null, tools, toolChoice, (Boolean)null, (String)null, (String)null, (WebSearchOptions)null); } public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) { this(messages, (String)null, (Boolean)null, (Map)null, (Double)null, (Map)null, (Boolean)null, (Integer)null, (Integer)null, (Integer)null, (Integer)null, (List)null, (AudioParameters)null, (Double)null, (ResponseFormat)null, (Integer)null, (String)null, (List)null, stream, (StreamOptions)null, (Double)null, (Double)null, (List)null, (Object)null, (Boolean)null, (String)null, (String)null, (WebSearchOptions)null); } public ChatCompletionRequest(@JsonProperty(\"messages\") List<ChatCompletionMessage> messages, @JsonProperty(\"model\") String model, @JsonProperty(\"store\") Boolean store, @JsonProperty(\"metadata\") Map<String, String> metadata, @JsonProperty(\"frequencypenalty\") Double frequencyPenalty, @JsonProperty(\"logitbias\") Map<String, Integer> logitBias, @JsonProperty(\"logprobs\") Boolean logprobs, @JsonProperty(\"toplogprobs\") Integer topLogprobs, @JsonProperty(\"maxtokens\") Integer maxTokens, @JsonProperty(\"maxcompletiontokens\") Integer maxCompletionTokens, @JsonProperty(\"n\") Integer n, @JsonProperty(\"modalities\") List<OutputModality> outputModalities, @JsonProperty(\"audio\") AudioParameters audioParameters, @JsonProperty(\"presencepenalty\") Double presencePenalty, @JsonProperty(\"responseformat\") ResponseFormat responseFormat, @JsonProperty(\"seed\") Integer seed, @JsonProperty(\"servicetier\") String serviceTier, @JsonProperty(\"stop\") List<String> stop, @JsonProperty(\"stream\") Boolean stream, @JsonProperty(\"streamoptions\") StreamOptions streamOptions, @JsonProperty(\"temperature\") Double temperature, @JsonProperty(\"topp\") Double topP, @JsonProperty(\"tools\") List<FunctionTool> tools, @JsonProperty(\"toolchoice\") Object toolChoice, @JsonProperty(\"paralleltoolcalls\") Boolean parallelToolCalls, @JsonProperty(\"user\") String user, @JsonProperty(\"reasoningeffort\") String reasoningEffort, @JsonProperty(\"websearchoptions\") WebSearchOptions webSearchOptions) { this.messages = messages; this.model = model; this.store = store; this.metadata = metadata; this.frequencyPenalty = frequencyPenalty; this.logitBias = logitBias; this.logprobs = logprobs; this.topLogprobs = topLogprobs; this.maxTokens = maxTokens; this.maxCompletionTokens = maxCompletionTokens; this.n = n; this.outputModalities = outputModalities; this.audioParameters = audioParameters; this.presencePenalty = presencePenalty; this.responseFormat = responseFormat; this.seed = seed; this.serviceTier = serviceTier; this.stop = stop; this.stream = stream; this.streamOptions = streamOptions; this.temperature = temperature; this.topP = topP; this.tools = tools; this.toolChoice = toolChoice; this.parallelToolCalls = parallelToolCalls; this.user = user; this.reasoningEffort = reasoningEffort; this.webSearchOptions = webSearchOptions; } public ChatCompletionRequest streamOptions(StreamOptions streamOptions) { return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty, this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP, this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions); } @JsonProperty(\"messages\") public List<ChatCompletionMessage> messages() { return this.messages; } @JsonProperty(\"model\") public String model() { return this.model; } @JsonProperty(\"store\") public Boolean store() { return this.store; } @JsonProperty(\"metadata\") public Map<String, String> metadata() { return this.metadata; } @JsonProperty(\"frequencypenalty\") public Double frequencyPenalty() { return this.frequencyPenalty; } @JsonProperty(\"logitbias\") public Map<String, Integer> logitBias() { return this.logitBias; } @JsonProperty(\"logprobs\") public Boolean logprobs() { return this.logprobs; } @JsonProperty(\"toplogprobs\") public Integer topLogprobs() { return this.topLogprobs; } @JsonProperty(\"maxtokens\") public Integer maxTokens() { return this.maxTokens; } @JsonProperty(\"maxcompletiontokens\") public Integer maxCompletionTokens() { return this.maxCompletionTokens; } @JsonProperty(\"n\") public Integer n() { return this.n; } @JsonProperty(\"modalities\") public List<OutputModality> outputModalities() { return this.outputModalities; } @JsonProperty(\"audio\") public AudioParameters audioParameters() { return this.audioParameters; } @JsonProperty(\"presencepenalty\") public Double presencePenalty() { return this.presencePenalty; } @JsonProperty(\"responseformat\") public ResponseFormat responseFormat() { return this.responseFormat; } @JsonProperty(\"seed\") public Integer seed() { return this.seed; } @JsonProperty(\"servicetier\") public String serviceTier() { return this.serviceTier; } @JsonProperty(\"stop\") public List<String> stop() { return this.stop; } @JsonProperty(\"stream\") public Boolean stream() { return this.stream; } @JsonProperty(\"streamoptions\") public StreamOptions streamOptions() { return this.streamOptions; } @JsonProperty(\"temperature\") public Double temperature() { return this.temperature; } @JsonProperty(\"topp\") public Double topP() { return this.topP; } @JsonProperty(\"tools\") public List<FunctionTool> tools() { return this.tools; } @JsonProperty(\"toolchoice\") public Object toolChoice() { return this.toolChoice; } @JsonProperty(\"paralleltoolcalls\") public Boolean parallelToolCalls() { return this.parallelToolCalls; } @JsonProperty(\"user\") public String user() { return this.user; } @JsonProperty(\"reasoningeffort\") public String reasoningEffort() { return this.reasoningEffort; } @JsonProperty(\"websearchoptions\") public WebSearchOptions webSearchOptions() { return this.webSearchOptions; } public static class ToolChoiceBuilder { public static final String AUTO = \"auto\"; public static final String NONE = \"none\"; public static Object FUNCTION(String functionName) { return Map.of(\"type\", \"function\", \"function\", Map.of(\"name\", functionName)); } } @JsonInclude(Include.NONNULL) public static record AudioParameters(Voice voice, AudioResponseFormat format) { public AudioParameters(@JsonProperty(\"voice\") Voice voice, @JsonProperty(\"format\") AudioResponseFormat format) { this.voice = voice; this.format = format; } @JsonProperty(\"voice\") public Voice voice() { return this.voice; } @JsonProperty(\"format\") public AudioResponseFormat format() { return this.format; } public static enum Voice { @JsonProperty(\"alloy\") ALLOY, @JsonProperty(\"echo\") ECHO, @JsonProperty(\"fable\") FABLE, @JsonProperty(\"onyx\") ONYX, @JsonProperty(\"nova\") NOVA, @JsonProperty(\"shimmer\") SHIMMER; } public static enum AudioResponseFormat { @JsonProperty(\"mp3\") MP3, @JsonProperty(\"flac\") FLAC, @JsonProperty(\"opus\") OPUS, @JsonProperty(\"pcm16\") PCM16, @JsonProperty(\"wav\") WAV; } } @JsonInclude(Include.NONNULL) public static record StreamOptions(Boolean includeUsage) { public static StreamOptions INCLUDEUSAGE = new StreamOptions(true); public StreamOptions(@JsonProperty(\"includeusage\") Boolean includeUsage) { this.includeUsage = includeUsage; } @JsonProperty(\"includeusage\") public Boolean includeUsage() { return this.includeUsage; } } @JsonInclude(Include.NONNULL) public static record WebSearchOptions(SearchContextSize searchContextSize, UserLocation userLocation) { public WebSearchOptions(@JsonProperty(\"searchcontextsize\") SearchContextSize searchContextSize, @JsonProperty(\"userlocation\") UserLocation userLocation) { this.searchContextSize = searchContextSize; this.userLocation = userLocation; } @JsonProperty(\"searchcontextsize\") public SearchContextSize searchContextSize() { return this.searchContextSize; } @JsonProperty(\"userlocation\") public UserLocation userLocation() { return this.userLocation; } public static enum SearchContextSize { @JsonProperty(\"low\") LOW, @JsonProperty(\"medium\") MEDIUM, @JsonProperty(\"high\") HIGH; } @JsonInclude(Include.NONNULL) public static record UserLocation(String type, Approximate approximate) { public UserLocation(@JsonProperty(\"type\") String type, @JsonProperty(\"approximate\") Approximate approximate) { this.type = type; this.approximate = approximate; } @JsonProperty(\"type\") public String type() { return this.type; } @JsonProperty(\"approximate\") public Approximate approximate() { return this.approximate; } @JsonInclude(Include.NONNULL) public static record Approximate(String city, String country, String region, String timezone) { public Approximate(@JsonProperty(\"city\") String city, @JsonProperty(\"country\") String country, @JsonProperty(\"region\") String region, @JsonProperty(\"timezone\") String timezone) { this.city = city; this.country = country; this.region = region; this.timezone = timezone; } @JsonProperty(\"city\") public String city() { return this.city; } @JsonProperty(\"country\") public String country() { return this.country; } @JsonProperty(\"region\") public String region() { return this.region; } @JsonProperty(\"timezone\") public String timezone() { return this.timezone; } } } } } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record ChatCompletionMessage(Object rawContent, Role role, String name, String toolCallId, List<ToolCall> toolCalls, String refusal, AudioOutput audioOutput, List<Annotation> annotations) { public ChatCompletionMessage(Object content, Role role) { this(content, role, (String)null, (String)null, (List)null, (String)null, (AudioOutput)null, (List)null); } public ChatCompletionMessage(@JsonProperty(\"content\") Object rawContent, @JsonProperty(\"role\") Role role, @JsonProperty(\"name\") String name, @JsonProperty(\"toolcallid\") String toolCallId, @JsonProperty(\"toolcalls\") @JsonFormat(with = {Feature.ACCEPTSINGLEVALUEASARRAY}) List<ToolCall> toolCalls, @JsonProperty(\"refusal\") String refusal, @JsonProperty(\"audio\") AudioOutput audioOutput, @JsonProperty(\"annotations\") List<Annotation> annotations) { this.rawContent = rawContent; this.role = role; this.name = name; this.toolCallId = toolCallId; this.toolCalls = toolCalls; this.refusal = refusal; this.audioOutput = audioOutput; this.annotations = annotations; } public String content() { if (this.rawContent == null) { return null; } else { Object var2 = this.rawContent; if (var2 instanceof String) { String text = (String)var2; return text; } else { throw new IllegalStateException(\"The content is not a string!\"); } } } @JsonProperty(\"content\") public Object rawContent() { return this.rawContent; } @JsonProperty(\"role\") public Role role() { return this.role; } @JsonProperty(\"name\") public String name() { return this.name; } @JsonProperty(\"toolcallid\") public String toolCallId() { return this.toolCallId; } @JsonProperty(\"toolcalls\") @JsonFormat( with = {Feature.ACCEPTSINGLEVALUEASARRAY} ) public List<ToolCall> toolCalls() { return this.toolCalls; } @JsonProperty(\"refusal\") public String refusal() { return this.refusal; } @JsonProperty(\"audio\") public AudioOutput audioOutput() { return this.audioOutput; } @JsonProperty(\"annotations\") public List<Annotation> annotations() { return this.annotations; } public static enum Role { @JsonProperty(\"system\") SYSTEM, @JsonProperty(\"user\") USER, @JsonProperty(\"assistant\") ASSISTANT, @JsonProperty(\"tool\") TOOL; } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record MediaContent(String type, String text, ImageUrl imageUrl, InputAudio inputAudio) { public MediaContent(String text) { this(\"text\", text, (ImageUrl)null, (InputAudio)null); } public MediaContent(ImageUrl imageUrl) { this(\"imageurl\", (String)null, imageUrl, (InputAudio)null); } public MediaContent(InputAudio inputAudio) { this(\"inputaudio\", (String)null, (ImageUrl)null, inputAudio); } public MediaContent(@JsonProperty(\"type\") String type, @JsonProperty(\"text\") String text, @JsonProperty(\"imageurl\") ImageUrl imageUrl, @JsonProperty(\"inputaudio\") InputAudio inputAudio) { this.type = type; this.text = text; this.imageUrl = imageUrl; this.inputAudio = inputAudio; } @JsonProperty(\"type\") public String type() { return this.type; } @JsonProperty(\"text\") public String text() { return this.text; } @JsonProperty(\"imageurl\") public ImageUrl imageUrl() { return this.imageUrl; } @JsonProperty(\"inputaudio\") public InputAudio inputAudio() { return this.inputAudio; } @JsonInclude(Include.NONNULL) public static record InputAudio(String data, Format format) { public InputAudio(@JsonProperty(\"data\") String data, @JsonProperty(\"format\") Format format) { this.data = data; this.format = format; } @JsonProperty(\"data\") public String data() { return this.data; } @JsonProperty(\"format\") public Format format() { return this.format; } public static enum Format { @JsonProperty(\"mp3\") MP3, @JsonProperty(\"wav\") WAV; } } @JsonInclude(Include.NONNULL) public static record ImageUrl(String url, String detail) { public ImageUrl(String url) { this(url, (String)null); } public ImageUrl(@JsonProperty(\"url\") String url, @JsonProperty(\"detail\") String detail) { this.url = url; this.detail = detail; } @JsonProperty(\"url\") public String url() { return this.url; } @JsonProperty(\"detail\") public String detail() { return this.detail; } } } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record ToolCall(Integer index, String id, String type, ChatCompletionFunction function) { public ToolCall(String id, String type, ChatCompletionFunction function) { this((Integer)null, id, type, function); } public ToolCall(@JsonProperty(\"index\") Integer index, @JsonProperty(\"id\") String id, @JsonProperty(\"type\") String type, @JsonProperty(\"function\") ChatCompletionFunction function) { this.index = index; this.id = id; this.type = type; this.function = function; } @JsonProperty(\"index\") public Integer index() { return this.index; } @JsonProperty(\"id\") public String id() { return this.id; } @JsonProperty(\"type\") public String type() { return this.type; } @JsonProperty(\"function\") public ChatCompletionFunction function() { return this.function; } } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record ChatCompletionFunction(String name, String arguments) { public ChatCompletionFunction(@JsonProperty(\"name\") String name, @JsonProperty(\"arguments\") String arguments) { this.name = name; this.arguments = arguments; } @JsonProperty(\"name\") public String name() { return this.name; } @JsonProperty(\"arguments\") public String arguments() { return this.arguments; } } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record AudioOutput(String id, String data, Long expiresAt, String transcript) { public AudioOutput(@JsonProperty(\"id\") String id, @JsonProperty(\"data\") String data, @JsonProperty(\"expiresat\") Long expiresAt, @JsonProperty(\"transcript\") String transcript) { this.id = id; this.data = data; this.expiresAt = expiresAt; this.transcript = transcript; } @JsonProperty(\"id\") public String id() { return this.id; } @JsonProperty(\"data\") public String data() { return this.data; } @JsonProperty(\"expiresat\") public Long expiresAt() { return this.expiresAt; } @JsonProperty(\"transcript\") public String transcript() { return this.transcript; } } @JsonInclude(Include.NONNULL) public static record Annotation(String type, UrlCitation urlCitation) { public Annotation(@JsonProperty(\"type\") String type, @JsonProperty(\"urlcitation\") UrlCitation urlCitation) { this.type = type; this.urlCitation = urlCitation; } @JsonProperty(\"type\") public String type() { return this.type; } @JsonProperty(\"urlcitation\") public UrlCitation urlCitation() { return this.urlCitation; } @JsonInclude(Include.NONNULL) public static record UrlCitation(Integer endIndex, Integer startIndex, String title, String url) { public UrlCitation(@JsonProperty(\"endindex\") Integer endIndex, @JsonProperty(\"startindex\") Integer startIndex, @JsonProperty(\"title\") String title, @JsonProperty(\"url\") String url) { this.endIndex = endIndex; this.startIndex = startIndex; this.title = title; this.url = url; } @JsonProperty(\"endindex\") public Integer endIndex() { return this.endIndex; } @JsonProperty(\"startindex\") public Integer startIndex() { return this.startIndex; } @JsonProperty(\"title\") public String title() { return this.title; } @JsonProperty(\"url\") public String url() { return this.url; } } } } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record ChatCompletion(String id, List<Choice> choices, Long created, String model, String serviceTier, String systemFingerprint, String object, Usage usage) { public ChatCompletion(@JsonProperty(\"id\") String id, @JsonProperty(\"choices\") List<Choice> choices, @JsonProperty(\"created\") Long created, @JsonProperty(\"model\") String model, @JsonProperty(\"servicetier\") String serviceTier, @JsonProperty(\"systemfingerprint\") String systemFingerprint, @JsonProperty(\"object\") String object, @JsonProperty(\"usage\") Usage usage) { this.id = id; this.choices = choices; this.created = created; this.model = model; this.serviceTier = serviceTier; this.systemFingerprint = systemFingerprint; this.object = object; this.usage = usage; } @JsonProperty(\"id\") public String id() { return this.id; } @JsonProperty(\"choices\") public List<Choice> choices() { return this.choices; } @JsonProperty(\"created\") public Long created() { return this.created; } @JsonProperty(\"model\") public String model() { return this.model; } @JsonProperty(\"servicetier\") public String serviceTier() { return this.serviceTier; } @JsonProperty(\"systemfingerprint\") public String systemFingerprint() { return this.systemFingerprint; } @JsonProperty(\"object\") public String object() { return this.object; } @JsonProperty(\"usage\") public Usage usage() { return this.usage; } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record Choice(ChatCompletionFinishReason finishReason, Integer index, ChatCompletionMessage message, LogProbs logprobs) { public Choice(@JsonProperty(\"finishreason\") ChatCompletionFinishReason finishReason, @JsonProperty(\"index\") Integer index, @JsonProperty(\"message\") ChatCompletionMessage message, @JsonProperty(\"logprobs\") LogProbs logprobs) { this.finishReason = finishReason; this.index = index; this.message = message; this.logprobs = logprobs; } @JsonProperty(\"finishreason\") public ChatCompletionFinishReason finishReason() { return this.finishReason; } @JsonProperty(\"index\") public Integer index() { return this.index; } @JsonProperty(\"message\") public ChatCompletionMessage message() { return this.message; } @JsonProperty(\"logprobs\") public LogProbs logprobs() { return this.logprobs; } } } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record LogProbs(List<Content> content, List<Content> refusal) { public LogProbs(@JsonProperty(\"content\") List<Content> content, @JsonProperty(\"refusal\") List<Content> refusal) { this.content = content; this.refusal = refusal; } @JsonProperty(\"content\") public List<Content> content() { return this.content; } @JsonProperty(\"refusal\") public List<Content> refusal() { return this.refusal; } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record Content(String token, Float logprob, List<Integer> probBytes, List<TopLogProbs> topLogprobs) { public Content(@JsonProperty(\"token\") String token, @JsonProperty(\"logprob\") Float logprob, @JsonProperty(\"bytes\") List<Integer> probBytes, @JsonProperty(\"toplogprobs\") List<TopLogProbs> topLogprobs) { this.token = token; this.logprob = logprob; this.probBytes = probBytes; this.topLogprobs = topLogprobs; } @JsonProperty(\"token\") public String token() { return this.token; } @JsonProperty(\"logprob\") public Float logprob() { return this.logprob; } @JsonProperty(\"bytes\") public List<Integer> probBytes() { return this.probBytes; } @JsonProperty(\"toplogprobs\") public List<TopLogProbs> topLogprobs() { return this.topLogprobs; } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record TopLogProbs(String token, Float logprob, List<Integer> probBytes) { public TopLogProbs(@JsonProperty(\"token\") String token, @JsonProperty(\"logprob\") Float logprob, @JsonProperty(\"bytes\") List<Integer> probBytes) { this.token = token; this.logprob = logprob; this.probBytes = probBytes; } @JsonProperty(\"token\") public String token() { return this.token; } @JsonProperty(\"logprob\") public Float logprob() { return this.logprob; } @JsonProperty(\"bytes\") public List<Integer> probBytes() { return this.probBytes; } } } } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens, PromptTokensDetails promptTokensDetails, CompletionTokenDetails completionTokenDetails) { public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) { this(completionTokens, promptTokens, totalTokens, (PromptTokensDetails)null, (CompletionTokenDetails)null); } public Usage(@JsonProperty(\"completiontokens\") Integer completionTokens, @JsonProperty(\"prompttokens\") Integer promptTokens, @JsonProperty(\"totaltokens\") Integer totalTokens, @JsonProperty(\"prompttokensdetails\") PromptTokensDetails promptTokensDetails, @JsonProperty(\"completiontokensdetails\") CompletionTokenDetails completionTokenDetails) { this.completionTokens = completionTokens; this.promptTokens = promptTokens; this.totalTokens = totalTokens; this.promptTokensDetails = promptTokensDetails; this.completionTokenDetails = completionTokenDetails; } @JsonProperty(\"completiontokens\") public Integer completionTokens() { return this.completionTokens; } @JsonProperty(\"prompttokens\") public Integer promptTokens() { return this.promptTokens; } @JsonProperty(\"totaltokens\") public Integer totalTokens() { return this.totalTokens; } @JsonProperty(\"prompttokensdetails\") public PromptTokensDetails promptTokensDetails() { return this.promptTokensDetails; } @JsonProperty(\"completiontokensdetails\") public CompletionTokenDetails completionTokenDetails() { return this.completionTokenDetails; } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record PromptTokensDetails(Integer audioTokens, Integer cachedTokens) { public PromptTokensDetails(@JsonProperty(\"audiotokens\") Integer audioTokens, @JsonProperty(\"cachedtokens\") Integer cachedTokens) { this.audioTokens = audioTokens; this.cachedTokens = cachedTokens; } @JsonProperty(\"audiotokens\") public Integer audioTokens() { return this.audioTokens; } @JsonProperty(\"cachedtokens\") public Integer cachedTokens() { return this.cachedTokens; } } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record CompletionTokenDetails(Integer reasoningTokens, Integer acceptedPredictionTokens, Integer audioTokens, Integer rejectedPredictionTokens) { public CompletionTokenDetails(@JsonProperty(\"reasoningtokens\") Integer reasoningTokens, @JsonProperty(\"acceptedpredictiontokens\") Integer acceptedPredictionTokens, @JsonProperty(\"audiotokens\") Integer audioTokens, @JsonProperty(\"rejectedpredictiontokens\") Integer rejectedPredictionTokens) { this.reasoningTokens = reasoningTokens; this.acceptedPredictionTokens = acceptedPredictionTokens; this.audioTokens = audioTokens; this.rejectedPredictionTokens = rejectedPredictionTokens; } @JsonProperty(\"reasoningtokens\") public Integer reasoningTokens() { return this.reasoningTokens; } @JsonProperty(\"acceptedpredictiontokens\") public Integer acceptedPredictionTokens() { return this.acceptedPredictionTokens; } @JsonProperty(\"audiotokens\") public Integer audioTokens() { return this.audioTokens; } @JsonProperty(\"rejectedpredictiontokens\") public Integer rejectedPredictionTokens() { return this.rejectedPredictionTokens; } } } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record ChatCompletionChunk(String id, List<ChunkChoice> choices, Long created, String model, String serviceTier, String systemFingerprint, String object, Usage usage) { public ChatCompletionChunk(@JsonProperty(\"id\") String id, @JsonProperty(\"choices\") List<ChunkChoice> choices, @JsonProperty(\"created\") Long created, @JsonProperty(\"model\") String model, @JsonProperty(\"servicetier\") String serviceTier, @JsonProperty(\"systemfingerprint\") String systemFingerprint, @JsonProperty(\"object\") String object, @JsonProperty(\"usage\") Usage usage) { this.id = id; this.choices = choices; this.created = created; this.model = model; this.serviceTier = serviceTier; this.systemFingerprint = systemFingerprint; this.object = object; this.usage = usage; } @JsonProperty(\"id\") public String id() { return this.id; } @JsonProperty(\"choices\") public List<ChunkChoice> choices() { return this.choices; } @JsonProperty(\"created\") public Long created() { return this.created; } @JsonProperty(\"model\") public String model() { return this.model; } @JsonProperty(\"servicetier\") public String serviceTier() { return this.serviceTier; } @JsonProperty(\"systemfingerprint\") public String systemFingerprint() { return this.systemFingerprint; } @JsonProperty(\"object\") public String object() { return this.object; } @JsonProperty(\"usage\") public Usage usage() { return this.usage; } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record ChunkChoice(ChatCompletionFinishReason finishReason, Integer index, ChatCompletionMessage delta, LogProbs logprobs) { public ChunkChoice(@JsonProperty(\"finishreason\") ChatCompletionFinishReason finishReason, @JsonProperty(\"index\") Integer index, @JsonProperty(\"delta\") ChatCompletionMessage delta, @JsonProperty(\"logprobs\") LogProbs logprobs) { this.finishReason = finishReason; this.index = index; this.delta = delta; this.logprobs = logprobs; } @JsonProperty(\"finishreason\") public ChatCompletionFinishReason finishReason() { return this.finishReason; } @JsonProperty(\"index\") public Integer index() { return this.index; } @JsonProperty(\"delta\") public ChatCompletionMessage delta() { return this.delta; } @JsonProperty(\"logprobs\") public LogProbs logprobs() { return this.logprobs; } } } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record Embedding(Integer index, float[] embedding, String object) { public Embedding(Integer index, float[] embedding) { this(index, embedding, \"embedding\"); } public Embedding(@JsonProperty(\"index\") Integer index, @JsonProperty(\"embedding\") float[] embedding, @JsonProperty(\"object\") String object) { this.index = index; this.embedding = embedding; this.object = object; } @JsonProperty(\"index\") public Integer index() { return this.index; } @JsonProperty(\"embedding\") public float[] embedding() { return this.embedding; } @JsonProperty(\"object\") public String object() { return this.object; } } @JsonInclude(Include.NONNULL) public static record EmbeddingRequest<T>(T input, String model, String encodingFormat, Integer dimensions, String user) { public EmbeddingRequest(T input, String model) { this(input, model, \"float\", (Integer)null, (String)null); } public EmbeddingRequest(T input) { this(input, OpenAiApi.DEFAULTEMBEDDINGMODEL); } public EmbeddingRequest(@JsonProperty(\"input\") T input, @JsonProperty(\"model\") String model, @JsonProperty(\"encodingformat\") String encodingFormat, @JsonProperty(\"dimensions\") Integer dimensions, @JsonProperty(\"user\") String user) { this.input = input; this.model = model; this.encodingFormat = encodingFormat; this.dimensions = dimensions; this.user = user; } @JsonProperty(\"input\") public T input() { return this.input; } @JsonProperty(\"model\") public String model() { return this.model; } @JsonProperty(\"encodingformat\") public String encodingFormat() { return this.encodingFormat; } @JsonProperty(\"dimensions\") public Integer dimensions() { return this.dimensions; } @JsonProperty(\"user\") public String user() { return this.user; } } @JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record EmbeddingList<T>(String object, List<T> data, String model, Usage usage) { public EmbeddingList(@JsonProperty(\"object\") String object, @JsonProperty(\"data\") List<T> data, @JsonProperty(\"model\") String model, @JsonProperty(\"usage\") Usage usage) { this.object = object; this.data = data; this.model = model; this.usage = usage; } @JsonProperty(\"object\") public String object() { return this.object; } @JsonProperty(\"data\") public List<T> data() { return this.data; } @JsonProperty(\"model\") public String model() { return this.model; } @JsonProperty(\"usage\") public Usage usage() { return this.usage; } } public static class Builder { private String baseUrl = \"https://api.openai.com\"; private ApiKey apiKey; private MultiValueMap<String, String> headers = new LinkedMultiValueMap(); private String completionsPath = \"/v1/chat/completions\"; private String embeddingsPath = \"/v1/embeddings\"; private RestClient.Builder restClientBuilder = RestClient.builder(); private WebClient.Builder webClientBuilder = WebClient.builder(); private ResponseErrorHandler responseErrorHandler; public Builder() { this.responseErrorHandler = RetryUtils.DEFAULTRESPONSEERRORHANDLER; } public Builder(OpenAiApi api) { this.responseErrorHandler = RetryUtils.DEFAULTRESPONSEERRORHANDLER; this.baseUrl = api.getBaseUrl(); this.apiKey = api.getApiKey(); this.headers = new LinkedMultiValueMap(api.getHeaders()); this.completionsPath = api.getCompletionsPath(); this.embeddingsPath = api.getEmbeddingsPath(); this.restClientBuilder = api.restClient != null ? api.restClient.mutate() : RestClient.builder(); this.webClientBuilder = api.webClient != null ? api.webClient.mutate() : WebClient.builder(); this.responseErrorHandler = api.getResponseErrorHandler(); } public Builder baseUrl(String baseUrl) { Assert.hasText(baseUrl, \"baseUrl cannot be null or empty\"); this.baseUrl = baseUrl; return this; } public Builder apiKey(ApiKey apiKey) { Assert.notNull(apiKey, \"apiKey cannot be null\"); this.apiKey = apiKey; return this; } public Builder apiKey(String simpleApiKey) { Assert.notNull(simpleApiKey, \"simpleApiKey cannot be null\"); this.apiKey = new SimpleApiKey(simpleApiKey); return this; } public Builder headers(MultiValueMap<String, String> headers) { Assert.notNull(headers, \"headers cannot be null\"); this.headers = headers; return this; } public Builder completionsPath(String completionsPath) { Assert.hasText(completionsPath, \"completionsPath cannot be null or empty\"); this.completionsPath = completionsPath; return this; } public Builder embeddingsPath(String embeddingsPath) { Assert.hasText(embeddingsPath, \"embeddingsPath cannot be null or empty\"); this.embeddingsPath = embeddingsPath; return this; } public Builder restClientBuilder(RestClient.Builder restClientBuilder) { Assert.notNull(restClientBuilder, \"restClientBuilder cannot be null\"); this.restClientBuilder = restClientBuilder; return this; } public Builder webClientBuilder(WebClient.Builder webClientBuilder) { Assert.notNull(webClientBuilder, \"webClientBuilder cannot be null\"); this.webClientBuilder = webClientBuilder; return this; } public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) { Assert.notNull(responseErrorHandler, \"responseErrorHandler cannot be null\"); this.responseErrorHandler = responseErrorHandler; return this; } public OpenAiApi build() { Assert.notNull(this.apiKey, \"apiKey must be set\"); return new OpenAiApi(this.baseUrl, this.apiKey, this.headers, this.completionsPath, this.embeddingsPath, this.restClientBuilder, this.webClientBuilder, this.responseErrorHandler); } }}
学习交流圈
你好,我是影子,曾先后在🐻、新能源、老铁就职,现在是一名AI研发工程师。目前新建了一个交流群,一个人走得快,一群人走得远,关注公众号后可获得个人微信,添加微信后备注“交流”入群。另外,本人长期维护一套飞书云文档笔记,涵盖后端、大数据系统化的面试资料,可私信免费获取