-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
122 lines (94 loc) · 2.96 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import requests
from dotenv import load_dotenv
from flask import Flask, request, jsonify, render_template
from flask_cors import cross_origin
from celery import Celery
import openai
app = Flask(__name__)
load_dotenv()
# Configure Celery here we use radis
app.config['CELERY_BROKER_URL'] = os.getenv('CELERY_BROKER_URL')
app.config['CELERY_RESULT_BACKEND'] = os.getenv('CELERY_RESULT_BACKEND')
celery = Celery(app.name, broker=app.config['CELERY_BROKER_URL'])
celery.conf.update(app.config)
# GPT-3 endpoint and credentials
gpt3_endpoint = "https://api.openai.com/v1/engines/text-davinci-003/completions"
gpt3_image_endpoint = "https://api.openai.com/v1/images/generations"
gpt3_api_key = os.getenv("OPEN_AI_KEY")
openai.api_key = gpt3_api_key
@celery.task
def generate_text(prompt):
# Send request to GPT-3
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {gpt3_api_key}"
}
data = {
"prompt": prompt,
"temperature": 0.5,
"max_tokens": 128
}
response = requests.post(gpt3_endpoint, headers=headers, json=data)
# Return response
return response.json()
@celery.task
def generate_image(prompt, number, image_size, image_width):
response = openai.Image.create(
prompt=prompt,
n=number,
size=str(image_size) + "x" + str(image_width)
)
image_url = response['data']
return image_url
# app.template_folder = 'templates'
# @app.route("/")
# def index():
# """
# Index / Main page
# :return: html
# """
#
# return render_template('chatbot.html', name="landing")
@app.route('/chat', methods=['POST'])
@cross_origin()
def chat():
# Get prompt from client
prompt = request.json.get('prompt')
# Run GPT-3 task asynchronously
task = generate_text.apply_async(args=(prompt,))
# Return task id
return jsonify({'task_id': task.id})
@app.route('/result/<task_id>', methods=['GET'])
@cross_origin()
def result(task_id):
# Get task result
response = generate_text.AsyncResult(task_id).get()
result = response['choices'][0]['text']
# Return response
return jsonify({
"data": result
})
@app.route('/image_chat', methods=['POST'])
@cross_origin()
def image_chat():
# Get prompt from client
prompt = request.json.get('prompt', "cartoon cat")
number = request.json.get('number', 1)
image_size = request.json.get('image_size', 1024)
image_width = request.json.get('image_width', 1024)
# Run GPT-3 task asynchronously
task = generate_image.apply_async(args=(prompt, number, image_size, image_width))
# Return task id
return jsonify({'task_id': task.id})
@app.route('/image/<task_id>', methods=['GET'])
@cross_origin()
def image_result(task_id):
# Get task result
response = generate_image.AsyncResult(task_id).get()
# Return response
return jsonify({
"data": response
})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)