Skip to content

Commit

Permalink
Merge pull request #58 from cloudera/mob/main
Browse files Browse the repository at this point in the history
Provide Sessions with Inference Model Persistence
  • Loading branch information
ewilliams-cloudera authored Dec 6, 2024
2 parents aeb82ee + 09264b6 commit 303a4a4
Show file tree
Hide file tree
Showing 38 changed files with 679 additions and 263 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ RAG Studio requires AWS for access to both LLM and embedding models. Please comp

- A S3 bucket to store the documents
- The following models configured and accessible via AWS Bedrock. Any of the models not enabled will not function in the UI.
- Llama3.1 8b Instruct V1 (`meta.llama3-1-8b-instruct-v1:0`) - This model is required for the RAG Studio to function
- Llama3.1 70b Instruct V1 (`meta.llama3-1-70b-instruct-v1:0`)
- Llama3.1 405b Instruct V1 (`meta.llama3-1-405b-instruct-v1:0`)
- Llama3.1 8b Instruct v1 (`meta.llama3-1-8b-instruct-v1:0`) - This model is required for the RAG Studio to function
- Llama3.1 70b Instruct v1 (`meta.llama3-1-70b-instruct-v1:0`)
- Cohere Command R+ v1 (`cohere.command-r-plus-v1:0`)
- For Embedding, you will need to enable the following model in AWS Bedrock:
- Cohere English Embedding v3 (`meta.cohere-english-embedding-v3:0`)

Expand Down
4 changes: 3 additions & 1 deletion backend/src/main/java/com/cloudera/cai/rag/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,7 @@ public record Session(
Instant timeUpdated,
String createdById,
String updatedById,
Instant lastInteractionTime) {}
Instant lastInteractionTime,
String inferenceModel,
Integer responseChunks) {}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*******************************************************************************
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
Expand Down Expand Up @@ -63,6 +63,13 @@ public Types.Session create(@RequestBody Types.Session input, HttpServletRequest
return sessionService.create(input);
}

@PostMapping(path = "/{id}", consumes = "application/json", produces = "application/json")
public Types.Session update(@RequestBody Types.Session input, HttpServletRequest request) {
String username = userTokenCookieDecoder.extractUsername(request.getCookies());
input = input.withUpdatedById(username);
return sessionService.update(input);
}

@DeleteMapping(path = "/{id}")
public void delete(@PathVariable Long id) {
sessionService.delete(id);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*******************************************************************************
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
Expand Down Expand Up @@ -65,8 +65,8 @@ public Long create(Types.Session input) {
handle -> {
var sql =
"""
INSERT INTO CHAT_SESSION (name, created_by_id, updated_by_id)
VALUES (:name, :createdById, :updatedById)
INSERT INTO CHAT_SESSION (name, created_by_id, updated_by_id, inference_model, response_chunks)
VALUES (:name, :createdById, :updatedById, :inferenceModel, :responseChunks)
""";
Long id = insertSession(input, handle, sql);
insertSessionDataSources(handle, id, input.dataSourceIds());
Expand Down Expand Up @@ -123,6 +123,8 @@ private Stream<Types.Session.SessionBuilder> querySessions(Query query) {
Types.Session.builder()
.id(sessionId)
.name(rowView.getColumn("name", String.class))
.inferenceModel(rowView.getColumn("inference_model", String.class))
.responseChunks(rowView.getColumn("response_chunks", Integer.class))
.createdById(rowView.getColumn("created_by_id", String.class))
.timeCreated(rowView.getColumn("time_created", Instant.class))
.updatedById(rowView.getColumn("updated_by_id", String.class))
Expand Down Expand Up @@ -157,9 +159,20 @@ public static SessionRepository createNull() {
}

public void delete(Long id) {
jdbi.useHandle(
handle -> handle.execute("UPDATE CHAT_SESSION SET DELETED = ? WHERE ID = ?", true, id));
}

public void update(Types.Session input) {
jdbi.useHandle(
handle -> {
handle.execute("UPDATE CHAT_SESSION SET DELETED = ? WHERE ID = ?", true, id);
var sql =
"""
UPDATE CHAT_SESSION
SET name = :name, updated_by_id = :updatedById, inference_model = :inferenceModel, response_chunks = :responseChunks
WHERE id = :id
""";
handle.createUpdate(sql).bindMethods(input).execute();
});
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*******************************************************************************
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
Expand Down Expand Up @@ -56,6 +56,11 @@ public Types.Session create(Types.Session input) {
return sessionRepository.getSessionById(id);
}

public Types.Session update(Types.Session input) {
sessionRepository.update(input);
return sessionRepository.getSessionById(input.id());
}

public List<Types.Session> getSessions() {
return sessionRepository.getSessions();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
*
* Applicable Open Source License: Apache 2.0
*
* NOTE: Cloudera open source products are modular software products
* made up of hundreds of individual components, each of which was
* individually copyrighted. Each Cloudera open source product is a
* collective work under U.S. Copyright Law. Your license to use the
* collective work is as provided in your written agreement with
* Cloudera. Used apart from the collective work, this file is
* licensed for your use pursuant to the open source license
* identified above.
*
* This code is provided to you pursuant a written agreement with
* (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
* this code. If you do not have a written agreement with Cloudera nor
* with an authorized and properly licensed third party, you do not
* have any rights to access nor to use this code.
*
* Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the
* contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
* KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
* WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
* IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
* AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
* ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
* OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
* CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
* RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
* DATA.
*/

SET MODE MYSQL;

BEGIN;

ALTER TABLE CHAT_SESSION DROP COLUMN inference_model;
ALTER TABLE CHAT_SESSION DROP COLUMN response_chunks;

COMMIT;
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
*
* Applicable Open Source License: Apache 2.0
*
* NOTE: Cloudera open source products are modular software products
* made up of hundreds of individual components, each of which was
* individually copyrighted. Each Cloudera open source product is a
* collective work under U.S. Copyright Law. Your license to use the
* collective work is as provided in your written agreement with
* Cloudera. Used apart from the collective work, this file is
* licensed for your use pursuant to the open source license
* identified above.
*
* This code is provided to you pursuant a written agreement with
* (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
* this code. If you do not have a written agreement with Cloudera nor
* with an authorized and properly licensed third party, you do not
* have any rights to access nor to use this code.
*
* Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the
* contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
* KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
* WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
* IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
* AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
* ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
* OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
* CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
* RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
* DATA.
*/

SET MODE MYSQL;

BEGIN;

ALTER TABLE CHAT_SESSION ADD COLUMN inference_model VARCHAR(255);
ALTER TABLE CHAT_SESSION ADD COLUMN response_chunks INTEGER DEFAULT 5;


COMMIT;
2 changes: 2 additions & 0 deletions backend/src/main/resources/migrations/migrations.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@
11_add_session_deleted.up.sql
12_add_doc_deleted.down.sql
12_add_doc_deleted.up.sql
13_add_chat_configuration.down.sql
13_add_chat_configuration.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
*
* Applicable Open Source License: Apache 2.0
*
* NOTE: Cloudera open source products are modular software products
* made up of hundreds of individual components, each of which was
* individually copyrighted. Each Cloudera open source product is a
* collective work under U.S. Copyright Law. Your license to use the
* collective work is as provided in your written agreement with
* Cloudera. Used apart from the collective work, this file is
* licensed for your use pursuant to the open source license
* identified above.
*
* This code is provided to you pursuant a written agreement with
* (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
* this code. If you do not have a written agreement with Cloudera nor
* with an authorized and properly licensed third party, you do not
* have any rights to access nor to use this code.
*
* Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the
* contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
* KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
* WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
* IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
* AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
* ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
* OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
* CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
* RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
* DATA.
*/

SET MODE MYSQL;

BEGIN;

ALTER TABLE CHAT_SESSION DROP COLUMN inference_model;
ALTER TABLE CHAT_SESSION DROP COLUMN response_chunks;

COMMIT;
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
*
* Applicable Open Source License: Apache 2.0
*
* NOTE: Cloudera open source products are modular software products
* made up of hundreds of individual components, each of which was
* individually copyrighted. Each Cloudera open source product is a
* collective work under U.S. Copyright Law. Your license to use the
* collective work is as provided in your written agreement with
* Cloudera. Used apart from the collective work, this file is
* licensed for your use pursuant to the open source license
* identified above.
*
* This code is provided to you pursuant a written agreement with
* (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
* this code. If you do not have a written agreement with Cloudera nor
* with an authorized and properly licensed third party, you do not
* have any rights to access nor to use this code.
*
* Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the
* contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
* KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
* WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
* IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
* AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
* ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
* OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
* CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
* RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
* DATA.
*/

SET MODE MYSQL;

BEGIN;

ALTER TABLE CHAT_SESSION ADD COLUMN inference_model VARCHAR(255);
ALTER TABLE CHAT_SESSION ADD COLUMN response_chunks INTEGER DEFAULT 5;


COMMIT;
3 changes: 2 additions & 1 deletion backend/src/test/java/com/cloudera/cai/rag/TestData.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@

public class TestData {
public static Types.Session createTestSessionInstance(String sessionName) {
return new Types.Session(null, sessionName, List.of(1L, 2L, 3L), null, null, null, null, null);
return new Types.Session(
null, sessionName, List.of(1L, 2L, 3L), null, null, null, null, null, "test-model", 3);
}

public static Types.RagDataSource createTestDataSourceInstance(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*******************************************************************************
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
Expand Down Expand Up @@ -62,6 +62,8 @@ void create() throws JsonProcessingException {
Types.Session result = sessionController.create(input, request);
assertThat(result.id()).isNotNull();
assertThat(result.name()).isEqualTo(sessionName);
assertThat(result.inferenceModel()).isEqualTo(input.inferenceModel());
assertThat(result.responseChunks()).isEqualTo(input.responseChunks());
assertThat(result.dataSourceIds()).containsExactlyInAnyOrder(1L, 2L, 3L);
assertThat(result.timeCreated()).isNotNull();
assertThat(result.timeUpdated()).isNotNull();
Expand All @@ -70,6 +72,45 @@ void create() throws JsonProcessingException {
assertThat(result.lastInteractionTime()).isNull();
}

@Test
void update() throws JsonProcessingException {
SessionController sessionController = new SessionController(SessionService.createNull());
var request = new MockHttpServletRequest();
request.setCookies(
new MockCookie("_basusertoken", UserTokenCookieDecoderTest.encodeCookie("test-user")));
var sessionName = "test";
Types.Session input = TestData.createTestSessionInstance(sessionName);
Types.Session result = sessionController.create(input, request);

var updatedResponseChunks = 1;
var updatedInferenceModel = "new-model-name";
var updatedName = "new-name";

request = new MockHttpServletRequest();
request.setCookies(
new MockCookie(
"_basusertoken", UserTokenCookieDecoderTest.encodeCookie("update-test-user")));

var updatedSession =
sessionController.update(
result
.withInferenceModel(updatedInferenceModel)
.withResponseChunks(updatedResponseChunks)
.withName(updatedName),
request);

assertThat(updatedSession.id()).isNotNull();
assertThat(updatedSession.name()).isEqualTo(updatedName);
assertThat(updatedSession.inferenceModel()).isEqualTo(updatedInferenceModel);
assertThat(updatedSession.responseChunks()).isEqualTo(updatedResponseChunks);
assertThat(updatedSession.dataSourceIds()).containsExactlyInAnyOrder(1L, 2L, 3L);
assertThat(updatedSession.timeCreated()).isNotNull();
assertThat(updatedSession.timeUpdated()).isNotNull();
assertThat(updatedSession.createdById()).isEqualTo("test-user");
assertThat(updatedSession.updatedById()).isEqualTo("update-test-user");
assertThat(updatedSession.lastInteractionTime()).isNull();
}

@Test
void delete() {
SessionService sessionService = SessionService.createNull();
Expand Down
5 changes: 4 additions & 1 deletion llm-service/app/rag_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,13 @@

from pydantic import BaseModel, ConfigDict

from app.services.models import DEFAULT_BEDROCK_LLM_MODEL


class RagPredictConfiguration(BaseModel):
model_config = ConfigDict(protected_namespaces=())

top_k: int = 5
chunk_size: int = 512
model_name: str = "meta.llama3-1-8b-instruct-v1:0"
model_name: str = DEFAULT_BEDROCK_LLM_MODEL
exclude_knowledge_base: Optional[bool] = False
Loading

0 comments on commit 303a4a4

Please sign in to comment.