Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid missing packages and attn_mask dtype error #992

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions FlagEmbedding/visual/modeling.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import os
import logging
import os
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.distributed as dist
from torch import nn, Tensor
from transformers import AutoModel, AutoTokenizer, AutoConfig
from PIL import Image
from torch import Tensor, nn
from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers.file_utils import ModelOutput


from FlagEmbedding.visual.eva_clip import create_eva_vision_and_transforms
from PIL import Image

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -200,6 +200,10 @@ def encode_text(self, texts):
inputs_embeds=None,
past_key_values_length=0,
)

# Ensure the attention mask has the same dtype as the query tensor
extended_attention_mask = extended_attention_mask.to(embedding_output.dtype)

encoder_outputs = self.bge_encoder(
embedding_output,
attention_mask=extended_attention_mask,
Expand Down
24 changes: 13 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from setuptools import setup, find_packages
from setuptools import find_packages, setup

with open("README.md", mode="r", encoding="utf-8") as readme_file:
readme = readme_file.read()

setup(
name='FlagEmbedding',
version='1.2.10',
description='FlagEmbedding',
name="FlagEmbedding",
version="1.2.10",
description="FlagEmbedding",
long_description=readme,
long_description_content_type="text/markdown",
author_email='[email protected]',
url='https://github.com/FlagOpen/FlagEmbedding',
author_email="[email protected]",
url="https://github.com/FlagOpen/FlagEmbedding",
packages=find_packages(),
install_requires=[
'torch>=1.6.0',
'transformers>=4.33.0',
'datasets',
'accelerate>=0.20.1',
'sentence_transformers',
"torch>=1.6.0",
"transformers>=4.33.0",
"datasets",
"accelerate>=0.20.1",
"sentence_transformers",
"peft",
"sentencepiece",
],
)