Skip to content

Commit

Permalink
feat: skip modify files with unstaged changes
Browse files Browse the repository at this point in the history
  • Loading branch information
fynnfluegge committed Aug 25, 2023
1 parent 9f0f350 commit d4a7204
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 43 deletions.
11 changes: 8 additions & 3 deletions doc_comments_ai/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def run():
file_name = args.dir

if not os.path.isfile(file_name):
sys.exit(f"File {file_name} does not exist")
sys.exit(f"File {utils.get_bold_text(file_name)} does not exist")

if utils.has_unstaged_changes(file_name):
sys.exit(f"File {utils.get_bold_text(file_name)} has unstaged changes")

if args.gpt4:
llm_wrapper = llm.LLM(model="gpt-4")
Expand All @@ -62,7 +65,9 @@ def run():
for node in treesitterNodes:
method_name = utils.get_bold_text(node.name)
if node.doc_comment:
print(f"Method {method_name} already has a doc comment. Skipping...")
print(
f"⚠️ Method {method_name} already has a doc comment. Skipping..."
)
continue

method_source_code = get_source_from_node(node.node)
Expand All @@ -77,7 +82,7 @@ def run():
generated_doc_comments[
method_source_code
] = utils.extract_content_from_markdown_code_block(
documented_method_source_code, programming_language.value
documented_method_source_code
)

spinner.stop()
Expand Down
76 changes: 39 additions & 37 deletions doc_comments_ai/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import os
import re
import subprocess
from doc_comments_ai.constants import Language


def get_programming_language(file_extension: str) -> Language:
"""
Returns the corresponding programming language based on the given file extension.
Retrieves the programming language based on the given file extension.
Args:
file_extension (str): The file extension of the programming file.
file_extension (str): The file extension to determine the programming language.
Returns:
Language: The corresponding programming language if it exists in the mapping, otherwise Language.UNKNOWN.
Language: The programming language associated with the given file extension.
"""
language_mapping = {
".py": Language.PYTHON,
Expand All @@ -28,40 +29,42 @@ def get_programming_language(file_extension: str) -> Language:

def get_file_extension(file_name: str) -> str:
"""
Returns the extension of a file.
Return the extension of the file.
Args:
file_name (str): The name of the file including its extension.
Parameters:
file_name (str): The name of the file.
Returns:
str: The extension of the file.
str: The file extension.
"""
return os.path.splitext(file_name)[-1]


def write_code_snippet_to_file(file_path: str, original_code: str, modified_code: str):
"""
This function replaces the code snippet in the file with the modified code snippet
Replace the original code snippet with the modified code in the given file.
Args:
file_path (str): The path to the file.
original_code (str): The code snippet to be replaced.
modified_code (str): The code snippet to replace the original code.
Returns:
None
"""
with open(file_path, "r") as file:
file_content = file.read()
start_pos = file_content.find(original_code)
if start_pos != -1: # Check if code_string is found in the original content
# Calculate the end position of code_string
if start_pos != -1:
end_pos = start_pos + len(original_code)

# Replace code_string with modified_code_string in the original content
modified_content = (
file_content[:start_pos] + modified_code + file_content[end_pos:]
)

# Open the file in write mode
with open(file_path, "w", encoding="utf-8") as file:
# Write the modified content to the file
file.write(modified_content)


def extract_content_from_markdown_code_block(markdown_code_block, language) -> str:
def extract_content_from_markdown_code_block(markdown_code_block) -> str:
"""
Extracts the content from a markdown code block inside a string.
Expand All @@ -76,35 +79,34 @@ def extract_content_from_markdown_code_block(markdown_code_block, language) -> s
match = re.search(pattern, markdown_code_block, re.DOTALL)

if match:
# TODO fix this
# sometimes the doc comment has ``` block itself, which will break
# the regex pattern. In this case, we need to extract the all
# subsequent ``` blocks and append them to the first one
# subsequent_matches = re.findall("```\n(.*?)```", markdown_code_block, re.DOTALL)
# if subsequent_matches:
# # join all subsequent code blocks
# subsequent_code = "\n".join(subsequent_matches).strip()
# # append the last block
# last_match = re.findall("```(.*?)```", markdown_code_block, re.DOTALL)
# if last_match:
# last_code_block = last_match[-1].strip()
# subsequent_code += "\n" + last_code_block
# # return the first code block + subsequent code blocks
# return match.group(1).strip() + "\n" + subsequent_code

return match.group(1).strip()
else:
return markdown_code_block.strip()


def get_bold_text(text):
"""
Returns the specified text formatted in bold.
Returns the provided text in bold format.
:param text: The text to be formatted.
:return: The formatted text.
"""
return f"\033[01m{text}\033[0m"


def has_unstaged_changes(file):
"""
Check if the given file has any unstaged changes in the Git repository.
Parameters:
- text (str): The text to be formatted.
Args:
file (str): The file to check for unstaged changes.
Returns:
str: The input text formatted in bold.
bool: True if the file has unstaged changes, False otherwise.
"""
return f"\033[01m{text}\033[0m"
try:
# Run the "git diff --quiet" command and capture its output
subprocess.check_output(["git", "diff", "--quiet", file])
return False # No unstaged changes
except subprocess.CalledProcessError:
return True # Unstaged changes exist
6 changes: 3 additions & 3 deletions tests/response_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@pytest.mark.usefixtures("response_fixture")
def test_response_parser(response_fixture):
markdown_code_block = utils.extract_content_from_markdown_code_block(
response_fixture, "python"
response_fixture
)
assert (
markdown_code_block
Expand All @@ -27,7 +27,7 @@ def test_response_parser(response_fixture):
@pytest.mark.usefixtures("response_fixture_language_enclosed")
def test_response_parser_language_enclosed(response_fixture_language_enclosed):
markdown_code_block = utils.extract_content_from_markdown_code_block(
response_fixture_language_enclosed, "python"
response_fixture_language_enclosed
)
assert (
markdown_code_block
Expand All @@ -49,7 +49,7 @@ def test_response_parser_language_enclosed(response_fixture_language_enclosed):
@pytest.mark.usefixtures("response_fixture_with_text")
def test_response_parser_with_text(response_fixture_with_text):
markdown_code_block = utils.extract_content_from_markdown_code_block(
response_fixture_with_text, "python"
response_fixture_with_text
)
assert (
markdown_code_block
Expand Down

0 comments on commit d4a7204

Please sign in to comment.