From 89a04d54fcd3b667ef0ea5474709ffb86845f081 Mon Sep 17 00:00:00 2001 From: l3ra Date: Mon, 30 Dec 2024 15:37:19 +0000 Subject: [PATCH] Added a remotemodelwrapper class --- .../models/wrappers/remote_model_wrapper.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 textattack/models/wrappers/remote_model_wrapper.py diff --git a/textattack/models/wrappers/remote_model_wrapper.py b/textattack/models/wrappers/remote_model_wrapper.py new file mode 100644 index 00000000..0a0dc046 --- /dev/null +++ b/textattack/models/wrappers/remote_model_wrapper.py @@ -0,0 +1,63 @@ +""" +RemoteModelWrapper class +-------------------------- + +""" + +import requests +import torch +import numpy as np +import transformers + +class RemoteModelWrapper(): + """This model wrapper queries a remote model with a list of text inputs. + + It sends the input to a remote endpoint provided in api_url. + + + """ + def __init__(self, api_url): + self.api_url = api_url + self.model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb") + + def __call__(self, text_input_list): + predictions = [] + for text in text_input_list: + params = dict() + params["text"] = text + response = requests.post(self.api_url, params=params, timeout=10) # Use POST with JSON payload + if response.status_code != 200: + print(f"Response content: {response.text}") + raise ValueError(f"API call failed with status {response.status_code}") + result = response.json() + # Assuming the API returns probabilities for positive and negative + predictions.append([result["negative"], result["positive"]]) + return torch.tensor(predictions) + +''' +Example usage: + +# Define the remote model API endpoint and tokenizer +api_url = "https://x.com/predict" + +model_wrapper = RemoteModelWrapper(api_url) + +# Build the attack +attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper) + +# Define dataset and attack arguments +dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") + +attack_args = textattack.AttackArgs( + num_examples=100, + log_to_csv="/textfooler.csv", + checkpoint_interval=5, + checkpoint_dir="checkpoints", + disable_stdout=True +) + +# Run the attack +attacker = textattack.Attacker(attack, dataset, attack_args) +attacker.attack_dataset() + +''' \ No newline at end of file