From 5782f8b3994e15d18cba0f34b3056fc683bcb928 Mon Sep 17 00:00:00 2001 From: dan Date: Thu, 31 Oct 2024 20:58:55 -0500 Subject: [PATCH] modify perplexity test to use sdpa --- sharktank/sharktank/evaluate/perplexity_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sharktank/sharktank/evaluate/perplexity_torch.py b/sharktank/sharktank/evaluate/perplexity_torch.py index fc3aa5fca..c7d90a19e 100644 --- a/sharktank/sharktank/evaluate/perplexity_torch.py +++ b/sharktank/sharktank/evaluate/perplexity_torch.py @@ -111,7 +111,7 @@ def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kern attention_dtype=self.attention_dtype, tensor_parallelism_size=tensor_parallelism_size, ) - + config.attention_kernel="torch" if config.tensor_parallelism_size > 1: dataset.root_theta = shard_theta(dataset.root_theta, config)