-
Notifications
You must be signed in to change notification settings - Fork 408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added a remotemodelwrapper class #809
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
""" | ||
RemoteModelWrapper class | ||
-------------------------- | ||
|
||
""" | ||
|
||
import requests | ||
import torch | ||
import numpy as np | ||
import transformers | ||
|
||
class RemoteModelWrapper(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should inherit from the ModelWrapper abstract class if you're looking to be using this as a ModelWrapper |
||
"""This model wrapper queries a remote model with a list of text inputs. It sends the input to a remote endpoint provided in api_url. | ||
|
||
Args: | ||
api_url (:obj:`<TYPE HERE>`): <DESCRIPTION HERE> | ||
""" | ||
def __init__(self, api_url): | ||
self.api_url = api_url | ||
self.model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you set this? The model variable isn't used elsewhere in this class |
||
|
||
def __call__(self, text_input_list): | ||
predictions = [] | ||
for text in text_input_list: | ||
params = dict() | ||
params["text"] = text | ||
response = requests.post(self.api_url, params=params, timeout=10) # Use POST with JSON payload | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this kind of request format guaranteed to work for all endpoints? For example, OpenAI requires a specific kind of payload. I might suggest adding a parameter when you initialize the wrapper you accept a lambda as a param to massage the data into a viable payload format |
||
if response.status_code != 200: | ||
print(f"Response content: {response.text}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recommend using the package's logger instead of making print statements, especially when you mean to throw an error. This print statement might not even be necessary since you throw the error below anyways |
||
raise ValueError(f"API call failed with status {response.status_code}") | ||
result = response.json() | ||
# Assuming the API returns probabilities for positive and negative | ||
predictions.append([result["negative"], result["positive"]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see you're making these assumptions, but I'm not sure if this is so common as to be widely applicable to make this a good wrapper function. To alleviate this, you could add another lambda to massage the output |
||
return torch.tensor(predictions) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason to cast this as a tensor? |
||
|
||
""" | ||
Example usage: | ||
|
||
>>> # Define the remote model API endpoint | ||
>>> api_url = "https://example.com" | ||
|
||
>>> model_wrapper = RemoteModelWrapper(api_url) | ||
|
||
>>> # Build the attack | ||
>>> attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper) | ||
|
||
>>> # Define dataset and attack arguments | ||
>>> dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") | ||
|
||
>>> attack_args = textattack.AttackArgs( | ||
... num_examples=100, | ||
... log_to_csv="/textfooler.csv", | ||
... checkpoint_interval=5, | ||
... checkpoint_dir="checkpoints", | ||
... disable_stdout=True | ||
... ) | ||
|
||
>>> # Run the attack | ||
>>> attacker = textattack.Attacker(attack, dataset, attack_args) | ||
>>> attacker.attack_dataset() | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: run
make format
or theblack
formatter, as mentioned in the contribution guidelines for this repo