Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for mongodb and added codec for object class #340

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,12 @@
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
<version>1.18.30</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-mongodb</artifactId>
</dependency>
</dependencies>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public static class TaskStore {
/**
* 任务存储方式: redis(默认)、in_memory.
*/
private Type type = Type.IN_MEMORY;
private Type type = Type.MONGODB;

public enum Type {
/**
Expand All @@ -157,7 +157,11 @@ public enum Type {
/**
* in_memory.
*/
IN_MEMORY
IN_MEMORY,
/**
* mongodb
*/
MONGODB
}
}

Expand Down
152 changes: 152 additions & 0 deletions src/main/java/com/github/novicezk/midjourney/codecs/ObjectCodec.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package com.github.novicezk.midjourney.codecs;

import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

import org.bson.BsonReader;
import org.bson.BsonType;
import org.bson.BsonWriter;
import org.bson.codecs.Codec;
import org.bson.codecs.DecoderContext;
import org.bson.codecs.EncoderContext;
import org.bson.types.Decimal128;

public class ObjectCodec implements Codec<Object> {

@Override
public void encode(BsonWriter writer, Object value, EncoderContext encoderContext) {
writeValue(writer, value);
}

@Override
public Class<Object> getEncoderClass() {
return Object.class;
}

@Override
public Object decode(BsonReader reader, DecoderContext decoderContext) {
return readValue(reader);
}

private static Object readValue(BsonReader bsonReader) {
var type = bsonReader.getCurrentBsonType();
switch (type) {
case ARRAY:
ArrayList<Object> array = new ArrayList<>();
bsonReader.readStartArray();
while (bsonReader.readBsonType() != BsonType.END_OF_DOCUMENT) {
array.add(readValue(bsonReader));
}
bsonReader.readEndArray();
return array;
case BINARY:
return bsonReader.readBinaryData();
case BOOLEAN:
return bsonReader.readBoolean();
case DATE_TIME:
return bsonReader.readDateTime();
case DB_POINTER:
return bsonReader.readDBPointer();
case DECIMAL128:
return bsonReader.readDecimal128();
case DOCUMENT:
HashMap<String, Object> nestedMap = new HashMap<>();
bsonReader.readStartDocument();
while (bsonReader.readBsonType() != BsonType.END_OF_DOCUMENT) {
nestedMap.put(bsonReader.readName(), readValue(bsonReader));
}
bsonReader.readEndDocument();
return nestedMap;
case DOUBLE:
return bsonReader.readDouble();
case INT32:
return bsonReader.readInt32();
case INT64:
return bsonReader.readInt64();
case NULL:
return null;
case OBJECT_ID:
return bsonReader.readObjectId();
case STRING:
return bsonReader.readString();
case TIMESTAMP:
return bsonReader.readTimestamp();
case UNDEFINED:
return null;
default:
return null;

}
}

private static void writeValue(BsonWriter bsonWriter, Object value) {
if (value instanceof String) {
bsonWriter.writeString(value.toString());
} else if (value instanceof Integer) {
bsonWriter.writeInt32((Integer) value);
} else if (value instanceof Long) {
bsonWriter.writeInt64((Long) value);
} else if (value instanceof BigDecimal) {
bsonWriter.writeDecimal128(Decimal128.parse(value.toString()));
} else if (value instanceof Double) {
bsonWriter.writeDouble((Double) value);
} else if (value instanceof Boolean) {
bsonWriter.writeBoolean((Boolean) value);
} else if (value instanceof HashMap) {
// Recursively handle HashMap for nesting
bsonWriter.writeStartDocument();
HashMap<?, ?> nestedMap = (HashMap<?, ?>) value;
for (Map.Entry<?, ?> entry : nestedMap.entrySet()) {
String key = entry.getKey().toString();
Object nestedValue = entry.getValue();
bsonWriter.writeName(key);
writeValue(bsonWriter, nestedValue);
}
bsonWriter.writeEndDocument();
} else if (value instanceof ArrayList) {
ArrayList<?> arrayList = (ArrayList<?>) value;
bsonWriter.writeStartArray();
for (Object item : arrayList) {
writeValue(bsonWriter, item);
}
bsonWriter.writeEndArray();
} else if (value.getClass().isArray()) {
bsonWriter.writeStartArray();
int length = java.lang.reflect.Array.getLength(value);
for (int i = 0; i < length; i++) {
Object item = java.lang.reflect.Array.get(value, i);
writeValue(bsonWriter, item);
}
bsonWriter.writeEndArray();
} else {
try {
Class<?> clazz = value.getClass();
bsonWriter.writeStartDocument();
while (clazz != null) {
Field[] fields = clazz.getDeclaredFields();
for (Field field : fields) {
int modifiers = field.getModifiers();
if (!java.lang.reflect.Modifier.isFinal(modifiers)) {
field.setAccessible(true);
String fieldName = field.getName();
if (fieldName.equals("id")) {
fieldName = "_id";
}
Object fieldValue = field.get(value);
bsonWriter.writeName(fieldName);
writeValue(bsonWriter, fieldValue);
}
}
clazz = clazz.getSuperclass();
}
bsonWriter.writeEndDocument();

} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import io.swagger.annotations.ApiModelProperty;
import lombok.Getter;
import lombok.Setter;
import org.springframework.data.annotation.Id;
import org.springframework.data.annotation.Transient;

import java.io.Serializable;
import java.util.HashMap;
Expand All @@ -15,12 +17,14 @@ public class DomainObject implements Serializable {
@Getter
@Setter
@ApiModelProperty("ID")
@Id
protected String id;

@Setter
protected Map<String, Object> properties; // 扩展属性,仅支持基本类型

@JsonIgnore
@Transient
private final transient Object lock = new Object();

public void sleep() throws InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.github.novicezk.midjourney.service.store;

import com.github.novicezk.midjourney.service.TaskStoreService;
import com.github.novicezk.midjourney.support.Task;
import com.github.novicezk.midjourney.support.TaskCondition;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.Query;

import java.util.List;


public class MongoTaskStoreServiceImpl implements TaskStoreService {

private final MongoTemplate mongoTemplate;

public MongoTaskStoreServiceImpl(MongoTemplate mongoTemplate){
this.mongoTemplate = mongoTemplate;
}


@Override
public void save(Task task) {
mongoTemplate.save(task);
}

@Override
public void delete(String id) {
Query query = new Query(Criteria.where("_id").is(id));
mongoTemplate.remove(query,Task.class);
}

@Override
public Task get(String id) {
Query query = new Query(Criteria.where("_id").is(id));
return mongoTemplate.findOne(query,Task.class);
}

@Override
public List<Task> list() {
return mongoTemplate.findAll(Task.class);
}

@Override
public List<Task> list(TaskCondition condition) {
return list().stream().filter(condition).toList();
}

@Override
public Task findOne(TaskCondition condition) {
return list().stream().filter(condition).findFirst().orElse(null);
}
}
89 changes: 47 additions & 42 deletions src/main/java/com/github/novicezk/midjourney/support/Task.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,62 +7,67 @@
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.springframework.data.annotation.Id;
import org.springframework.data.mongodb.core.mapping.Document;

import java.io.Serial;
import java.util.LinkedList;
import java.util.List;

@Data
@EqualsAndHashCode(callSuper = true)
@ApiModel("任务")
@Document(collection = "discordTask")
public class Task extends DomainObject {
@Serial
private static final long serialVersionUID = -674915748204390789L;
@Serial
private static final long serialVersionUID = -674915748204390789L;

@ApiModelProperty("任务类型")
private TaskAction action;
@ApiModelProperty("任务状态")
private TaskStatus status = TaskStatus.NOT_START;
@ApiModelProperty("任务类型")
private TaskAction action;
@ApiModelProperty("任务状态")
private TaskStatus status = TaskStatus.NOT_START;

@ApiModelProperty("提示词")
private String prompt;
@ApiModelProperty("提示词-英文")
private String promptEn;
@ApiModelProperty("提示词")
private String prompt;
@ApiModelProperty("提示词-英文")
private String promptEn;

@ApiModelProperty("任务描述")
private String description;
@ApiModelProperty("自定义参数")
private String state;
@ApiModelProperty("任务描述")
private String description;
@ApiModelProperty("自定义参数")
private String state;

@ApiModelProperty("提交时间")
private Long submitTime;
@ApiModelProperty("开始执行时间")
private Long startTime;
@ApiModelProperty("结束时间")
private Long finishTime;
@ApiModelProperty("提交时间")
private Long submitTime;
@ApiModelProperty("开始执行时间")
private Long startTime;
@ApiModelProperty("结束时间")
private Long finishTime;

@ApiModelProperty("图片url")
private String imageUrl;
@ApiModelProperty("图片url")
private String imageUrl;

@ApiModelProperty("任务进度")
private String progress;
@ApiModelProperty("失败原因")
private String failReason;
@ApiModelProperty("任务进度")
private String progress;
@ApiModelProperty("失败原因")
private String failReason;

public void start() {
this.startTime = System.currentTimeMillis();
this.status = TaskStatus.SUBMITTED;
this.progress = "0%";
}
public void start() {
this.startTime = System.currentTimeMillis();
this.status = TaskStatus.SUBMITTED;
this.progress = "0%";
}

public void success() {
this.finishTime = System.currentTimeMillis();
this.status = TaskStatus.SUCCESS;
this.progress = "100%";
}
public void success() {
this.finishTime = System.currentTimeMillis();
this.status = TaskStatus.SUCCESS;
this.progress = "100%";
}

public void fail(String reason) {
this.finishTime = System.currentTimeMillis();
this.status = TaskStatus.FAILURE;
this.failReason = reason;
this.progress = "";
}
public void fail(String reason) {
this.finishTime = System.currentTimeMillis();
this.status = TaskStatus.FAILURE;
this.failReason = reason;
this.progress = "";
}
}
Loading