Skip to content

Commit

Permalink
ONNX Support (#3562)
Browse files Browse the repository at this point in the history
Note: this branch based on #3548, not on main

While find out what needs to be done to implement onnx, found that I can
do draft of it pretty quickly, so... here it is)
Supports LoRA and TI.
As example - cat with sadcatmeme lora:

![image](https://github.com/invoke-ai/InvokeAI/assets/7768370/dbd1a5df-0629-4741-94b3-8e09f4b4d5a3)

![image](https://github.com/invoke-ai/InvokeAI/assets/7768370/d918836c-fdc7-43c0-aa81-dde9182f2e0f)
  • Loading branch information
hipsterusername authored Jul 31, 2023
2 parents ae0f4ef + 746afcd commit 81654da
Show file tree
Hide file tree
Showing 41 changed files with 2,134 additions and 394 deletions.
7 changes: 5 additions & 2 deletions installer/lib/installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def get_torch_source() -> (Union[str, None], str):
device = graphical_accelerator()

url = None
optional_modules = None
optional_modules = "[onnx]"
if OS == "Linux":
if device == "rocm":
url = "https://download.pytorch.org/whl/rocm5.4.2"
Expand All @@ -464,7 +464,10 @@ def get_torch_source() -> (Union[str, None], str):

if device == "cuda":
url = "https://download.pytorch.org/whl/cu117"
optional_modules = "[xformers]"
optional_modules = "[xformers,onnx-cuda]"
if device == "cuda_and_dml":
url = "https://download.pytorch.org/whl/cu117"
optional_modules = "[xformers,onnx-directml]"

# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

Expand Down
6 changes: 5 additions & 1 deletion installer/lib/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ def graphical_accelerator():
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
"cuda",
)
nvidia_with_dml = (
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
"cuda_and_dml",
)
amd = (
"an [gold1 b]AMD[/] GPU (using ROCm™)",
"rocm",
Expand All @@ -181,7 +185,7 @@ def graphical_accelerator():
)

if OS == "Windows":
options = [nvidia, cpu]
options = [nvidia, nvidia_with_dml, cpu]
if OS == "Linux":
options = [nvidia, amd, cpu]
elif OS == "Darwin":
Expand Down
8 changes: 8 additions & 0 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from typing import Literal, Optional, Union, List, Annotated
from pydantic import BaseModel, Field
import re

from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .model import ClipField

from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher

import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
Expand Down
1 change: 1 addition & 0 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.model_management import ModelPatcher
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
Expand Down
1 change: 1 addition & 0 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class MainModelField(BaseModel):

model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Model Type")


class LoRAModelField(BaseModel):
Expand Down
Loading

0 comments on commit 81654da

Please sign in to comment.