Skip to content

Commit 6da1e9b

Browse files
authored
Merge pull request #126 from supabase-community/feat/drop-sql-file
feat: import sql files
2 parents d4f4b9c + c2d7298 commit 6da1e9b

File tree

7 files changed

+255
-23
lines changed

7 files changed

+255
-23
lines changed

apps/postgres-new/components/chat.tsx

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -77,28 +77,43 @@ export default function Chat() {
7777

7878
const sendCsv = useCallback(
7979
async (file: File) => {
80-
if (file.type !== 'text/csv') {
81-
// Add an artificial tool call requesting the CSV
82-
// with an error indicating the file wasn't a CSV
83-
appendMessage({
84-
role: 'assistant',
85-
content: '',
86-
toolInvocations: [
87-
{
88-
state: 'result',
89-
toolCallId: generateId(),
90-
toolName: 'requestCsv',
91-
args: {},
92-
result: {
93-
success: false,
94-
error: `The file has type '${file.type}'. Let the user know that only CSV imports are currently supported.`,
80+
const fileId = generateId()
81+
82+
await saveFile(fileId, file)
83+
84+
const text = await file.text()
85+
86+
// Add an artificial tool call requesting the CSV
87+
// with the file result all in one operation.
88+
appendMessage({
89+
role: 'assistant',
90+
content: '',
91+
toolInvocations: [
92+
{
93+
state: 'result',
94+
toolCallId: generateId(),
95+
toolName: 'requestCsv',
96+
args: {},
97+
result: {
98+
success: true,
99+
fileId: fileId,
100+
file: {
101+
name: file.name,
102+
size: file.size,
103+
type: file.type,
104+
lastModified: file.lastModified,
95105
},
106+
preview: text.split('\n').slice(0, 4).join('\n').trim(),
96107
},
97-
],
98-
})
99-
return
100-
}
108+
},
109+
],
110+
})
111+
},
112+
[appendMessage]
113+
)
101114

115+
const sendSql = useCallback(
116+
async (file: File) => {
102117
const fileId = generateId()
103118

104119
await saveFile(fileId, file)
@@ -114,7 +129,7 @@ export default function Chat() {
114129
{
115130
state: 'result',
116131
toolCallId: generateId(),
117-
toolName: 'requestCsv',
132+
toolName: 'requestSql',
118133
args: {},
119134
result: {
120135
success: true,
@@ -125,7 +140,7 @@ export default function Chat() {
125140
type: file.type,
126141
lastModified: file.lastModified,
127142
},
128-
preview: text.split('\n').slice(0, 4).join('\n').trim(),
143+
preview: text.split('\n').slice(0, 10).join('\n').trim(),
129144
},
130145
},
131146
],
@@ -147,7 +162,16 @@ export default function Chat() {
147162
const [file] = files
148163

149164
if (file) {
150-
await sendCsv(file)
165+
if (file.type === 'text/csv' || file.name.endsWith('.csv')) {
166+
await sendCsv(file)
167+
} else if (file.type === 'application/sql' || file.name.endsWith('.sql')) {
168+
await sendSql(file)
169+
} else {
170+
appendMessage({
171+
role: 'assistant',
172+
content: `Only CSV and SQL files are currently supported.`,
173+
})
174+
}
151175
}
152176
},
153177
cursorElement: (

apps/postgres-new/components/ide.tsx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ export default function IDE({ children, className }: IDEProps) {
5151
return toolInvocations
5252
.map((tool) =>
5353
// Only include SQL that successfully executed against the DB
54-
tool.toolName === 'executeSql' && 'result' in tool && tool.result.success === true
54+
(tool.toolName === 'executeSql' || tool.toolName === 'importSql') &&
55+
'result' in tool &&
56+
tool.result.success === true
5557
? tool.args.sql
5658
: undefined
5759
)

apps/postgres-new/components/tools/index.tsx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import CsvRequest from './csv-request'
66
import ExecutedSql from './executed-sql'
77
import GeneratedChart from './generated-chart'
88
import GeneratedEmbedding from './generated-embedding'
9+
import SqlImport from './sql-import'
10+
import SqlRequest from './sql-request'
911

1012
export type ToolUiProps = {
1113
toolInvocation: ToolInvocation
@@ -23,6 +25,10 @@ export function ToolUi({ toolInvocation }: ToolUiProps) {
2325
return <CsvImport toolInvocation={toolInvocation} />
2426
case 'exportCsv':
2527
return <CsvExport toolInvocation={toolInvocation} />
28+
case 'requestSql':
29+
return <SqlRequest toolInvocation={toolInvocation} />
30+
case 'importSql':
31+
return <SqlImport toolInvocation={toolInvocation} />
2632
case 'renameConversation':
2733
return <ConversationRename toolInvocation={toolInvocation} />
2834
case 'embed':
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import { useMemo } from 'react'
2+
import { formatSql } from '~/lib/sql-util'
3+
import { ToolInvocation } from '~/lib/tools'
4+
import CodeAccordion from '../code-accordion'
5+
6+
export type SqlImportProps = {
7+
toolInvocation: ToolInvocation<'importSql'>
8+
}
9+
10+
export default function SqlImport({ toolInvocation }: SqlImportProps) {
11+
const { fileId, sql } = toolInvocation.args
12+
13+
const formattedSql = useMemo(() => formatSql(sql), [sql])
14+
15+
if (!('result' in toolInvocation)) {
16+
return null
17+
}
18+
19+
const { result } = toolInvocation
20+
21+
if (!result.success) {
22+
return (
23+
<CodeAccordion
24+
title="Error executing SQL"
25+
language="sql"
26+
code={formattedSql ?? sql}
27+
error={result.error}
28+
/>
29+
)
30+
}
31+
32+
return <CodeAccordion title="Executed SQL" language="sql" code={formattedSql ?? sql} />
33+
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import { generateId } from 'ai'
2+
import { useChat } from 'ai/react'
3+
import { m } from 'framer-motion'
4+
import { Paperclip } from 'lucide-react'
5+
import { loadFile, saveFile } from '~/lib/files'
6+
import { ToolInvocation } from '~/lib/tools'
7+
import { downloadFile } from '~/lib/util'
8+
import { useWorkspace } from '../workspace'
9+
10+
export type SqlRequestProps = {
11+
toolInvocation: ToolInvocation<'requestSql'>
12+
}
13+
14+
export default function SqlRequest({ toolInvocation }: SqlRequestProps) {
15+
const { databaseId } = useWorkspace()
16+
17+
const { addToolResult } = useChat({
18+
id: databaseId,
19+
api: '/api/chat',
20+
})
21+
22+
if ('result' in toolInvocation) {
23+
const { result } = toolInvocation
24+
25+
if (!result.success) {
26+
return (
27+
<m.div
28+
layout="position"
29+
layoutId={toolInvocation.toolCallId}
30+
className="self-end px-5 py-2.5 text-base rounded-full bg-destructive flex gap-2 items-center text-lighter italic"
31+
>
32+
No SQL file selected
33+
</m.div>
34+
)
35+
}
36+
37+
return (
38+
<m.div
39+
layout="position"
40+
layoutId={toolInvocation.toolCallId}
41+
className="self-end px-5 py-2.5 text-base rounded-full bg-border flex gap-2 items-center text-lighter italic"
42+
style={{
43+
// same value as tailwind, used to keep constant radius during framer animation
44+
// see: https://www.framer.com/motion/layout-animations/##scale-correction
45+
borderRadius: 9999,
46+
}}
47+
>
48+
<Paperclip size={14} />
49+
<m.span
50+
className="cursor-pointer hover:underline"
51+
layout
52+
onClick={async () => {
53+
const file = await loadFile(result.fileId)
54+
downloadFile(file)
55+
}}
56+
>
57+
{result.file.name}
58+
</m.span>
59+
</m.div>
60+
)
61+
}
62+
63+
return (
64+
<m.div layout="position" layoutId={toolInvocation.toolCallId}>
65+
<input
66+
type="file"
67+
onChange={async (e) => {
68+
if (e.target.files) {
69+
try {
70+
const [file] = Array.from(e.target.files)
71+
72+
if (!file) {
73+
throw new Error('No file found')
74+
}
75+
76+
if (file.type !== 'text/sql') {
77+
throw new Error('File is not a SQL file')
78+
}
79+
80+
const fileId = generateId()
81+
82+
await saveFile(fileId, file)
83+
84+
const text = await file.text()
85+
86+
addToolResult({
87+
toolCallId: toolInvocation.toolCallId,
88+
result: {
89+
success: true,
90+
fileId: fileId,
91+
file: {
92+
name: file.name,
93+
size: file.size,
94+
type: file.type,
95+
lastModified: file.lastModified,
96+
},
97+
preview: text.split('\n').slice(0, 10).join('\n').trim(),
98+
},
99+
})
100+
} catch (error) {
101+
addToolResult({
102+
toolCallId: toolInvocation.toolCallId,
103+
result: {
104+
success: false,
105+
error: error instanceof Error ? error.message : 'An unknown error occurred',
106+
},
107+
})
108+
}
109+
}
110+
}}
111+
/>
112+
</m.div>
113+
)
114+
}

apps/postgres-new/lib/hooks.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,25 @@ export function useOnToolCall(databaseId: string) {
435435
}
436436
}
437437
}
438+
case 'importSql': {
439+
const { fileId } = toolCall.args
440+
441+
try {
442+
const file = await loadFile(fileId)
443+
await db.exec(await file.text())
444+
await refetchTables()
445+
446+
return {
447+
success: true,
448+
message: 'The SQL file has been executed successfully.',
449+
}
450+
} catch (error) {
451+
return {
452+
success: false,
453+
error: error instanceof Error ? error.message : 'An unknown error has occurred',
454+
}
455+
}
456+
}
438457
}
439458
},
440459
[dbManager, refetchTables, updateDatabase, databaseId, vectorDataTypeId]

apps/postgres-new/lib/tools.ts

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,40 @@ export const tools = {
162162
})
163163
),
164164
},
165+
requestSql: {
166+
description: codeBlock`
167+
Requests a SQL file upload from the user.
168+
`,
169+
args: z.object({}),
170+
result: result(
171+
z.object({
172+
fileId: z.string(),
173+
file: z.object({
174+
name: z.string(),
175+
size: z.number(),
176+
type: z.string(),
177+
lastModified: z.number(),
178+
}),
179+
preview: z.string(),
180+
})
181+
),
182+
},
183+
importSql: {
184+
description: codeBlock`
185+
Executes a Postgres SQL file with the specified ID against the user's database. Call \`requestSql\` first.
186+
`,
187+
args: z.object({
188+
fileId: z.string().describe('The ID of the SQL file to execute'),
189+
sql: z.string().describe(codeBlock`
190+
The Postgres SQL file content to execute against the user's database.
191+
`),
192+
}),
193+
result: result(
194+
z.object({
195+
message: z.string(),
196+
})
197+
),
198+
},
165199
embed: {
166200
description: codeBlock`
167201
Generates vector embeddings for texts. Use with pgvector extension.

0 commit comments

Comments
 (0)