Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clustering Component #2

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def main():
help="Evaluate & save model")
parser.add_argument('--prefix', type=str, default='',
help="Prefix for saving predictions")
parser.add_argument('--debug', action='store_true',
parser.add_argument('--debugCode', action='store_true',
help="Use a subset of data for debugging")
parser.add_argument('--debugTrain', action='store_true',
help="Use a subset of data for debugging")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
Expand All @@ -111,7 +113,7 @@ def main():
help="some checkpoint pdb debug")

# parameters for SpanSeqGen
parser.add_argument("--top_k_passages", default=10, type=int)
parser.add_argument("--num_top_passages", default=10, type=int)
# "data/reranking_results/ambigqa"
parser.add_argument("--ranking_folder_path",
default=None) # "data/reranking_results/ambigqa"
Expand All @@ -130,7 +132,6 @@ def main():
parser.add_argument("--passage_clustering",
default=False, action="store_true")
parser.add_argument("--k_cluster", default = 10, type=int)
parser.add_argument("--rank_threshold", default=60, type=int)
parser.add_argument("--is_contrastive", default=False, action="store_true")


Expand Down Expand Up @@ -199,7 +200,11 @@ def main():
if args.model.lower() == "t5" and args.prepend_question_token == False:
logger.warning("t5 model needs prepending, it's adjusted now")
args.prepend_question_token = True

if args.debugCode and args.debugTrain:
raise ValueError("debug code and debug train mode are both turned"
" on. You need to either remove --debugCode or "
"--deubgTrain to make the script running.")

logger.info("Using {} gpus".format(args.n_gpu))
if args.device == "cuda":
assert args.n_gpu > 1, "if there is only one gpu, set args.device=0"
Expand Down
Loading