Skip to content

Commit

Permalink
Add guard logic to prevent crash when no newlines in response (#45)
Browse files Browse the repository at this point in the history
* Add guard logic to prevent crash when no newlines in response

Closes #44

* Bump version to 0.6.2 in `setup.py`
  • Loading branch information
heyodai authored Nov 9, 2023
1 parent 0a3af1d commit 45f2c85
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 15 deletions.
35 changes: 21 additions & 14 deletions cli/magic_commit/magic_commit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ class OpenAIKeyError(Exception):
"""Custom exception for OpenAI API key errors."""



class Llama2ServerError(Exception):
"""Custom exception for Llama2 server errors."""
pass

pass


def is_git_repository(directory: str) -> bool:
Expand Down Expand Up @@ -190,27 +189,31 @@ def generate_commit_message(
# Call the Llama2 server
response = call_llama2_server(llama2_url, messages)
print(response)
response = response['choices'][0]['message']['content'].strip()
response = response["choices"][0]["message"]["content"].strip()
else:
# Use OpenAI's service
openai.api_key = api_key
response = openai.ChatCompletion.create(model=model, messages=messages)
response = response.choices[0].message.content.strip()

# Strip the first line of response
# Assign it to start if it is empty
# Otherwise, remove the first line from the response
if start:
response = response.split("\n", 1)[1]
# Split the response by newline and store the result
split_response = response.split("\n", 1)

# Check if split_response contains at least 2 elements
if len(split_response) > 1:
response = split_response[1] if start else split_response[0]
else:
start = response.split("\n", 1)[0]
response = response.split("\n", 1)[1]
# If there is no newline, the whole response is either the start or the generated message
if not start:
start = response
# If start is already set, we leave response as is, or set it to an empty string
else:
response = ""

# Render and return the template
return render_final_template(start, response, ticket).strip()



def call_llama2_server(url: str, messages: list) -> dict:
"""
Call the Llama2 server.
Expand All @@ -237,7 +240,9 @@ def call_llama2_server(url: str, messages: list) -> dict:
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
raise Llama2ServerError(f"An error occurred while connecting to the Llama2 server: {e}")
raise Llama2ServerError(
f"An error occurred while connecting to the Llama2 server: {e}"
)


def render_template(message: str, template_name: str) -> str:
Expand Down Expand Up @@ -322,7 +327,7 @@ def run_magic_commit(
api_key: str,
model: str,
show_loading_message: bool,
llama2_url: str = None
llama2_url: str = None,
) -> str:
"""
Generate a commit message and return it.
Expand Down Expand Up @@ -359,7 +364,9 @@ def run_magic_commit(
diff = run_git_diff(directory)
if not check_git_status(directory): # Check if there are staged changes
return "⛔ Warning: No staged changes detected. Please stage some changes before running magic-commit."
commit_message = generate_commit_message(diff, start, ticket, api_key, model, llama2_url)
commit_message = generate_commit_message(
diff, start, ticket, api_key, model, llama2_url
)
finally:
# Ensure the loading animation stops
if show_loading_message:
Expand Down
2 changes: 1 addition & 1 deletion cli/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="magic-commit",
version="0.6.1",
version="0.6.2",
packages=find_packages(),
include_package_data=True, # This line is needed to include non-code files
package_data={
Expand Down

0 comments on commit 45f2c85

Please sign in to comment.