-
Notifications
You must be signed in to change notification settings - Fork 6
TCD Scheduler + LoRA IPAdapter SDXL #76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e98c099
c404e32
4407724
cabf811
9999514
f79a59c
977afb1
0044a9b
b3182d0
41e5122
b04f0e8
1c0f1f6
e2778b6
53f7d92
55a20c9
123ba69
312811c
54f0546
7e210ea
a0779f4
3854dbe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
|
||
import hashlib | ||
import logging | ||
from enum import Enum | ||
from pathlib import Path | ||
|
@@ -75,15 +77,30 @@ def __init__(self, engine_dir: str): | |
'loader': lambda path, cuda_stream, **kwargs: str(path) | ||
} | ||
} | ||
|
||
|
||
def _lora_signature(self, lora_dict: Dict[str, float]) -> str: | ||
"""Create a short, stable signature for a set of LoRAs. | ||
|
||
Uses sorted basenames and weights, hashed to a short hex to avoid | ||
long/invalid paths while keeping cache keys stable across runs. | ||
""" | ||
# Build canonical string of basename:weight pairs | ||
parts = [] | ||
for path, weight in sorted(lora_dict.items(), key=lambda x: str(x[0])): | ||
base = Path(str(path)).name # basename only | ||
parts.append(f"{base}:{weight}") | ||
canon = "|".join(parts) | ||
h = hashlib.sha1(canon.encode("utf-8")).hexdigest()[:10] | ||
return f"{len(lora_dict)}-{h}" | ||
|
||
def get_engine_path(self, | ||
engine_type: EngineType, | ||
model_id_or_path: str, | ||
max_batch_size: int, | ||
min_batch_size: int, | ||
mode: str, | ||
use_lcm_lora: bool, | ||
use_tiny_vae: bool, | ||
lora_dict: Optional[Dict[str, float]] = None, | ||
ipadapter_scale: Optional[float] = None, | ||
ipadapter_tokens: Optional[int] = None, | ||
controlnet_model_id: Optional[str] = None, | ||
|
@@ -114,14 +131,18 @@ def get_engine_path(self, | |
base_name = maybe_path.stem if maybe_path.exists() else model_id_or_path | ||
|
||
# Create prefix (from wrapper.py lines 1005-1013) | ||
prefix = f"{base_name}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--min_batch-{min_batch_size}--max_batch-{max_batch_size}" | ||
prefix = f"{base_name}--tiny_vae-{use_tiny_vae}--min_batch-{min_batch_size}--max_batch-{max_batch_size}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will cause engines to rebuild - so it's easiest to remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ty for the heads up on this |
||
|
||
# IP-Adapter differentiation: add type and (optionally) tokens | ||
# Keep scale out of identity for runtime control, but include a type flag to separate caches | ||
if is_faceid is True: | ||
prefix += f"--fid" | ||
if ipadapter_tokens is not None: | ||
prefix += f"--tokens{ipadapter_tokens}" | ||
|
||
# Fused Loras - use concise hashed signature to avoid long/invalid paths | ||
if lora_dict is not None and len(lora_dict) > 0: | ||
prefix += f"--lora-{self._lora_signature(lora_dict)}" | ||
|
||
prefix += f"--mode-{mode}" | ||
|
||
|
@@ -287,7 +308,6 @@ def get_or_load_controlnet_engine(self, | |
max_batch_size=max_batch_size, | ||
min_batch_size=min_batch_size, | ||
mode="", # Not used for ControlNet | ||
use_lcm_lora=False, # Not used for ControlNet | ||
use_tiny_vae=False, # Not used for ControlNet | ||
controlnet_model_id=model_id | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -360,6 +360,29 @@ def reset_cuda_graph(self): | |
self.graph = None | ||
|
||
def infer(self, feed_dict, stream, use_cuda_graph=False): | ||
# Filter inputs to only those the engine actually exposes to avoid binding errors | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not 100% sure about this |
||
try: | ||
allowed_inputs = set() | ||
for idx in range(self.engine.num_io_tensors): | ||
name = self.engine.get_tensor_name(idx) | ||
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: | ||
allowed_inputs.add(name) | ||
|
||
# Drop any extra keys (e.g., text_embeds/time_ids) that the engine was not built to accept | ||
if allowed_inputs: | ||
filtered_feed_dict = {k: v for k, v in feed_dict.items() if k in allowed_inputs} | ||
if len(filtered_feed_dict) != len(feed_dict): | ||
missing = [k for k in feed_dict.keys() if k not in allowed_inputs] | ||
if missing: | ||
logger.debug( | ||
"TensorRT Engine: filtering unsupported inputs %s (allowed=%s)", | ||
missing, sorted(list(allowed_inputs)) | ||
) | ||
feed_dict = filtered_feed_dict | ||
except Exception: | ||
# Be permissive if engine query fails; proceed with original dict | ||
pass | ||
|
||
for name, buf in feed_dict.items(): | ||
self.tensors[name].copy_(buf) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These dep versions changed for Windows support.