forked from cbitosc/HTF24-Team-155
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
35dfe36
commit 49dadda
Showing
8 changed files
with
732 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,28 @@ | ||
from flask import Flask, request, jsonify | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
# Initialize Flask app | ||
app = Flask(__name__) | ||
|
||
@app.route('/') | ||
def home(): | ||
return "<h1>Welcome to the Mental Health Chatbot</h1><p>Use /chat to interact.</p>" | ||
|
||
@app.route('/chat', methods=['POST']) | ||
def chat(): | ||
user_input = request.json.get('message') | ||
response = get_response(user_input) # Call the response function | ||
return jsonify({'response': response}) | ||
|
||
def get_response(user_input): | ||
# Simple logic based on keywords | ||
if "anxiety" in user_input.lower(): | ||
return "It's okay to feel anxious. Consider practicing deep breathing." | ||
elif "stress" in user_input.lower(): | ||
return "Stress can be managed with meditation and exercise." | ||
elif "depression" in user_input.lower(): | ||
return "It's important to talk to someone. You're not alone." | ||
else: | ||
return "I'm here to help. Can you tell me more about what you're feeling?" | ||
# Load the fine-tuned model and tokenizer | ||
# Change this line: | ||
|
||
# To this line (removing the './'): | ||
model = AutoModelForCausalLM.from_pretrained('fine_tuned_model') | ||
|
||
tokenizer = AutoTokenizer.from_pretrained('./fine_tuned_model') | ||
|
||
@app.route('/predict', methods=['POST']) | ||
def predict(): | ||
data = request.get_json() | ||
input_text = data['text'] # Get the input text from the request | ||
inputs = tokenizer.encode(input_text, return_tensors='pt') | ||
|
||
# Generate predictions | ||
outputs = model.generate(inputs, max_length=100) | ||
predicted_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | ||
|
||
return jsonify({'response': predicted_text}) | ||
|
||
if __name__ == '__main__': | ||
app.run(debug=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.