Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a remotemodelwrapper class #809

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions textattack/models/wrappers/remote_model_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
RemoteModelWrapper class
--------------------------

"""

import requests
import torch
import numpy as np
import transformers
Comment on lines +7 to +10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: run make format or the black formatter, as mentioned in the contribution guidelines for this repo


class RemoteModelWrapper():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should inherit from the ModelWrapper abstract class if you're looking to be using this as a ModelWrapper

"""This model wrapper queries a remote model with a list of text inputs. It sends the input to a remote endpoint provided in api_url.

Args:
api_url (:obj:`<TYPE HERE>`): <DESCRIPTION HERE>
"""
def __init__(self, api_url):
self.api_url = api_url
self.model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you set this? The model variable isn't used elsewhere in this class


def __call__(self, text_input_list):
predictions = []
for text in text_input_list:
params = dict()
params["text"] = text
response = requests.post(self.api_url, params=params, timeout=10) # Use POST with JSON payload
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this kind of request format guaranteed to work for all endpoints?

For example, OpenAI requires a specific kind of payload. I might suggest adding a parameter when you initialize the wrapper you accept a lambda as a param to massage the data into a viable payload format

if response.status_code != 200:
print(f"Response content: {response.text}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommend using the package's logger instead of making print statements, especially when you mean to throw an error. This print statement might not even be necessary since you throw the error below anyways

raise ValueError(f"API call failed with status {response.status_code}")
result = response.json()
# Assuming the API returns probabilities for positive and negative
predictions.append([result["negative"], result["positive"]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you're making these assumptions, but I'm not sure if this is so common as to be widely applicable to make this a good wrapper function. To alleviate this, you could add another lambda to massage the output

return torch.tensor(predictions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to cast this as a tensor?


"""
Example usage:

>>> # Define the remote model API endpoint
>>> api_url = "https://example.com"

>>> model_wrapper = RemoteModelWrapper(api_url)

>>> # Build the attack
>>> attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper)

>>> # Define dataset and attack arguments
>>> dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test")

>>> attack_args = textattack.AttackArgs(
... num_examples=100,
... log_to_csv="/textfooler.csv",
... checkpoint_interval=5,
... checkpoint_dir="checkpoints",
... disable_stdout=True
... )

>>> # Run the attack
>>> attacker = textattack.Attacker(attack, dataset, attack_args)
>>> attacker.attack_dataset()
"""