Skip to content

Commit

Permalink
OpenAI rolling Summarization works
Browse files Browse the repository at this point in the history
Rolling summarization using OpenAI works, using the '-detail {0.01 - 1.00}' command
  • Loading branch information
rmusser01 committed May 16, 2024
1 parent 56f42d4 commit 3147fde
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 36 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ By default videos, transcriptions and summaries are stored in a folder with the
------------
### Similar/Other projects:
- https://github.com/Dicklesworthstone/bulk_transcribe_youtube_videos_from_playlist/tree/main

- https://github.com/akashe/YoutubeSummarizer
------------

### <a name="credits"></a>Credits
Expand Down
168 changes: 133 additions & 35 deletions summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,16 @@
#######
# Function Sections
#
# Config Loading
# System Checks
# Processing Paths and local file handling
# Video Download/Handling
# Audio Transcription
# Diarization
# Chunking-related Techniques & Functions
# Tokenization-related Techniques & Functions
# Summarizers
# Gradio UI
# Main
#
#######
Expand Down Expand Up @@ -363,9 +367,9 @@ def process_url(url, num_speakers, whisper_model, custom_prompt, offset, api_nam
video_file_path = None
print("API Name received:", api_name) # Debugging line
try:
results = main(url, api_name=api_name, api_key=api_key, num_speakers=num_speakers,
whisper_model=whisper_model, offset=offset, vad_filter=vad_filter,
download_video_flag=download_video, custom_prompt=custom_prompt)
results = main(url, api_name=api_name, api_key=api_key, num_speakers=num_speakers, whisper_model=whisper_model,
offset=offset, vad_filter=vad_filter, download_video_flag=download_video,
custom_prompt=custom_prompt)
if results:
transcription_result = results[0]

Expand Down Expand Up @@ -558,7 +562,7 @@ def convert_to_wav(video_file_path, offset=0, overwrite=False):
logging.debug("ffmpeg being ran on windows")

if sys.platform.startswith('win'):
ffmpeg_cmd = "..\\Bin\\ffmpeg.exe"
ffmpeg_cmd = ".\\Bin\\ffmpeg.exe"
logging.debug(f"ffmpeg_cmd: {ffmpeg_cmd}")
else:
ffmpeg_cmd = 'ffmpeg' # Assume 'ffmpeg' is in PATH for non-Windows systems
Expand Down Expand Up @@ -781,19 +785,33 @@ def speech_to_text(audio_file_path, selected_source_lang='en', whisper_model='sm
#


# This is dirty and shameful and terrible. It should be replaced with a proper implementation.
# anyways lets get to it....
client = OpenAI(api_key=openai_api_key)
def get_chat_completion(messages, model='gpt-4-turbo'):
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
)
return response.choices[0].message.content


# This function chunks a text into smaller pieces based on a maximum token count and a delimiter
def chunk_on_delimiter(input_string: str,
max_tokens: int,
delimiter: str) -> List[str]:
chunks = input_string.split(delimiter)
combined_chunks, _, dropped_chunk_count = combine_chunks_with_no_minimum(chunks, max_tokens, chunk_delimiter=delimiter, add_ellipsis_for_overflow=True)
combined_chunks, _, dropped_chunk_count = combine_chunks_with_no_minimum(
chunks, max_tokens, chunk_delimiter=delimiter, add_ellipsis_for_overflow=True)
if dropped_chunk_count > 0:
print(f"Warning: {dropped_chunk_count} chunks were dropped due to exceeding the token limit.")
combined_chunks = [f"{chunk}{delimiter}" for chunk in combined_chunks]
return combined_chunks


# This function combines chunks into larger pieces based on a maximum token count
# This function combines text chunks into larger blocks without exceeding a specified token count.
# It returns the combined chunks, their original indices, and the number of dropped chunks due to overflow.
def combine_chunks_with_no_minimum(
chunks: List[str],
max_tokens: int,
Expand All @@ -810,16 +828,19 @@ def combine_chunks_with_no_minimum(
candidate_indices = []
for chunk_i, chunk in enumerate(chunks):
chunk_with_header = [chunk] if header is None else [header, chunk]
#FIXME MAKE NOT OPENAI SPECIFIC
if len(openai_tokenize(chunk_delimiter.join(chunk_with_header))) > max_tokens:
print(f"warning: chunk overflow")
if (
add_ellipsis_for_overflow
# FIXME MAKE NOT OPENAI SPECIFIC
and len(openai_tokenize(chunk_delimiter.join(candidate + ["..."]))) <= max_tokens
):
candidate.append("...")
dropped_chunk_count += 1
continue # this case would break downstream assumptions
# estimate token count with the current chunk added
# FIXME MAKE NOT OPENAI SPECIFIC
extended_candidate_token_count = len(openai_tokenize(chunk_delimiter.join(candidate + [chunk])))
# If the token count exceeds max_tokens, add the current candidate to output and start a new candidate
if extended_candidate_token_count > max_tokens:
Expand All @@ -837,7 +858,8 @@ def combine_chunks_with_no_minimum(
output_indices.append(candidate_indices)
return output, output_indices, dropped_chunk_count

def openai_chunk_summarize(text: str,

def rolling_summarize(text: str,
detail: float = 0,
model: str = 'gpt-4-turbo',
additional_instructions: Optional[str] = None,
Expand Down Expand Up @@ -877,11 +899,13 @@ def openai_chunk_summarize(text: str,
num_chunks = int(min_chunks + detail * (max_chunks - min_chunks))

# adjust chunk_size based on interpolated number of chunks
# FIXME MAKE NOT OPENAI SPECIFIC
document_length = len(openai_tokenize(text))
chunk_size = max(minimum_chunk_size, document_length // num_chunks)
text_chunks = chunk_on_delimiter(text, chunk_size, chunk_delimiter)
if verbose:
print(f"Splitting the text into {len(text_chunks)} chunks to be summarized.")
# FIXME MAKE NOT OPENAI SPECIFIC
print(f"Chunk lengths are {[len(openai_tokenize(x)) for x in text_chunks]}")

# set system message
Expand All @@ -894,8 +918,7 @@ def openai_chunk_summarize(text: str,
if summarize_recursively and accumulated_summaries:
# Creating a structured prompt for recursive summarization
accumulated_summaries_string = '\n\n'.join(accumulated_summaries)
user_message_content = (f"Previous summaries:\n\n{accumulated_summaries_string}\n\nText to summarize "
f"next:\n\n{chunk}")
user_message_content = f"Previous summaries:\n\n{accumulated_summaries_string}\n\nText to summarize next:\n\n{chunk}"
else:
# Directly passing the chunk for summarization without recursive context
user_message_content = chunk
Expand All @@ -911,6 +934,7 @@ def openai_chunk_summarize(text: str,
accumulated_summaries.append(response)

# Compile final summary from partial summaries
global final_summary
final_summary = '\n\n'.join(accumulated_summaries)

return final_summary
Expand All @@ -930,20 +954,6 @@ def openai_tokenize(text: str) -> List[str]:
return encoding.encode(text)

# openai summarize chunks
def summarize_with_detail(detail, final_summary):
openai_summarization = openai_chunk_summarize(final_summary, detail=detail, verbose=True)
return openai_summarization


def get_chat_completion(messages, model='gpt-4-turbo'):
client = OpenAI(api_key="<OPENAI_API_KEY_REPLACE_ME>")
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
)
return response.choices[0].message.content


#
#
Expand All @@ -956,6 +966,7 @@ def get_chat_completion(messages, model='gpt-4-turbo'):
#
#


def extract_text_from_segments(segments):
logging.debug(f"Main: extracting text from {segments}")
text = ' '.join([segment['text'] for segment in segments])
Expand Down Expand Up @@ -1349,6 +1360,26 @@ def save_summary_to_file(summary, file_path):
#######################################################################################################################


#######################################################################################################################
# Summarization with Detail
#

def summarize_with_detail_openai(text, detail, verbose=False):
# FIXME MAKE function not specific to the artifiical intelligence example
summary_with_detail_variable = rolling_summarize(text, detail=detail, verbose=True)
print(len(openai_tokenize(summary_with_detail_variable)))
return summary_with_detail_variable


def summarize_with_detail_recursive_openai(text, detail, verbose=False):
summary_with_recursive_summarization = rolling_summarize(text, detail=detail, summarize_recursively=True)
print(summary_with_recursive_summarization)

#
#
#######################################################################################################################


#######################################################################################################################
# Gradio UI
#
Expand Down Expand Up @@ -1539,8 +1570,9 @@ def toggle_ui(mode):
# Main()
#
def main(input_path, api_name=None, api_key=None, num_speakers=2, whisper_model="small.en", offset=0, vad_filter=False,
download_video_flag=False, demo_mode=False, custom_prompt=None, overwrite=False):
global summary
download_video_flag=False, demo_mode=False, custom_prompt=None, overwrite=False,
rolling_summarization=None, detail=0.01):
global summary, audio_file
if input_path is None and args.user_interface:
return []
start_time = time.monotonic()
Expand Down Expand Up @@ -1568,6 +1600,7 @@ def main(input_path, api_name=None, api_key=None, num_speakers=2, whisper_model=
if path.startswith('http'):
logging.debug("MAIN: URL Detected")
info_dict = get_youtube(path)
json_file_path = None
if info_dict:
logging.debug("MAIN: Creating path for video file...")
download_path = create_download_directory(info_dict['title'])
Expand Down Expand Up @@ -1599,10 +1632,43 @@ def main(input_path, api_name=None, api_key=None, num_speakers=2, whisper_model=
'transcription': segments
}
results.append(transcription_result)
logging.info(f"Transcription complete: {audio_file}")
logging.info(f"MAIN: Transcription complete: {audio_file}")

# Perform rolling summarization based on API Name, detail level, and if an API key exists
# Will remove the API key once rolling is added for llama.cpp
if rolling_summarization:
logging.info("MAIN: Rolling Summarization")

# Extract the text from the segments
text = extract_text_from_segments(segments)

# Set the json_file_path
json_file_path = audio_file.replace('.wav', '.segments.json')

# Perform rolling summarization
summary = summarize_with_detail_openai(text, detail=args.detail_level, verbose=False)

# Handle the summarized output
if summary:
transcription_result['summary'] = summary
logging.info("MAIN: Rolling Summarization successful.")
save_summary_to_file(summary, json_file_path)
else:
logging.warning("MAIN: Rolling Summarization failed.")

# if api_name and api_key:
# logging.debug(f"MAIN: Rolling summarization being performed by {api_name}")
# json_file_path = audio_file.replace('.wav', '.segments.json')
# if api_name.lower() == 'openai':
# openai_api_key = api_key if api_key else config.get('API', 'openai_api_key',
# fallback=None)
# try:
# logging.debug(f"MAIN: trying to summarize with openAI")
# summary = (openai_api_key, json_file_path, openai_model, custom_prompt)
# except requests.exceptions.ConnectionError:
# requests.status_code = "Connection: "
# Perform summarization based on the specified API
if api_name and api_key:
elif api_name and api_key:
logging.debug(f"MAIN: Summarization being performed by {api_name}")
json_file_path = audio_file.replace('.wav', '.segments.json')
if api_name.lower() == 'openai':
Expand Down Expand Up @@ -1677,6 +1743,10 @@ def main(input_path, api_name=None, api_key=None, num_speakers=2, whisper_model=
transcription_result['summary'] = summary
logging.info(f"Summary generated using {api_name} API")
save_summary_to_file(summary, json_file_path)
elif final_summary:
logging.info(f"Rolling summary generated using {api_name} API")
logging.info(f"Final Rolling summary is {final_summary}\n\n")
save_summary_to_file(final_summary, json_file_path)
else:
logging.warning(f"Failed to generate summary using {api_name} API")
else:
Expand Down Expand Up @@ -1707,8 +1777,16 @@ def main(input_path, api_name=None, api_key=None, num_speakers=2, whisper_model=
parser.add_argument('-ui', '--user_interface', action='store_true', help="Launch the Gradio user interface")
parser.add_argument('-demo', '--demo_mode', action='store_true', help='Enable demo mode')
parser.add_argument('-prompt', '--custom_prompt', type=str,
help='Pass in a custom prompt to be used in place of the existing one.(Probably should just '
help='Pass in a custom prompt to be used in place of the existing one.\n (Probably should just '
'modify the script itself...)')
parser.add_argument('-overwrite', '--overwrite', action='store_true', help='Overwrite existing files')
parser.add_argument('-roll', '--rolling_summarization', action='store_true', help='Enable rolling summarization')
parser.add_argument('-detail', '--detail_level', type=float, help='Mandatory if rolling summarization is enabled, '
'defines the chunk size.\n Default is 0.01(lots '
'of chunks) -> 1.00 (few chunks)\n Currently '
'only OpenAI works. ',
default=0.01,)
# parser.add_argument('-o', '--output_path', type=str, help='Path to save the output file')
# parser.add_argument('--log_file', action=str, help='Where to save logfile (non-default)')
args = parser.parse_args()

Expand Down Expand Up @@ -1740,20 +1818,38 @@ def main(input_path, api_name=None, api_key=None, num_speakers=2, whisper_model=
logging.info('Starting the transcription and summarization process.')
logging.info(f'Input path: {args.input_path}')
logging.info(f'API Name: {args.api_name}')
logging.debug(f'API Key: {args.api_key}') # ehhhhh
logging.info(f'Number of speakers: {args.num_speakers}')
logging.info(f'Whisper model: {args.whisper_model}')
logging.info(f'Offset: {args.offset}')
logging.info(f'VAD filter: {args.vad_filter}')
logging.info(f'Log Level: {args.log_level}') # lol
logging.info(f'Demo Mode: {args.demo_mode}')
logging.info(f'Custom Prompt: {args.custom_prompt}')
logging.info(f'Overwrite: {args.overwrite}')
logging.info(f'Rolling Summarization: {args.rolling_summarization}')
logging.info(f'User Interface: {args.user_interface}')
logging.info(f'Video Download: {args.video}')
# logging.info(f'Save File location: {args.output_path}')
# logging.info(f'Log File location: {args.log_file}')

# Get all API keys from the config
api_keys = {key: value for key, value in config.items('API') if key.endswith('_api_key')}

# Rolling Summarization will only be performed if an API is specified and the API key is available
# and the rolling summarization flag is set
#
summary = None # Initialize to ensure it's always defined
if args.api_name and args.rolling_summarization and any(
key.startswith(args.api_name) and value is not None for key, value in api_keys.items()):
logging.info(f'MAIN: API used: {args.api_name}')
logging.info('MAIN: Rolling Summarization will be performed.')

elif args.api_name:
logging.info(f'MAIN: API used: {args.api_name}')
logging.info('MAIN: Summarization (not rolling) will be performed.')

if args.api_name and args.api_key:
logging.info(f'API: {args.api_name}')
logging.info('Summarization will be performed.')
summary = None # Initialize to ensure it's always defined
else:
logging.info('No API specified. Summarization will not be performed.')
summary = None # Initialize to ensure it's always defined

logging.debug("Platform check being performed...")
platform_check()
Expand All @@ -1765,7 +1861,9 @@ def main(input_path, api_name=None, api_key=None, num_speakers=2, whisper_model=
try:
results = main(args.input_path, api_name=args.api_name, api_key=args.api_key,
num_speakers=args.num_speakers, whisper_model=args.whisper_model, offset=args.offset,
vad_filter=args.vad_filter, download_video_flag=args.video, overwrite=args.overwrite)
vad_filter=args.vad_filter, download_video_flag=args.video, overwrite=args.overwrite,
rolling_summarization=args.rolling_summarization, custom_prompt=args.custom_prompt,
demo_mode=args.demo_mode, detail=args.detail_level)
logging.info('Transcription process completed.')
except Exception as e:
logging.error('An error occurred during the transcription process.')
Expand Down

0 comments on commit 3147fde

Please sign in to comment.