From be3ef77ce3b4608471d4a73140c9d1a696c481e2 Mon Sep 17 00:00:00 2001 From: marnix Date: Thu, 25 Jul 2024 14:10:49 +0200 Subject: [PATCH] Avoid missing packages and attn_mask dtype error --- FlagEmbedding/visual/modeling.py | 14 +++++++++----- setup.py | 24 +++++++++++++----------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/FlagEmbedding/visual/modeling.py b/FlagEmbedding/visual/modeling.py index 41fd97d3..0c9e66ac 100644 --- a/FlagEmbedding/visual/modeling.py +++ b/FlagEmbedding/visual/modeling.py @@ -1,16 +1,16 @@ -import os import logging +import os from dataclasses import dataclass from typing import Optional, Tuple + import torch import torch.distributed as dist -from torch import nn, Tensor -from transformers import AutoModel, AutoTokenizer, AutoConfig +from PIL import Image +from torch import Tensor, nn +from transformers import AutoConfig, AutoModel, AutoTokenizer from transformers.file_utils import ModelOutput - from FlagEmbedding.visual.eva_clip import create_eva_vision_and_transforms -from PIL import Image logger = logging.getLogger(__name__) @@ -200,6 +200,10 @@ def encode_text(self, texts): inputs_embeds=None, past_key_values_length=0, ) + + # Ensure the attention mask has the same dtype as the query tensor + extended_attention_mask = extended_attention_mask.to(embedding_output.dtype) + encoder_outputs = self.bge_encoder( embedding_output, attention_mask=extended_attention_mask, diff --git a/setup.py b/setup.py index 139b14cf..a78e0090 100644 --- a/setup.py +++ b/setup.py @@ -1,22 +1,24 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup with open("README.md", mode="r", encoding="utf-8") as readme_file: readme = readme_file.read() setup( - name='FlagEmbedding', - version='1.2.10', - description='FlagEmbedding', + name="FlagEmbedding", + version="1.2.10", + description="FlagEmbedding", long_description=readme, long_description_content_type="text/markdown", - author_email='2906698981@qq.com', - url='https://github.com/FlagOpen/FlagEmbedding', + author_email="2906698981@qq.com", + url="https://github.com/FlagOpen/FlagEmbedding", packages=find_packages(), install_requires=[ - 'torch>=1.6.0', - 'transformers>=4.33.0', - 'datasets', - 'accelerate>=0.20.1', - 'sentence_transformers', + "torch>=1.6.0", + "transformers>=4.33.0", + "datasets", + "accelerate>=0.20.1", + "sentence_transformers", + "peft", + "sentencepiece", ], )