Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 31, 2025
1 parent caf4cde commit fe9a700
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions examples/llm/tech_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,15 @@ def train(args, data_lists):
drop_last=False, pin_memory=True, shuffle=False)
gnn = GAT(in_channels=768, hidden_channels=hidden_channels,
out_channels=1024, num_layers=num_gnn_layers, heads=4)
if args. llm_generator_mode == "full":
if args.llm_generator_mode == "full":
llm = LLM(model_name=args.llm_generator_name, dtype=torch.float32)
model = GRetriever(llm=llm, gnn=gnn)
elif args. llm_generator_mode == "lora":
elif args.llm_generator_mode == "lora":
llm = LLM(model_name=args.llm_generator_name, dtype=torch.float32)
model = GRetriever(llm=llm, gnn=gnn, use_lora=True)
else:
llm = LLM(model_name=args.llm_generator_name, dtype=torch.float32).eval()
llm = LLM(model_name=args.llm_generator_name,
dtype=torch.float32).eval()
for _, p in llm.named_parameters():
p.requires_grad = False
model = GRetriever(llm=llm, gnn=gnn)
Expand Down

0 comments on commit fe9a700

Please sign in to comment.