-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathdemo.py
186 lines (153 loc) · 6.91 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import time
import os
import openai
import hydra
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from selenium.webdriver.support.relative_locator import locate_with
from tqdm import tqdm
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.common.by import By
from selenium import webdriver
from selenium.webdriver.common.action_chains import ActionChains
import logging
from IPython import embed
logger = logging.getLogger(__name__)
class actGPTEnv:
def __init__(self, executable_path, driver=None, user_data_dir='user_data', headless=True):
if driver is None:
chrome_options = webdriver.ChromeOptions()
chrome_options.add_argument(f"user-data-dir={user_data_dir}")
if headless:
chrome_options.add_argument("--headless")
self.driver = webdriver.Chrome(
executable_path, options=chrome_options)
else:
self.driver = driver
def get(self, url):
if not url.startswith("http"):
url = "http://" + url
self.driver.get(url)
time.sleep(3)
def find_nearest_textbox(self, element):
try:
textbox = self.driver.find_element(locate_with(
By.XPATH, "//div[@role = 'textbox']").near(element))
except:
textbox = self.driver.find_element(
locate_with(By.TAG_NAME, "input").near(element))
return textbox
def find_nearest_text(self, element):
try:
textbox = self.driver.find_element(locate_with(
By.XPATH, "//*[text() != '']").near(element))
except:
return ""
return textbox.text
def find_nearest(self, e, xpath):
try:
return self.driver.find_element(locate_with(
By.XPATH, xpath).near(e))
except:
return self.driver.find_element(locate_with(
By.XPATH, xpath).below(e))
def send_keys(self, keys):
ActionChains(self.driver).pause(1).send_keys(keys).pause(1).perform()
def click(self, element):
ActionChains(self.driver).pause(1).move_to_element(
element).pause(1).click(element).perform()
def get_observation(self):
elements = self.driver.find_elements(By.XPATH,
"//div[@role != '']|//button")
observation = []
for element in elements:
observation.append(
{"type": "button", "text": element.text, "element": element})
return observation
def is_button(self, element):
return element.tag_name == "button" or element.get_attribute("role") == "button"
def is_textbox(self, element):
return element.tag_name == "input" or element.get_attribute("role") == "textbox"
# This function calls the OpenAI API and returns the generated response
# This function takes in a prompt and model name
# The prompt is a string that is passed to the API to generate the response
# The model name is the name of the model that is used to generate the response (code-davinci-002)
# The function returns a string that is the generated response
def get_openai_response(self, prompt, model="text-davinci-003"):
# First, we call the OpenAI API to generate the response
if 'write code' not in prompt:
temperature = 0.7
lines = prompt.splitlines()
if len(lines) > 10:
prompt = " ".join(lines)[:300]
else:
temperature = 0
response = openai.Completion.create(
model=model,
prompt=prompt,
temperature=temperature,
max_tokens=512,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
stop=["```"],
best_of=3,
)
# Next, we extract the response that was generated by the API
text = response["choices"][0]["text"]
logger.info('\n' + colored(text.format(), 'blue'))
# Finally, we return the response
return text
def get_prompt_selenium(instruction, code=False):
prompt = f"""
You have an instance `env` with the following methods:
- `env.driver.find_elements(by='id', value=None)` which finds and returns list of WebElement. The arguement `by` is a string that specifies the locator strategy. The arguement `value` is a string that specifies the locator value. `by` is usually `xpath` and `value` is the xpath of the element.
- `env.find_nearest(e, xpath)` can only be used to locate an element that matches the xpath near element e.
- `env.send_keys(text)` is only used to type in string `text`. string ENTER is Keys.ENTER
- `env.get(url)` goes to url.
- `env.get_openai_response(text)` that ask AI about a string `text`.
- `env.click(element)` clicks the element.
WebElement has functions:
1. `element.text` returns the text of the element
2. `element.get_attribute(attr)` returns the value of the attribute of the element. If the attribute does not exist, it returns ''.
3. `element.find_elements(by='id', value=None)` it's the same as `env.driver.find_elements()` except that it only searches the children of the element.
4. `element.is_displayed()` returns if the element is visible
The xpath of a textbox is usually "//div[@role = 'textarea']|//div[@role = 'textbox']|//input".
The xpath of text is usually "//*[string-length(text()) > 0]".
The xpath for a button is usually "//div[@role = 'button']|//button".
The xpath for an element whose text is "text" is "//*[text() = 'text']".
The xpath for the tweet is "//span[contains(text(), '')]".
The xpath for the like button is "//div[@role != '' and @data-testid='like']|//button".
The xpath for the unlike button is "//div[@role != '' and @data-testid='unlike']|//button".
Your code must obey the following constraints:
1. respect the lowercase and uppercase letters in the instruction.
2. Does not call any functions besides those given above and those defined by the base language spec.
3. has correct indentation.
4. only write code
5. only do what I instructed you to do.
{instruction}
```python"""
if code:
prompt = '"'*3 + prompt + '"'*3
return prompt
@ hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg: DictConfig):
openai.api_key = cfg.OPENAI_API_KEY
env = actGPTEnv(cfg.executable_path,
user_data_dir=cfg.user_data_dir, headless=False)
ldict = {"env": env}
while True:
inp = ''
print("\nenter instruction:")
while True:
dummy = input()+'\n'
if dummy == '\n':
break
inp += dummy
prompt = get_prompt_selenium(inp, code=False)
# - find all tweet that is longer than 20 characters. For each of them:
text = env.get_openai_response(prompt, model="text-davinci-003")
text = text.replace("```", "")
exec(text, globals(), ldict)
if __name__ == "__main__":
main()