Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ui,api): support for HF tokens in UI, handle Unauthorized and Forbidden responses from HF #7201

Merged
merged 3 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""

import contextlib
import io
import pathlib
import shutil
Expand All @@ -10,6 +11,7 @@
from tempfile import TemporaryDirectory
from typing import List, Optional, Type

import huggingface_hub
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.routing import APIRouter
Expand All @@ -27,6 +29,7 @@
ModelRecordChanges,
UnknownModelException,
)
from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
Expand Down Expand Up @@ -923,3 +926,51 @@ async def get_stats() -> Optional[CacheStats]:
"""Return performance statistics on the model manager's RAM cache. Will return null if no models have been loaded."""

return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats


class HFTokenStatus(str, Enum):
VALID = "valid"
INVALID = "invalid"
UNKNOWN = "unknown"


class HFTokenHelper:
@classmethod
def get_status(cls) -> HFTokenStatus:
try:
if huggingface_hub.get_token_permission(huggingface_hub.get_token()):
# Valid token!
return HFTokenStatus.VALID
# No token set
return HFTokenStatus.INVALID
except Exception:
return HFTokenStatus.UNKNOWN

@classmethod
def set_token(cls, token: str) -> HFTokenStatus:
with SuppressOutput(), contextlib.suppress(Exception):
huggingface_hub.login(token=token, add_to_git_credential=False)
return cls.get_status()


@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
async def get_hf_login_status() -> HFTokenStatus:
token_status = HFTokenHelper.get_status()

if token_status is HFTokenStatus.UNKNOWN:
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")

return token_status


@model_manager_router.post("/hf_login", operation_id="do_hf_login", response_model=HFTokenStatus)
async def do_hf_login(
token: str = Body(description="Hugging Face token to use for login", embed=True),
) -> HFTokenStatus:
HFTokenHelper.set_token(token)
token_status = HFTokenHelper.get_status()

if token_status is HFTokenStatus.UNKNOWN:
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")

return token_status
12 changes: 11 additions & 1 deletion invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,17 @@
"huggingFacePlaceholder": "owner/model-name",
"huggingFaceRepoID": "HuggingFace Repo ID",
"huggingFaceHelper": "If multiple models are found in this repo, you will be prompted to select one to install.",
"hfToken": "HuggingFace Token",
"hfTokenLabel": "HuggingFace Token (Required for some models)",
"hfTokenHelperText": "A HF token is required to use some models. Click here to create or get your token.",
"hfTokenInvalid": "Invalid or Missing HF Token",
"hfForbidden": "You do not have access to this HF model",
"hfForbiddenErrorMessage": "We recommend visiting the repo page on HuggingFace.com. The owner may require acceptance of terms in order to download.",
"hfTokenInvalidErrorMessage": "Invalid or missing HuggingFace token.",
"hfTokenRequired": "You are trying to download a model that requires a valid HuggingFace Token.",
"hfTokenInvalidErrorMessage2": "Update it in the ",
"hfTokenUnableToVerify": "Unable to Verify HF Token",
"hfTokenUnableToVerifyErrorMessage": "Unable to verify HuggingFace token. This is likely due to a network error. Please try again later.",
"hfTokenSaved": "HF Token Saved",
"imageEncoderModelId": "Image Encoder Model ID",
"includesNModels": "Includes {{n}} models and their dependencies",
"installQueue": "Install Queue",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { ExternalLink, Text, useToast } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { atom } from 'nanostores';
import { useEffect } from 'react';
import { useTranslation } from 'react-i18next';

const TOAST_ID = 'hfForbidden';

/**
* Tracks whether or not the HF Login toast is showing
*/
export const $isHFForbiddenToastOpen = atom<{ isEnabled: boolean; source?: string }>({ isEnabled: false });

export const useHFForbiddenToast = () => {
const { t } = useTranslation();
const toast = useToast();
const isHFForbiddenToastOpen = useStore($isHFForbiddenToastOpen);

useEffect(() => {
if (!isHFForbiddenToastOpen.isEnabled) {
toast.close(TOAST_ID);
return;
}

if (isHFForbiddenToastOpen.isEnabled) {
toast({
id: TOAST_ID,
title: t('modelManager.hfForbidden'),
description: (
<Text fontSize="md">
{t('modelManager.hfForbiddenErrorMessage')}
<ExternalLink
label={isHFForbiddenToastOpen.source || ''}
href={`https://huggingface.co/${isHFForbiddenToastOpen.source}`}
/>
</Text>
),
status: 'error',
isClosable: true,
duration: null,
onCloseComplete: () => $isHFForbiddenToastOpen.set({ isEnabled: false }),
});
}
}, [isHFForbiddenToastOpen, t, toast]);
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import { Button, Text, useToast } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch } from 'app/store/storeHooks';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { t } from 'i18next';
import { atom } from 'nanostores';
import { useCallback, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetHFTokenStatusQuery } from 'services/api/endpoints/models';
import type { S } from 'services/api/types';

const FEATURE_ID = 'hfToken';
const TOAST_ID = 'hfTokenLogin';
/**
* Tracks whether or not the HF Login toast is showing
*/
export const $isHFLoginToastOpen = atom<boolean>(false);

const getTitle = (token_status: S['HFTokenStatus']) => {
switch (token_status) {
case 'invalid':
return t('modelManager.hfTokenInvalid');
case 'unknown':
return t('modelManager.hfTokenUnableToVerify');
}
};

export const useHFLoginToast = () => {
const isEnabled = useFeatureStatus(FEATURE_ID);
const { data } = useGetHFTokenStatusQuery(isEnabled ? undefined : skipToken);
const toast = useToast();
const isHFLoginToastOpen = useStore($isHFLoginToastOpen);

useEffect(() => {
if (!isHFLoginToastOpen) {
toast.close(TOAST_ID);
return;
}

if (isHFLoginToastOpen && data) {
const title = getTitle(data);
toast({
id: TOAST_ID,
title,
description: <ToastDescription token_status={data} />,
status: 'error',
isClosable: true,
duration: null,
onCloseComplete: () => $isHFLoginToastOpen.set(false),
});
}
}, [isHFLoginToastOpen, data, toast]);
};

type Props = {
token_status: S['HFTokenStatus'];
};

const ToastDescription = ({ token_status }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const toast = useToast();

const onClick = useCallback(() => {
dispatch(setActiveTab('models'));
toast.close(FEATURE_ID);
}, [dispatch, toast]);

if (token_status === 'invalid') {
return (
<Text fontSize="md">
{t('modelManager.hfTokenInvalidErrorMessage')} {t('modelManager.hfTokenRequired')}{' '}
{t('modelManager.hfTokenInvalidErrorMessage2')}
<Button onClick={onClick} variant="link" color="base.50" flexGrow={0}>
{t('modelManager.modelManager')}.
</Button>
</Text>
);
}

if (token_status === 'unknown') {
return (
<Text fontSize="md">
{t('modelManager.hfTokenUnableToErrorMessage')}{' '}
<Button onClick={onClick} variant="link" color="base.50" flexGrow={0}>
{t('modelManager.modelManager')}.
</Button>
</Text>
);
}
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import {
Button,
ExternalLink,
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
FormLabel,
Input,
useToast,
} from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { $isHFLoginToastOpen } from 'features/modelManagerV2/hooks/useHFLoginToast';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import type { ChangeEvent } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetHFTokenStatusQuery, useSetHFTokenMutation } from 'services/api/endpoints/models';

export const HFToken = () => {
const { t } = useTranslation();
const isHFTokenEnabled = useFeatureStatus('hfToken');
const [token, setToken] = useState('');
const { currentData } = useGetHFTokenStatusQuery(isHFTokenEnabled ? undefined : skipToken);
const [trigger, { isLoading, isUninitialized }] = useSetHFTokenMutation();
const toast = useToast();
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setToken(e.target.value);
}, []);
const onClick = useCallback(() => {
trigger({ token })
.unwrap()
.then((res) => {
if (res === 'valid') {
setToken('');
toast({
title: t('modelManager.hfTokenSaved'),
status: 'success',
duration: 3000,
});
$isHFLoginToastOpen.set(false);
}
});
}, [t, toast, token, trigger]);

const error = useMemo(() => {
if (!currentData || isUninitialized || isLoading) {
return null;
}
if (currentData === 'invalid') {
return t('modelManager.hfTokenInvalidErrorMessage');
}
if (currentData === 'unknown') {
return t('modelManager.hfTokenUnableToVerifyErrorMessage');
}
return null;
}, [currentData, isLoading, isUninitialized, t]);

if (!currentData || currentData === 'valid') {
return null;
}

return (
<Flex borderRadius="base" w="full">
<FormControl isInvalid={!isUninitialized && Boolean(error)} orientation="vertical">
<FormLabel>{t('modelManager.hfTokenLabel')}</FormLabel>
<Flex gap={3} alignItems="center" w="full">
<Input type="password" value={token} onChange={onChange} />
<Button onClick={onClick} size="sm" isDisabled={token.trim().length === 0} isLoading={isLoading}>
{t('common.save')}
</Button>
</Flex>
<FormHelperText>
<ExternalLink label={t('modelManager.hfTokenHelperText')} href="https://huggingface.co/settings/tokens" />
</FormHelperText>
<FormErrorMessage>{error}</FormErrorMessage>
</FormControl>
</Flex>
);
};
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import type { ChangeEventHandler } from 'react';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
import { useGetHFTokenStatusQuery, useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';

import { HFToken } from './HFToken';
import { HuggingFaceResults } from './HuggingFaceResults';

export const HuggingFaceForm = memo(() => {
const [huggingFaceRepo, setHuggingFaceRepo] = useState('');
const [displayResults, setDisplayResults] = useState(false);
const [errorMessage, setErrorMessage] = useState('');
const { t } = useTranslation();
const isHFTokenEnabled = useFeatureStatus('hfToken');
const { currentData } = useGetHFTokenStatusQuery(isHFTokenEnabled ? undefined : skipToken);

const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery();
const [installModel] = useInstallModel();
Expand Down Expand Up @@ -41,7 +46,7 @@ export const HuggingFaceForm = memo(() => {
}, []);

return (
<Flex flexDir="column" height="100%" gap={3}>
<Flex flexDir="column" height="100%" gap={4}>
<FormControl isInvalid={!!errorMessage.length} w="full" orientation="vertical" flexShrink={0}>
<FormLabel>{t('modelManager.huggingFaceRepoID')}</FormLabel>
<Flex gap={3} alignItems="center" w="full">
Expand All @@ -63,6 +68,7 @@ export const HuggingFaceForm = memo(() => {
<FormHelperText>{t('modelManager.huggingFaceHelper')}</FormHelperText>
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</FormControl>
{currentData !== 'valid' && <HFToken />}
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
</Flex>
);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useHFForbiddenToast } from 'features/modelManagerV2/hooks/useHFForbiddenToast';
import { useHFLoginToast } from 'features/modelManagerV2/hooks/useHFLoginToast';
import { selectSelectedModelKey, setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
Expand All @@ -16,6 +18,9 @@ export const ModelManager = memo(() => {
}, [dispatch]);
const selectedModelKey = useAppSelector(selectSelectedModelKey);

useHFLoginToast();
useHFForbiddenToast();

return (
<Flex flexDir="column" layerStyle="first" p={4} gap={4} borderRadius="base" w="50%" h="full">
<Flex w="full" gap={4} justifyContent="space-between" alignItems="center">
Expand Down
Loading
Loading