Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
henicosa committed May 1, 2024
2 parents e6eefe6 + 063a124 commit 1d3fbb1
Showing 1 changed file with 62 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from pathlib import Path

from tira.rest_api_client import Client
from tira.third_party_integrations import get_output_directory

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.pipeline import Pipeline




def evaluate_model(model, data, labels):
predictions = model.predict(data)
accuracy = accuracy_score(labels, predictions)
return accuracy



if __name__ == "__main__":
tira = Client()

# loading train data
text_train = tira.pd.inputs(
"nlpbuw-fsu-sose-24", "authorship-verification-train-20240408-training"
)
targets_train = tira.pd.truths(
"nlpbuw-fsu-sose-24", "authorship-verification-train-20240408-training"
)
# loading validation data (automatically replaced by test data when run on tira)
text_validation = tira.pd.inputs(
"nlpbuw-fsu-sose-24", "authorship-verification-validation-20240408-training"
)
targets_validation = tira.pd.truths(
"nlpbuw-fsu-sose-24", "authorship-verification-validation-20240408-training"
)

tfidf_vectorizer = TfidfVectorizer(max_features=1000)

# Model Training
model = Pipeline([
('vectorizer', tfidf_vectorizer),
('classifier', LogisticRegression())
])

model.fit(text_train['text'], targets_train['generated'])


val_accuracy = evaluate_model(model, text_validation['text'], targets_validation['generated'])
# print("Validation Accuracy:", val_accuracy)

# make predictions
predictions = model.predict(text_validation["text"])
text_validation["generated"] = predictions
df = text_validation[["id", "generated"]]

# Save the predictions
output_directory = get_output_directory(str(Path(__file__).parent))
df.to_json(
Path(output_directory) / "predictions.jsonl", orient="records", lines=True
)

0 comments on commit 1d3fbb1

Please sign in to comment.