-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathMusicLM.py
147 lines (114 loc) · 5.1 KB
/
MusicLM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import requests
import os
import dotenv
import datetime
import json
import base64
import logging
from time import sleep
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from seleniumbase import Driver
logging.basicConfig(level=logging.INFO)
dotenv_file = dotenv.find_dotenv()
dotenv.load_dotenv(dotenv_file)
class Music:
musiclm_url = "https://content-aisandbox-pa.googleapis.com/v1:soundDemo?alt=json"
url = 'https://aitestkitchen.withgoogle.com/experiments/music-lm'
def __init__(self):
self.email = os.environ["EMAIL"]
self.password = os.environ["PASSWORD"]
if os.environ["TOKEN"] == "":
self.token = self.get_token()
elif self.token_refresh():
self.token = self.get_token()
else:
self.token = os.environ["TOKEN"]
def get_tracks(self, input, generationCount):
if not isinstance(generationCount, int):
generationCount = 2
generationCount = min(8, max(1, generationCount))
payload = json.dumps({
"generationCount": generationCount,
"input": {
"textInput": input
},
"soundLengthSeconds": 30 # this doesn't change anything
})
headers = {
'Authorization': f'Bearer {self.token}'
}
try:
response = requests.post(self.musiclm_url, headers=headers, data=payload)
except requests.exceptions.ConnectionError:
logging.error("Can't connect to the server.")
# Bad Gateway
return 502
if response.status_code == 400:
logging.error("Oops, can't generate audio for that.")
# Bad Request
return 400
tracks = []
for sound in response.json()['sounds']:
tracks.append(sound["data"])
return tracks
def b64toMP3(self, tracks_list, filename):
count = 0
new_filename = filename
while os.path.exists(new_filename):
count += 1
new_filename = f'{filename} ({count})'
os.mkdir(new_filename)
for i, track in enumerate(tracks_list):
with open(f"{new_filename}/track{i+1}.mp3", "wb") as f:
f.write(base64.b64decode(track))
logging.info("Tracks successfully generated!")
# Successful Request
return 200
def get_token(self):
driver = Driver(uc=True, headless=True)
try:
driver.get(self.url)
try:
WebDriverWait(driver, 20).until(EC.presence_of_element_located((By.XPATH, "//button[contains(@class, 'evTGAR')]"))).click()
WebDriverWait(driver, 20).until(EC.presence_of_element_located((By.XPATH, "//button[contains(@class, 'lgEhqE')]"))).click()
driver.switch_to.window(driver.window_handles[1])
logging.info('Logging in...')
WebDriverWait(driver, 20).until(EC.visibility_of_element_located((By.NAME, 'identifier'))).send_keys(f'{self.email}\n')
WebDriverWait(driver, 20).until(EC.visibility_of_element_located((By.NAME, 'Passwd'))).send_keys(f'{self.password}\n')
logging.info('Successfully logged in')
driver.switch_to.window(driver.window_handles[0])
except Exception as e:
logging.ERROR("An error occurred while interacting with the webpage, details: " + str(e))
raise Exception("Unable to fetch token due to Selenium interaction error")
sleep(5)
logging.info('Getting OAuth 2.0 token')
cookies = driver.get_cookies()
except Exception as e:
logging.ERROR("An error occurred while fetching the token, details: " + str(e))
raise Exception("Unable to fetch token due to browser error")
finally:
driver.quit()
token_cookie = next((cookie['value'] for cookie in cookies if cookie['name'] == 'TOKEN'), None)
if token_cookie is None:
raise Exception("Unable to obtain token")
start_sub = "ya29"
end_sub = "%22"
start_idx = token_cookie.index(start_sub)
end_idx = token_cookie.index(end_sub, start_idx)
token = token_cookie[start_idx:end_idx]
dotenv.set_key(dotenv_file, "TOKEN", str(token))
os.environ["TOKEN"] = str(token)
dotenv.set_key(dotenv_file, "EXPIRATION_TIMESTAMP", str(datetime.datetime.now() + datetime.timedelta(minutes=59)))
os.environ["EXPIRATION_TIMESTAMP"] = str(datetime.datetime.now() + datetime.timedelta(minutes=59))
logging.info('OAuth 2.0 token obtained')
return token
def token_refresh(self):
current_timestamp = datetime.datetime.now().replace(microsecond=0)
expiration_timestamp = datetime.datetime.strptime(os.environ['EXPIRATION_TIMESTAMP'], '%Y-%m-%d %H:%M:%S.%f')
difference = current_timestamp - expiration_timestamp
if difference >= datetime.timedelta(minutes=0):
return True
else:
return False