Skip to content

Commit d1e4a54

Browse files
authored
🐛 fix: update convertUsage to handle XAI provider and adjust OpenAIStream to pass provider (#8557)
1 parent 1bc8815 commit d1e4a54

File tree

3 files changed

+59
-7
lines changed

3 files changed

+59
-7
lines changed

src/libs/model-runtime/utils/streams/openai/openai.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import {
2323
const transformOpenAIStream = (
2424
chunk: OpenAI.ChatCompletionChunk,
2525
streamContext: StreamContext,
26+
provider?: string,
2627
): StreamProtocolChunk | StreamProtocolChunk[] => {
2728
// handle the first chunk error
2829
if (FIRST_CHUNK_ERROR_KEY in chunk) {
@@ -45,7 +46,7 @@ const transformOpenAIStream = (
4546
if (!item) {
4647
if (chunk.usage) {
4748
const usage = chunk.usage;
48-
return { data: convertUsage(usage), id: chunk.id, type: 'usage' };
49+
return { data: convertUsage(usage, provider), id: chunk.id, type: 'usage' };
4950
}
5051

5152
return { data: chunk, id: chunk.id, type: 'data' };
@@ -155,7 +156,7 @@ const transformOpenAIStream = (
155156

156157
if (chunk.usage) {
157158
const usage = chunk.usage;
158-
return { data: convertUsage(usage), id: chunk.id, type: 'usage' };
159+
return { data: convertUsage(usage, provider), id: chunk.id, type: 'usage' };
159160
}
160161

161162
// xAI Live Search 功能返回引用源
@@ -274,7 +275,7 @@ const transformOpenAIStream = (
274275
// litellm 的返回结果中,存在 delta 为空,但是有 usage 的情况
275276
if (chunk.usage) {
276277
const usage = chunk.usage;
277-
return { data: convertUsage(usage), id: chunk.id, type: 'usage' };
278+
return { data: convertUsage(usage, provider), id: chunk.id, type: 'usage' };
278279
}
279280

280281
// 其余情况下,返回 delta 和 index
@@ -321,6 +322,9 @@ export const OpenAIStream = (
321322
) => {
322323
const streamStack: StreamContext = { id: '' };
323324

325+
const transformWithProvider = (chunk: OpenAI.ChatCompletionChunk, streamContext: StreamContext) =>
326+
transformOpenAIStream(chunk, streamContext, provider);
327+
324328
const readableStream =
325329
stream instanceof ReadableStream ? stream : convertIterableToStream(stream);
326330

@@ -330,7 +334,7 @@ export const OpenAIStream = (
330334
// provider like huggingface or minimax will return error in the stream,
331335
// so in the first Transformer, we need to handle the error
332336
.pipeThrough(createFirstErrorHandleTransformer(bizErrorTypeTransformer, provider))
333-
.pipeThrough(createTokenSpeedCalculator(transformOpenAIStream, { inputStartAt, streamStack }))
337+
.pipeThrough(createTokenSpeedCalculator(transformWithProvider, { inputStartAt, streamStack }))
334338
.pipeThrough(createSSEProtocolTransformer((c) => c, streamStack))
335339
.pipeThrough(createCallbacksTransformer(callbacks))
336340
);

src/libs/model-runtime/utils/usageConverter.test.ts

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import OpenAI from 'openai';
22
import { describe, expect, it } from 'vitest';
33

4-
import { convertUsage } from './usageConverter';
4+
import { convertUsage, convertResponseUsage } from './usageConverter';
55

66
describe('convertUsage', () => {
77
it('should convert basic OpenAI usage data correctly', () => {
@@ -246,4 +246,48 @@ describe('convertUsage', () => {
246246
expect(result).not.toHaveProperty('outputReasoningTokens');
247247
expect(result).not.toHaveProperty('outputAudioTokens');
248248
});
249+
250+
it('should handle XAI provider correctly where completion_tokens does not include reasoning_tokens', () => {
251+
// Arrange
252+
const xaiUsage: OpenAI.Completions.CompletionUsage = {
253+
prompt_tokens: 6103,
254+
completion_tokens: 66, // 这个不包含 reasoning_tokens
255+
total_tokens: 6550,
256+
prompt_tokens_details: {
257+
audio_tokens: 0,
258+
cached_tokens: 0,
259+
},
260+
completion_tokens_details: {
261+
accepted_prediction_tokens: 0,
262+
audio_tokens: 0,
263+
reasoning_tokens: 381, // 这是额外的 reasoning tokens
264+
rejected_prediction_tokens: 0,
265+
},
266+
};
267+
268+
// Act
269+
const xaiResult = convertUsage(xaiUsage, 'xai');
270+
271+
// Assert
272+
expect(xaiResult).toMatchObject({
273+
totalInputTokens: 6103,
274+
totalOutputTokens: 66,
275+
outputTextTokens: 66, // 不减去 reasoning_tokens
276+
outputReasoningTokens: 381,
277+
totalTokens: 6550,
278+
});
279+
280+
// 测试其他 provider(默认行为)
281+
const defaultResult = convertUsage(xaiUsage);
282+
283+
// 默认行为: outputTextTokens 应该是 completion_tokens - reasoning_tokens - audio_tokens = 66 - 381 - 0 = -315
284+
expect(defaultResult.outputTextTokens).toBe(-315);
285+
expect(defaultResult).toMatchObject({
286+
totalInputTokens: 6103,
287+
totalOutputTokens: 66,
288+
outputTextTokens: -315, // 负数确实会出现在结果中
289+
outputReasoningTokens: 381,
290+
totalTokens: 6550,
291+
});
292+
});
249293
});

src/libs/model-runtime/utils/usageConverter.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import OpenAI from 'openai';
22

33
import { ModelTokensUsage } from '@/types/message';
44

5-
export const convertUsage = (usage: OpenAI.Completions.CompletionUsage): ModelTokensUsage => {
5+
export const convertUsage = (usage: OpenAI.Completions.CompletionUsage, provider?: string): ModelTokensUsage => {
66
// 目前只有 pplx 才有 citation_tokens
77
const inputTextTokens = usage.prompt_tokens || 0;
88
const inputCitationTokens = (usage as any).citation_tokens || 0;
@@ -17,7 +17,11 @@ export const convertUsage = (usage: OpenAI.Completions.CompletionUsage): ModelTo
1717
const totalOutputTokens = usage.completion_tokens;
1818
const outputReasoning = usage.completion_tokens_details?.reasoning_tokens || 0;
1919
const outputAudioTokens = usage.completion_tokens_details?.audio_tokens || 0;
20-
const outputTextTokens = totalOutputTokens - outputReasoning - outputAudioTokens;
20+
21+
// XAI 的 completion_tokens 不包含 reasoning_tokens,需要特殊处理
22+
const outputTextTokens = provider === 'xai'
23+
? totalOutputTokens - outputAudioTokens
24+
: totalOutputTokens - outputReasoning - outputAudioTokens;
2125

2226
const totalTokens = inputCitationTokens + usage.total_tokens;
2327

0 commit comments

Comments
 (0)