-
Notifications
You must be signed in to change notification settings - Fork 112
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #316 from kartikbhtt7/dev
Implemented BERTopic Model for topic segmentation
- Loading branch information
Showing
7 changed files
with
166 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Base image | ||
FROM python:3.9 | ||
|
||
# Set the working directory | ||
WORKDIR /app | ||
|
||
# Install dependencies | ||
COPY requirements.txt . | ||
RUN pip install -r requirements.txt | ||
|
||
# Copy all source code | ||
COPY . . | ||
COPY . /app/ | ||
|
||
# Expose port for the server | ||
EXPOSE 8000 | ||
|
||
# Command to run the server | ||
CMD ["hypercorn", "--bind", "0.0.0.0:8000", "api:app"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
## BERTopic Topic Extraction Model | ||
|
||
### Purpose : | ||
Model to extract meaningful segmentations of a query dataset | ||
|
||
### Testing the model deployment : | ||
To run for testing of the model for topic head generation, follow the given below steps: | ||
|
||
- Git clone the repo | ||
- Go to current folder location i.e. ``` cd src/topic_modelling/BERTopic ``` | ||
- Create docker image file and test the api: | ||
#### (IMP) The input .csv file must have one column having preprocessed text and column name as 'text' | ||
''' | ||
docker build -t testmodel . | ||
docker run -p 8000:8000 testmodel | ||
curl -X POST -F "test.csv" http://localhost:8000/embed -o output4.csv | ||
''' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .request import * | ||
from .model import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import os | ||
import io | ||
import json | ||
import pandas as pd | ||
from quart import Quart, request, Response, send_file | ||
from model import Model | ||
from request import ModelRequest | ||
|
||
app = Quart(__name__) | ||
|
||
# Initialize the model to be used for inference. | ||
model = None | ||
|
||
@app.before_serving | ||
async def startup(): | ||
"""This function is called once before the server starts to initialize the model.""" | ||
global model | ||
model = Model(app) | ||
|
||
@app.route('/embed', methods=['POST']) | ||
async def embed(): | ||
"""This endpoint receives a CSV file, extracts text data from it, and uses the model to generate embeddings and topic information.""" | ||
global model | ||
|
||
files = await request.files # Get the uploaded files | ||
uploaded_file = files.get('file') # Get the uploaded CSV file | ||
|
||
if not uploaded_file: | ||
return Response(json.dumps({"error": "No file uploaded"}), status=400, mimetype='application/json') | ||
|
||
# Read the CSV file into a DataFrame | ||
csv_data = pd.read_csv(io.BytesIO(uploaded_file.stream.read())) | ||
|
||
# Extract the text data | ||
text_data = csv_data['text'].tolist() | ||
|
||
# Create a ModelRequest object with the extracted text data | ||
req = ModelRequest(text=text_data) | ||
|
||
# Call the model's inference method and get the response | ||
response = await model.inference(req) | ||
|
||
if response is None: | ||
# If an error occurred during inference, return an error response | ||
return Response(json.dumps({"error": "Inference error"}), status=500, mimetype='application/json') | ||
|
||
# Convert the CSV string from the response into a DataFrame | ||
df = pd.read_csv(io.StringIO(response)) | ||
|
||
# Save the DataFrame to a CSV file | ||
output_file_path = 'output.csv' | ||
df.to_csv(output_file_path, index=False) | ||
|
||
# Send the CSV file back as a download response | ||
return await send_file(output_file_path, mimetype='text/csv', as_attachment=True, attachment_filename='output.csv') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import pandas as pd | ||
from sentence_transformers import SentenceTransformer | ||
from bertopic import BERTopic | ||
from umap import UMAP | ||
from sklearn.feature_extraction.text import CountVectorizer | ||
import json | ||
import nltk | ||
from request import ModelRequest | ||
|
||
nltk.download("punkt") | ||
|
||
class Model: | ||
def __init__(self, context): | ||
self.context = context | ||
self.sentence_model = SentenceTransformer("all-MiniLM-L6-v2") | ||
self.vectorizer_model = CountVectorizer(stop_words="english") | ||
self.umap_model = UMAP(n_neighbors=15, min_dist=0.0, metric="cosine", random_state=69) | ||
# self.hdbscan_model = HDBSCAN(min_cluster_size=15, metric="euclidean", prediction_data=True) | ||
self.topic_model = BERTopic( | ||
umap_model = self.umap_model, | ||
# hdbscan_model = self.hdbscan_model, | ||
vectorizer_model = self.vectorizer_model, | ||
) | ||
|
||
async def inference(self, request: ModelRequest): | ||
text = request.text | ||
try: | ||
# Encode the text using SentenceTransformer | ||
corpus_embeddings = self.sentence_model.encode(text) | ||
|
||
# Fit the topic model | ||
topics, probabilities = self.topic_model.fit_transform(text, corpus_embeddings) | ||
|
||
# Get topic information and cluster labels | ||
df_classes = self.topic_model.get_topic_info() | ||
cluster_labels, _ = self.topic_model.transform(text, corpus_embeddings) | ||
|
||
df_result = pd.DataFrame({ | ||
"document_text": text, | ||
"predicted_class_label": cluster_labels, | ||
"probabilities": probabilities, | ||
}) | ||
|
||
# Mapping cluster names to topic labels | ||
cluster_names_map = dict(zip(df_classes["Topic"], df_classes["Name"])) | ||
df_result["predicted_class_name"] = df_result["predicted_class_label"].map(cluster_names_map) | ||
|
||
csv_string = df_result.to_csv(index=False) | ||
|
||
except Exception as e: | ||
# Log & print the error | ||
print(f"Error during inference: {e}") | ||
return None | ||
|
||
return csv_string | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import json | ||
|
||
class ModelRequest(): | ||
def __init__(self, text): | ||
self.text = text | ||
|
||
def to_json(self): | ||
return json.dumps(self, default=lambda o: o.__dict__, | ||
sort_keys=True, indent=4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
quart | ||
aiohttp | ||
pandas | ||
bertopic | ||
sentence_transformers | ||
numpy | ||
nltk | ||
scikit-learn |