Skip to content

Commit

Permalink
feat: add guided mode and tiktoken (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
fynnfluegge authored Aug 26, 2023
1 parent 026bc94 commit ed01c8e
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 6 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Focus on writing your code, let AI write the documentation for you. With just a

## ✨ Features
- Create documentation comment blocks for all methods in a file
- e.g. Javadoc, JSDoc, Docstring, Rustdoc
- Create inline documentation comments in method bodies
- Treesitter integration

Expand All @@ -29,6 +30,7 @@ Focus on writing your code, let AI write the documentation for you. With just a
- `aicomments <RELATIVE_FILE_PATH>`: Create documentations for any method in the file which doesn't have any yet.
- `aicomments <RELATIVE_FILE_PATH> --inline`: Create also documentation comments in the method body.
- `aicomments <RELATIVE_FILE_PATH> --gpt4`: Use GPT-4 model (Default is GPT-3.5).
- `aicomments <RELATIVE_FILE_PATH> --guided`: Guided mode, confirm documentation generation for each method.

## ⚙️ Supported Languages
- [x] Python
Expand Down
23 changes: 22 additions & 1 deletion doc_comments_ai/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import sys
from doc_comments_ai import utils, llm, domain
from doc_comments_ai.llm import GptModel
from doc_comments_ai.treesitter.treesitter import (
Treesitter,
TreesitterNode,
Expand Down Expand Up @@ -31,6 +32,11 @@ def run():
action="store_true",
help="Uses GPT-4 (default GPT-3.5).",
)
parser.add_argument(
"--guided",
action="store_true",
help="User will get asked to confirm the doc generation for each method.",
)

if sys.argv.__len__() < 2:
sys.exit("Please provide a file")
Expand All @@ -46,7 +52,7 @@ def run():
sys.exit(f"File {utils.get_bold_text(file_name)} has unstaged changes")

if args.gpt4:
llm_wrapper = llm.LLM(model="gpt-4")
llm_wrapper = llm.LLM(model=GptModel.GPT_4)
else:
llm_wrapper = llm.LLM()

Expand All @@ -64,6 +70,12 @@ def run():

for node in treesitterNodes:
method_name = utils.get_bold_text(node.name)

if args.guided:
print(f"Generate doc for {utils.get_bold_text(method_name)}? (y/n)")
if not input().lower() == "y":
continue

if node.doc_comment:
print(
f"⚠️ Method {method_name} already has a doc comment. Skipping..."
Expand All @@ -72,6 +84,15 @@ def run():

method_source_code = get_source_from_node(node.node)

tokens = utils.count_tokens(method_source_code)
if tokens > 2048 and not args.gpt4:
print(
f"⚠️ Method {method_name} has too many tokens. "
f"Consider using {utils.get_bold_text('--gpt4')}. "
"Skipping for now..."
)
continue

spinner = yaspin(text=f"🔧 Generating doc comment for {method_name}...")
spinner.start()

Expand Down
11 changes: 9 additions & 2 deletions doc_comments_ai/llm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from enum import Enum
from langchain.chat_models import ChatOpenAI
from langchain import PromptTemplate, LLMChain


class GptModel(Enum):
GPT_35 = "gpt-3.5-turbo"
GPT_4 = "gpt-4"


class LLM:
def __init__(self, model: str = "gpt-3.5-turbo"):
self.llm = ChatOpenAI(temperature=0.9, max_tokens=2048, model=model)
def __init__(self, model: GptModel = GptModel.GPT_35):
max_tokens = 2048 if model == GptModel.GPT_35 else 4096
self.llm = ChatOpenAI(temperature=0.9, max_tokens=max_tokens, model=model.value)
self.template = (
"I have this {language} method:\n{code}\nAdd a doc comment to the method. "
"Return the method with the doc comment as a markdown code block. "
Expand Down
12 changes: 10 additions & 2 deletions doc_comments_ai/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
import subprocess
import tiktoken
from doc_comments_ai.constants import Language


Expand Down Expand Up @@ -97,10 +98,10 @@ def get_bold_text(text):
def has_unstaged_changes(file):
"""
Check if the given file has any unstaged changes in the Git repository.
Args:
file (str): The file to check for unstaged changes.
Returns:
bool: True if the file has unstaged changes, False otherwise.
"""
Expand All @@ -110,3 +111,10 @@ def has_unstaged_changes(file):
return False # No unstaged changes
except subprocess.CalledProcessError:
return True # Unstaged changes exist


# Return the number of tokens in a string
def count_tokens(text):
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
tokenized = encoding.encode(text)
return len(tokenized)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "doc-comments-ai"
version = "0.1.3"
version = "0.1.4"
description = ""
authors = ["fynnfluegge <[email protected]>"]
readme = "README.md"
Expand Down

0 comments on commit ed01c8e

Please sign in to comment.