From 639d88afd6cfafc5039a668f09e974c8a218e1c7 Mon Sep 17 00:00:00 2001
From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com>
Date: Wed, 5 Jul 2023 16:39:15 +1200
Subject: [PATCH 1/2] revert: inference_mode to no_grad
---
invokeai/app/invocations/compel.py | 2 +-
invokeai/app/invocations/latent.py | 8 ++++----
invokeai/backend/model_management/lora.py | 4 ++--
3 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py
index d77269da208..d4ba7efedae 100644
--- a/invokeai/app/invocations/compel.py
+++ b/invokeai/app/invocations/compel.py
@@ -56,7 +56,7 @@ class Config(InvocationConfig):
},
}
- @torch.inference_mode()
+ @torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(),
diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py
index 50c901f15f3..3e691c934e8 100644
--- a/invokeai/app/invocations/latent.py
+++ b/invokeai/app/invocations/latent.py
@@ -285,7 +285,7 @@ def prep_control_data(
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
- @torch.inference_mode()
+ @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
@@ -369,7 +369,7 @@ class Config(InvocationConfig):
},
}
- @torch.inference_mode()
+ @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
@@ -461,7 +461,7 @@ class Config(InvocationConfig):
},
}
- @torch.inference_mode()
+ @torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
@@ -599,7 +599,7 @@ class Config(InvocationConfig):
},
}
- @torch.inference_mode()
+ @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py
index bcd47ff00af..5d27555ab3a 100644
--- a/invokeai/backend/model_management/lora.py
+++ b/invokeai/backend/model_management/lora.py
@@ -526,7 +526,7 @@ def apply_lora(
):
original_weights = dict()
try:
- with torch.inference_mode():
+ with torch.no_grad():
for lora, lora_weight in loras:
#assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
@@ -552,7 +552,7 @@ def apply_lora(
yield # wait for context manager exit
finally:
- with torch.inference_mode():
+ with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)
From 1a29a3fe39aeb8b4623f69196e3fb952045c95d4 Mon Sep 17 00:00:00 2001
From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com>
Date: Wed, 5 Jul 2023 16:39:28 +1200
Subject: [PATCH 2/2] feat: Add Lora to Canvas
---
.../graphBuilders/buildCanvasInpaintGraph.ts | 3 +++
.../UnifiedCanvas/UnifiedCanvasParameters.tsx | 18 ++++++++++--------
2 files changed, 13 insertions(+), 8 deletions(-)
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts
index 82912de2198..c4f9415067a 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts
@@ -8,6 +8,7 @@ import {
RangeOfSizeInvocation,
} from 'services/api/types';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
+import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import {
INPAINT,
@@ -194,6 +195,8 @@ export const buildCanvasInpaintGraph = (
],
};
+ addLoRAsToGraph(graph, state, INPAINT);
+
// Add VAE
addVAEToGraph(graph, state);
diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx
index 061ebb962e2..63ed4cc1cfe 100644
--- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx
+++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx
@@ -1,14 +1,15 @@
-import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
-import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
-import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
+import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
+import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
import ParamSeamCorrectionCollapse from 'features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse';
-import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
-import { memo } from 'react';
-import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
-import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
-import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
+import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
+import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
+import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
+import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
+import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
+import { memo } from 'react';
+import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
const UnifiedCanvasParameters = () => {
return (
@@ -17,6 +18,7 @@ const UnifiedCanvasParameters = () => {
+