Skip to content

Commit

Permalink
Send image to AI (#1006)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kitenite authored Jan 10, 2025
1 parent 9c3fe42 commit 8dd7765
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 76 deletions.
1 change: 1 addition & 0 deletions apps/studio/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"@xterm/xterm": "^5.6.0-beta.70",
"@zonke-cloud/sdk": "^0.1.6",
"ai": "^3.4.29",
"browser-image-compression": "^2.0.2",
"chokidar": "^4.0.1",
"electron-log": "^5.2.0",
"electron-updater": "^6.3.4",
Expand Down
68 changes: 34 additions & 34 deletions apps/studio/src/lib/editor/engine/chat/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
type FileMessageContext,
type HighlightMessageContext,
} from '@onlook/models/chat';
import type { DomElement } from '@onlook/models/element';
import { makeAutoObservable, reaction } from 'mobx';
import type { EditorEngine } from '..';

Expand All @@ -14,34 +15,53 @@ export class ChatContext {
makeAutoObservable(this);
reaction(
() => this.editorEngine.elements.selected,
() => this.getChatContext(true).then((context) => (this.context = context)),
() => this.getChatContext().then((context) => (this.context = context)),
);
}

async getChatContext(skipContent: boolean = false) {
async getChatContext(): Promise<ChatMessageContext[]> {
const selected = this.editorEngine.elements.selected;
if (selected.length === 0) {
return [];
}

const fileNames = new Set<string>();
const highlightedContext = await this.getHighlightedContext(selected, fileNames);
const fileContext = await this.getFileContext(fileNames);
const imageContext = this.context.filter(
(context) => context.type === MessageContextType.IMAGE,
);
return [...fileContext, ...highlightedContext, ...imageContext];
}

private async getFileContext(fileNames: Set<string>): Promise<FileMessageContext[]> {
const fileContext: FileMessageContext[] = [];
for (const fileName of fileNames) {
const fileContent = await this.editorEngine.code.getFileContent(fileName, true);
if (fileContent === null) {
continue;
}
fileContext.push({
type: MessageContextType.FILE,
displayName: fileName,
path: fileName,
content: fileContent,
});
}
return fileContext;
}

private async getHighlightedContext(
selected: DomElement[],
fileNames: Set<string>,
): Promise<HighlightMessageContext[]> {
const highlightedContext: HighlightMessageContext[] = [];
for (const node of selected) {
const oid = node.oid;
if (!oid) {
continue;
}

let codeBlock: string | null;

// Skip content for display context
if (skipContent) {
codeBlock = '';
} else {
codeBlock = await this.editorEngine.code.getCodeBlock(oid);
}

const codeBlock = await this.editorEngine.code.getCodeBlock(oid);
if (codeBlock === null) {
continue;
}
Expand All @@ -50,6 +70,7 @@ export class ChatContext {
if (!templateNode) {
continue;
}

highlightedContext.push({
type: MessageContextType.HIGHLIGHT,
displayName: node.tagName.toLowerCase(),
Expand All @@ -61,28 +82,7 @@ export class ChatContext {
fileNames.add(templateNode.path);
}

const fileContext: FileMessageContext[] = [];
for (const fileName of fileNames) {
let fileContent: string | null;

// Skip content for display context
if (skipContent) {
fileContent = '';
} else {
fileContent = await this.editorEngine.code.getFileContent(fileName, true);
}
if (fileContent === null) {
continue;
}
fileContext.push({
type: MessageContextType.FILE,
displayName: fileName,
path: fileName,
content: fileContent,
});
}

return [...fileContext, ...highlightedContext];
return highlightedContext;
}

clear() {
Expand Down
25 changes: 22 additions & 3 deletions apps/studio/src/lib/editor/engine/chat/message/user.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { PromptProvider } from '@onlook/ai/src/prompt/provider';
import type { ChatMessageContext } from '@onlook/models/chat';
import type { ChatMessageContext, ImageMessageContext } from '@onlook/models/chat';
import {
ChatMessageRole,
ChatMessageType,
MessageContextType,
type UserChatMessage,
} from '@onlook/models/chat';
import type { CoreUserMessage } from 'ai';
import type { CoreUserMessage, ImagePart, TextPart } from 'ai';
import { nanoid } from 'nanoid/non-secure';

export class UserChatMessageImpl implements UserChatMessage {
Expand Down Expand Up @@ -51,10 +51,29 @@ export class UserChatMessageImpl implements UserChatMessage {
});
}

getImagePart(image: ImageMessageContext): ImagePart {
return {
type: 'image',
image: image.content,
mimeType: image.mimeType,
};
}

getTextPart(): TextPart {
return {
type: 'text',
text: this.hydratedContent,
};
}

toCoreMessage(): CoreUserMessage {
const imageParts = this.context
.filter((c) => c.type === MessageContextType.IMAGE)
.map(this.getImagePart);
const textPart = this.getTextPart();
return {
role: this.role,
content: this.hydratedContent,
content: [...imageParts, textPart],
};
}
}
1 change: 1 addition & 0 deletions apps/studio/src/lib/editor/engine/chat/mockData.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const MOCK_USER_MSG = new UserChatMessageImpl('Test message with some selected f
{
type: MessageContextType.IMAGE,
content: 'https://example.com/screenshot',
mimeType: 'image/png',
displayName: 'screenshot.png',
},
]);
Expand Down
77 changes: 39 additions & 38 deletions apps/studio/src/routes/editor/EditPanel/ChatTab/ChatInput.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { useEditorEngine } from '@/components/Context';
import type { ChatMessageContext } from '@onlook/models/chat';
import type { ChatMessageContext, ImageMessageContext } from '@onlook/models/chat';
import { MessageContextType } from '@onlook/models/chat';
import { Button } from '@onlook/ui/button';
import { Icons } from '@onlook/ui/icons';
import { Textarea } from '@onlook/ui/textarea';
import { Tooltip, TooltipContent, TooltipPortal, TooltipTrigger } from '@onlook/ui/tooltip';
import { cn } from '@onlook/ui/utils';
import imageCompression from 'browser-image-compression';
import { AnimatePresence } from 'framer-motion';
import { observer } from 'mobx-react-lite';
import { useState } from 'react';
Expand Down Expand Up @@ -65,17 +66,8 @@ export const ChatInput = observer(() => {
if (inputElement.files && inputElement.files.length > 0) {
const file = inputElement.files[0];
const fileName = file.name;
const reader = new FileReader();
reader.onload = (event) => {
const base64URL = event.target?.result as string;
editorEngine.chat.context.context.push({
type: MessageContextType.IMAGE,
content: base64URL,
displayName: fileName,
});
setTimeout(() => setIsHandlingFile(false), 100);
};
reader.readAsDataURL(file);
handleImageEvent(file, fileName);
setTimeout(() => setIsHandlingFile(false), 100);
} else {
setIsHandlingFile(false);
}
Expand All @@ -93,18 +85,7 @@ export const ChatInput = observer(() => {
if (!file) {
continue;
}

const reader = new FileReader();
reader.onload = (event) => {
const base64URL = event.target?.result as string;
editorEngine.chat.context.context.push({
type: MessageContextType.IMAGE,
content: base64URL,
displayName: 'Pasted image',
});
e.currentTarget.focus();
};
reader.readAsDataURL(file);
handleImageEvent(file, 'Pasted image');
break;
}
}
Expand All @@ -121,24 +102,44 @@ export const ChatInput = observer(() => {
if (!file) {
continue;
}

const reader = new FileReader();
reader.onload = (event) => {
const base64URL = event.target?.result as string;
editorEngine.chat.context.context.push({
type: MessageContextType.IMAGE,
content: base64URL,
displayName: file.name || 'Dropped image',
});
const textarea = e.currentTarget.querySelector('textarea');
textarea?.focus();
};
reader.readAsDataURL(file);
handleImageEvent(file, 'Dropped image');
break;
}
}
};

const handleImageEvent = async (file: File, displayName?: string) => {
const reader = new FileReader();
reader.onload = async (event) => {
const compressedImage = await compressImage(file);
const base64URL = compressedImage || (event.target?.result as string);
const contextImage: ImageMessageContext = {
type: MessageContextType.IMAGE,
content: base64URL,
mimeType: file.type,
displayName: displayName || file.name,
};
editorEngine.chat.context.context.push(contextImage);
};
reader.readAsDataURL(file);
};

async function compressImage(file: File): Promise<string | undefined> {
const options = {
maxSizeMB: 1,
maxWidthOrHeight: 1024,
};

try {
const compressedFile = await imageCompression(file, options);
const base64URL = imageCompression.getDataUrlFromFile(compressedFile);
console.log(`Image size reduced from ${file.size} to ${compressedFile.size} (bytes)`);
return base64URL;
} catch (error) {
console.error('Error compressing image:', error);
}
}

const handleDragOver = (e: React.DragEvent<HTMLDivElement>) => {
e.preventDefault();
};
Expand Down Expand Up @@ -287,7 +288,7 @@ export const ChatInput = observer(() => {
<Button
variant={'ghost'}
size={'icon'}
className="w-9 h-9 text-foreground-tertiary group hover:bg-transparent opacity-0"
className="w-9 h-9 text-foreground-tertiary group hover:bg-transparent"
onClick={handleOpenFileDialog}
disabled={disabled}
>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export function SentContextPill({ context }: { context: ChatMessageContext }) {
key={context.displayName}
>
{getContextIcon(context)}
<span>{getTruncatedName(context)}</span>
<span className="truncate">{getTruncatedName(context)}</span>
</span>
);
}
Binary file modified bun.lockb
Binary file not shown.
1 change: 1 addition & 0 deletions packages/models/src/chat/message/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ export type HighlightMessageContext = BaseMessageContext & {

export type ImageMessageContext = BaseMessageContext & {
type: MessageContextType.IMAGE;
mimeType: string;
};

export type ChatMessageContext = FileMessageContext | HighlightMessageContext | ImageMessageContext;

0 comments on commit 8dd7765

Please sign in to comment.