Skip to content

Commit

Permalink
Default model settings (#5850)
Browse files Browse the repository at this point in the history
* UI in MM to create trigger phrases

* add scheduler and vaePrecision to config

* UI for configuring default settings for models'

* hook MM default model settings up to API

* add button to set default settings in parameters

* pull out trigger phrases

* back-end for default settings

* lint

* remove log;
gi

* ruff

* ruff format

---------

Co-authored-by: Mary Hipp <[email protected]>
  • Loading branch information
maryhipp and Mary Hipp authored Mar 4, 2024
1 parent 893bcd1 commit 8b34f52
Show file tree
Hide file tree
Showing 28 changed files with 1,122 additions and 120 deletions.
43 changes: 43 additions & 0 deletions invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing_extensions import Annotated

from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
Expand All @@ -32,6 +33,7 @@
)
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
from invokeai.backend.model_manager.search import ModelSearch

from ..dependencies import ApiDependencies
Expand Down Expand Up @@ -243,6 +245,47 @@ async def get_model_metadata(
return result


@model_manager_router.patch(
"/i/{key}/metadata",
operation_id="update_model_metadata",
responses={
201: {
"description": "The model metadata was updated successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def update_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
changes: ModelMetadataChanges = Body(description="The changes"),
) -> Optional[AnyModelRepoMetadata]:
"""Updates or creates a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store

try:
original_metadata = record_store.get_metadata(key)
if original_metadata:
if changes.default_settings:
original_metadata.default_settings = changes.default_settings

metadata_store.update_metadata(key, original_metadata)
else:
metadata_store.add_metadata(
key, BaseMetadata(name="", author="", default_settings=changes.default_settings)
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An error occurred while updating the model metadata: {e}",
)

result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)

return result


@model_manager_router.get(
"/tags",
operation_id="list_tags",
Expand Down
18 changes: 17 additions & 1 deletion invokeai/app/services/model_metadata/metadata_store_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,25 @@
"""

from abc import ABC, abstractmethod
from typing import List, Set, Tuple
from typing import List, Optional, Set, Tuple

from pydantic import Field

from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings


class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
"""A set of changes to apply to model metadata.
Only limited changes are valid:
- `default_settings`: the user-configured default settings for this model
"""

default_settings: Optional[ModelDefaultSettings] = Field(
default=None, description="The user-configured default settings for this model"
)
"""The user-configured default settings for this model"""


class ModelMetadataStoreBase(ABC):
Expand Down
73 changes: 37 additions & 36 deletions invokeai/app/services/model_metadata/metadata_store_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,44 +179,45 @@ def search_by_name(self, name: str) -> Set[str]:
)
return {x[0] for x in self._cursor.fetchall()}

def _update_tags(self, model_key: str, tags: Set[str]) -> None:
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
"""Update tags for the model referenced by model_key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)

for tag in tags:
if tags:
# remove previous tags from this model
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
DELETE FROM model_tags
WHERE model_id=?;
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
(model_key,),
)

for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)
15 changes: 14 additions & 1 deletion invokeai/backend/model_manager/metadata/metadata_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from requests.sessions import Session
from typing_extensions import Annotated

from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.backend.model_manager import ModelRepoVariant

from ..util import select_hf_files
Expand Down Expand Up @@ -68,12 +69,24 @@ class RemoteModelFile(BaseModel):
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)


class ModelDefaultSettings(BaseModel):
vae: str | None
vae_precision: str | None
scheduler: SCHEDULER_NAME_VALUES | None
steps: int | None
cfg_scale: float | None
cfg_rescale_multiplier: float | None


class ModelMetadataBase(BaseModel):
"""Base class for model metadata information."""

name: str = Field(description="model's name")
author: str = Field(description="model's author")
tags: Set[str] = Field(description="tags provided by model source")
tags: Optional[Set[str]] = Field(description="tags provided by model source", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
description="default settings for this model", default=None
)


class BaseMetadata(ModelMetadataBase):
Expand Down
7 changes: 7 additions & 0 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"aboutDesc": "Using Invoke for work? Check out:",
"aboutHeading": "Own Your Creative Power",
"accept": "Accept",
"add": "Add",
"advanced": "Advanced",
"advancedOptions": "Advanced Options",
"ai": "ai",
Expand Down Expand Up @@ -734,6 +735,8 @@
"customConfig": "Custom Config",
"customConfigFileLocation": "Custom Config File Location",
"customSaveLocation": "Custom Save Location",
"defaultSettings": "Default Settings",
"defaultSettingsSaved": "Default Settings Saved",
"delete": "Delete",
"deleteConfig": "Delete Config",
"deleteModel": "Delete Model",
Expand Down Expand Up @@ -768,6 +771,7 @@
"mergedModelName": "Merged Model Name",
"mergedModelSaveLocation": "Save Location",
"mergeModels": "Merge Models",
"metadata": "Metadata",
"model": "Model",
"modelAdded": "Model Added",
"modelConversionFailed": "Model Conversion Failed",
Expand Down Expand Up @@ -839,9 +843,12 @@
"statusConverting": "Converting",
"syncModels": "Sync Models",
"syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.",
"triggerPhrases": "Trigger Phrases",
"typePhraseHere": "Type phrase here",
"upcastAttention": "Upcast Attention",
"updateModel": "Update Model",
"useCustomConfig": "Use Custom Config",
"useDefaultSettings": "Use Default Settings",
"v1": "v1",
"v2_768": "v2 (768px)",
"v2_base": "v2 (512px)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';

import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';

export const listenerMiddleware = createListenerMiddleware();

export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
Expand Down Expand Up @@ -153,3 +155,5 @@ addUpscaleRequestedListener(startAppListening);

// Dynamic prompts
addDynamicPromptsListener(startAppListening);

addSetDefaultSettingsListener(startAppListening);
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setDefaultSettings } from 'features/parameters/store/actions';
import {
setCfgRescaleMultiplier,
setCfgScale,
setScheduler,
setSteps,
vaePrecisionChanged,
vaeSelected,
} from 'features/parameters/store/generationSlice';
import {
isParameterCFGRescaleMultiplier,
isParameterCFGScale,
isParameterPrecision,
isParameterScheduler,
isParameterSteps,
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { map } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';

export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: setDefaultSettings,
effect: async (action, { dispatch, getState }) => {
const state = getState();

const currentModel = state.generation.model;

if (!currentModel) {
return;
}

const metadata = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)).unwrap();

if (!metadata || !metadata.default_settings) {
return;
}

const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings;

if (vae) {
// we store this as "default" within default settings
// to distinguish it from no default set
if (vae === 'default') {
dispatch(vaeSelected(null));
} else {
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
const vaeArray = map(data?.entities);
const validVae = vaeArray.find((model) => model.key === vae);

const result = zParameterVAEModel.safeParse(validVae);
if (!result.success) {
return;
}
dispatch(vaeSelected(result.data));
}
}

if (vae_precision) {
if (isParameterPrecision(vae_precision)) {
dispatch(vaePrecisionChanged(vae_precision));
}
}

if (cfg_scale) {
if (isParameterCFGScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale));
}
}

if (cfg_rescale_multiplier) {
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
}
}

if (steps) {
if (isParameterSteps(steps)) {
dispatch(setSteps(steps));
}
}

if (scheduler) {
if (isParameterScheduler(scheduler)) {
dispatch(setScheduler(scheduler));
}
}

dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
},
});
};
3 changes: 3 additions & 0 deletions invokeai/frontend/web/src/app/types/invokeai.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import type { O } from 'ts-toolbelt';

Expand Down Expand Up @@ -82,6 +83,8 @@ export type AppConfig = {
guidance: NumericalParameterConfig;
cfgRescaleMultiplier: NumericalParameterConfig;
img2imgStrength: NumericalParameterConfig;
scheduler?: ParameterScheduler;
vaePrecision?: ParameterPrecision;
// Canvas
boundingBoxHeight: NumericalParameterConfig; // initial value comes from model
boundingBoxWidth: NumericalParameterConfig; // initial value comes from model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export const ModelPane = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
return (
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
{selectedModelKey ? <Model /> : <ImportModels />}
{selectedModelKey ? <Model key={selectedModelKey} /> : <ImportModels />}
</Box>
);
};
Loading

0 comments on commit 8b34f52

Please sign in to comment.