diff --git a/README.md b/README.md index d1c4c0f..458aac6 100644 --- a/README.md +++ b/README.md @@ -10,9 +10,9 @@ RAG Studio requires AWS for access to both LLM and embedding models. Please comp - A S3 bucket to store the documents - The following models configured and accessible via AWS Bedrock. Any of the models not enabled will not function in the UI. - - Llama3.1 8b Instruct V1 (`meta.llama3-1-8b-instruct-v1:0`) - This model is required for the RAG Studio to function - - Llama3.1 70b Instruct V1 (`meta.llama3-1-70b-instruct-v1:0`) - - Llama3.1 405b Instruct V1 (`meta.llama3-1-405b-instruct-v1:0`) + - Llama3.1 8b Instruct v1 (`meta.llama3-1-8b-instruct-v1:0`) - This model is required for the RAG Studio to function + - Llama3.1 70b Instruct v1 (`meta.llama3-1-70b-instruct-v1:0`) + - Cohere Command R+ v1 (`cohere.command-r-plus-v1:0`) - For Embedding, you will need to enable the following model in AWS Bedrock: - Cohere English Embedding v3 (`meta.cohere-english-embedding-v3:0`) diff --git a/backend/src/main/java/com/cloudera/cai/rag/Types.java b/backend/src/main/java/com/cloudera/cai/rag/Types.java index bfd14fd..ac376ba 100644 --- a/backend/src/main/java/com/cloudera/cai/rag/Types.java +++ b/backend/src/main/java/com/cloudera/cai/rag/Types.java @@ -100,5 +100,7 @@ public record Session( Instant timeUpdated, String createdById, String updatedById, - Instant lastInteractionTime) {} + Instant lastInteractionTime, + String inferenceModel, + Integer responseChunks) {} } diff --git a/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionController.java b/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionController.java index d283fa5..cae16e0 100644 --- a/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionController.java +++ b/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionController.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* * CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) * (C) Cloudera, Inc. 2024 * All rights reserved. @@ -63,6 +63,13 @@ public Types.Session create(@RequestBody Types.Session input, HttpServletRequest return sessionService.create(input); } + @PostMapping(path = "/{id}", consumes = "application/json", produces = "application/json") + public Types.Session update(@RequestBody Types.Session input, HttpServletRequest request) { + String username = userTokenCookieDecoder.extractUsername(request.getCookies()); + input = input.withUpdatedById(username); + return sessionService.update(input); + } + @DeleteMapping(path = "/{id}") public void delete(@PathVariable Long id) { sessionService.delete(id); diff --git a/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionRepository.java b/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionRepository.java index b795972..3b80677 100644 --- a/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionRepository.java +++ b/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionRepository.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* * CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) * (C) Cloudera, Inc. 2024 * All rights reserved. @@ -65,8 +65,8 @@ public Long create(Types.Session input) { handle -> { var sql = """ - INSERT INTO CHAT_SESSION (name, created_by_id, updated_by_id) - VALUES (:name, :createdById, :updatedById) + INSERT INTO CHAT_SESSION (name, created_by_id, updated_by_id, inference_model, response_chunks) + VALUES (:name, :createdById, :updatedById, :inferenceModel, :responseChunks) """; Long id = insertSession(input, handle, sql); insertSessionDataSources(handle, id, input.dataSourceIds()); @@ -123,6 +123,8 @@ private Stream querySessions(Query query) { Types.Session.builder() .id(sessionId) .name(rowView.getColumn("name", String.class)) + .inferenceModel(rowView.getColumn("inference_model", String.class)) + .responseChunks(rowView.getColumn("response_chunks", Integer.class)) .createdById(rowView.getColumn("created_by_id", String.class)) .timeCreated(rowView.getColumn("time_created", Instant.class)) .updatedById(rowView.getColumn("updated_by_id", String.class)) @@ -157,9 +159,20 @@ public static SessionRepository createNull() { } public void delete(Long id) { + jdbi.useHandle( + handle -> handle.execute("UPDATE CHAT_SESSION SET DELETED = ? WHERE ID = ?", true, id)); + } + + public void update(Types.Session input) { jdbi.useHandle( handle -> { - handle.execute("UPDATE CHAT_SESSION SET DELETED = ? WHERE ID = ?", true, id); + var sql = + """ + UPDATE CHAT_SESSION + SET name = :name, updated_by_id = :updatedById, inference_model = :inferenceModel, response_chunks = :responseChunks + WHERE id = :id + """; + handle.createUpdate(sql).bindMethods(input).execute(); }); } } diff --git a/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionService.java b/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionService.java index d4b7f09..b0d3466 100644 --- a/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionService.java +++ b/backend/src/main/java/com/cloudera/cai/rag/sessions/SessionService.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* * CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) * (C) Cloudera, Inc. 2024 * All rights reserved. @@ -56,6 +56,11 @@ public Types.Session create(Types.Session input) { return sessionRepository.getSessionById(id); } + public Types.Session update(Types.Session input) { + sessionRepository.update(input); + return sessionRepository.getSessionById(input.id()); + } + public List getSessions() { return sessionRepository.getSessions(); } diff --git a/backend/src/main/resources/migrations/h2/13_add_chat_configuration.down.sql b/backend/src/main/resources/migrations/h2/13_add_chat_configuration.down.sql new file mode 100644 index 0000000..dc60648 --- /dev/null +++ b/backend/src/main/resources/migrations/h2/13_add_chat_configuration.down.sql @@ -0,0 +1,46 @@ +/* + * CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) + * (C) Cloudera, Inc. 2024 + * All rights reserved. + * + * Applicable Open Source License: Apache 2.0 + * + * NOTE: Cloudera open source products are modular software products + * made up of hundreds of individual components, each of which was + * individually copyrighted. Each Cloudera open source product is a + * collective work under U.S. Copyright Law. Your license to use the + * collective work is as provided in your written agreement with + * Cloudera. Used apart from the collective work, this file is + * licensed for your use pursuant to the open source license + * identified above. + * + * This code is provided to you pursuant a written agreement with + * (i) Cloudera, Inc. or (ii) a third-party authorized to distribute + * this code. If you do not have a written agreement with Cloudera nor + * with an authorized and properly licensed third party, you do not + * have any rights to access nor to use this code. + * + * Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the + * contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY + * KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED + * WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO + * IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, + * AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS + * ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE + * OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR + * CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES + * RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF + * BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF + * DATA. + */ + +SET MODE MYSQL; + +BEGIN; + +ALTER TABLE CHAT_SESSION DROP COLUMN inference_model; +ALTER TABLE CHAT_SESSION DROP COLUMN response_chunks; + +COMMIT; \ No newline at end of file diff --git a/backend/src/main/resources/migrations/h2/13_add_chat_configuration.up.sql b/backend/src/main/resources/migrations/h2/13_add_chat_configuration.up.sql new file mode 100644 index 0000000..06aa54f --- /dev/null +++ b/backend/src/main/resources/migrations/h2/13_add_chat_configuration.up.sql @@ -0,0 +1,47 @@ +/* + * CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) + * (C) Cloudera, Inc. 2024 + * All rights reserved. + * + * Applicable Open Source License: Apache 2.0 + * + * NOTE: Cloudera open source products are modular software products + * made up of hundreds of individual components, each of which was + * individually copyrighted. Each Cloudera open source product is a + * collective work under U.S. Copyright Law. Your license to use the + * collective work is as provided in your written agreement with + * Cloudera. Used apart from the collective work, this file is + * licensed for your use pursuant to the open source license + * identified above. + * + * This code is provided to you pursuant a written agreement with + * (i) Cloudera, Inc. or (ii) a third-party authorized to distribute + * this code. If you do not have a written agreement with Cloudera nor + * with an authorized and properly licensed third party, you do not + * have any rights to access nor to use this code. + * + * Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the + * contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY + * KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED + * WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO + * IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, + * AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS + * ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE + * OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR + * CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES + * RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF + * BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF + * DATA. + */ + +SET MODE MYSQL; + +BEGIN; + +ALTER TABLE CHAT_SESSION ADD COLUMN inference_model VARCHAR(255); +ALTER TABLE CHAT_SESSION ADD COLUMN response_chunks INTEGER DEFAULT 5; + + +COMMIT; \ No newline at end of file diff --git a/backend/src/main/resources/migrations/migrations.txt b/backend/src/main/resources/migrations/migrations.txt index c2b3462..8ecd017 100644 --- a/backend/src/main/resources/migrations/migrations.txt +++ b/backend/src/main/resources/migrations/migrations.txt @@ -23,3 +23,5 @@ 11_add_session_deleted.up.sql 12_add_doc_deleted.down.sql 12_add_doc_deleted.up.sql +13_add_chat_configuration.down.sql +13_add_chat_configuration.up.sql \ No newline at end of file diff --git a/backend/src/main/resources/migrations/postgres/13_add_chat_configuration.down.sql b/backend/src/main/resources/migrations/postgres/13_add_chat_configuration.down.sql new file mode 100644 index 0000000..dc60648 --- /dev/null +++ b/backend/src/main/resources/migrations/postgres/13_add_chat_configuration.down.sql @@ -0,0 +1,46 @@ +/* + * CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) + * (C) Cloudera, Inc. 2024 + * All rights reserved. + * + * Applicable Open Source License: Apache 2.0 + * + * NOTE: Cloudera open source products are modular software products + * made up of hundreds of individual components, each of which was + * individually copyrighted. Each Cloudera open source product is a + * collective work under U.S. Copyright Law. Your license to use the + * collective work is as provided in your written agreement with + * Cloudera. Used apart from the collective work, this file is + * licensed for your use pursuant to the open source license + * identified above. + * + * This code is provided to you pursuant a written agreement with + * (i) Cloudera, Inc. or (ii) a third-party authorized to distribute + * this code. If you do not have a written agreement with Cloudera nor + * with an authorized and properly licensed third party, you do not + * have any rights to access nor to use this code. + * + * Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the + * contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY + * KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED + * WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO + * IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, + * AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS + * ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE + * OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR + * CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES + * RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF + * BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF + * DATA. + */ + +SET MODE MYSQL; + +BEGIN; + +ALTER TABLE CHAT_SESSION DROP COLUMN inference_model; +ALTER TABLE CHAT_SESSION DROP COLUMN response_chunks; + +COMMIT; \ No newline at end of file diff --git a/backend/src/main/resources/migrations/postgres/13_add_chat_configuration.up.sql b/backend/src/main/resources/migrations/postgres/13_add_chat_configuration.up.sql new file mode 100644 index 0000000..06aa54f --- /dev/null +++ b/backend/src/main/resources/migrations/postgres/13_add_chat_configuration.up.sql @@ -0,0 +1,47 @@ +/* + * CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) + * (C) Cloudera, Inc. 2024 + * All rights reserved. + * + * Applicable Open Source License: Apache 2.0 + * + * NOTE: Cloudera open source products are modular software products + * made up of hundreds of individual components, each of which was + * individually copyrighted. Each Cloudera open source product is a + * collective work under U.S. Copyright Law. Your license to use the + * collective work is as provided in your written agreement with + * Cloudera. Used apart from the collective work, this file is + * licensed for your use pursuant to the open source license + * identified above. + * + * This code is provided to you pursuant a written agreement with + * (i) Cloudera, Inc. or (ii) a third-party authorized to distribute + * this code. If you do not have a written agreement with Cloudera nor + * with an authorized and properly licensed third party, you do not + * have any rights to access nor to use this code. + * + * Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the + * contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY + * KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED + * WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO + * IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, + * AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS + * ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE + * OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR + * CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES + * RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF + * BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF + * DATA. + */ + +SET MODE MYSQL; + +BEGIN; + +ALTER TABLE CHAT_SESSION ADD COLUMN inference_model VARCHAR(255); +ALTER TABLE CHAT_SESSION ADD COLUMN response_chunks INTEGER DEFAULT 5; + + +COMMIT; \ No newline at end of file diff --git a/backend/src/test/java/com/cloudera/cai/rag/TestData.java b/backend/src/test/java/com/cloudera/cai/rag/TestData.java index fd93b0f..c5176ad 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/TestData.java +++ b/backend/src/test/java/com/cloudera/cai/rag/TestData.java @@ -45,7 +45,8 @@ public class TestData { public static Types.Session createTestSessionInstance(String sessionName) { - return new Types.Session(null, sessionName, List.of(1L, 2L, 3L), null, null, null, null, null); + return new Types.Session( + null, sessionName, List.of(1L, 2L, 3L), null, null, null, null, null, "test-model", 3); } public static Types.RagDataSource createTestDataSourceInstance( diff --git a/backend/src/test/java/com/cloudera/cai/rag/sessions/SessionControllerTest.java b/backend/src/test/java/com/cloudera/cai/rag/sessions/SessionControllerTest.java index 315a3e5..799a611 100644 --- a/backend/src/test/java/com/cloudera/cai/rag/sessions/SessionControllerTest.java +++ b/backend/src/test/java/com/cloudera/cai/rag/sessions/SessionControllerTest.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* * CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) * (C) Cloudera, Inc. 2024 * All rights reserved. @@ -62,6 +62,8 @@ void create() throws JsonProcessingException { Types.Session result = sessionController.create(input, request); assertThat(result.id()).isNotNull(); assertThat(result.name()).isEqualTo(sessionName); + assertThat(result.inferenceModel()).isEqualTo(input.inferenceModel()); + assertThat(result.responseChunks()).isEqualTo(input.responseChunks()); assertThat(result.dataSourceIds()).containsExactlyInAnyOrder(1L, 2L, 3L); assertThat(result.timeCreated()).isNotNull(); assertThat(result.timeUpdated()).isNotNull(); @@ -70,6 +72,45 @@ void create() throws JsonProcessingException { assertThat(result.lastInteractionTime()).isNull(); } + @Test + void update() throws JsonProcessingException { + SessionController sessionController = new SessionController(SessionService.createNull()); + var request = new MockHttpServletRequest(); + request.setCookies( + new MockCookie("_basusertoken", UserTokenCookieDecoderTest.encodeCookie("test-user"))); + var sessionName = "test"; + Types.Session input = TestData.createTestSessionInstance(sessionName); + Types.Session result = sessionController.create(input, request); + + var updatedResponseChunks = 1; + var updatedInferenceModel = "new-model-name"; + var updatedName = "new-name"; + + request = new MockHttpServletRequest(); + request.setCookies( + new MockCookie( + "_basusertoken", UserTokenCookieDecoderTest.encodeCookie("update-test-user"))); + + var updatedSession = + sessionController.update( + result + .withInferenceModel(updatedInferenceModel) + .withResponseChunks(updatedResponseChunks) + .withName(updatedName), + request); + + assertThat(updatedSession.id()).isNotNull(); + assertThat(updatedSession.name()).isEqualTo(updatedName); + assertThat(updatedSession.inferenceModel()).isEqualTo(updatedInferenceModel); + assertThat(updatedSession.responseChunks()).isEqualTo(updatedResponseChunks); + assertThat(updatedSession.dataSourceIds()).containsExactlyInAnyOrder(1L, 2L, 3L); + assertThat(updatedSession.timeCreated()).isNotNull(); + assertThat(updatedSession.timeUpdated()).isNotNull(); + assertThat(updatedSession.createdById()).isEqualTo("test-user"); + assertThat(updatedSession.updatedById()).isEqualTo("update-test-user"); + assertThat(updatedSession.lastInteractionTime()).isNull(); + } + @Test void delete() { SessionService sessionService = SessionService.createNull(); diff --git a/llm-service/app/rag_types.py b/llm-service/app/rag_types.py index 38339cb..a6a0be8 100644 --- a/llm-service/app/rag_types.py +++ b/llm-service/app/rag_types.py @@ -40,10 +40,13 @@ from pydantic import BaseModel, ConfigDict +from app.services.models import DEFAULT_BEDROCK_LLM_MODEL + + class RagPredictConfiguration(BaseModel): model_config = ConfigDict(protected_namespaces=()) top_k: int = 5 chunk_size: int = 512 - model_name: str = "meta.llama3-1-8b-instruct-v1:0" + model_name: str = DEFAULT_BEDROCK_LLM_MODEL exclude_knowledge_base: Optional[bool] = False diff --git a/llm-service/app/services/doc_summaries.py b/llm-service/app/services/doc_summaries.py index 586cdd6..46e6000 100644 --- a/llm-service/app/services/doc_summaries.py +++ b/llm-service/app/services/doc_summaries.py @@ -115,7 +115,7 @@ def generate_summary( ## todo: move to somewhere better; these are defaults to use when none are explicitly provided def set_settings_globals() -> None: - Settings.llm = models.get_llm("meta.llama3-8b-instruct-v1:0") + Settings.llm = models.get_llm() Settings.embed_model = models.get_embedding_model() Settings.text_splitter = SentenceSplitter(chunk_size=1024) diff --git a/llm-service/app/services/evaluators.py b/llm-service/app/services/evaluators.py index 8ea5eba..a428237 100644 --- a/llm-service/app/services/evaluators.py +++ b/llm-service/app/services/evaluators.py @@ -47,7 +47,7 @@ def evaluate_response( query: str, chat_response: AgentChatResponse, ) -> tuple[float, float]: - evaluator_llm = models.get_llm("meta.llama3-8b-instruct-v1:0") + evaluator_llm = models.get_llm() relevancy_evaluator = RelevancyEvaluator(llm=evaluator_llm) relevance = relevancy_evaluator.evaluate_response( diff --git a/llm-service/app/services/models.py b/llm-service/app/services/models.py index f64dbfc..602949a 100644 --- a/llm-service/app/services/models.py +++ b/llm-service/app/services/models.py @@ -51,6 +51,7 @@ from .caii import get_llm as caii_llm from .llama_utils import completion_to_prompt, messages_to_prompt +DEFAULT_BEDROCK_LLM_MODEL = "meta.llama3-1-8b-instruct-v1:0" def get_embedding_model() -> BaseEmbedding: if is_caii_enabled(): @@ -58,7 +59,7 @@ def get_embedding_model() -> BaseEmbedding: return BedrockEmbedding(model_name="cohere.embed-english-v3") -def get_llm(model_name: str = None) -> LLM: +def get_llm(model_name: str = DEFAULT_BEDROCK_LLM_MODEL) -> LLM: if is_caii_enabled(): return caii_llm( domain=os.environ["CAII_DOMAIN"], @@ -66,6 +67,9 @@ def get_llm(model_name: str = None) -> LLM: messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, ) + if not model_name: + model_name = DEFAULT_BEDROCK_LLM_MODEL + return BedrockConverse( model=model_name, # context_size=BEDROCK_MODELS.get(model_name, 8192), @@ -90,21 +94,16 @@ def is_caii_enabled() -> bool: domain: str = os.environ.get("CAII_DOMAIN", "") return len(domain) > 0 - def _get_bedrock_llm_models() -> List[Dict[str, Any]]: return [ { - "model_id": "meta.llama3-1-8b-instruct-v1:0", + "model_id": DEFAULT_BEDROCK_LLM_MODEL, "name": "Llama3.1 8B Instruct v1", }, { "model_id": "meta.llama3-1-70b-instruct-v1:0", "name": "Llama3.1 70B Instruct v1", }, - { - "model_id": "meta.llama3-1-405b-instruct-v1:0", - "name": "Llama3.1 405B Instruct v1", - }, { "model_id": "cohere.command-r-plus-v1:0", "name": "Cohere Command R Plus v1", diff --git a/llm-service/app/tests/conftest.py b/llm-service/app/tests/conftest.py index 24149bd..7665cf4 100644 --- a/llm-service/app/tests/conftest.py +++ b/llm-service/app/tests/conftest.py @@ -234,7 +234,7 @@ def llm(monkeypatch: pytest.MonkeyPatch) -> LLM: model = DummyLlm() # Requires that the app usages import the file and not the function directly as python creates a copy when importing the function - monkeypatch.setattr(models, "get_llm", lambda model_name: model) + monkeypatch.setattr(models, "get_llm", lambda : model) return model diff --git a/local-dev.sh b/local-dev.sh index 25072c7..2eaed0f 100755 --- a/local-dev.sh +++ b/local-dev.sh @@ -53,6 +53,8 @@ for sig in INT QUIT HUP TERM; do done trap cleanup EXIT +docker stop qdrant_dev || true + mkdir -p databases docker run --name qdrant_dev --rm -d -p 6333:6333 -p 6334:6334 -v $(pwd)/databases/qdrant_storage:/qdrant/storage:z qdrant/qdrant diff --git a/ui/src/api/chatApi.ts b/ui/src/api/chatApi.ts index 80c5f16..74e17f7 100644 --- a/ui/src/api/chatApi.ts +++ b/ui/src/api/chatApi.ts @@ -46,6 +46,7 @@ import { } from "src/api/utils.ts"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { suggestedQuestionKey } from "src/api/ragQueryApi.ts"; +import { Session } from "src/api/sessionApi.ts"; export interface SourceNode { node_id: string; @@ -191,3 +192,22 @@ const chatMutation = async ( request, ); }; + +export const createQueryConfiguration = ( + excludeKnowledgeBase: boolean, + activeSession?: Session, +): QueryConfiguration => { + // todo: maybe we should just throw an exception here? + if (!activeSession) { + return { + top_k: 5, + model_name: "", + exclude_knowledge_base: false, + }; + } + return { + top_k: activeSession.responseChunks, + model_name: activeSession.inferenceModel ?? "", + exclude_knowledge_base: excludeKnowledgeBase, + }; +}; diff --git a/ui/src/api/ragQueryApi.ts b/ui/src/api/ragQueryApi.ts index c104f4e..e546a63 100644 --- a/ui/src/api/ragQueryApi.ts +++ b/ui/src/api/ragQueryApi.ts @@ -46,11 +46,6 @@ import { QueryKeys, } from "src/api/utils.ts"; -export interface RagMessage { - role: "user" | "assistant"; - content: string; -} - export interface SuggestQuestionsRequest { data_source_id: string; configuration: QueryConfiguration; @@ -78,9 +73,7 @@ export const useSuggestQuestions = (request: SuggestQuestionsRequest) => { // eslint-disable-next-line @tanstack/query/exhaustive-deps queryKey: suggestedQuestionKey(request.data_source_id), queryFn: () => suggestQuestionsQuery(request), - enabled: - Boolean(request.data_source_id) && - Boolean(request.configuration.model_name), + enabled: Boolean(request.data_source_id), gcTime: 0, }); }; diff --git a/ui/src/api/sessionApi.ts b/ui/src/api/sessionApi.ts index 91a1a08..17a40bf 100644 --- a/ui/src/api/sessionApi.ts +++ b/ui/src/api/sessionApi.ts @@ -53,6 +53,8 @@ export interface Session { id: number; name: string; dataSourceIds: number[]; + inferenceModel?: string; + responseChunks: number; timeCreated: number; timeUpdated: number; createdById: string; @@ -60,6 +62,16 @@ export interface Session { lastInteractionTime: number; } +export type CreateSessionRequest = Pick< + Session, + "name" | "dataSourceIds" | "inferenceModel" | "responseChunks" +>; + +export type UpdateSessionRequest = Pick< + Session, + "responseChunks" | "inferenceModel" | "name" | "id" +>; + export const getSessionsQueryOptions = queryOptions({ queryKey: [QueryKeys.getSessions], queryFn: async () => await getSessionsQuery(), @@ -69,11 +81,6 @@ export const getSessionsQuery = async (): Promise => { return await getRequest(`${ragPath}/${paths.sessions}`); }; -export interface CreateSessionRequest { - name: string; - dataSourceIds: number[]; -} - export const useCreateSessionMutation = ({ onSuccess, onError, @@ -92,6 +99,27 @@ const createSessionMutation = async ( return await postRequest(`${ragPath}/${paths.sessions}`, request); }; +export const useUpdateSessionMutation = ({ + onSuccess, + onError, +}: UseMutationType) => { + return useMutation({ + mutationKey: [MutationKeys.updateSession], + mutationFn: updateSessionMutation, + onSuccess, + onError, + }); +}; + +const updateSessionMutation = async ( + request: UpdateSessionRequest, +): Promise => { + return await postRequest( + `${ragPath}/${paths.sessions}/${request.id.toString()}`, + request, + ); +}; + export const useDeleteSessionMutation = ({ onSuccess, onError, diff --git a/ui/src/api/utils.ts b/ui/src/api/utils.ts index bf3ed9d..b64e61e 100644 --- a/ui/src/api/utils.ts +++ b/ui/src/api/utils.ts @@ -68,6 +68,7 @@ export enum MutationKeys { "testLlmModel" = "testLlmModel", "testEmbeddingModel" = "testEmbeddingModel", "visualizeDataSourceWithUserQuery" = "visualizeDataSourceWithUserQuery", + "updateSession" = "updateSession", } export enum QueryKeys { diff --git a/ui/src/pages/RagChatTab/ChatLayout.tsx b/ui/src/pages/RagChatTab/ChatLayout.tsx index f35f719..0579e9d 100644 --- a/ui/src/pages/RagChatTab/ChatLayout.tsx +++ b/ui/src/pages/RagChatTab/ChatLayout.tsx @@ -43,38 +43,31 @@ import { getSessionsQueryOptions, Session } from "src/api/sessionApi.ts"; import { groupBy } from "lodash"; import { format } from "date-fns"; import { useParams } from "@tanstack/react-router"; -import { useEffect, useState, useMemo } from "react"; -import { QueryConfiguration, useChatHistoryQuery } from "src/api/chatApi.ts"; -import { - defaultQueryConfig, - RagChatContext, -} from "pages/RagChatTab/State/RagChatContext.tsx"; +import { useEffect, useMemo, useState } from "react"; +import { useChatHistoryQuery } from "src/api/chatApi.ts"; +import { RagChatContext } from "pages/RagChatTab/State/RagChatContext.tsx"; import { useGetDataSourcesQuery } from "src/api/dataSourceApi.ts"; import { useSuspenseQuery } from "@tanstack/react-query"; -import { getLlmModelsQueryOptions } from "src/api/modelsApi.ts"; const getSessionForSessionId = (sessionId?: string, sessions?: Session[]) => { return sessions?.find((session) => session.id.toString() === sessionId); }; -const getDataSourceIdForSession = (session?: Session) => { - return session?.dataSourceIds[0]; -}; - function ChatLayout() { const { data: sessions } = useSuspenseQuery(getSessionsQueryOptions); - const { data: llmModels } = useSuspenseQuery(getLlmModelsQueryOptions); + const { sessionId } = useParams({ strict: false }); - const activeSession = getSessionForSessionId(sessionId, sessions); - const dataSourceId = getDataSourceIdForSession(activeSession); const [currentQuestion, setCurrentQuestion] = useState(""); const { data: dataSources, status: dataSourcesStatus } = useGetDataSourcesQuery(); - const [queryConfiguration, setQueryConfiguration] = - useState(defaultQueryConfig); + const [excludeKnowledgeBase, setExcludeKnowledgeBase] = useState(false); const { status: chatHistoryStatus, data: chatHistory } = useChatHistoryQuery( sessionId?.toString() ?? "", ); + + const activeSession = getSessionForSessionId(sessionId, sessions); + const dataSourceId = activeSession?.dataSourceIds[0]; + const dataSourceSize = useMemo(() => { return ( dataSources?.find((ds) => ds.id === dataSourceId)?.totalDocSize ?? null @@ -85,15 +78,6 @@ function ChatLayout() { setCurrentQuestion(""); }, [sessionId]); - useEffect(() => { - if (llmModels.length) { - setQueryConfiguration((prev) => ({ - ...prev, - model_name: llmModels[0].model_id, - })); - } - }, [llmModels, setQueryConfiguration]); - const sessionsByDate = groupBy(sessions, (session) => { const relevantTime = session.lastInteractionTime || session.timeUpdated; return format(relevantTime * 1000, "yyyyMMdd"); @@ -102,17 +86,18 @@ function ChatLayout() { return ( { @@ -92,20 +94,23 @@ describe("ChatBodyController", () => { const renderWithContext = (contextValue: Partial) => { const defaultContextValue: RagChatContextType = { - currentQuestion: "", - chatHistory: [], - dataSourceId: undefined, - dataSourcesStatus: undefined, - queryConfiguration: { - top_k: 5, - model_name: "", - exclude_knowledge_base: false, - }, - setQueryConfiguration: () => null, - setCurrentQuestion: () => null, + chatHistoryQuery: { chatHistoryStatus: undefined, chatHistory: [] }, + currentQuestionState: ["", () => null], + dataSourcesQuery: { dataSourcesStatus: undefined, dataSources: [] }, + excludeKnowledgeBaseState: [false, () => null], dataSourceSize: null, - dataSources: [], - activeSession: undefined, + activeSession: { + dataSourceIds: [], + id: 0, + name: "", + timeCreated: 0, + timeUpdated: 0, + createdById: "", + updatedById: "", + lastInteractionTime: 0, + responseChunks: 5, + inferenceModel: "", + }, }; return render( @@ -119,10 +124,8 @@ describe("ChatBodyController", () => { it("renders NoSessionState when no sessionId and dataSources are available", () => { renderWithContext({ - dataSourceId: undefined, - dataSourcesStatus: undefined, + dataSourcesQuery: { dataSourcesStatus: undefined, dataSources: [] }, dataSourceSize: null, - dataSources: [], activeSession: undefined, }); @@ -131,7 +134,7 @@ describe("ChatBodyController", () => { it("renders ChatLoading when dataSourcesStatus or chatHistoryStatus is pending", () => { renderWithContext({ - dataSourcesStatus: "pending", + dataSourcesQuery: { dataSourcesStatus: "pending", dataSources: [] }, }); expect(screen.getByTestId("chatLoadingSpinner")).toBeTruthy(); @@ -139,7 +142,7 @@ describe("ChatBodyController", () => { it("renders error message when dataSourcesStatus or chatHistoryStatus is error", () => { renderWithContext({ - dataSourcesStatus: "error", + dataSourcesQuery: { dataSourcesStatus: "error", dataSources: [] }, }); expect(screen.getByText("We encountered an error")).toBeTruthy(); @@ -147,18 +150,21 @@ describe("ChatBodyController", () => { it("renders ChatMessageController when chatHistory exists", () => { renderWithContext({ - chatHistory: [ - { - id: "1", - rag_message: { - user: "a test question", - assistant: "a test response", + chatHistoryQuery: { + chatHistoryStatus: undefined, + chatHistory: [ + { + id: "1", + rag_message: { + user: "a test question", + assistant: "a test response", + }, + source_nodes: [], + evaluations: [], + timestamp: 123, }, - source_nodes: [], - evaluations: [], - timestamp: 123, - }, - ], + ], + }, }); expect(screen.getByTestId("chat-message")).toBeTruthy(); @@ -166,7 +172,7 @@ describe("ChatBodyController", () => { it("renders NoDataSourcesState when no dataSources are available", () => { renderWithContext({ - dataSources: [], + dataSourcesQuery: { dataSources: [] }, activeSession: testSession, }); @@ -179,8 +185,7 @@ describe("ChatBodyController", () => { it("renders NoDataSourceForSession when no currentDataSource is found", () => { renderWithContext({ - dataSources: [testDataSource], - dataSourceId: undefined, + dataSourcesQuery: { dataSources: [testDataSource] }, activeSession: { ...testSession, dataSourceIds: [2] }, }); @@ -191,9 +196,9 @@ describe("ChatBodyController", () => { it("renders ChatMessageController when currentQuestion and dataSourceSize are available", () => { renderWithContext({ - currentQuestion: "What is AI?", + currentQuestionState: ["What is AI?", () => null], dataSourceSize: 1, - dataSources: [testDataSource], + dataSourcesQuery: { dataSources: [testDataSource] }, activeSession: testSession, }); @@ -213,7 +218,7 @@ describe("ChatBodyController", () => { renderWithContext({ dataSourceSize: 1, - dataSources: [testDataSource], + dataSourcesQuery: { dataSources: [testDataSource] }, activeSession: testSession, }); diff --git a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatBodyController.tsx b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatBodyController.tsx index 2d054c8..c24529e 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatBodyController.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatBodyController.tsx @@ -49,12 +49,10 @@ import NoDataSourceForSession from "pages/RagChatTab/ChatOutput/Placeholders/NoD const ChatBodyController = () => { const { - currentQuestion, - chatHistory, - dataSourcesStatus, - chatHistoryStatus, + currentQuestionState: [currentQuestion], + chatHistoryQuery: { chatHistory, chatHistoryStatus }, + dataSourcesQuery: { dataSources, dataSourcesStatus }, dataSourceSize, - dataSources, activeSession, } = useContext(RagChatContext); const { sessionId } = useParams({ strict: false }); diff --git a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageController.tsx b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageController.tsx index 6ce9b10..c7b4b7e 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageController.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageController.tsx @@ -41,8 +41,12 @@ import ChatMessage from "pages/RagChatTab/ChatOutput/ChatMessages/ChatMessage.ts import { RagChatContext } from "pages/RagChatTab/State/RagChatContext.tsx"; const ChatMessageController = () => { - const { chatHistory, dataSourceId } = useContext(RagChatContext); + const { + chatHistoryQuery: { chatHistory }, + activeSession, + } = useContext(RagChatContext); const scrollEl = useRef(null); + const dataSourceId = activeSession?.dataSourceIds[0]; useEffect(() => { setTimeout(() => { diff --git a/ui/src/pages/RagChatTab/ChatOutput/Placeholders/EmptyChatState.tsx b/ui/src/pages/RagChatTab/ChatOutput/Placeholders/EmptyChatState.tsx index a45f56f..3ef2b1b 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/Placeholders/EmptyChatState.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/Placeholders/EmptyChatState.tsx @@ -46,8 +46,8 @@ import { cdlGray700 } from "src/cuix/variables.ts"; import Images from "src/components/images/Images.ts"; const DataSourceSummaryCard = () => { - const { dataSourceId } = useContext(RagChatContext); - + const { activeSession } = useContext(RagChatContext); + const dataSourceId = activeSession?.dataSourceIds[0]; const dataSourceSummary = useGetDataSourceSummary({ data_source_id: dataSourceId?.toString() ?? "", queryEnabled: true, @@ -81,7 +81,9 @@ const DataSourceSummaryCard = () => { }; const EmptyChatState = ({ dataSourceSize }: { dataSourceSize: number }) => { - const { dataSourceId } = useContext(RagChatContext); + const { activeSession } = useContext(RagChatContext); + const dataSourceId = activeSession?.dataSourceIds[0]; + return ( { - const { dataSourceId, setCurrentQuestion, queryConfiguration } = - useContext(RagChatContext); - const { sessionId } = useParams({ strict: false }); - + const { + currentQuestionState: [, setCurrentQuestion], + activeSession, + excludeKnowledgeBaseState: [excludeKnowledgeBase], + } = useContext(RagChatContext); + const dataSourceId = activeSession?.dataSourceIds[0]; + const sessionId = activeSession?.id.toString(); const { data, isPending: suggestedQuestionsIsPending, isFetching: suggestedQuestionsIsFetching, } = useSuggestQuestions({ data_source_id: dataSourceId?.toString() ?? "", - configuration: queryConfiguration, + configuration: createQueryConfiguration( + excludeKnowledgeBase, + activeSession, + ), session_id: sessionId ?? "", }); @@ -80,7 +85,10 @@ const SuggestedQuestionsCards = () => { query: suggestedQuestion, data_source_id: dataSourceId.toString(), session_id: sessionId, - configuration: queryConfiguration, + configuration: createQueryConfiguration( + excludeKnowledgeBase, + activeSession, + ), }); } }; diff --git a/ui/src/pages/RagChatTab/ChatOutput/Sources/SourceCard.tsx b/ui/src/pages/RagChatTab/ChatOutput/Sources/SourceCard.tsx index 74b83a4..3193949 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/Sources/SourceCard.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/Sources/SourceCard.tsx @@ -57,9 +57,10 @@ import { cdlGray600 } from "src/cuix/variables.ts"; import MetaData from "pages/RagChatTab/ChatOutput/Sources/MetaData.tsx"; export const SourceCard = ({ source }: { source: SourceNode }) => { - const { dataSourceId } = useContext(RagChatContext); + const { activeSession } = useContext(RagChatContext); const [showContent, setShowContent] = useState(false); const chunkContents = useGetChunkContents(); + const dataSourceId = activeSession?.dataSourceIds[0]; const documentSummary = useGetDocumentSummary({ data_source_id: dataSourceId?.toString() ?? "", doc_id: source.doc_id, diff --git a/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx b/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx index 901b8b1..682a7d4 100644 --- a/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx +++ b/ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx @@ -42,7 +42,7 @@ import { DatabaseFilled, SendOutlined } from "@ant-design/icons"; import { useContext, useState } from "react"; import { RagChatContext } from "pages/RagChatTab/State/RagChatContext.tsx"; import messageQueue from "src/utils/messageQueue.ts"; -import { useChatMutation } from "src/api/chatApi.ts"; +import { createQueryConfiguration, useChatMutation } from "src/api/chatApi.ts"; import { useSuggestQuestions } from "src/api/ragQueryApi.ts"; import { useParams } from "@tanstack/react-router"; import { cdlBlue600 } from "src/cuix/variables.ts"; @@ -51,25 +51,28 @@ import type { SwitchChangeEventHandler } from "antd/lib/switch"; const RagChatQueryInput = () => { const { - dataSourceId, - queryConfiguration, - setCurrentQuestion, - chatHistory, + excludeKnowledgeBaseState: [excludeKnowledgeBase, setExcludeKnowledgeBase], + currentQuestionState: [, setCurrentQuestion], + chatHistoryQuery: { chatHistory }, dataSourceSize, - dataSourcesStatus, - setQueryConfiguration, + dataSourcesQuery: { dataSourcesStatus }, + activeSession, } = useContext(RagChatContext); - + const dataSourceId = activeSession?.dataSourceIds[0]; const [userInput, setUserInput] = useState(""); const { sessionId } = useParams({ strict: false }); + const configuration = createQueryConfiguration( + excludeKnowledgeBase, + activeSession, + ); const { data: sampleQuestions, isPending: sampleQuestionsIsPending, isFetching: sampleQuestionsIsFetching, } = useSuggestQuestions({ data_source_id: dataSourceId?.toString() ?? "", - configuration: queryConfiguration, + configuration, session_id: sessionId ?? "", }); @@ -90,16 +93,16 @@ const RagChatQueryInput = () => { query: userInput, data_source_id: dataSourceId.toString(), session_id: sessionId, - configuration: queryConfiguration, + configuration: createQueryConfiguration( + excludeKnowledgeBase, + activeSession, + ), }); } }; const handleExcludeKnowledgeBase: SwitchChangeEventHandler = (checked) => { - setQueryConfiguration((prev) => ({ - ...prev, - exclude_knowledge_base: !checked, - })); + setExcludeKnowledgeBase(() => !checked); }; return ( @@ -133,7 +136,7 @@ const RagChatQueryInput = () => { } - value={!queryConfiguration.exclude_knowledge_base} + value={!excludeKnowledgeBase} onChange={handleExcludeKnowledgeBase} /> diff --git a/ui/src/pages/RagChatTab/Header/RagChatHeader.tsx b/ui/src/pages/RagChatTab/Header/RagChatHeader.tsx index 96740da..8034cb1 100644 --- a/ui/src/pages/RagChatTab/Header/RagChatHeader.tsx +++ b/ui/src/pages/RagChatTab/Header/RagChatHeader.tsx @@ -38,14 +38,11 @@ import { Session } from "src/api/sessionApi.ts"; import { DataSourceType } from "src/api/dataSourceApi.ts"; -import { Button, Flex, Layout, Tooltip, Typography } from "antd"; -import QueryTimeSettingsModal from "pages/RagChatTab/Settings/QueryTimeSettingsModal.tsx"; -import { useContext } from "react"; -import { QueryConfiguration } from "src/api/chatApi.ts"; -import { RagChatContext } from "pages/RagChatTab/State/RagChatContext.tsx"; +import { Button, Flex, Layout, Typography } from "antd"; +import ChatSettingsModal from "pages/RagChatTab/Settings/ChatSettingsModal.tsx"; import useModal from "src/utils/useModal.ts"; import SettingsIcon from "src/cuix/icons/SettingsIcon"; -import { cdlBlue600 } from "src/cuix/variables.ts"; +import { cdlBlue600, cdlGray600 } from "src/cuix/variables.ts"; const { Header } = Layout; @@ -62,22 +59,19 @@ function getHeaderTitle( return `${activeSession.name} / ${currentDataSource.name}`; } -export const RagChatHeader = (props: { +export const RagChatHeader = ({ + activeSession, + currentDataSource, +}: { activeSession?: Session; currentDataSource?: DataSourceType; }) => { - const { queryConfiguration, setQueryConfiguration } = - useContext(RagChatContext); const settingsModal = useModal(); const handleOpenModal = () => { settingsModal.setIsModalOpen(!settingsModal.isModalOpen); }; - const handleUpdateConfiguration = (formValues: QueryConfiguration) => { - setQueryConfiguration(formValues); - settingsModal.setIsModalOpen(false); - }; return (
@@ -91,33 +85,35 @@ export const RagChatHeader = (props: { marginBottom: 0, }} > - {getHeaderTitle(props.activeSession, props.currentDataSource)} + {getHeaderTitle(activeSession, currentDataSource)} - - - + Chat Settings + + + {" "} - { settingsModal.setIsModalOpen(false); }} - queryConfiguration={queryConfiguration} - handleUpdateConfiguration={handleUpdateConfiguration} />
); diff --git a/ui/src/pages/RagChatTab/RagChat.tsx b/ui/src/pages/RagChatTab/RagChat.tsx index 356ae9e..7aa4ea2 100644 --- a/ui/src/pages/RagChatTab/RagChat.tsx +++ b/ui/src/pages/RagChatTab/RagChat.tsx @@ -46,11 +46,13 @@ import { RagChatHeader } from "pages/RagChatTab/Header/RagChatHeader.tsx"; const { Footer, Content } = Layout; const RagChat = () => { - const { dataSourceId, dataSources, activeSession } = - useContext(RagChatContext); + const { + dataSourcesQuery: { dataSources }, + activeSession, + } = useContext(RagChatContext); const currentDataSource = dataSources.find((dataSource) => { - return dataSource.id === dataSourceId; + return dataSource.id === activeSession?.dataSourceIds[0]; }); return ( diff --git a/ui/src/pages/RagChatTab/Sessions/CreateSessionForm.tsx b/ui/src/pages/RagChatTab/Sessions/CreateSessionForm.tsx index c4854e9..34c6505 100644 --- a/ui/src/pages/RagChatTab/Sessions/CreateSessionForm.tsx +++ b/ui/src/pages/RagChatTab/Sessions/CreateSessionForm.tsx @@ -36,9 +36,12 @@ * DATA. ******************************************************************************/ -import { Form, FormInstance, Input, Select } from "antd"; +import { Form, FormInstance, Input, Select, Slider } from "antd"; import { DataSourceType } from "src/api/dataSourceApi.ts"; import { CreateSessionType } from "pages/RagChatTab/Sessions/CreateSessionModal.tsx"; +import { transformModelOptions } from "src/utils/modelUtils.ts"; +import { ResponseChunksRange } from "pages/RagChatTab/Settings/ResponseChunksSlider.tsx"; +import { useGetLlmModels } from "src/api/modelsApi.ts"; export interface CreateSessionFormProps { form: FormInstance; @@ -46,11 +49,13 @@ export interface CreateSessionFormProps { } const layout = { - labelCol: { span: 8 }, - wrapperCol: { span: 16 }, + labelCol: { span: 12 }, + wrapperCol: { span: 12 }, }; const CreateSessionForm = ({ form, dataSources }: CreateSessionFormProps) => { + const { data } = useGetLlmModels(); + const formatDataSource = (value: DataSourceType) => { return { ...value, @@ -89,6 +94,23 @@ const CreateSessionForm = ({ form, dataSources }: CreateSessionFormProps) => { + + initialValue={ + data === undefined || data.length === 0 ? "" : data[0].model_id + } + name="inferenceModel" + label="Response synthesizer model" + rules={[{ required: true }]} + > + + + + - - - - name="top_k" - initialValue={queryConfiguration.top_k} - label="Maximum number of documents" - > - - - -
- - ); -}; - -export default QueryTimeSettingsModal; diff --git a/ui/src/pages/RagChatTab/State/RagChatContext.tsx b/ui/src/pages/RagChatTab/State/RagChatContext.tsx index e83d495..02d1852 100644 --- a/ui/src/pages/RagChatTab/State/RagChatContext.tsx +++ b/ui/src/pages/RagChatTab/State/RagChatContext.tsx @@ -37,40 +37,30 @@ ******************************************************************************/ import { createContext, Dispatch, SetStateAction } from "react"; -import { ChatMessageType, QueryConfiguration } from "src/api/chatApi.ts"; +import { ChatMessageType } from "src/api/chatApi.ts"; import { Session } from "src/api/sessionApi.ts"; import { DataSourceType } from "src/api/dataSourceApi.ts"; export interface RagChatContextType { - dataSourceId?: number; - queryConfiguration: QueryConfiguration; - setQueryConfiguration: Dispatch>; - setCurrentQuestion: Dispatch>; - currentQuestion: string; - chatHistory: ChatMessageType[]; - chatHistoryStatus?: "error" | "success" | "pending"; - dataSourceSize: number | null; - dataSourcesStatus?: "error" | "success" | "pending"; activeSession?: Session; - dataSources: DataSourceType[]; + currentQuestionState: [string, Dispatch>]; + chatHistoryQuery: { + chatHistory: ChatMessageType[]; + chatHistoryStatus?: "error" | "success" | "pending"; + }; + dataSourcesQuery: { + dataSources: DataSourceType[]; + dataSourcesStatus?: "error" | "success" | "pending"; + }; + dataSourceSize: number | null; + excludeKnowledgeBaseState: [boolean, Dispatch>]; } -export const defaultQueryConfig = { - top_k: 5, - model_name: "", - exclude_knowledge_base: false, -}; - export const RagChatContext = createContext({ - dataSourceId: undefined, // TODO: remove this and have it pulled from active session activeSession: undefined, - dataSources: [], - queryConfiguration: defaultQueryConfig, - setQueryConfiguration: () => null, - setCurrentQuestion: () => null, - currentQuestion: "", - chatHistory: [], + currentQuestionState: ["", () => null], + chatHistoryQuery: { chatHistory: [], chatHistoryStatus: undefined }, + dataSourcesQuery: { dataSources: [], dataSourcesStatus: undefined }, dataSourceSize: null, - chatHistoryStatus: undefined, - dataSourcesStatus: undefined, + excludeKnowledgeBaseState: [false, () => null], });