33
44from langchain .base_language import BaseLanguageModel
55from langchain .chat_models .openai import ChatOpenAI
6+ from langchain .chat_models .anthropic import ChatAnthropic
67from langchain .schema import AIMessage , OutputParserException
78
8- from codeinterpreterapi .prompts import determine_modifications_function , determine_modifications_prompt
9+ from codeinterpreterapi .prompts import determine_modifications_prompt
910
1011
1112async def get_file_modifications (
@@ -15,44 +16,44 @@ async def get_file_modifications(
1516) -> Optional [List [str ]]:
1617 if retry < 1 :
1718 return None
18- messages = determine_modifications_prompt .format_prompt (code = code ).to_messages ()
19- message = await llm .apredict_messages (messages , functions = [determine_modifications_function ])
2019
21- if not isinstance (message , AIMessage ):
22- raise OutputParserException ("Expected an AIMessage" )
20+ prompt = determine_modifications_prompt .format (code = code )
2321
24- function_call = message . additional_kwargs . get ( "function_call" , None )
22+ result = await llm . apredict ( prompt , stop = "```" )
2523
26- if function_call is None :
24+
25+ try :
26+ result = json .loads (result )
27+ except json .JSONDecodeError :
28+ result = ""
29+ if not result or not isinstance (result , dict ) or "modifications" not in result :
2730 return await get_file_modifications (code , llm , retry = retry - 1 )
28- else :
29- function_call = json .loads (function_call ["arguments" ])
30- return function_call ["modifications" ]
31-
31+ return result ["modifications" ]
32+
3233
3334async def test ():
34- llm = ChatOpenAI (model = "gpt-3.5" ) # type: ignore
35-
36- code = """
37- import matplotlib.pyplot as plt
38-
39- x = list(range(1, 11))
40- y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77]
41-
42- plt.plot(x, y, marker='o')
43- plt.xlabel('Index')
44- plt.ylabel('Value')
45- plt.title('Data Plot')
46-
47- plt.show()
48- """
49-
35+ llm = ChatAnthropic (model = "claude-1.3" ) # type: ignore
36+
37+ code = \
38+ """
39+ import matplotlib.pyplot as plt
40+
41+ x = list(range(1, 11))
42+ y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77]
43+
44+ plt.plot(x, y, marker='o')
45+ plt.xlabel('Index')
46+ plt.ylabel('Value')
47+ plt.title('Data Plot')
48+
49+ plt.show()
50+ """
51+
5052 print (await get_file_modifications (code , llm ))
51-
53+
5254
5355if __name__ == "__main__" :
54- import asyncio
55- from dotenv import load_dotenv
56- load_dotenv ()
56+ import asyncio , dotenv
57+ dotenv .load_dotenv ()
5758
5859 asyncio .run (test ())
0 commit comments