Skip to content

Commit

Permalink
Isolate questions code in own module to ease debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
pydanny committed Oct 22, 2023
1 parent ee5d09c commit 8db5efd
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 64 deletions.
20 changes: 20 additions & 0 deletions src/interviewkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing_extensions import Annotated

from questions import generate_questions_from_transcript
from slicer import audio_slicing
from transcript import transcribe_from_paths

Expand Down Expand Up @@ -41,6 +42,25 @@ def slice(
"""Slices an audio file into smaller audio files."""
audio_slicing(source, start, duration)

@app.command()
def generate_questions(source: Annotated[
Path,
typer.Argument(
exists=True,
file_okay=True,
dir_okay=False,
readable=True,
resolve_path=True,
help="Source transcript file",
),
],
target: Path
):
"""Generates questions from a transcript."""

questions = generate_questions_from_transcript(source.read_text())
target.write_text(questions)


@app.command()
def transcribe(
Expand Down
50 changes: 50 additions & 0 deletions src/interviewkit/questions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
from clarifai_grpc.grpc.api.status import status_code_pb2

# # Securely get your credentials
# TODO: Pass in arguments or use env vars
CLARIFAI_PAT = ''
# Specify the correct user_id/app_id pairings
# Since you're making inferences outside your app's scope
CLARIFAI_USER_ID = 'meta'
CLARIFAI_APP_ID = 'Llama-2'
# Change these to whatever model and text URL you want to use
CLARIFAI_MODEL_ID = 'llama2-70b-chat'
CLARIFAI_MODEL_VERSION_ID = 'acba9c1995f8462390d7cb77d482810b'


def generate_questions_from_transcript(transcript: str):

channel = ClarifaiChannel.get_grpc_channel()
stub = service_pb2_grpc.V2Stub(channel)

metadata = (('authorization', 'Key ' + CLARIFAI_PAT),)
userDataObject = resources_pb2.UserAppIDSet(
user_id=CLARIFAI_USER_ID, app_id=CLARIFAI_APP_ID)

post_model_outputs_response = stub.PostModelOutputs(
service_pb2.PostModelOutputsRequest(
user_app_id=userDataObject,
model_id=CLARIFAI_MODEL_ID,
version_id=CLARIFAI_MODEL_VERSION_ID,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(
text=resources_pb2.Text(
raw=transcript
)
)
)
]
),
metadata=metadata
)

if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
print(post_model_outputs_response.status)
status = post_model_outputs_response.status.description
raise Exception(f"Post model outputs failed, status: {status}")

output = post_model_outputs_response.outputs[0]
return output.data.text.raw
64 changes: 0 additions & 64 deletions src/interviewkit/transcript.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from pathlib import Path
from rich.console import Console
import sys
from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
from clarifai_grpc.grpc.api.status import status_code_pb2

try:
import whisper
Expand All @@ -20,53 +17,6 @@ class Transcript(BaseModel):
"""The Transcript entity represents the transcript of an interview."""
content: str

# # Securely get your credentials
# PAT = os.getenv('CLARIFAI_PAT')
# USER_ID = os.getenv('CLARIFAI_USER_ID')
# APP_ID = os.getenv('CLARIFAI_APP_ID')
# MODEL_ID = os.getenv('CLARIFAI_MODEL_ID')
# MODEL_VERSION_ID = os.getenv('CLARIFAI_MODEL_VERSION_ID')
PAT = ''
# Specify the correct user_id/app_id pairings
# Since you're making inferences outside your app's scope
USER_ID = 'meta'
APP_ID = 'Llama-2'
# Change these to whatever model and text URL you want to use
MODEL_ID = 'llama2-70b-chat'
MODEL_VERSION_ID = 'acba9c1995f8462390d7cb77d482810b'

def generate_questions(transcript_chunk):
channel = ClarifaiChannel.get_grpc_channel()
stub = service_pb2_grpc.V2Stub(channel)

metadata = (('authorization', 'Key ' + PAT),)
userDataObject = resources_pb2.UserAppIDSet(user_id=USER_ID, app_id=APP_ID)

post_model_outputs_response = stub.PostModelOutputs(
service_pb2.PostModelOutputsRequest(
user_app_id=userDataObject,
model_id=MODEL_ID,
version_id=MODEL_VERSION_ID,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(
text=resources_pb2.Text(
raw=transcript_chunk
)
)
)
]
),
metadata=metadata
)

if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
print(post_model_outputs_response.status)
raise Exception(f"Post model outputs failed, status: {post_model_outputs_response.status.description}")

output = post_model_outputs_response.outputs[0]
return output.data.text.raw

def transcribe_from_paths(source: Path, target: Path) -> None:
console.print("Loading whisper base model...")
model = whisper.load_model("base")
Expand All @@ -89,20 +39,6 @@ def transcribe_from_paths(source: Path, target: Path) -> None:
console.print("Transcript saved to:")
console.print(f" [green bold]{target / source.name}.txt[/green bold]")

# Generate questions from the transcript
transcript_chunk = result['text'] # Assuming 'result' contains the transcribed text
# Debug: Print type and value of transcript_chunk
print(f"Type of transcript_chunk: {type(transcript_chunk)}")
print(f"Value of transcript_chunk: {transcript_chunk}")

# Ensure transcript_chunk is a string
if not isinstance(transcript_chunk, str):
print("Warning: transcript_chunk is not a string. Trying to convert...")
transcript_chunk = str(transcript_chunk)

questions = generate_questions(transcript_chunk)
console.print("Generated Questions:\n", questions)

if __name__ == "__main__":
source = Path(sys.argv[1])
target = Path(sys.argv[2])
Expand Down

0 comments on commit 8db5efd

Please sign in to comment.