diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 29ec37799b5..206e95bb801 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -1,6 +1,7 @@ # Copyright (c) 2023 Lincoln D. Stein """FastAPI route for model configuration records.""" +import contextlib import io import pathlib import shutil @@ -10,6 +11,7 @@ from tempfile import TemporaryDirectory from typing import List, Optional, Type +import huggingface_hub from fastapi import Body, Path, Query, Response, UploadFile from fastapi.responses import FileResponse, HTMLResponse from fastapi.routing import APIRouter @@ -27,6 +29,7 @@ ModelRecordChanges, UnknownModelException, ) +from invokeai.app.util.suppress_output import SuppressOutput from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, @@ -923,3 +926,51 @@ async def get_stats() -> Optional[CacheStats]: """Return performance statistics on the model manager's RAM cache. Will return null if no models have been loaded.""" return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats + + +class HFTokenStatus(str, Enum): + VALID = "valid" + INVALID = "invalid" + UNKNOWN = "unknown" + + +class HFTokenHelper: + @classmethod + def get_status(cls) -> HFTokenStatus: + try: + if huggingface_hub.get_token_permission(huggingface_hub.get_token()): + # Valid token! + return HFTokenStatus.VALID + # No token set + return HFTokenStatus.INVALID + except Exception: + return HFTokenStatus.UNKNOWN + + @classmethod + def set_token(cls, token: str) -> HFTokenStatus: + with SuppressOutput(), contextlib.suppress(Exception): + huggingface_hub.login(token=token, add_to_git_credential=False) + return cls.get_status() + + +@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus) +async def get_hf_login_status() -> HFTokenStatus: + token_status = HFTokenHelper.get_status() + + if token_status is HFTokenStatus.UNKNOWN: + ApiDependencies.invoker.services.logger.warning("Unable to verify HF token") + + return token_status + + +@model_manager_router.post("/hf_login", operation_id="do_hf_login", response_model=HFTokenStatus) +async def do_hf_login( + token: str = Body(description="Hugging Face token to use for login", embed=True), +) -> HFTokenStatus: + HFTokenHelper.set_token(token) + token_status = HFTokenHelper.get_status() + + if token_status is HFTokenStatus.UNKNOWN: + ApiDependencies.invoker.services.logger.warning("Unable to verify HF token") + + return token_status diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 4acdc8df102..093b5c547d4 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -733,7 +733,17 @@ "huggingFacePlaceholder": "owner/model-name", "huggingFaceRepoID": "HuggingFace Repo ID", "huggingFaceHelper": "If multiple models are found in this repo, you will be prompted to select one to install.", - "hfToken": "HuggingFace Token", + "hfTokenLabel": "HuggingFace Token (Required for some models)", + "hfTokenHelperText": "A HF token is required to use some models. Click here to create or get your token.", + "hfTokenInvalid": "Invalid or Missing HF Token", + "hfForbidden": "You do not have access to this HF model", + "hfForbiddenErrorMessage": "We recommend visiting the repo page on HuggingFace.com. The owner may require acceptance of terms in order to download.", + "hfTokenInvalidErrorMessage": "Invalid or missing HuggingFace token.", + "hfTokenRequired": "You are trying to download a model that requires a valid HuggingFace Token.", + "hfTokenInvalidErrorMessage2": "Update it in the ", + "hfTokenUnableToVerify": "Unable to Verify HF Token", + "hfTokenUnableToVerifyErrorMessage": "Unable to verify HuggingFace token. This is likely due to a network error. Please try again later.", + "hfTokenSaved": "HF Token Saved", "imageEncoderModelId": "Image Encoder Model ID", "includesNModels": "Includes {{n}} models and their dependencies", "installQueue": "Install Queue", diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useHFForbiddenToast.tsx b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useHFForbiddenToast.tsx new file mode 100644 index 00000000000..8dc64420ed3 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useHFForbiddenToast.tsx @@ -0,0 +1,45 @@ +import { ExternalLink, Text, useToast } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { atom } from 'nanostores'; +import { useEffect } from 'react'; +import { useTranslation } from 'react-i18next'; + +const TOAST_ID = 'hfForbidden'; + +/** + * Tracks whether or not the HF Login toast is showing + */ +export const $isHFForbiddenToastOpen = atom<{ isEnabled: boolean; source?: string }>({ isEnabled: false }); + +export const useHFForbiddenToast = () => { + const { t } = useTranslation(); + const toast = useToast(); + const isHFForbiddenToastOpen = useStore($isHFForbiddenToastOpen); + + useEffect(() => { + if (!isHFForbiddenToastOpen.isEnabled) { + toast.close(TOAST_ID); + return; + } + + if (isHFForbiddenToastOpen.isEnabled) { + toast({ + id: TOAST_ID, + title: t('modelManager.hfForbidden'), + description: ( + + {t('modelManager.hfForbiddenErrorMessage')} + + + ), + status: 'error', + isClosable: true, + duration: null, + onCloseComplete: () => $isHFForbiddenToastOpen.set({ isEnabled: false }), + }); + } + }, [isHFForbiddenToastOpen, t, toast]); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useHFLoginToast.tsx b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useHFLoginToast.tsx new file mode 100644 index 00000000000..fb831dc2b99 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useHFLoginToast.tsx @@ -0,0 +1,93 @@ +import { Button, Text, useToast } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { skipToken } from '@reduxjs/toolkit/query'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { setActiveTab } from 'features/ui/store/uiSlice'; +import { t } from 'i18next'; +import { atom } from 'nanostores'; +import { useCallback, useEffect } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useGetHFTokenStatusQuery } from 'services/api/endpoints/models'; +import type { S } from 'services/api/types'; + +const FEATURE_ID = 'hfToken'; +const TOAST_ID = 'hfTokenLogin'; +/** + * Tracks whether or not the HF Login toast is showing + */ +export const $isHFLoginToastOpen = atom(false); + +const getTitle = (token_status: S['HFTokenStatus']) => { + switch (token_status) { + case 'invalid': + return t('modelManager.hfTokenInvalid'); + case 'unknown': + return t('modelManager.hfTokenUnableToVerify'); + } +}; + +export const useHFLoginToast = () => { + const isEnabled = useFeatureStatus(FEATURE_ID); + const { data } = useGetHFTokenStatusQuery(isEnabled ? undefined : skipToken); + const toast = useToast(); + const isHFLoginToastOpen = useStore($isHFLoginToastOpen); + + useEffect(() => { + if (!isHFLoginToastOpen) { + toast.close(TOAST_ID); + return; + } + + if (isHFLoginToastOpen && data) { + const title = getTitle(data); + toast({ + id: TOAST_ID, + title, + description: , + status: 'error', + isClosable: true, + duration: null, + onCloseComplete: () => $isHFLoginToastOpen.set(false), + }); + } + }, [isHFLoginToastOpen, data, toast]); +}; + +type Props = { + token_status: S['HFTokenStatus']; +}; + +const ToastDescription = ({ token_status }: Props) => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const toast = useToast(); + + const onClick = useCallback(() => { + dispatch(setActiveTab('models')); + toast.close(FEATURE_ID); + }, [dispatch, toast]); + + if (token_status === 'invalid') { + return ( + + {t('modelManager.hfTokenInvalidErrorMessage')} {t('modelManager.hfTokenRequired')}{' '} + {t('modelManager.hfTokenInvalidErrorMessage2')} + + + ); + } + + if (token_status === 'unknown') { + return ( + + {t('modelManager.hfTokenUnableToErrorMessage')}{' '} + + + ); + } +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HFToken.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HFToken.tsx new file mode 100644 index 00000000000..81d71209eff --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HFToken.tsx @@ -0,0 +1,80 @@ +import { + Button, + ExternalLink, + Flex, + FormControl, + FormErrorMessage, + FormHelperText, + FormLabel, + Input, + useToast, +} from '@invoke-ai/ui-library'; +import { skipToken } from '@reduxjs/toolkit/query'; +import { $isHFLoginToastOpen } from 'features/modelManagerV2/hooks/useHFLoginToast'; +import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import type { ChangeEvent } from 'react'; +import { useCallback, useMemo, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useGetHFTokenStatusQuery, useSetHFTokenMutation } from 'services/api/endpoints/models'; + +export const HFToken = () => { + const { t } = useTranslation(); + const isHFTokenEnabled = useFeatureStatus('hfToken'); + const [token, setToken] = useState(''); + const { currentData } = useGetHFTokenStatusQuery(isHFTokenEnabled ? undefined : skipToken); + const [trigger, { isLoading, isUninitialized }] = useSetHFTokenMutation(); + const toast = useToast(); + const onChange = useCallback((e: ChangeEvent) => { + setToken(e.target.value); + }, []); + const onClick = useCallback(() => { + trigger({ token }) + .unwrap() + .then((res) => { + if (res === 'valid') { + setToken(''); + toast({ + title: t('modelManager.hfTokenSaved'), + status: 'success', + duration: 3000, + }); + $isHFLoginToastOpen.set(false); + } + }); + }, [t, toast, token, trigger]); + + const error = useMemo(() => { + if (!currentData || isUninitialized || isLoading) { + return null; + } + if (currentData === 'invalid') { + return t('modelManager.hfTokenInvalidErrorMessage'); + } + if (currentData === 'unknown') { + return t('modelManager.hfTokenUnableToVerifyErrorMessage'); + } + return null; + }, [currentData, isLoading, isUninitialized, t]); + + if (!currentData || currentData === 'valid') { + return null; + } + + return ( + + + {t('modelManager.hfTokenLabel')} + + + + + + + + {error} + + + ); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx index 905063f900c..ee6b6714dd0 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx @@ -1,10 +1,13 @@ import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library'; +import { skipToken } from '@reduxjs/toolkit/query'; import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel'; +import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import type { ChangeEventHandler } from 'react'; import { memo, useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; -import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models'; +import { useGetHFTokenStatusQuery, useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models'; +import { HFToken } from './HFToken'; import { HuggingFaceResults } from './HuggingFaceResults'; export const HuggingFaceForm = memo(() => { @@ -12,6 +15,8 @@ export const HuggingFaceForm = memo(() => { const [displayResults, setDisplayResults] = useState(false); const [errorMessage, setErrorMessage] = useState(''); const { t } = useTranslation(); + const isHFTokenEnabled = useFeatureStatus('hfToken'); + const { currentData } = useGetHFTokenStatusQuery(isHFTokenEnabled ? undefined : skipToken); const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery(); const [installModel] = useInstallModel(); @@ -41,7 +46,7 @@ export const HuggingFaceForm = memo(() => { }, []); return ( - + {t('modelManager.huggingFaceRepoID')} @@ -63,6 +68,7 @@ export const HuggingFaceForm = memo(() => { {t('modelManager.huggingFaceHelper')} {!!errorMessage.length && {errorMessage}} + {currentData !== 'valid' && } {data && data.urls && displayResults && } ); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx index 8133960c516..59449c0781b 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx @@ -1,5 +1,7 @@ import { Button, Flex, Heading } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useHFForbiddenToast } from 'features/modelManagerV2/hooks/useHFForbiddenToast'; +import { useHFLoginToast } from 'features/modelManagerV2/hooks/useHFLoginToast'; import { selectSelectedModelKey, setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; @@ -16,6 +18,9 @@ export const ModelManager = memo(() => { }, [dispatch]); const selectedModelKey = useAppSelector(selectSelectedModelKey); + useHFLoginToast(); + useHFForbiddenToast(); + return ( diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index a32bc494fae..3180e932b76 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -3,7 +3,7 @@ import { createEntityAdapter } from '@reduxjs/toolkit'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import queryString from 'query-string'; import type { operations, paths } from 'services/api/schema'; -import type { AnyModelConfig } from 'services/api/types'; +import type { AnyModelConfig, GetHFTokenStatusResponse, SetHFTokenArg, SetHFTokenResponse } from 'services/api/types'; import type { ApiTagDescription } from '..'; import { api, buildV2Url, LIST_TAG } from '..'; @@ -259,6 +259,22 @@ export const modelsApi = api.injectEndpoints({ query: () => buildModelsUrl('starter_models'), providesTags: [{ type: 'ModelConfig', id: LIST_TAG }], }), + getHFTokenStatus: build.query({ + query: () => buildModelsUrl('hf_login'), + providesTags: ['HFTokenStatus'], + }), + setHFToken: build.mutation({ + query: (body) => ({ url: buildModelsUrl('hf_login'), method: 'POST', body }), + invalidatesTags: ['HFTokenStatus'], + onQueryStarted: async (_, { dispatch, queryFulfilled }) => { + try { + const { data } = await queryFulfilled; + dispatch(modelsApi.util.updateQueryData('getHFTokenStatus', undefined, () => data)); + } catch { + // no-op + } + }, + }), }), }); @@ -277,6 +293,8 @@ export const { useCancelModelInstallMutation, usePruneCompletedModelInstallsMutation, useGetStarterModelsQuery, + useGetHFTokenStatusQuery, + useSetHFTokenMutation, } = modelsApi; export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select(); diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index ff8c4f1164a..8426a6c263a 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -344,6 +344,24 @@ export type paths = { patch?: never; trace?: never; }; + "/api/v2/models/hf_login": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** Get Hf Login Status */ + get: operations["get_hf_login_status"]; + put?: never; + /** Do Hf Login */ + post: operations["do_hf_login"]; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/v1/download_queue/": { parameters: { query?: never; @@ -2157,6 +2175,14 @@ export type components = { */ image_names: string[]; }; + /** Body_do_hf_login */ + Body_do_hf_login: { + /** + * Token + * @description Hugging Face token to use for login + */ + token: string; + }; /** Body_download */ Body_download: { /** @@ -7322,6 +7348,11 @@ export type components = { */ type: "hf"; }; + /** + * HFTokenStatus + * @enum {string} + */ + HFTokenStatus: "valid" | "invalid" | "unknown"; /** HTTPValidationError */ HTTPValidationError: { /** Detail */ @@ -18274,6 +18305,59 @@ export interface operations { }; }; }; + get_hf_login_status: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HFTokenStatus"]; + }; + }; + }; + }; + do_hf_login: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody: { + content: { + "application/json": components["schemas"]["Body_do_hf_login"]; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HFTokenStatus"]; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; list_downloads: { parameters: { query?: never; diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index dcbd8575815..5c667190fee 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -244,3 +244,12 @@ export type PostUploadAction = export type BoardRecordOrderBy = S['BoardRecordOrderBy']; export type StarterModel = S['StarterModel']; + +export type GetHFTokenStatusResponse = + paths['/api/v2/models/hf_login']['get']['responses']['200']['content']['application/json']; +export type SetHFTokenResponse = NonNullable< + paths['/api/v2/models/hf_login']['post']['responses']['200']['content']['application/json'] +>; +export type SetHFTokenArg = NonNullable< + paths['/api/v2/models/hf_login']['post']['requestBody']['content']['application/json'] +>; diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index 8ef22e08a7c..d181ebb0a35 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -7,6 +7,8 @@ import { $queueId } from 'app/store/nanostores/queueId'; import type { AppStore } from 'app/store/store'; import type { SerializableObject } from 'common/types'; import { deepClone } from 'common/util/deepClone'; +import { $isHFForbiddenToastOpen } from 'features/modelManagerV2/hooks/useHFForbiddenToast'; +import { $isHFLoginToastOpen } from 'features/modelManagerV2/hooks/useHFLoginToast'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription'; @@ -295,6 +297,14 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis const { id, error, error_type } = data; const installs = selectModelInstalls(getState()).data; + if (error === 'Unauthorized') { + $isHFLoginToastOpen.set(true); + } + + if (error === 'Forbidden') { + $isHFForbiddenToastOpen.set({ isEnabled: true, source: data.source }); + } + if (!installs?.find((install) => install.id === id)) { dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); } else {