Skip to content

Commit

Permalink
Fix request settings and stop words in HF provider (eclipse-theia#14504)
Browse files Browse the repository at this point in the history
fixed eclipse-theia#14503

Signed-off-by: Jonas Helming <[email protected]>
  • Loading branch information
JonasHelming authored Nov 26, 2024
1 parent eb9a7e4 commit 55e29ed
Showing 1 changed file with 38 additions and 13 deletions.
51 changes: 38 additions & 13 deletions packages/ai-hugging-face/src/node/huggingface-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,50 +67,75 @@ export class HuggingFaceModel implements LanguageModel {
}
}

protected getDefaultSettings(): Record<string, unknown> {
return {
max_new_tokens: 2024,
stop: ['<|endoftext|>', '<eos>']
};
}

protected async handleNonStreamingRequest(hfInference: HfInference, request: LanguageModelRequest): Promise<LanguageModelTextResponse> {
const settings = request.settings || this.getDefaultSettings();

const response = await hfInference.textGeneration({
model: this.model,
inputs: toHuggingFacePrompt(request.messages),
parameters: {
temperature: 0.1, // Controls randomness, 0.1 for consistent outputs
max_new_tokens: 200, // Limits response length
return_full_text: false, // Ensures only the generated part is returned, not the prompt
do_sample: true, // Enables sampling for more varied responses
stop: ['<|endoftext|>'] // Stop generation at this token
...settings
}
});

const cleanText = response.generated_text.replace(/<\|endoftext\|>/g, '');
const stopWords = Array.isArray(settings.stop) ? settings.stop : [];
let cleanText = response.generated_text;

stopWords.forEach(stopWord => {
if (cleanText.endsWith(stopWord)) {
cleanText = cleanText.slice(0, -stopWord.length).trim();
}
});

return {
text: cleanText
};
}

protected async handleStreamingRequest(hfInference: HfInference, request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
protected async handleStreamingRequest(
hfInference: HfInference,
request: LanguageModelRequest,
cancellationToken?: CancellationToken
): Promise<LanguageModelResponse> {
const settings = request.settings || this.getDefaultSettings();

const stream = hfInference.textGenerationStream({
model: this.model,
inputs: toHuggingFacePrompt(request.messages),
parameters: {
temperature: 0.1,
max_new_tokens: 200,
return_full_text: false,
do_sample: true,
stop: ['<|endoftext|>']
...settings
}
});

const stopWords = Array.isArray(settings.stop) ? settings.stop : [];

const asyncIterator = {
async *[Symbol.asyncIterator](): AsyncIterator<LanguageModelStreamResponsePart> {
for await (const chunk of stream) {
const content = chunk.token.text.replace(/<\|endoftext\|>/g, '');
let content = chunk.token.text;

stopWords.forEach(stopWord => {
if (content.endsWith(stopWord)) {
content = content.slice(0, -stopWord.length).trim();
}
});

yield { content };

if (cancellationToken?.isCancellationRequested) {
break;
}
}
}
};

return { stream: asyncIterator };
}

Expand Down

0 comments on commit 55e29ed

Please sign in to comment.