Skip to content

Commit 46a9c00

Browse files
committed
add --increment-seed argument
1 parent c1fd82a commit 46a9c00

File tree

1 file changed

+33
-16
lines changed

1 file changed

+33
-16
lines changed

swift/StableDiffusionCLI/main.swift

+33-16
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ struct StableDiffusionSample: ParsableCommand {
6868
@Flag(help: "Reduce memory usage")
6969
var reduceMemory: Bool = false
7070

71+
@Flag(help: "Increse random seed by 1 for each image")
72+
var incrementSeed: Bool = false
73+
7174
mutating func run() throws {
7275
guard FileManager.default.fileExists(atPath: resourcePath) else {
7376
throw RunError.resources("Resource path does not exist \(resourcePath)")
@@ -89,24 +92,38 @@ struct StableDiffusionSample: ParsableCommand {
8992
let sampleTimer = SampleTimer()
9093
sampleTimer.start()
9194

92-
let images = try pipeline.generateImages(
93-
prompt: prompt,
94-
negativePrompt: negativePrompt,
95-
imageCount: imageCount,
96-
stepCount: stepCount,
97-
seed: seed,
98-
guidanceScale: guidanceScale,
99-
scheduler: scheduler.stableDiffusionScheduler
100-
) { progress in
101-
sampleTimer.stop()
102-
handleProgress(progress,sampleTimer)
103-
if progress.stepCount != progress.step {
104-
sampleTimer.start()
95+
let loops = incrementSeed ? imageCount : 1
96+
let imageCountPerBatch = incrementSeed ? 1 : imageCount
97+
98+
for i in 0 ..< loops {
99+
if (incrementSeed) {
100+
log("Generating image \(i+1) of \(imageCount) with seed \(seed)\n")
101+
log("\n")
105102
}
106-
return true
107-
}
108103

109-
_ = try saveImages(images, logNames: true)
104+
let images = try pipeline.generateImages(
105+
prompt: prompt,
106+
negativePrompt: negativePrompt,
107+
imageCount: imageCount,
108+
stepCount: stepCount,
109+
seed: seed,
110+
guidanceScale: guidanceScale,
111+
scheduler: scheduler.stableDiffusionScheduler
112+
) { progress in
113+
sampleTimer.stop()
114+
handleProgress(progress,sampleTimer)
115+
if progress.stepCount != progress.step {
116+
sampleTimer.start()
117+
}
118+
return true
119+
}
120+
121+
_ = try saveImages(images, logNames: true)
122+
123+
if (incrementSeed) {
124+
seed += 1
125+
}
126+
}
110127
}
111128

112129
func handleProgress(

0 commit comments

Comments
 (0)