Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add seperate get_result and upload_result api for gpt task #11

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 74 additions & 14 deletions api/v1/inference_tasks/get_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crynux_relay/api/v1/response"
"crynux_relay/config"
"crynux_relay/models"
"encoding/json"
"errors"
"os"
"path/filepath"
Expand All @@ -13,20 +14,20 @@ import (
"gorm.io/gorm"
)

type GetResultInput struct {
type GetSDResultInput struct {
ImageNum string `path:"image_num" json:"image_num" description:"Image number" validate:"required"`
TaskId uint64 `path:"task_id" json:"task_id" description:"Task id" validate:"required"`
}

type GetResultInputWithSignature struct {
GetResultInput
type GetSDResultInputWithSignature struct {
GetSDResultInput
Timestamp int64 `query:"timestamp" description:"Signature timestamp" validate:"required"`
Signature string `query:"signature" description:"Signature" validate:"required"`
}

func GetResult(ctx *gin.Context, in *GetResultInputWithSignature) error {
func GetSDResult(ctx *gin.Context, in *GetSDResultInputWithSignature) error {

match, address, err := ValidateSignature(in.GetResultInput, in.Timestamp, in.Signature)
match, address, err := ValidateSignature(in.GetSDResultInput, in.Timestamp, in.Signature)

if err != nil || !match {

Expand Down Expand Up @@ -57,18 +58,11 @@ func GetResult(ctx *gin.Context, in *GetResultInputWithSignature) error {

appConfig := config.GetConfig()

var fileExt string
if task.TaskType == models.TaskTypeSD {
fileExt = ".png"
} else {
fileExt = ".json"
}

resultFile := filepath.Join(
appConfig.DataDir.InferenceTasks,
task.GetTaskIdAsString(),
"results",
in.ImageNum+fileExt,
in.ImageNum+".png",
)

if _, err := os.Stat(resultFile); err != nil {
Expand All @@ -77,9 +71,75 @@ func GetResult(ctx *gin.Context, in *GetResultInputWithSignature) error {

ctx.Header("Content-Description", "File Transfer")
ctx.Header("Content-Transfer-Encoding", "binary")
ctx.Header("Content-Disposition", "attachment; filename="+in.ImageNum+fileExt)
ctx.Header("Content-Disposition", "attachment; filename="+in.ImageNum+".png")
ctx.Header("Content-Type", "application/octet-stream")
ctx.File(resultFile)

return nil
}

type GetGPTResultInput struct {
TaskId uint64 `path:"task_id" json:"task_id" description:"Task id" validate:"required"`
}

type GetGPTResultInputWithSignature struct {
GetGPTResultInput
Timestamp int64 `query:"timestamp" description:"Signature timestamp" validate:"required"`
Signature string `query:"signature" description:"Signature" validate:"required"`
}

func GetGPTResult(ctx *gin.Context, in *GetGPTResultInputWithSignature) (*GPTResultResponse, error) {
match, address, err := ValidateSignature(in.GetGPTResultInput, in.Timestamp, in.Signature)

if err != nil || !match {

if err != nil {
log.Debugln(err)
}

return nil, response.NewValidationErrorResponse("signature", "Invalid signature")
}

var task models.InferenceTask

if err := config.GetDB().Where(&models.InferenceTask{TaskId: in.TaskId}).First(&task).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, response.NewValidationErrorResponse("task_id", "Task not found")
} else {
return nil, response.NewExceptionResponse(err)
}
}

if task.Creator != address {
return nil, response.NewValidationErrorResponse("signature", "Signer not allowed")
}

if task.Status != models.InferenceTaskResultsUploaded {
return nil, response.NewValidationErrorResponse("task_id", "Task results not uploaded")
}

appConfig := config.GetConfig()

resultFile := filepath.Join(
appConfig.DataDir.InferenceTasks,
task.GetTaskIdAsString(),
"results",
"0.json",
)

if _, err := os.Stat(resultFile); err != nil {
return nil, response.NewValidationErrorResponse("image_num", "File not found")
}

resultContent, err := os.ReadFile(resultFile)
if err != nil {
return nil, response.NewExceptionResponse(err)
}

data := &models.GPTTaskResponse{}
if err := json.Unmarshal(resultContent, data); err != nil {
return nil, response.NewExceptionResponse(err)
}

return &GPTResultResponse{Data: *data}, nil
}
216 changes: 142 additions & 74 deletions api/v1/inference_tasks/get_result_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crynux_relay/models"
"crynux_relay/tests"
v1 "crynux_relay/tests/api/v1"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
Expand All @@ -18,105 +19,161 @@ import (
)

func TestUnauthorizedGetImage(t *testing.T) {
for _, taskType := range tests.TaskTypes {
addresses, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "prepare accounts error")
addresses, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "prepare accounts error")

_, task, err := tests.PrepareResultUploadedTask(taskType, addresses, config.GetDB())
assert.Equal(t, nil, err, "prepare task error")
_, task, err := tests.PrepareResultUploadedTask(models.TaskTypeSD, addresses, config.GetDB())
assert.Equal(t, nil, err, "prepare task error")

getResultInput := &inference_tasks.GetResultInput{
TaskId: task.TaskId,
ImageNum: "0",
}
getResultInput := &inference_tasks.GetSDResultInput{
TaskId: task.TaskId,
ImageNum: "0",
}

timestamp, signature, err := v1.SignData(getResultInput, privateKeys[1])
assert.Equal(t, nil, err, "sign data error")
timestamp, signature, err := v1.SignData(getResultInput, privateKeys[1])
assert.Equal(t, nil, err, "sign data error")

r := callGetImageApi(
task.GetTaskIdAsString(),
"0",
timestamp,
signature)
r := callGetImageApi(
task.GetTaskIdAsString(),
"0",
timestamp,
signature)

v1.AssertValidationErrorResponse(t, r, "signature", "Signer not allowed")
v1.AssertValidationErrorResponse(t, r, "signature", "Signer not allowed")

t.Cleanup(func() {
tests.ClearDB()
if err := tests.ClearDataFolders(); err != nil {
t.Error(err)
}
})
}
t.Cleanup(func() {
tests.ClearDB()
if err := tests.ClearDataFolders(); err != nil {
t.Error(err)
}
})
}

func TestGetImage(t *testing.T) {
for _, taskType := range tests.TaskTypes {
addresses, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "prepare accounts error")
func TestUnauthorizedGetGPTResponse(t *testing.T) {
addresses, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "prepare accounts error")

_, task, err := tests.PrepareResultUploadedTask(taskType, addresses, config.GetDB())
assert.Equal(t, nil, err, "prepare task error")
_, task, err := tests.PrepareResultUploadedTask(models.TaskTypeLLM, addresses, config.GetDB())
assert.Equal(t, nil, err, "prepare task error")

var imageNum, srcFile, dstFile string
getResultInput := &inference_tasks.GetGPTResultInput{
TaskId: task.TaskId,
}

if taskType == models.TaskTypeSD {
imageNum = "2"
srcFile = "2.png"
dstFile = "downloaded.png"
} else {
imageNum = "0"
srcFile = "0.json"
dstFile = "downloaded.json"
timestamp, signature, err := v1.SignData(getResultInput, privateKeys[1])
assert.Equal(t, nil, err, "sign data error")

r := callGetGPTResponseApi(
task.GetTaskIdAsString(),
timestamp,
signature)

v1.AssertValidationErrorResponse(t, r, "signature", "Signer not allowed")

t.Cleanup(func() {
tests.ClearDB()
if err := tests.ClearDataFolders(); err != nil {
t.Error(err)
}
})
}

getResultInput := &inference_tasks.GetResultInput{
TaskId: task.TaskId,
ImageNum: imageNum,
func TestGetImage(t *testing.T) {
t.Cleanup(func() {
tests.ClearDB()
if err := tests.ClearDataFolders(); err != nil {
t.Error(err)
}
})

addresses, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "prepare accounts error")

_, task, err := tests.PrepareResultUploadedTask(models.TaskTypeSD, addresses, config.GetDB())
assert.Equal(t, nil, err, "prepare task error")

imageNum := "2"
srcFile := "2.png"
dstFile := "downloaded.png"

getResultInput := &inference_tasks.GetSDResultInput{
TaskId: task.TaskId,
ImageNum: imageNum,
}

timestamp, signature, err := v1.SignData(getResultInput, privateKeys[0])
assert.Equal(t, nil, err, "sign data error")

r := callGetImageApi(
task.GetTaskIdAsString(),
imageNum,
timestamp,
signature)

assert.Equal(t, 200, r.Code, "wrong http status code. message: "+r.Body.String())

timestamp, signature, err := v1.SignData(getResultInput, privateKeys[0])
assert.Equal(t, nil, err, "sign data error")
appConfig := config.GetConfig()
imageFolder := filepath.Join(
appConfig.DataDir.InferenceTasks,
task.GetTaskIdAsString(),
"results",
)

r := callGetImageApi(
task.GetTaskIdAsString(),
imageNum,
timestamp,
signature)
out, err := os.Create(filepath.Join(imageFolder, dstFile))
assert.Equal(t, nil, err, "create tmp file error")

assert.Equal(t, 200, r.Code, "wrong http status code. message: "+string(r.Body.Bytes()))
_, err = io.Copy(out, r.Body)
assert.Equal(t, nil, err, "write tmp file error")

appConfig := config.GetConfig()
imageFolder := filepath.Join(
appConfig.DataDir.InferenceTasks,
task.GetTaskIdAsString(),
"results",
)
err = out.Close()
assert.Equal(t, nil, err, "close tmp file error")

out, err := os.Create(filepath.Join(imageFolder, dstFile))
assert.Equal(t, nil, err, "create tmp file error")
originalFile, err := os.Stat(filepath.Join(imageFolder, srcFile))
assert.Equal(t, nil, err, "read original file error")

_, err = io.Copy(out, r.Body)
assert.Equal(t, nil, err, "write tmp file error")
downloadedFile, err := os.Stat(filepath.Join(imageFolder, dstFile))
assert.Equal(t, nil, err, "read downloaded file error")

err = out.Close()
assert.Equal(t, nil, err, "close tmp file error")
assert.Equal(t, originalFile.Size(), downloadedFile.Size(), "different file sizes")
}

originalFile, err := os.Stat(filepath.Join(imageFolder, srcFile))
assert.Equal(t, nil, err, "read original file error")
func TestGetGPTResponse(t *testing.T) {
t.Cleanup(func() {
tests.ClearDB()
if err := tests.ClearDataFolders(); err != nil {
t.Error(err)
}
})

downloadedFile, err := os.Stat(filepath.Join(imageFolder, dstFile))
assert.Equal(t, nil, err, "read downloaded file error")
addresses, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "prepare accounts error")

assert.Equal(t, originalFile.Size(), downloadedFile.Size(), "different file sizes")
_, task, err := tests.PrepareResultUploadedTask(models.TaskTypeLLM, addresses, config.GetDB())
assert.Equal(t, nil, err, "prepare task error")

t.Cleanup(func() {
tests.ClearDB()
if err := tests.ClearDataFolders(); err != nil {
t.Error(err)
}
})
getResultInput := &inference_tasks.GetGPTResultInput{
TaskId: task.TaskId,
}

timestamp, signature, err := v1.SignData(getResultInput, privateKeys[0])
assert.Equal(t, nil, err, "sign data error")

r := callGetGPTResponseApi(
task.GetTaskIdAsString(),
timestamp,
signature)

assert.Equal(t, 200, r.Code, "wrong http status code. message: "+r.Body.String())

res := inference_tasks.GPTResultResponse{}
if err := json.Unmarshal(r.Body.Bytes(), &res); err != nil {
t.Error(err)
}
target := models.GPTTaskResponse{}
if err := json.Unmarshal([]byte(tests.GPTResponseStr), &target); err != nil {
t.Error(err)
}
assert.Equal(t, target, res.Data, "wrong returned gpt response")
}

func callGetImageApi(
Expand All @@ -125,7 +182,18 @@ func callGetImageApi(
timestamp int64,
signature string) *httptest.ResponseRecorder {

endpoint := "/v1/inference_tasks/" + taskIdStr + "/results/" + imageNum
endpoint := "/v1/inference_tasks/stable_diffusion/" + taskIdStr + "/results/" + imageNum
query := "?timestamp=" + strconv.FormatInt(timestamp, 10) + "&signature=" + signature

req, _ := http.NewRequest("GET", endpoint+query, nil)
w := httptest.NewRecorder()
tests.Application.ServeHTTP(w, req)

return w
}

func callGetGPTResponseApi(taskIdStr string, timestamp int64, signature string) *httptest.ResponseRecorder {
endpoint := "/v1/inference_tasks/gpt/" + taskIdStr + "/results"
query := "?timestamp=" + strconv.FormatInt(timestamp, 10) + "&signature=" + signature

req, _ := http.NewRequest("GET", endpoint+query, nil)
Expand Down
Loading