Skip to content

Commit

Permalink
Merge pull request #22 from small-thinking/add-rag-to-control
Browse files Browse the repository at this point in the history
Voice + RAG
  • Loading branch information
yxjiang authored Feb 25, 2024
2 parents 8a3c493 + a8cffc4 commit 6392e69
Show file tree
Hide file tree
Showing 9 changed files with 609 additions and 10 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,11 @@ Use G|B|V|C|D|E|R|T keys to rotate to absolute orientations. 'F' to cancel a rot
```bash
docker rm $(docker ps -a -q) ; docker images | grep '<none>' | awk '{print $3}' | xargs docker rmi
```




TODOs:
5. Build a key value store to store the verbal command and the list of commands.
6. Index the key value store with the verbal command in to vector db.
7. Add RAG after voice recognition.
4 changes: 4 additions & 0 deletions mnlm/client/gpt_control/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,15 @@ def start_conversation(
use_voice_input: bool,
use_voice_output: bool,
use_dummy_robot_arm_server: bool,
use_rag: bool,
logger: Logger,
) -> None:
client = OpenAI()
tools = init_tools(
logger=logger,
verbose=verbose,
use_dummy_robot_arm_server=use_dummy_robot_arm_server,
use_rag=use_rag,
)
assistant = create_assistant(
client=client, tools=tools, logger=logger, verbose=verbose
Expand Down Expand Up @@ -233,12 +235,14 @@ def start_conversation(
use_voice_input = True # Set to True to enable voice input. In docker container, it's not possible.
use_voice_output = True # Set to True to enable voice output. In docker container, it's not possible.
use_dummy_robot_arm_server = False # Set to True to use the simulation mode
use_rag = True
logger = Logger(__name__)
start_conversation(
verbose=verbose,
nudge_user=nudge_user,
use_voice_input=use_voice_input,
use_voice_output=use_voice_output,
use_dummy_robot_arm_server=use_dummy_robot_arm_server,
use_rag=use_rag,
logger=logger,
)
261 changes: 261 additions & 0 deletions mnlm/client/gpt_control/command_indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import argparse
import json
import os
from typing import Any, Dict, List

import faiss
import numpy as np
from dotenv import load_dotenv
from openai import OpenAI
from utils import Logger


class InstructionIndexer:

def __init__(self):
"""
Initialize the InstructionIndexer class.
Args:
json_file_path (str): Path to the JSON file containing instructions.
"""
load_dotenv(override=True)
self.logger = Logger(__file__)
self.client = OpenAI()
self.operation_sequences = []
self.index = None

def create_index(
self, command_bank_file_path: str, index_destination: str, data_destination: str
) -> None:
"""
Create the FAISS index for instructions.
"""
instructions_data = self._load_json_file(command_bank_file_path)
instructions = []
operation_sequences = []
for instruction, operations in instructions_data.items():
instructions.append(instruction)
operations_blob = {
"instruction": instruction,
"operations": operations["operations"],
}
operation_sequences.append(operations_blob)

embeddings = self._embed_instructions(instructions)

# Creating the FAISS index
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatL2(dimension)
self.index.add(embeddings)

# Save the FAISS index and operation sequences
self._save_index_and_data(
operation_sequences=operation_sequences,
index_destination=index_destination,
data_destination=data_destination,
)

def _load_json_file(self, command_bank_file_path: str) -> Dict[str, Any]:
"""
Load the JSON file containing instructions.
Returns:
dict: Dictionary containing instructions.
"""
self.logger.info(f"Loading JSON file: {command_bank_file_path}")
with open(command_bank_file_path, "r") as file:
return json.load(file)

def _embed_instructions(self, instructions: List[str]) -> np.ndarray:
"""
Embed instructions using OpenAI's embedding API.
Args:
instructions (list): List of instructions.
Returns:
np.ndarray: Array of instruction embeddings.
"""
self.logger.info(f"Embedding instructions...")
embeddings = []
for instruction in instructions:
# Ensure instruction is a single line
instruction = instruction.replace("\n", " ")
# Create embedding
response = self.client.embeddings.create(
input=[instruction], model="text-embedding-3-small"
)
embeddings.append(response.data[0].embedding)
return np.array(embeddings, dtype="float32")

def _save_index_and_data(
self,
operation_sequences: List[Dict[str, Any]],
index_destination: str,
data_destination: str,
) -> None:
"""
Save the FAISS index and operation sequences to files.
Args:
index_destination (str): Path to save the FAISS index.
data_destination (str): Path to save the operation sequences.
"""
# Ensure the index and data are created
if self.index is None:
raise ValueError("Index has not been created. Call create_index() first.")

# Save the FAISS index
if os.path.exists(index_destination):
os.remove(index_destination)
os.makedirs(os.path.dirname(index_destination), exist_ok=True)
faiss.write_index(self.index, index_destination)

# Save the operation sequences
if os.path.exists(data_destination):
os.remove(data_destination)
os.makedirs(os.path.dirname(data_destination), exist_ok=True)
with open(data_destination, "w") as file:
json.dump(operation_sequences, file, indent=2)

def load_index_and_data(
self, index_path: str = None, data_path: str = None
) -> None:
"""
Load the FAISS index and operation sequences from files.
Args:
index_path (str): Path to the FAISS index file.
data_path (str): Path to the operation sequences file.
"""
if not index_path:
index_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"knowledge/index/instructions.index",
)
if not data_path:
data_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"knowledge/index/instructions_data.json",
)
self.index = faiss.read_index(index_path)
with open(data_path, "r") as file:
self.operation_sequences = json.load(file)

def retrieve_operation_sequences(self, instruction: str, k: int = 1) -> str:
"""
Retrieve the operation sequences for a given query.
Args:
instruction (str): Query to search for.
k (int): Number of operation sequences to retrieve.
Returns:
list: List of operation sequences.
"""
# Embed the query
query_embedding = self._embed_instructions([instruction])

# Search the index
_, indices = self.index.search(query_embedding, k)
retrieved_operations = [self.operation_sequences[i] for i in indices[0]]
json_blob = json.dumps(retrieved_operations[0])
return json_blob


def parse_args():
# Setup argument parser
parser = argparse.ArgumentParser(description="Instruction Indexer")
subparsers = parser.add_subparsers(dest="command", help="Available commands")

index_destination = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "knowledge/index/instructions.index"
)

data_destination = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"knowledge/index/instructions_data.json",
)

# Subparser for creating index
create_index_parser = subparsers.add_parser(
"index", help="Create a new FAISS index from JSON data"
)

default_command_file_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "knowledge/command_bank.json"
)
create_index_parser.add_argument(
"--command-bank-file-path",
type=str,
default=default_command_file_path,
help="Path to the JSON file containing instructions",
)

create_index_parser.add_argument(
"--index-destination",
type=str,
default=index_destination,
required=False,
help="Path to save the FAISS index",
)

create_index_parser.add_argument(
"--data-destination",
type=str,
default=data_destination,
required=False,
help="Path to save the operation sequences",
)

# Subparser for querying index
query_index_parser = subparsers.add_parser(
"query", help="Query an existing FAISS index"
)
query_index_parser.add_argument(
"-q", "--query", type=str, help="Query to search for"
)
query_index_parser.add_argument(
"--index-path",
type=str,
default=index_destination,
required=False,
help="Path to the FAISS index file",
)
query_index_parser.add_argument(
"--data-path",
type=str,
default=data_destination,
required=False,
help="Path to the JSON file containing operation sequences",
)
query_index_parser.add_argument(
"--k",
type=int,
default=1,
required=False,
help="Number of operation sequences to retrieve",
)

# Parse arguments
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
if args.command == "index":
indexer = InstructionIndexer()
indexer.create_index(
command_bank_file_path=args.command_bank_file_path,
index_destination=args.index_destination,
data_destination=args.data_destination,
)
print("Index and data saved.")
else:
indexer = InstructionIndexer()
indexer.load_index_and_data(
index_path=args.index_path, data_path=args.data_path
)
operation_sequences = indexer.retrieve_operation_sequences(args.query, args.k)
print(f"Operation sequences: {operation_sequences}")
35 changes: 26 additions & 9 deletions mnlm/client/gpt_control/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional

from command_indexer import InstructionIndexer # type: ignore
from openai import OpenAI
from robot_arm import (
OperationSequenceGenerator,
Expand Down Expand Up @@ -80,18 +81,25 @@ def __init__(
logger: Logger,
gpt_client: Optional[OpenAI] = None,
simulation: bool = True,
use_rag: bool = False,
verbose: bool = False,
):
super().__init__(name=name, logger=logger, verbose=verbose)
# The api doc is under client/knowledge/robot_arm.md
api_document_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "knowledge", "robot_arm.md"
)
self.operation_generator = OperationSequenceGenerator(
api_document_path=api_document_path,
gpt_client=gpt_client,
logger=logger,
)
self.use_rag = use_rag
if self.use_rag:
self.indexer = InstructionIndexer()
self.indexer.load_index_and_data()
else:
self.operation_generator = OperationSequenceGenerator(
api_document_path=api_document_path,
gpt_client=gpt_client,
logger=logger,
)

self.simulation = simulation
self.knowledge = f"""
Expand Down Expand Up @@ -125,9 +133,14 @@ def get_signature(self) -> Dict[str, Any]:
def execute(self, instruction: str) -> str:
try:
# Execute operations using the chosen mode (simulation or real)
operations_json = self.operation_generator.translate_prompt_to_sequence(
prompt=instruction
)
if self.use_rag:
operations_json = self.indexer.retrieve_operation_sequences(
instruction=instruction
)
else:
operations_json = self.operation_generator.translate_prompt_to_sequence(
prompt=instruction
)
if self.verbose:
self.logger.info(f"Robot arm command: {operations_json}.")
self.robot_arm_control.execute_operations(operations_json)
Expand All @@ -138,7 +151,10 @@ def execute(self, instruction: str) -> str:


def init_tools(
logger: Logger, verbose: bool = False, use_dummy_robot_arm_server: bool = False
logger: Logger,
verbose: bool = False,
use_dummy_robot_arm_server: bool = False,
use_rag: bool = False,
) -> Dict[str, Any]:
"""Initialize the tools for the assistant.
Expand All @@ -155,6 +171,7 @@ def init_tools(
name="robot_arm",
logger=logger,
verbose=verbose,
use_rag=use_rag,
simulation=use_dummy_robot_arm_server,
),
}
Expand Down
Loading

0 comments on commit 6392e69

Please sign in to comment.