diff --git a/helpers/prompts.py b/helpers/prompts.py index 8ebaaf8..b166051 100644 --- a/helpers/prompts.py +++ b/helpers/prompts.py @@ -7,7 +7,7 @@ pipe = pipeline( "text-generation", model="HuggingFaceH4/zephyr-7b-gemma-v0.1", - device_map="auto", + #device_map="auto", torch_dtype=torch.bfloat16, )