Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve model preview. #59

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions civitai/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,15 +409,22 @@ def clear_hypernetwork():
#endregion

#region Resource Management
def update_resource_preview(hash: str, preview_url: str):
def update_resource_preview(hash: str, to_update: dict):
resources = load_resource_list([])
matches = [resource for resource in resources if hash.lower() == resource['hash']]
if len(matches) == 0: return

for resource in matches:
# download image and save to resource['path'] - ext + '.preview.png'
preview_path = os.path.splitext(resource['path'])[0] + '.preview.png'
download_file(preview_url, preview_path)
if 'preview_url' in to_update:
# download image and save to resource['path'] - ext + '.preview.png'
preview_path = os.path.splitext(resource['path'])[0] + '.preview.png'
if not os.path.isfile(preview_path):
download_file(to_update['preview_url'], preview_path)
if 'triggers' in to_update:
trigger_path = os.path.splitext(resource['path'])[0] + '.txt'
if not os.path.isfile(trigger_path):
with open(trigger_path, 'w') as f:
f.write(to_update['triggers'])

#endregion Selecting Resources

Expand Down
44 changes: 29 additions & 15 deletions scripts/previews.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@
previewable_types = ['LORA', 'Hypernetwork', 'TextualInversion', 'Checkpoint']
def load_previews():
download_missing_previews = shared.opts.data.get('civitai_download_previews', True)
if not download_missing_previews: return
download_missing_triggers = shared.opts.data.get('civitai_download_triggers', True)
if not download_missing_previews and not download_missing_triggers: return
nsfw_previews = shared.opts.data.get('civitai_nsfw_previews', True)

civitai.log(f"Check resources for missing preview images")
civitai.log(f"Check resources for missing preview images or trigger words")
resources = civitai.load_resource_list()
resources = [r for r in resources if r['type'] in previewable_types]

# get all resources that are missing previews
missing_previews = [r for r in resources if r['hasPreview'] is False]
civitai.log(f"Found {len(missing_previews)} resources missing preview images")
hashes = [r['hash'] for r in missing_previews]
missing_preview_hashes = [r['hash'] for r in resources if r['hasPreview'] is False]
missing_trigger_hashes = [r['hash'] for r in resources if r['hasInfo'] is False]
hashes = list(set(missing_preview_hashes + missing_trigger_hashes))
civitai.log(f"Found {len(hashes)} resources missing preview images or trigger words")

# split hashes into batches of 100 and fetch into results
results = []
Expand All @@ -28,11 +30,11 @@ def load_previews():
batch = hashes[i:i + 100]
results.extend(civitai.get_all_by_hash(batch))
except:
civitai.log("Failed to fetch preview images from Civitai")
civitai.log("Failed to fetch preview images and/or trigger words from Civitai")
return

if len(results) == 0:
civitai.log("No preview images found on Civitai")
civitai.log("No preview images and/or trigger words found on Civitai")
return

civitai.log(f"Found {len(results)} hash matches")
Expand All @@ -43,17 +45,29 @@ def load_previews():
if (r is None): continue

for file in r['files']:
if not 'hashes' in file or not 'SHA256' in file['hashes']: continue
if not 'hashes' in file or not 'SHA256' in file['hashes']:
continue
hash = file['hashes']['SHA256']
if hash.lower() not in hashes: continue
images = r['images']
if (nsfw_previews is False): images = [i for i in images if i['nsfw'] is False]
if (len(images) == 0): continue
image_url = images[0]['url']
civitai.update_resource_preview(hash, image_url)
to_update = {}
if hash.lower() in missing_preview_hashes:
images = r['images']
if (nsfw_previews is False):
images = [i for i in images if i['nsfw'] is False]
if (len(images) > 0):
to_update['image_url'] = images[0]['url']

if hash.lower() in missing_trigger_hashes:
orig_len = len(r['trainedWords'])
triggers = [w for w in r['trainedWords'] if w.isprintable()]
if orig_len != len(triggers):
info = (file['name'], file['hashes']['AutoV2'])
civitai.log("Skipped non-ascii trigger(s): model %s, %s" % info)
if (len(triggers) > 0):
to_update['triggers'] = ', '.join(triggers)
civitai.update_resource_preview(hash, to_update)
updated += 1

civitai.log(f"Updated {updated} preview images")
civitai.log(f"Updated {updated} preview images and/or trigger words")

# Automatically pull model with corresponding hash from Civitai
def start_load_previews(demo: gr.Blocks, app):
Expand Down
3 changes: 2 additions & 1 deletion scripts/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ def on_ui_settings():
shared.opts.add_option("civitai_link_logging", shared.OptionInfo(True, "Show Civitai Link events in the console", section=section))
shared.opts.add_option("civitai_api_key", shared.OptionInfo("", "Your Civitai API Key", section=section))
shared.opts.add_option("civitai_download_previews", shared.OptionInfo(True, "Download missing preview images on startup", section=section))
shared.opts.add_option("civitai_download_triggers", shared.OptionInfo(True, "Download missing trigger text files on startup", section=section))
shared.opts.add_option("civitai_nsfw_previews", shared.OptionInfo(False, "Download NSFW (adult) preview images", section=section))
shared.opts.add_option("civitai_download_missing_models", shared.OptionInfo(True, "Download missing models upon reading generation parameters from prompt", section=section))
shared.opts.add_option("civitai_hashify_resources", shared.OptionInfo(True, "Include resource hashes in image metadata (for resource auto-detection on Civitai)", section=section))
shared.opts.add_option("civitai_folder_lora", shared.OptionInfo("", "LoRA directory (if not default)", section=section))


script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_ui_settings(on_ui_settings)