Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: show source nodes #20

Merged
merged 3 commits into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion examples/llamaindex_rag/app.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import sys
import uuid
import logging
import click
import uvicorn
import fastapi
import asyncio
import contextvars
from sqlalchemy import URL
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
Expand All @@ -14,9 +16,31 @@
from llama_index.vector_stores.tidbvector import TiDBVectorStore
from llama_index.readers.web import SimpleWebPageReader


# Setup logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()

# Setup in-memory cache
class InMemoryCache:
def __init__(self):
self.cache = {}

def set(self, key, value):
self.cache[key] = value

def get(self, key):
return self.cache.get(key)

def delete(self, key):
if key in self.cache:
del self.cache[key]

def clear(self):
self.cache.clear()

cache = InMemoryCache()


logger.info("Initializing TiDB Vector Store....")
tidb_connection_url = URL(
Expand Down Expand Up @@ -63,6 +87,17 @@ async def astreamer(response: llamaStreamingResponse):
app = fastapi.FastAPI()
templates = Jinja2Templates(directory="templates")

# Setup contextvars
request_id_contextvar = contextvars.ContextVar('request_id', default=None)

@app.middleware("http")
async def add_request_id_header(request: fastapi.Request, call_next):
request_id = str(uuid.uuid4())
request_id_contextvar.set(request_id)
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response


@app.get('/', response_class=HTMLResponse)
def index(request: fastapi.Request):
Expand All @@ -72,9 +107,16 @@ def index(request: fastapi.Request):
@app.get('/ask')
async def ask(q: str):
response = query_engine.query(q)
request_id = request_id_contextvar.get()
cache.set(request_id, vars(response))
return StreamingResponse(astreamer(response), media_type='text/event-stream')


@app.get('/getResponseMeta/{request_id}')
async def response(request_id: str):
return cache.get(request_id)


@click.group(context_settings={'max_content_width': 150})
def cli():
pass
Expand All @@ -83,7 +125,7 @@ def cli():
@cli.command()
@click.option('--host', default='127.0.0.1', help="Host, default=127.0.0.1")
@click.option('--port', default=3000, help="Port, default=3000")
@click.option('--reload', is_flag=True, help="Enable auto-reload")
@click.option('--reload', is_flag=True, default=True, help="Enable auto-reload")
def runserver(host, port, reload):
uvicorn.run(
"__main__:app", host=host, port=port, reload=reload,
Expand Down
48 changes: 37 additions & 11 deletions examples/llamaindex_rag/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,63 @@
<meta charset="UTF-8">
<title>LlamaIndex & TiDB RAG Demo</title>
<script src="https://cdn.tailwindcss.com"></script>
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/vue.min.js"></script>
<script src="https://unpkg.com/[email protected]"></script>
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/css/all.min.css" rel="stylesheet">
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/vue.min.js"></script>
</head>

<body>
<div class="flex justify-center py-10 h-svh">
<div id="app" class="w-1/2">
<div id="app" class="flex justify-around p-10 min-h-screen">
<div class="w-1/2 m-4">
<label class="block">
<input type="text" v-model="question" placeholder="Input question..."
class="block p-2 w-full border-solid border-gray-300 focus:border-gray-900 border rounded-none focus:outline-none">
class="block p-2 w-full border-solid border-gray-600 focus:border-gray-900 border rounded-none focus:outline-none">
</label>
<button type="button" class="bg-black block w-full my-2 p-2 rounded text-white" :disabled="loading"
@click="askQuestion">
<i v-if="loading && responses == ''" class="fas fa-spinner animate-spin"></i>
<i v-if="loading && answer == ''" class="fas fa-spinner animate-spin"></i>
Ask
</button>
<div id="responses" class="w-full min-h-max border py-2 px-6">
<div id="answer" class="w-full min-h-max border py-2 px-6 bg-slate-100">
<div v-html="compiledMarkdown">Empty</div>
</div>
</div>
<div class="w-1/2 m-4 border p-2 bg-slate-100 min-h-fit">
<h1>Source Nodes</h1>
<ul v-for="node in responseMeta.source_nodes" class="list-decimal ml-4">
<li v-text="node.node.text" class="pl-4"></li>
</ul>
</div>
</div>
<script>
new Vue({
el: '#app',
data: {
question: 'What did the author do at each stage of his/her growth? Use markdown format if possible',
responses: 'Empty...',
loading: false
answer: 'Empty...',
requestID: '',
loading: false,
responseMeta: {},
},
computed: {
compiledMarkdown: function () {
return marked(this.responses, { sanitize: true });
return marked(this.answer, { sanitize: true });
}
},
watch: {
loading(newVal, oldVal) {
if (!newVal) {
const self = this;
fetch(`/getResponseMeta/${self.requestID}`)
.then(r => r.json())
.then(meta => {
self.responseMeta = meta;
})
.catch(error => {
console.error('Fetch error:', error);
alert('An error occurred while fetching the response.');
});
}
}
},
methods: {
Expand All @@ -48,7 +72,8 @@
alert('Please input a question.');
return;
}
self.responses = '';
self.responseMeta = {};
self.answer = '';
self.loading = true;
fetch(`/ask?q=${encodeURIComponent(this.question)}`)
.then(response => {
Expand All @@ -63,7 +88,7 @@
return;
}
const chunk = new TextDecoder("utf-8").decode(value);
self.responses += chunk;
self.answer += chunk;
controller.enqueue(value);
push();
}).catch(error => {
Expand All @@ -75,6 +100,7 @@
push();
}
});
self.requestID = response.headers.get('X-Request-ID');
return new Response(stream);
})
.catch(error => {
Expand Down
Loading