Skip to content

Commit

Permalink
Merge pull request #10 from gatewayd-io/v3-cnn-bilstm
Browse files Browse the repository at this point in the history
Model v3: CNN BiLSTM
  • Loading branch information
mostafa authored May 21, 2024
2 parents 430917f + f9adeb6 commit fce6114
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 10 deletions.
Binary file added sqli_model/3/assets.extra/tf_serving_warmup_requests
Binary file not shown.
1 change: 1 addition & 0 deletions sqli_model/3/fingerprint.pb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
��鯴���'������n�������! ����Ɔ��K(����䈽��2
Binary file added sqli_model/3/saved_model.pb
Binary file not shown.
Binary file not shown.
Binary file added sqli_model/3/variables/variables.index
Binary file not shown.
26 changes: 16 additions & 10 deletions training/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@
"model_path": "sqli_model/2",
"index": 1,
}
MODELV3 = {
"dataset": "dataset/sqli_dataset2.csv",
"model_path": "sqli_model/3",
"index": 2,
}


@pytest.fixture(
params=[
MODELV1,
MODELV2,
MODELV3,
],
)
def model(request):
Expand All @@ -47,18 +53,18 @@ def model(request):
@pytest.mark.parametrize(
"sample",
[
("select * from users where id=1 or 1=1;", [99.99, 97.40]),
("select * from users where id='1' or 1=1--", [92.02, 97.40]),
("select * from users", [0.077, 0.015]),
("select * from users where id=10000", [14.83, 88.93]),
("select '1' union select 'a'; -- -'", [99.99, 97.36]),
("select * from users where id=1 or 1=1;", [99.99, 97.40, 11.96]),
("select * from users where id='1' or 1=1--", [92.02, 97.40, 11.96]),
("select * from users", [0.077, 0.015, 0.002]),
("select * from users where id=10000", [14.83, 88.93, 0.229]),
("select '1' union select 'a'; -- -'", [99.99, 97.32, 99.97]),
(
"""select '' union select 'malicious php code' \g /var/www/test.php; -- -';""",
[99.99, 80.65],
"select '' union select 'malicious php code' \g /var/www/test.php; -- -';",
[99.99, 80.65, 99.98],
),
(
"""select '' || pg_sleep((ascii((select 'a' limit 1)) - 32) / 2); -- -';""",
[99.99, 99.99],
"select '' || pg_sleep((ascii((select 'a' limit 1)) - 32) / 2); -- -';",
[99.99, 99.99, 99.93],
),
],
)
Expand All @@ -73,7 +79,7 @@ def test_sqli_model(model, sample):
# Scale up to 100
output = "dense"
if "output_0" in predictions:
output = "output_0" # Model v2 uses output_0 instead of dense
output = "output_0" # Model v2 and v3 use output_0 instead of dense

print(predictions[output].numpy() * 100) # Debugging purposes (prints on error)
assert predictions[output].numpy() * 100 == pytest.approx(
Expand Down
157 changes: 157 additions & 0 deletions training/train_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import sys
import pandas as pd
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
Bidirectional,
Conv1D,
Dense,
Embedding,
Flatten,
LSTM,
MaxPooling1D,
)
from tensorflow.keras.metrics import Accuracy, Recall, Precision
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
accuracy_score,
recall_score,
precision_score,
f1_score,
confusion_matrix,
)
import numpy as np
import matplotlib.pyplot as plt


# Check if the input file and output directory are provided
if len(sys.argv) != 3:
print("Usage: python train.py <input_file> <output_dir>")
sys.exit(1)

# Load dataset
data = pd.read_csv(sys.argv[1])

# Define parameters
MAX_WORDS = 10000
MAX_LEN = 100

# Use Tokenizer to encode text
tokenizer = Tokenizer(num_words=MAX_WORDS, filters="")
tokenizer.fit_on_texts(data["Query"])
sequences = tokenizer.texts_to_sequences(data["Query"])

# Pad the text sequence
X = pad_sequences(sequences, maxlen=MAX_LEN)

# Split the training set and test set
y = data["Label"]
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)

# Create CNN-BiLSTM model
model = Sequential()
model.add(Embedding(MAX_WORDS, 128))
model.add(Conv1D(filters=64, kernel_size=3, padding="same", activation="relu"))
model.add(MaxPooling1D(pool_size=2))
model.add(Bidirectional(LSTM(64, dropout=0.2, recurrent_dropout=0.2)))
model.add(Flatten())
model.add(Dense(1, activation="sigmoid"))

model.compile(
loss="binary_crossentropy",
optimizer="adam",
metrics=[
Accuracy(),
Recall(),
Precision(),
],
)

# Define early stopping callback with a rollback of 5
early_stopping = EarlyStopping(
monitor="val_loss", patience=5, restore_best_weights=True
)

# Train model with early stopping
history = model.fit(
X_train,
y_train,
epochs=50, # Maximum number of epochs
batch_size=32,
validation_data=(X_test, y_test),
callbacks=[early_stopping],
verbose=1,
)

# Predict test set
y_pred = model.predict(X_test, verbose=1)
y_pred_classes = np.argmax(y_pred, axis=1)

# Calculate model performance indicators
accuracy = accuracy_score(y_test, y_pred_classes)
recall = recall_score(y_test, y_pred_classes, zero_division=1)
precision = precision_score(y_test, y_pred_classes, zero_division=1)
f1 = f1_score(y_test, y_pred_classes, zero_division=1)
tn, fp, fn, tp = confusion_matrix(y_test, y_pred_classes).ravel()

# Output performance indicators
print("Accuracy: {:.2f}%".format(accuracy * 100))
print("Recall: {:.2f}%".format(recall * 100))
print("Precision: {:.2f}%".format(precision * 100))
print("F1-score: {:.2f}%".format(f1 * 100))
print("Specificity: {:.2f}%".format(tn / (tn + fp) * 100))
print("ROC: {:.2f}%".format(tp / (tp + fn) * 100))

# Save model as SavedModel format
model.export(sys.argv[2])


# Plot the training history
def plot_history(history):
plt.figure(figsize=(12, 8))

# Plot loss
plt.subplot(2, 2, 1)
plt.plot(history.history["loss"], label="Training Loss")
plt.plot(history.history["val_loss"], label="Validation Loss")
plt.title("Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()

# Plot accuracy
plt.subplot(2, 2, 2)
plt.plot(history.history["accuracy"], label="Training Accuracy")
plt.plot(history.history["val_accuracy"], label="Validation Accuracy")
plt.title("Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()

# Plot precision
plt.subplot(2, 2, 3)
plt.plot(history.history["precision"], label="Training Precision")
plt.plot(history.history["val_precision"], label="Validation Precision")
plt.title("Precision")
plt.xlabel("Epochs")
plt.ylabel("Precision")
plt.legend()

# Plot recall
plt.subplot(2, 2, 4)
plt.plot(history.history["recall"], label="Training Recall")
plt.plot(history.history["val_recall"], label="Validation Recall")
plt.title("Recall")
plt.xlabel("Epochs")
plt.ylabel("Recall")
plt.legend()

plt.tight_layout()
plt.savefig("training_history.png")


plot_history(history)

0 comments on commit fce6114

Please sign in to comment.