Skip to content

Commit

Permalink
Updated to use gpt-4o by default
Browse files Browse the repository at this point in the history
  • Loading branch information
beveradb committed May 15, 2024
1 parent 255d9b7 commit 6a273ee
Show file tree
Hide file tree
Showing 5 changed files with 939 additions and 913 deletions.
22 changes: 11 additions & 11 deletions .github/removetriton.patch
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
1106d1105
1164d1163
< triton = ">=2.0.0,<3"
2033d2031
< triton = {version = "2.2.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
2129,2151d2126
2067d2065
< triton = {version = "2.3.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""}
2163,2185d2160
< name = "triton"
< version = "2.2.0"
< version = "2.3.0"
< description = "A language and compiler for custom Deep Learning operations"
< optional = false
< python-versions = "*"
< files = [
< {file = "triton-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2294514340cfe4e8f4f9e5c66c702744c4a117d25e618bd08469d0bfed1e2e5"},
< {file = "triton-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da58a152bddb62cafa9a857dd2bc1f886dbf9f9c90a2b5da82157cd2b34392b0"},
< {file = "triton-2.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af58716e721460a61886668b205963dc4d1e4ac20508cc3f623aef0d70283d5"},
< {file = "triton-2.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8fe46d3ab94a8103e291bd44c741cc294b91d1d81c1a2888254cbf7ff846dab"},
< {file = "triton-2.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8ce26093e539d727e7cf6f6f0d932b1ab0574dc02567e684377630d86723ace"},
< {file = "triton-2.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:227cc6f357c5efcb357f3867ac2a8e7ecea2298cd4606a8ba1e931d1d5a947df"},
< {file = "triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce4b8ff70c48e47274c66f269cce8861cf1dc347ceeb7a67414ca151b1822d8"},
< {file = "triton-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c3d9607f85103afdb279938fc1dd2a66e4f5999a58eb48a346bd42738f986dd"},
< {file = "triton-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:218d742e67480d9581bafb73ed598416cc8a56f6316152e5562ee65e33de01c0"},
< {file = "triton-2.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:381ec6b3dac06922d3e4099cfc943ef032893b25415de295e82b1a82b0359d2c"},
< {file = "triton-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:038e06a09c06a164fef9c48de3af1e13a63dc1ba3c792871e61a8e79720ea440"},
< {file = "triton-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d8f636e0341ac348899a47a057c3daea99ea7db31528a225a3ba4ded28ccc65"},
< ]
<
< [package.dependencies]
Expand Down
17 changes: 14 additions & 3 deletions lyrics_transcriber/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
log_level=logging.DEBUG,
log_formatter=None,
transcription_model="medium",
llm_model="gpt-4-1106-preview",
llm_model="gpt-4o",
llm_prompt_matching="lyrics_transcriber/llm_prompts/llm_prompt_lyrics_matching_andrew_handwritten_20231118.txt",
llm_prompt_correction="lyrics_transcriber/llm_prompts/llm_prompt_lyrics_correction_andrew_handwritten_20231118.txt",
render_video=False,
Expand Down Expand Up @@ -66,7 +66,15 @@ def __init__(
self.llm_model = llm_model
self.llm_prompt_matching = llm_prompt_matching
self.llm_prompt_correction = llm_prompt_correction

self.openai_client = OpenAI()

# Uncomment for local models e.g. with ollama
# self.openai_client = OpenAI(
# base_url="http://localhost:11434/v1",
# api_key="ollama",
# )

self.openai_client.log = self.log_level

self.render_video = render_video
Expand Down Expand Up @@ -391,8 +399,11 @@ def calculate_llm_costs(self):
},
}

input_cost = price_dollars_per_1000_tokens[self.llm_model]["input"] * (self.outputs["llm_token_usage"]["input"] / 1000)
output_cost = price_dollars_per_1000_tokens[self.llm_model]["output"] * (self.outputs["llm_token_usage"]["output"] / 1000)
input_price = price_dollars_per_1000_tokens.get(self.llm_model, {"input": 0, "output": 0})["input"]
output_price = price_dollars_per_1000_tokens.get(self.llm_model, {"input": 0, "output": 0})["output"]

input_cost = input_price * (self.outputs["llm_token_usage"]["input"] / 1000)
output_cost = output_price * (self.outputs["llm_token_usage"]["output"] / 1000)

self.outputs["llm_costs_usd"]["input"] = round(input_cost, 3)
self.outputs["llm_costs_usd"]["output"] = round(output_cost, 3)
Expand Down
4 changes: 2 additions & 2 deletions lyrics_transcriber/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def main():

parser.add_argument(
"--llm_model",
default="gpt-4-1106-preview",
help="Optional: LLM model to use (currently only supports OpenAI chat completion models, e.g. gpt-4-1106-preview). Default: gpt-3.5-turbo-1106",
default="gpt-4o",
help="Optional: LLM model to use (currently only supports OpenAI chat completion compatible models",
)

parser.add_argument(
Expand Down
Loading

0 comments on commit 6a273ee

Please sign in to comment.