Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

examples: Add test run for lv2v #283

Merged
merged 31 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1a7c557
examples: Add possible example for lv2v
victorges Nov 20, 2024
2225784
Merge branch 'feat/test-live-video-to-video' of https://github.com/li…
varshith15 Nov 21, 2024
1b3ed77
go.mod: Update pkg/errors to fix build
victorges Nov 21, 2024
5835437
feat: noop webcam lv2v example
varshith15 Nov 21, 2024
347722e
Merge branch 'feat/test-live-video-to-video' of https://github.com/li…
varshith15 Nov 21, 2024
3c6489d
temp: dummy trickle client
varshith15 Nov 22, 2024
47feacb
temp: docker network host
varshith15 Nov 23, 2024
3c2def5
feat: example for lv2v noop with zmq
varshith15 Nov 27, 2024
4810532
fix: stream protocol as param
varshith15 Nov 28, 2024
fc26de1
fix: zmq bind change
varshith15 Nov 28, 2024
9d47065
feat: fps monitor init
varshith15 Nov 29, 2024
b7e4a39
fix: monitor to async, revert stream_protocol param,
varshith15 Dec 2, 2024
6cfaf4f
fix: kafka revert, ci added for noop
varshith15 Dec 3, 2024
aaf8338
Merge branch 'main' into feat/test-live-video-to-video
varshith15 Dec 4, 2024
0d23630
Merge branch 'main' of https://github.com/livepeer/ai-worker into fea…
varshith15 Dec 5, 2024
81da469
fix: working ci test
varshith15 Dec 5, 2024
6c51f8e
fix: possible caching issue
varshith15 Dec 5, 2024
0ffc438
fix: clean up
varshith15 Dec 6, 2024
5b087ef
fix: remove client resizing
varshith15 Dec 6, 2024
d89d4e6
feat: comfyui test
varshith15 Dec 9, 2024
a3217a3
workflow: Switch over to self-hosted gpu runner (#345)
hjpotter92 Dec 11, 2024
bc50ddb
Merge branch 'main' into feat/test-live-video-to-video
varshith15 Dec 11, 2024
eb14c69
fix: review fixes
varshith15 Dec 11, 2024
e7fc022
fix: run on the same node
varshith15 Dec 11, 2024
e5123b6
fix: workflows
varshith15 Dec 12, 2024
88bd756
fix: add symlink
varshith15 Dec 12, 2024
dfd732e
revert: workflow interlink
varshith15 Dec 12, 2024
a5c973e
Merge branch 'main' of https://github.com/livepeer/ai-worker into fea…
varshith15 Dec 12, 2024
c2f76e0
Merge branch 'main' of https://github.com/livepeer/ai-worker into fea…
varshith15 Dec 12, 2024
afe82dc
fix: revert fps test
varshith15 Dec 12, 2024
38a5e4a
fix: revert ci test
varshith15 Jan 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ output
aiModels.json
models
checkpoints
runner/run-lv2v.log

# IDE
.vscode
Expand Down
374 changes: 374 additions & 0 deletions cmd/examples/live-video-to-video/main.go
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great stuff!

Original file line number Diff line number Diff line change
@@ -0,0 +1,374 @@
package main

import (
"bytes"
"image"
"context"
"errors"
"flag"
"log/slog"
"os"
"path/filepath"
"sync"
"time"
"fmt"
"math"
"sort"
"strings"

"github.com/pebbe/zmq4"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
docker "github.com/docker/docker/client"
"github.com/livepeer/ai-worker/worker"
)

const defaultPrompt = `{
"1": {
"inputs": {
"images": ["2", 0]
},
"class_type": "SaveTensor",
"_meta": {
"title": "SaveTensor"
}
},
"2": {
"inputs": {
"engine": "depth_anything_vitl14-fp16.engine",
"images": ["3", 0]
},
"class_type": "DepthAnythingTensorrt",
"_meta": {
"title": "Depth Anything Tensorrt"
}
},
"3": {
"inputs": {},
"class_type": "LoadTensor",
"_meta": {
"title": "LoadTensor"
}
}
}`

func sendImages(ctx context.Context, imageDir string, inputFps int) error {
publisher, err := zmq4.NewSocket(zmq4.PUB)
if err != nil {
return fmt.Errorf("failed to create ZMQ PUB socket: %v", err)
}
defer publisher.Close()

sendAddress := "tcp://*:5555"
err = publisher.Bind(sendAddress)
if err != nil {
return fmt.Errorf("failed to bind ZMQ PUB socket: %v", err)
}

var preprocessedImages [][]byte
files, err := os.ReadDir(imageDir)
if err != nil {
return fmt.Errorf("failed to read image directory: %v", err)
}

for _, file := range files {
if !file.IsDir() {
ext := strings.ToLower(filepath.Ext(file.Name()))
if ext == ".jpg" || ext == ".jpeg" || ext == ".png" {
imagePath := filepath.Join(imageDir, file.Name())

fileBytes, err := os.ReadFile(imagePath)
if err != nil {
slog.Error("Failed to read image file", slog.String("path", imagePath), slog.String("error", err.Error()))
continue
}

preprocessedImages = append(preprocessedImages, fileBytes)
}
}
}

if len(preprocessedImages) == 0 {
return fmt.Errorf("no image files found in directory")
}

interval := time.Second / time.Duration(inputFps)

slog.Info(fmt.Sprintf("Sending images at %d FPS to %s", inputFps, sendAddress))

ticker := time.NewTicker(interval)
defer ticker.Stop()

for {
for _, imageBytes := range preprocessedImages {
select {
case <-ctx.Done():
return nil
case <-ticker.C:
_, err = publisher.SendBytes(imageBytes, 0)
if err != nil {
slog.Error("Failed to send image bytes", slog.String("error", err.Error()))
}
}
}
}
}



func printFPSStatistics(fpsList []float64, expOutputFps int) {
if len(fpsList) < 5 {
slog.Info("Not enough FPS values collected (minimum 5 required)")
return
}

// Remove first 5 sec values (warm up)
fpsList = fpsList[5:]

if len(fpsList) == 0 {
slog.Info("No FPS values remaining after removing first 5 values")
return
}

sorted := make([]float64, len(fpsList))
copy(sorted, fpsList)
sort.Float64s(sorted)

min := sorted[0]
max := sorted[len(sorted)-1]

var sum float64
for _, v := range sorted {
sum += v
}
avg := sum / float64(len(sorted))

p1 := calculatePercentile(sorted, 1)
p5 := calculatePercentile(sorted, 5)
p10 := calculatePercentile(sorted, 10)

slog.Info(fmt.Sprintf("FPS Statistics:"+
"\nMin: %.2f"+
"\nMax: %.2f"+
"\nAvg: %.2f"+
"\nP1: %.2f"+
"\nP5: %.2f"+
"\nP10: %.2f\n",
min, max, avg, p1, p5, p10))

if p1 >= float64(expOutputFps) {
slog.Info("TEST PASSED!")
} else {
slog.Info("TEST FAILED!")
}
}

func calculatePercentile(sorted []float64, percentile float64) float64 {
index := (percentile / 100.0) * float64(len(sorted)-1)
i := int(math.Floor(index))
fraction := index - float64(i)

if i+1 >= len(sorted) {
return sorted[i]
}

return sorted[i] + fraction*(sorted[i+1]-sorted[i])
}


func receiveImages(ctx context.Context, expOutputFps int) error {
subscriber, err := zmq4.NewSocket(zmq4.SUB)
if err != nil {
return fmt.Errorf("failed to create ZMQ SUB socket: %v", err)
}
defer subscriber.Close()

receiveAddress := "tcp://*:5556"
err = subscriber.Bind(receiveAddress)
if err != nil {
return fmt.Errorf("failed to connect ZMQ SUB socket: %v", err)
}

err = subscriber.SetSubscribe("")
if err != nil {
return fmt.Errorf("failed to subscribe to all messages: %v", err)
}

slog.Info(fmt.Sprintf("Receiving images on %s", receiveAddress))

startTime := time.Now()
numImages := 0
var fpsList []float64

for {
select {
case <-ctx.Done():
printFPSStatistics(fpsList, expOutputFps)
return nil
default:
imageBytes, err := subscriber.RecvBytes(0)
if err != nil {
slog.Error("Failed to receive image bytes", slog.String("error", err.Error()))
continue
}

reader := bytes.NewReader(imageBytes)
_, _, err = image.Decode(reader)
if err != nil {
slog.Error("Failed to decode received image", slog.String("error", err.Error()))
continue
}

numImages++

currentTime := time.Now()
elapsedTime := currentTime.Sub(startTime)
if elapsedTime >= time.Second {
currentFPS := float64(numImages) / elapsedTime.Seconds()
fpsList = append(fpsList, currentFPS)
slog.Info(fmt.Sprintf("Receiving FPS: %.2f", currentFPS))
startTime = currentTime
numImages = 0
}
}
}
}

func main() {
aiModelsDir := flag.String("aimodelsdir", "runner/models", "path to the models directory")
inputFps := flag.Int("inputfps", 30, "Frames per second to send")
modelID := flag.String("modelid", "liveportrait", "Model ID for the live pipeline")
imageDir := flag.String("imagedir", "runner/example_data/live-video-to-video/", "Path to the image to send")
expOutputFps := flag.Int("expoutputfps", 27, "Minimum expected output FPS")
comfyuiPrompt := flag.String("comfyuiprompt", defaultPrompt, "Prompt to be used in comfyui pipeline")
flag.Parse()

pipeline := "live-video-to-video"
defaultImageID := "livepeer/ai-runner:latest"
gpus := []string{"0"}

modelsDir, err := filepath.Abs(*aiModelsDir)
if errors.Is(err, os.ErrNotExist) {
slog.Error("Directory does not exist", slog.String("path", *aiModelsDir))
return
} else if err != nil {
slog.Error("Error getting absolute path for 'aiModelsDir'", slog.String("error", err.Error()))
return
}

w, err := worker.NewWorker(defaultImageID, gpus, modelsDir)
if err != nil {
slog.Error("Error creating worker", slog.String("error", err.Error()))
return
}

dockerClient, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation())
if err != nil {
slog.Error("Error creating docker client", slog.String("error", err.Error()))
return
}

ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
defer cancel()

existingContainers, err := dockerClient.ContainerList(ctx, container.ListOptions{
Filters: filters.NewArgs(
filters.Arg("name", "^"+pipeline+"_"+*modelID),
),
All: true,
})
if err != nil {
slog.Error("Error listing existing containers", slog.String("error", err.Error()))
return
}
for _, _container := range existingContainers {
slog.Info("Removing existing container", slog.String("container_id", _container.ID))
err := dockerClient.ContainerRemove(ctx, _container.ID, container.RemoveOptions{
Force: true,
})
if err != nil {
slog.Error("Error removing container", slog.String("container_id", _container.ID), slog.String("error", err.Error()))
return
}
}

slog.Info("Warming container")

optimizationFlags := worker.OptimizationFlags{
"STREAM_PROTOCOL": "zeromq",
}
Comment on lines +297 to +299
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was so confused. I always thought these "optimization flags" were something for docker not the app lol


if err := w.Warm(ctx, pipeline, *modelID, worker.RunnerEndpoint{}, optimizationFlags); err != nil {
slog.Error("Error warming container", slog.String("error", err.Error()))
return
}

slog.Info("Warm container is up")

req := worker.GenLiveVideoToVideoJSONRequestBody{
ModelId: modelID,
SubscribeUrl: "tcp://172.17.0.1:5555",
PublishUrl: "tcp://172.17.0.1:5556",
Params: &map[string]interface{}{
"prompt": *comfyuiPrompt,
},
}

slog.Info("Running live-video-to-video")

resp, err := w.LiveVideoToVideo(ctx, req)
if err != nil {
slog.Error("Error running live-video-to-video", slog.String("error", err.Error()))
return
}

slog.Info("Got response", slog.Any("response", resp))

var wg sync.WaitGroup

wg.Add(1)
go func() {
defer wg.Done()
select {
case <-time.After(10 * time.Second):
err := sendImages(ctx, *imageDir, *inputFps)
if err != nil {
slog.Error("Error in sendImages", slog.String("error", err.Error()))
}
case <-ctx.Done():
return
}
}()

wg.Add(1)
go func() {
defer wg.Done()
select {
case <-time.After(10 * time.Second):
err := receiveImages(ctx, *expOutputFps)
if err != nil {
slog.Error("Error in receiveImages", slog.String("error", err.Error()))
}
case <-ctx.Done():
return
}
}()

// Wait for either the context to be done or the goroutines to finish
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()

select {
case <-ctx.Done():
slog.Info("Context done, waiting 10 sec for goroutines")
time.Sleep(10 * time.Second)
slog.Info("10 sec waiting done, stopping")
case <-done:
slog.Info("All goroutines finished")
}

w.Stop(ctx)
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
github.com/go-chi/chi/v5 v5.1.0
github.com/oapi-codegen/runtime v1.1.1
github.com/opencontainers/image-spec v1.1.0
github.com/pebbe/zmq4 v1.2.11
github.com/stretchr/testify v1.9.0
github.com/vincent-petithory/dataurl v1.0.0
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
github.com/pebbe/zmq4 v1.2.11 h1:Ua5mgIaZeabUGnH7tqswkUcjkL7JYGai5e8v4hpEU9Q=
github.com/pebbe/zmq4 v1.2.11/go.mod h1:nqnPueOapVhE2wItZ0uOErngczsJdLOGkebMxaO8r48=
github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s=
github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
Expand Down
Loading
Loading