From 8db5efd9a12d601b7174d83d0addc28be027a96c Mon Sep 17 00:00:00 2001 From: Daniel Roy Greenfeld <62857+pydanny@users.noreply.github.com> Date: Sun, 22 Oct 2023 18:25:55 +0200 Subject: [PATCH] Isolate questions code in own module to ease debugging --- src/interviewkit/cli.py | 20 +++++++++++ src/interviewkit/questions.py | 50 ++++++++++++++++++++++++++ src/interviewkit/transcript.py | 64 ---------------------------------- 3 files changed, 70 insertions(+), 64 deletions(-) create mode 100644 src/interviewkit/questions.py diff --git a/src/interviewkit/cli.py b/src/interviewkit/cli.py index 64102d5..b29f030 100644 --- a/src/interviewkit/cli.py +++ b/src/interviewkit/cli.py @@ -4,6 +4,7 @@ from pathlib import Path from typing_extensions import Annotated +from questions import generate_questions_from_transcript from slicer import audio_slicing from transcript import transcribe_from_paths @@ -41,6 +42,25 @@ def slice( """Slices an audio file into smaller audio files.""" audio_slicing(source, start, duration) +@app.command() +def generate_questions(source: Annotated[ + Path, + typer.Argument( + exists=True, + file_okay=True, + dir_okay=False, + readable=True, + resolve_path=True, + help="Source transcript file", + ), + ], + target: Path + ): + """Generates questions from a transcript.""" + + questions = generate_questions_from_transcript(source.read_text()) + target.write_text(questions) + @app.command() def transcribe( diff --git a/src/interviewkit/questions.py b/src/interviewkit/questions.py new file mode 100644 index 0000000..743f458 --- /dev/null +++ b/src/interviewkit/questions.py @@ -0,0 +1,50 @@ +from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel +from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc +from clarifai_grpc.grpc.api.status import status_code_pb2 + +# # Securely get your credentials +# TODO: Pass in arguments or use env vars +CLARIFAI_PAT = '' +# Specify the correct user_id/app_id pairings +# Since you're making inferences outside your app's scope +CLARIFAI_USER_ID = 'meta' +CLARIFAI_APP_ID = 'Llama-2' +# Change these to whatever model and text URL you want to use +CLARIFAI_MODEL_ID = 'llama2-70b-chat' +CLARIFAI_MODEL_VERSION_ID = 'acba9c1995f8462390d7cb77d482810b' + + +def generate_questions_from_transcript(transcript: str): + + channel = ClarifaiChannel.get_grpc_channel() + stub = service_pb2_grpc.V2Stub(channel) + + metadata = (('authorization', 'Key ' + CLARIFAI_PAT),) + userDataObject = resources_pb2.UserAppIDSet( + user_id=CLARIFAI_USER_ID, app_id=CLARIFAI_APP_ID) + + post_model_outputs_response = stub.PostModelOutputs( + service_pb2.PostModelOutputsRequest( + user_app_id=userDataObject, + model_id=CLARIFAI_MODEL_ID, + version_id=CLARIFAI_MODEL_VERSION_ID, + inputs=[ + resources_pb2.Input( + data=resources_pb2.Data( + text=resources_pb2.Text( + raw=transcript + ) + ) + ) + ] + ), + metadata=metadata + ) + + if post_model_outputs_response.status.code != status_code_pb2.SUCCESS: + print(post_model_outputs_response.status) + status = post_model_outputs_response.status.description + raise Exception(f"Post model outputs failed, status: {status}") + + output = post_model_outputs_response.outputs[0] + return output.data.text.raw diff --git a/src/interviewkit/transcript.py b/src/interviewkit/transcript.py index 423e331..bfc1aaf 100644 --- a/src/interviewkit/transcript.py +++ b/src/interviewkit/transcript.py @@ -1,9 +1,6 @@ from pathlib import Path from rich.console import Console import sys -from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel -from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc -from clarifai_grpc.grpc.api.status import status_code_pb2 try: import whisper @@ -20,53 +17,6 @@ class Transcript(BaseModel): """The Transcript entity represents the transcript of an interview.""" content: str -# # Securely get your credentials -# PAT = os.getenv('CLARIFAI_PAT') -# USER_ID = os.getenv('CLARIFAI_USER_ID') -# APP_ID = os.getenv('CLARIFAI_APP_ID') -# MODEL_ID = os.getenv('CLARIFAI_MODEL_ID') -# MODEL_VERSION_ID = os.getenv('CLARIFAI_MODEL_VERSION_ID') -PAT = '' -# Specify the correct user_id/app_id pairings -# Since you're making inferences outside your app's scope -USER_ID = 'meta' -APP_ID = 'Llama-2' -# Change these to whatever model and text URL you want to use -MODEL_ID = 'llama2-70b-chat' -MODEL_VERSION_ID = 'acba9c1995f8462390d7cb77d482810b' - -def generate_questions(transcript_chunk): - channel = ClarifaiChannel.get_grpc_channel() - stub = service_pb2_grpc.V2Stub(channel) - - metadata = (('authorization', 'Key ' + PAT),) - userDataObject = resources_pb2.UserAppIDSet(user_id=USER_ID, app_id=APP_ID) - - post_model_outputs_response = stub.PostModelOutputs( - service_pb2.PostModelOutputsRequest( - user_app_id=userDataObject, - model_id=MODEL_ID, - version_id=MODEL_VERSION_ID, - inputs=[ - resources_pb2.Input( - data=resources_pb2.Data( - text=resources_pb2.Text( - raw=transcript_chunk - ) - ) - ) - ] - ), - metadata=metadata - ) - - if post_model_outputs_response.status.code != status_code_pb2.SUCCESS: - print(post_model_outputs_response.status) - raise Exception(f"Post model outputs failed, status: {post_model_outputs_response.status.description}") - - output = post_model_outputs_response.outputs[0] - return output.data.text.raw - def transcribe_from_paths(source: Path, target: Path) -> None: console.print("Loading whisper base model...") model = whisper.load_model("base") @@ -89,20 +39,6 @@ def transcribe_from_paths(source: Path, target: Path) -> None: console.print("Transcript saved to:") console.print(f" [green bold]{target / source.name}.txt[/green bold]") - # Generate questions from the transcript - transcript_chunk = result['text'] # Assuming 'result' contains the transcribed text - # Debug: Print type and value of transcript_chunk - print(f"Type of transcript_chunk: {type(transcript_chunk)}") - print(f"Value of transcript_chunk: {transcript_chunk}") - - # Ensure transcript_chunk is a string - if not isinstance(transcript_chunk, str): - print("Warning: transcript_chunk is not a string. Trying to convert...") - transcript_chunk = str(transcript_chunk) - - questions = generate_questions(transcript_chunk) - console.print("Generated Questions:\n", questions) - if __name__ == "__main__": source = Path(sys.argv[1]) target = Path(sys.argv[2])