Skip to content

Commit

Permalink
hmm
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Nov 6, 2024
1 parent 4ea210b commit ec6731b
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
12 changes: 6 additions & 6 deletions rankers/train/data_arguments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field, fields
from typing import Dict, Any
from typing import Dict, Any, Optional
import json
from enum import Enum
import torch
Expand All @@ -15,23 +15,23 @@ class RankerDataArguments:
training_dataset_file : str = field(
metadata={"help": "Path to the training dataset"}
)
teacher_file : str = field(
teacher_file : Optional[str] = field(
default=None,
metadata={"help": "Path to the teacher scores"}
)
validation_dataset_file : str = field(
validation_dataset_file : Optional[str] = field(
default=None,
metadata={"help": "Path to the validation dataset"}
)
test_dataset_file : str = field(
test_dataset_file : Optional[str] = field(
default=None,
metadata={"help": "Path to the test dataset"}
)
ir_dataset : str = field(
ir_dataset : Optional[str] = field(
default=None,
metadata={"help": "IR Dataset for text lookup"}
)
use_positive : bool = field(
use_positive : Optional[bool] = field(
default=False,
metadata={"help": "Use positive samples locatd in 'doc_id_a' column otherwise use solely 'doc_id_b'"}
)
Expand Down
9 changes: 5 additions & 4 deletions rankers/train/model_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, Any
import json
from enum import Enum
from typing import Optional
import torch
from .. import is_torch_available

Expand Down Expand Up @@ -61,19 +62,19 @@ def to_sanitized_dict(self) -> Dict[str, Any]:

@dataclass
class RankerDotArguments(RankerModelArguments):
pooling : str = field(
pooling : Optional[str] = field(
default='cls',
metadata={"help": "Pooling strategy"}
)
use_pooler : bool = field(
use_pooler : Optional[bool] = field(
default=False,
metadata={"help": "Whether to use the pooler MLP"}
)
model_tied : bool = field(
model_tied : Optional[bool] = field(
default=False,
metadata={"help": "Whether to tie the weights of the query and document encoder"}
)
in_batch_loss : str = field(
in_batch_loss : Optional[str] = field(
default=None,
metadata={"help": "Loss function to use for in-batch negatives"}
)
Expand Down
7 changes: 4 additions & 3 deletions rankers/train/training_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from transformers.trainer_pt_utils import AcceleratorConfig
from transformers.utils import is_accelerate_available
from dataclasses import field, fields, dataclass
from typing import Optional
from enum import Enum
from .. import is_ir_measures_available, is_ir_datasets_available, seed_everything

Expand All @@ -21,15 +22,15 @@ def get_loss(loss_fn : str):

@dataclass
class RankerTrainingArguments(TrainingArguments):
group_size : int = field(
group_size : Optional[int] = field(
default=2,
metadata={"help": "Number of documents per query"}
)
ir_dataset : str = field(
ir_dataset : Optional[str] = field(
default=None,
metadata={"help": "IR Dataset for text lookup"}
)
wandb_project : str = field(
wandb_project : Optional[str] = field(
default=None,
metadata={"help": "Wandb project name"}
)
Expand Down

0 comments on commit ec6731b

Please sign in to comment.