Skip to content

Commit fce6114

Browse files
authored
Merge pull request #10 from gatewayd-io/v3-cnn-bilstm
Model v3: CNN BiLSTM
2 parents 430917f + f9adeb6 commit fce6114

File tree

7 files changed

+174
-10
lines changed

7 files changed

+174
-10
lines changed
Binary file not shown.

sqli_model/3/fingerprint.pb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
��鯴���'������n�������! ����Ɔ��K(����䈽��2

sqli_model/3/saved_model.pb

181 KB
Binary file not shown.
Binary file not shown.
1.59 KB
Binary file not shown.

training/test_train.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,18 @@
1818
"model_path": "sqli_model/2",
1919
"index": 1,
2020
}
21+
MODELV3 = {
22+
"dataset": "dataset/sqli_dataset2.csv",
23+
"model_path": "sqli_model/3",
24+
"index": 2,
25+
}
2126

2227

2328
@pytest.fixture(
2429
params=[
2530
MODELV1,
2631
MODELV2,
32+
MODELV3,
2733
],
2834
)
2935
def model(request):
@@ -47,18 +53,18 @@ def model(request):
4753
@pytest.mark.parametrize(
4854
"sample",
4955
[
50-
("select * from users where id=1 or 1=1;", [99.99, 97.40]),
51-
("select * from users where id='1' or 1=1--", [92.02, 97.40]),
52-
("select * from users", [0.077, 0.015]),
53-
("select * from users where id=10000", [14.83, 88.93]),
54-
("select '1' union select 'a'; -- -'", [99.99, 97.36]),
56+
("select * from users where id=1 or 1=1;", [99.99, 97.40, 11.96]),
57+
("select * from users where id='1' or 1=1--", [92.02, 97.40, 11.96]),
58+
("select * from users", [0.077, 0.015, 0.002]),
59+
("select * from users where id=10000", [14.83, 88.93, 0.229]),
60+
("select '1' union select 'a'; -- -'", [99.99, 97.32, 99.97]),
5561
(
56-
"""select '' union select 'malicious php code' \g /var/www/test.php; -- -';""",
57-
[99.99, 80.65],
62+
"select '' union select 'malicious php code' \g /var/www/test.php; -- -';",
63+
[99.99, 80.65, 99.98],
5864
),
5965
(
60-
"""select '' || pg_sleep((ascii((select 'a' limit 1)) - 32) / 2); -- -';""",
61-
[99.99, 99.99],
66+
"select '' || pg_sleep((ascii((select 'a' limit 1)) - 32) / 2); -- -';",
67+
[99.99, 99.99, 99.93],
6268
),
6369
],
6470
)
@@ -73,7 +79,7 @@ def test_sqli_model(model, sample):
7379
# Scale up to 100
7480
output = "dense"
7581
if "output_0" in predictions:
76-
output = "output_0" # Model v2 uses output_0 instead of dense
82+
output = "output_0" # Model v2 and v3 use output_0 instead of dense
7783

7884
print(predictions[output].numpy() * 100) # Debugging purposes (prints on error)
7985
assert predictions[output].numpy() * 100 == pytest.approx(

training/train_v3.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import sys
2+
import pandas as pd
3+
from tensorflow.keras.preprocessing.text import Tokenizer
4+
from tensorflow.keras.preprocessing.sequence import pad_sequences
5+
from tensorflow.keras.models import Sequential
6+
from tensorflow.keras.layers import (
7+
Bidirectional,
8+
Conv1D,
9+
Dense,
10+
Embedding,
11+
Flatten,
12+
LSTM,
13+
MaxPooling1D,
14+
)
15+
from tensorflow.keras.metrics import Accuracy, Recall, Precision
16+
from tensorflow.keras.callbacks import EarlyStopping
17+
from sklearn.model_selection import train_test_split
18+
from sklearn.metrics import (
19+
accuracy_score,
20+
recall_score,
21+
precision_score,
22+
f1_score,
23+
confusion_matrix,
24+
)
25+
import numpy as np
26+
import matplotlib.pyplot as plt
27+
28+
29+
# Check if the input file and output directory are provided
30+
if len(sys.argv) != 3:
31+
print("Usage: python train.py <input_file> <output_dir>")
32+
sys.exit(1)
33+
34+
# Load dataset
35+
data = pd.read_csv(sys.argv[1])
36+
37+
# Define parameters
38+
MAX_WORDS = 10000
39+
MAX_LEN = 100
40+
41+
# Use Tokenizer to encode text
42+
tokenizer = Tokenizer(num_words=MAX_WORDS, filters="")
43+
tokenizer.fit_on_texts(data["Query"])
44+
sequences = tokenizer.texts_to_sequences(data["Query"])
45+
46+
# Pad the text sequence
47+
X = pad_sequences(sequences, maxlen=MAX_LEN)
48+
49+
# Split the training set and test set
50+
y = data["Label"]
51+
X_train, X_test, y_train, y_test = train_test_split(
52+
X, y, test_size=0.2, random_state=42
53+
)
54+
55+
# Create CNN-BiLSTM model
56+
model = Sequential()
57+
model.add(Embedding(MAX_WORDS, 128))
58+
model.add(Conv1D(filters=64, kernel_size=3, padding="same", activation="relu"))
59+
model.add(MaxPooling1D(pool_size=2))
60+
model.add(Bidirectional(LSTM(64, dropout=0.2, recurrent_dropout=0.2)))
61+
model.add(Flatten())
62+
model.add(Dense(1, activation="sigmoid"))
63+
64+
model.compile(
65+
loss="binary_crossentropy",
66+
optimizer="adam",
67+
metrics=[
68+
Accuracy(),
69+
Recall(),
70+
Precision(),
71+
],
72+
)
73+
74+
# Define early stopping callback with a rollback of 5
75+
early_stopping = EarlyStopping(
76+
monitor="val_loss", patience=5, restore_best_weights=True
77+
)
78+
79+
# Train model with early stopping
80+
history = model.fit(
81+
X_train,
82+
y_train,
83+
epochs=50, # Maximum number of epochs
84+
batch_size=32,
85+
validation_data=(X_test, y_test),
86+
callbacks=[early_stopping],
87+
verbose=1,
88+
)
89+
90+
# Predict test set
91+
y_pred = model.predict(X_test, verbose=1)
92+
y_pred_classes = np.argmax(y_pred, axis=1)
93+
94+
# Calculate model performance indicators
95+
accuracy = accuracy_score(y_test, y_pred_classes)
96+
recall = recall_score(y_test, y_pred_classes, zero_division=1)
97+
precision = precision_score(y_test, y_pred_classes, zero_division=1)
98+
f1 = f1_score(y_test, y_pred_classes, zero_division=1)
99+
tn, fp, fn, tp = confusion_matrix(y_test, y_pred_classes).ravel()
100+
101+
# Output performance indicators
102+
print("Accuracy: {:.2f}%".format(accuracy * 100))
103+
print("Recall: {:.2f}%".format(recall * 100))
104+
print("Precision: {:.2f}%".format(precision * 100))
105+
print("F1-score: {:.2f}%".format(f1 * 100))
106+
print("Specificity: {:.2f}%".format(tn / (tn + fp) * 100))
107+
print("ROC: {:.2f}%".format(tp / (tp + fn) * 100))
108+
109+
# Save model as SavedModel format
110+
model.export(sys.argv[2])
111+
112+
113+
# Plot the training history
114+
def plot_history(history):
115+
plt.figure(figsize=(12, 8))
116+
117+
# Plot loss
118+
plt.subplot(2, 2, 1)
119+
plt.plot(history.history["loss"], label="Training Loss")
120+
plt.plot(history.history["val_loss"], label="Validation Loss")
121+
plt.title("Loss")
122+
plt.xlabel("Epochs")
123+
plt.ylabel("Loss")
124+
plt.legend()
125+
126+
# Plot accuracy
127+
plt.subplot(2, 2, 2)
128+
plt.plot(history.history["accuracy"], label="Training Accuracy")
129+
plt.plot(history.history["val_accuracy"], label="Validation Accuracy")
130+
plt.title("Accuracy")
131+
plt.xlabel("Epochs")
132+
plt.ylabel("Accuracy")
133+
plt.legend()
134+
135+
# Plot precision
136+
plt.subplot(2, 2, 3)
137+
plt.plot(history.history["precision"], label="Training Precision")
138+
plt.plot(history.history["val_precision"], label="Validation Precision")
139+
plt.title("Precision")
140+
plt.xlabel("Epochs")
141+
plt.ylabel("Precision")
142+
plt.legend()
143+
144+
# Plot recall
145+
plt.subplot(2, 2, 4)
146+
plt.plot(history.history["recall"], label="Training Recall")
147+
plt.plot(history.history["val_recall"], label="Validation Recall")
148+
plt.title("Recall")
149+
plt.xlabel("Epochs")
150+
plt.ylabel("Recall")
151+
plt.legend()
152+
153+
plt.tight_layout()
154+
plt.savefig("training_history.png")
155+
156+
157+
plot_history(history)

0 commit comments

Comments
 (0)