|
| 1 | +import sys |
| 2 | +import pandas as pd |
| 3 | +from tensorflow.keras.preprocessing.text import Tokenizer |
| 4 | +from tensorflow.keras.preprocessing.sequence import pad_sequences |
| 5 | +from tensorflow.keras.models import Sequential |
| 6 | +from tensorflow.keras.layers import ( |
| 7 | + Bidirectional, |
| 8 | + Conv1D, |
| 9 | + Dense, |
| 10 | + Embedding, |
| 11 | + Flatten, |
| 12 | + LSTM, |
| 13 | + MaxPooling1D, |
| 14 | +) |
| 15 | +from tensorflow.keras.metrics import Accuracy, Recall, Precision |
| 16 | +from tensorflow.keras.callbacks import EarlyStopping |
| 17 | +from sklearn.model_selection import train_test_split |
| 18 | +from sklearn.metrics import ( |
| 19 | + accuracy_score, |
| 20 | + recall_score, |
| 21 | + precision_score, |
| 22 | + f1_score, |
| 23 | + confusion_matrix, |
| 24 | +) |
| 25 | +import numpy as np |
| 26 | +import matplotlib.pyplot as plt |
| 27 | + |
| 28 | + |
| 29 | +# Check if the input file and output directory are provided |
| 30 | +if len(sys.argv) != 3: |
| 31 | + print("Usage: python train.py <input_file> <output_dir>") |
| 32 | + sys.exit(1) |
| 33 | + |
| 34 | +# Load dataset |
| 35 | +data = pd.read_csv(sys.argv[1]) |
| 36 | + |
| 37 | +# Define parameters |
| 38 | +MAX_WORDS = 10000 |
| 39 | +MAX_LEN = 100 |
| 40 | + |
| 41 | +# Use Tokenizer to encode text |
| 42 | +tokenizer = Tokenizer(num_words=MAX_WORDS, filters="") |
| 43 | +tokenizer.fit_on_texts(data["Query"]) |
| 44 | +sequences = tokenizer.texts_to_sequences(data["Query"]) |
| 45 | + |
| 46 | +# Pad the text sequence |
| 47 | +X = pad_sequences(sequences, maxlen=MAX_LEN) |
| 48 | + |
| 49 | +# Split the training set and test set |
| 50 | +y = data["Label"] |
| 51 | +X_train, X_test, y_train, y_test = train_test_split( |
| 52 | + X, y, test_size=0.2, random_state=42 |
| 53 | +) |
| 54 | + |
| 55 | +# Create CNN-BiLSTM model |
| 56 | +model = Sequential() |
| 57 | +model.add(Embedding(MAX_WORDS, 128)) |
| 58 | +model.add(Conv1D(filters=64, kernel_size=3, padding="same", activation="relu")) |
| 59 | +model.add(MaxPooling1D(pool_size=2)) |
| 60 | +model.add(Bidirectional(LSTM(64, dropout=0.2, recurrent_dropout=0.2))) |
| 61 | +model.add(Flatten()) |
| 62 | +model.add(Dense(1, activation="sigmoid")) |
| 63 | + |
| 64 | +model.compile( |
| 65 | + loss="binary_crossentropy", |
| 66 | + optimizer="adam", |
| 67 | + metrics=[ |
| 68 | + Accuracy(), |
| 69 | + Recall(), |
| 70 | + Precision(), |
| 71 | + ], |
| 72 | +) |
| 73 | + |
| 74 | +# Define early stopping callback with a rollback of 5 |
| 75 | +early_stopping = EarlyStopping( |
| 76 | + monitor="val_loss", patience=5, restore_best_weights=True |
| 77 | +) |
| 78 | + |
| 79 | +# Train model with early stopping |
| 80 | +history = model.fit( |
| 81 | + X_train, |
| 82 | + y_train, |
| 83 | + epochs=50, # Maximum number of epochs |
| 84 | + batch_size=32, |
| 85 | + validation_data=(X_test, y_test), |
| 86 | + callbacks=[early_stopping], |
| 87 | + verbose=1, |
| 88 | +) |
| 89 | + |
| 90 | +# Predict test set |
| 91 | +y_pred = model.predict(X_test, verbose=1) |
| 92 | +y_pred_classes = np.argmax(y_pred, axis=1) |
| 93 | + |
| 94 | +# Calculate model performance indicators |
| 95 | +accuracy = accuracy_score(y_test, y_pred_classes) |
| 96 | +recall = recall_score(y_test, y_pred_classes, zero_division=1) |
| 97 | +precision = precision_score(y_test, y_pred_classes, zero_division=1) |
| 98 | +f1 = f1_score(y_test, y_pred_classes, zero_division=1) |
| 99 | +tn, fp, fn, tp = confusion_matrix(y_test, y_pred_classes).ravel() |
| 100 | + |
| 101 | +# Output performance indicators |
| 102 | +print("Accuracy: {:.2f}%".format(accuracy * 100)) |
| 103 | +print("Recall: {:.2f}%".format(recall * 100)) |
| 104 | +print("Precision: {:.2f}%".format(precision * 100)) |
| 105 | +print("F1-score: {:.2f}%".format(f1 * 100)) |
| 106 | +print("Specificity: {:.2f}%".format(tn / (tn + fp) * 100)) |
| 107 | +print("ROC: {:.2f}%".format(tp / (tp + fn) * 100)) |
| 108 | + |
| 109 | +# Save model as SavedModel format |
| 110 | +model.export(sys.argv[2]) |
| 111 | + |
| 112 | + |
| 113 | +# Plot the training history |
| 114 | +def plot_history(history): |
| 115 | + plt.figure(figsize=(12, 8)) |
| 116 | + |
| 117 | + # Plot loss |
| 118 | + plt.subplot(2, 2, 1) |
| 119 | + plt.plot(history.history["loss"], label="Training Loss") |
| 120 | + plt.plot(history.history["val_loss"], label="Validation Loss") |
| 121 | + plt.title("Loss") |
| 122 | + plt.xlabel("Epochs") |
| 123 | + plt.ylabel("Loss") |
| 124 | + plt.legend() |
| 125 | + |
| 126 | + # Plot accuracy |
| 127 | + plt.subplot(2, 2, 2) |
| 128 | + plt.plot(history.history["accuracy"], label="Training Accuracy") |
| 129 | + plt.plot(history.history["val_accuracy"], label="Validation Accuracy") |
| 130 | + plt.title("Accuracy") |
| 131 | + plt.xlabel("Epochs") |
| 132 | + plt.ylabel("Accuracy") |
| 133 | + plt.legend() |
| 134 | + |
| 135 | + # Plot precision |
| 136 | + plt.subplot(2, 2, 3) |
| 137 | + plt.plot(history.history["precision"], label="Training Precision") |
| 138 | + plt.plot(history.history["val_precision"], label="Validation Precision") |
| 139 | + plt.title("Precision") |
| 140 | + plt.xlabel("Epochs") |
| 141 | + plt.ylabel("Precision") |
| 142 | + plt.legend() |
| 143 | + |
| 144 | + # Plot recall |
| 145 | + plt.subplot(2, 2, 4) |
| 146 | + plt.plot(history.history["recall"], label="Training Recall") |
| 147 | + plt.plot(history.history["val_recall"], label="Validation Recall") |
| 148 | + plt.title("Recall") |
| 149 | + plt.xlabel("Epochs") |
| 150 | + plt.ylabel("Recall") |
| 151 | + plt.legend() |
| 152 | + |
| 153 | + plt.tight_layout() |
| 154 | + plt.savefig("training_history.png") |
| 155 | + |
| 156 | + |
| 157 | +plot_history(history) |
0 commit comments