Skip to content

Commit

Permalink
Gpt4o llm support,multifile at once support
Browse files Browse the repository at this point in the history
  • Loading branch information
adithya-aiplanet authored and tarun-aiplanet committed May 27, 2024
1 parent aae011d commit f77bc7d
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 68 deletions.
3 changes: 2 additions & 1 deletion src/beyondllm/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from .chatopenai import ChatOpenAIModel
from .azurechat import AzureOpenAIModel
from .ollama import OllamaModel
from .multimodal import GeminiMultiModal
from .multimodal import GeminiMultiModal
from .gpt4o import GPT4OpenAIModel
185 changes: 185 additions & 0 deletions src/beyondllm/llms/gpt4o.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from beyondllm.llms.base import BaseLLMModel, ModelConfig
from typing import Any, Dict, List, Optional
import os
from dataclasses import dataclass, field
import base64
import subprocess, sys

@dataclass
class GPT4OpenAIModel:
"""
Class representing a Chat Language Model (LLM) model using OpenAI GPT-4 with Vision capabilities
Example:
from beyondllm.llms import GPT4OpenAIModel
llm = GPT4OpenAIModel(model="gpt-4o", api_key = "", model_kwargs = {"max_tokens":512,"temperature":0.1})
"""

api_key: str = ""
model: str = field(default="gpt-4o")
model_kwargs: Optional[Dict] = None

def __post_init__(self):
if not self.api_key:
self.api_key = os.getenv("OPENAI_API_KEY")
if not self.api_key:
raise ValueError(
"OPENAI_API_KEY is not provided and not found in environment variables."
)
self.load_llm()

def load_llm(self):
try:
import openai
except ImportError:
raise ImportError("OpenAI library is not installed. Please install it with ``pip install openai``.")

try:
self.client = openai.OpenAI(api_key=self.api_key)

except Exception as e:
raise Exception("Failed to load the model from OpenAI:", str(e))

return self.client

def _process_image(self, image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")

def _process_audio(self, audio_path):
# Transcribe the audio
transcription = self.client.audio.transcriptions.create(
model="whisper-1", file=open(audio_path, "rb")
)
return transcription.text

def _process_video(self, video_path, seconds_per_frame=1):
base64Frames = []
base_video_path, _ = os.path.splitext(video_path)

try:
import cv2
except ImportError:
user_agree = input("The feature you're trying to use requires an additional library(s):opencv-python. Would you like to install it now? [y/N]: ")
if user_agree.lower() == 'y':
subprocess.check_call([sys.executable, "-m", "pip", "install", "opencv-python"])
import cv2
else:
raise ImportError("The required 'opencv-python' is not installed.")

video = cv2.VideoCapture(video_path)
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
fps = video.get(cv2.CAP_PROP_FPS)
frames_to_skip = int(fps * seconds_per_frame)
curr_frame = 0

while curr_frame < total_frames - 1:
video.set(cv2.CAP_PROP_POS_FRAMES, curr_frame)
success, frame = video.read()
if not success:
break
_, buffer = cv2.imencode(".jpg", frame)
base64Frames.append(base64.b64encode(buffer).decode("utf-8"))
curr_frame += frames_to_skip
video.release()

audio_path = f"{base_video_path}.mp3"
try:
from moviepy.editor import VideoFileClip
except ImportError:
user_agree = input("The feature you're trying to use requires an additional library(s):moviepy. Would you like to install it now? [y/N]: ")
if user_agree.lower() == 'y':
subprocess.check_call([sys.executable, "-m", "pip", "install", "moviepy"])
from moviepy.editor import VideoFileClip
else:
raise ImportError("The required 'moviepy' is not installed.")

clip = VideoFileClip(video_path)
clip.audio.write_audiofile(audio_path, bitrate="32k")
clip.audio.close()
clip.close()

return base64Frames, audio_path

def predict(self, prompt: Any, media_paths: str = None):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]

# Ensure media_paths is a list if it's not None
if media_paths:
if isinstance(media_paths, str):
media_paths = [media_paths]

for media_path in media_paths:
media_type = media_path.split(".")[-1].lower()
if media_type in ["jpg", "png"]:
base64_image = self._process_image(media_path)
messages.append(
{
"role": "user",
"content": f"![image](data:image/{media_type};base64,{base64_image})"
}
)
elif media_type in ["mp3", "wav"]:
transcription = self._process_audio(media_path)
messages.append(
{"role": "user", "content": f"The audio transcription is: {transcription}"}
)
elif media_type in ["mp4", "avi", "webm"]:
base64Frames, audio_path = self._process_video(media_path)
transcription = self._process_audio(audio_path)
messages.append(
{
"role": "user",
"content": [
"These are the frames from the video.",
*map(
lambda x: {
"type": "image_url",
"image_url": {
"url": f"data:image/jpg;base64,{x}",
"detail": "low",
},
},
base64Frames,
),
{"type": "text", "text": f"The audio transcription is: {transcription}"},
{"type": "text", "text": prompt},
],
},
)

# transcription = self._process_audio(audio_path)
# messages.append(
# {"role": "user", "content": "These are the frames from the video:"}
# )
# for frame in base64Frames:
# messages.append(
# {
# "role": "user",
# "content": f"![frame](data:image/jpg;base64,{frame})"
# }
# )
# messages.append(
# {"role": "user", "content": f"The audio transcription is: {transcription}"}
# )
else:
raise ValueError(f"Unsupported media type: {media_type}")

if self.model_kwargs is not None:
response = self.client.chat.completions.create(
model=self.model, messages=messages, **self.model_kwargs
)
else:
response = self.client.chat.completions.create(
model=self.model, messages=messages
)

return response.choices[0].message.content

@staticmethod
def load_from_kwargs(self, kwargs):
model_config = ModelConfig(**kwargs)
self.config = model_config
self.load_llm()
22 changes: 18 additions & 4 deletions src/beyondllm/loaders/llamaParseLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,26 @@ class LlamaParseLoader(BaseLoader):
chunk_overlap: int = 100

def load(self, path):
"""Load data from a file to be parsed by LlamaParse cloud API."""
"""Load data from a file, list of files, or directory to be parsed by LlamaParse cloud API."""
llama_parse_key = self.llama_parse_key or os.getenv('LLAMA_CLOUD_API_KEY')
input_files = []

if isinstance(path, str):
if os.path.isdir(path):
for root, _, files in os.walk(path):
for file in files:
input_files.append(os.path.join(root, file))
else:
input_files.append(path)
elif isinstance(path, list):
input_files.extend(input)

try:
docs = LlamaParse(result_type="markdown",api_key=llama_parse_key).load_data(path)
except:
raise ValueError("File not compatible/no result returned from Llamaparse")
docs = []
for file in input_files:
docs.extend(LlamaParse(result_type="markdown", api_key=llama_parse_key).load_data(file))
except Exception as e:
raise ValueError(f"File not compatible/no result returned from Llamaparse: {e}")
return docs

def split(self, documents):
Expand Down
13 changes: 9 additions & 4 deletions src/beyondllm/loaders/notionLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@ class NotionLoader(BaseLoader):
chunk_overlap: int = 100

def load(self, path):
"""Load Notion page data from the page ID of your Notion page: The hash value at the end of your URL"""
"""Load Notion page data from the page ID of your Notion page or a list of page IDs."""
integration_token = self.notion_integration_token or os.getenv('NOTION_INTEGRATION_TOKEN')
page_ids = []

if isinstance(path, str):
page_ids.append(path)
elif isinstance(input, list):
page_ids.extend(path)

loader = NotionPageReader(integration_token=integration_token)
docs = loader.load_data(
page_ids=[path]
)
docs = loader.load_data(page_ids=page_ids)
return docs

def split(self, documents):
Expand Down
14 changes: 13 additions & 1 deletion src/beyondllm/loaders/simpleLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from dataclasses import dataclass
import os

@dataclass
class SimpleLoader(BaseLoader):
Expand All @@ -16,7 +17,18 @@ class SimpleLoader(BaseLoader):

def load(self, path):
"""Load data from a file."""
docs = SimpleDirectoryReader(input_files=[path]).load_data()
input_files = []

if isinstance(path, str):
if os.path.isdir(path):
docs = SimpleDirectoryReader(path).load_data()
else:
input_files.append(path)
docs = SimpleDirectoryReader(input_files=input_files).load_data()
elif isinstance(path, list):
input_files.extend(path)
docs = SimpleDirectoryReader(input_files=input_files).load_data()

return docs

def split(self, documents):
Expand Down
16 changes: 9 additions & 7 deletions src/beyondllm/loaders/urlLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ class UrlLoader(BaseLoader):
chunk_overlap: int = 150

def load(self, path):
"""
Load web page data from a file.
Requires a url to be passed to read the HTML data of the page.
"""
docs = SimpleWebPageReader(html_to_text=True).load_data(
[path]
)
"""Load data from a single URL or a list of URLs."""
urls = []

if isinstance(path, str):
urls.append(path)
elif isinstance(path, list):
urls.extend(path)

docs = SimpleWebPageReader(urls=urls).load_data()
return docs

def split(self, documents):
Expand Down
15 changes: 11 additions & 4 deletions src/beyondllm/loaders/youtubeLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,18 @@ class YoutubeLoader(BaseLoader):
chunk_overlap: int = 150

def load(self, path):
"""Load youtube video transcript data from the URL of the video."""
"""Load YouTube video transcript data from a single URL or a list of URLs."""
ytlinks = []

if isinstance(path, str):
ytlinks.append(path)
elif isinstance(path, list):
ytlinks.extend(path)

print(ytlinks)

loader = YoutubeTranscriptReader()
docs = loader.load_data(
ytlinks=[path]
)
docs = loader.load_data(ytlinks=ytlinks)
return docs

def split(self, documents):
Expand Down
Loading

0 comments on commit f77bc7d

Please sign in to comment.