Skip to content

Commit

Permalink
chore: Bring back code
Browse files Browse the repository at this point in the history
  • Loading branch information
Romuald Rousseau committed Aug 27, 2024
1 parent 66349be commit 14a3019
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import java.io.IOException;
import java.net.URISyntaxException;

import com.github.romualdrousseau.any2json.modeldata.JsonModelBuilder;

public class ModelDB {

public static Model createConnection(final String modelName) {
try {
return new ModelBuilder()
return new JsonModelBuilder()
.fromResource(new ModelDB().getClass(), String.format("/data/%s.json", modelName))
.build();
} catch (final URISyntaxException | IOException x) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import java.io.IOException;
import java.net.URISyntaxException;

import com.github.romualdrousseau.any2json.modeldata.JsonModelBuilder;

public class ModelDB {

public static Model createConnection(final String modelName) {
try {
return new ModelBuilder()
return new JsonModelBuilder()
.fromResource(new ModelDB().getClass(), String.format("/data/%s.json", modelName))
.build();
} catch (final URISyntaxException | IOException x) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import org.slf4j.LoggerFactory;

import com.github.romualdrousseau.any2json.Header;
import com.github.romualdrousseau.any2json.ModelBuilder;
import com.github.romualdrousseau.any2json.modeldata.JsonModelBuilder;
import com.github.romualdrousseau.any2json.Row;
import com.github.romualdrousseau.any2json.Sheet;
import com.github.romualdrousseau.any2json.event.BitmapGeneratedEvent;
Expand All @@ -21,15 +21,15 @@ public class Common {
private static final Logger LOGGER = LoggerFactory.getLogger(Common.class);
private static final String REPO_BASE_URL = "https://raw.githubusercontent.com/RomualdRousseau/Any2Json-Models/main";

public static <T> ModelBuilder loadModelBuilder(final String modelName, final Class<T> clazz) {
return new ModelBuilder().fromPath(Common.getResourcePath(String.format("/models/%s.json", modelName), clazz));
public static <T> JsonModelBuilder loadModelBuilder(final String modelName, final Class<T> clazz) {
return new JsonModelBuilder().fromPath(Common.getResourcePath(String.format("/models/%s.json", modelName), clazz));
}

public static ModelBuilder loadModelBuilderFromGitHub(final String modelName) {
public static JsonModelBuilder loadModelBuilderFromGitHub(final String modelName) {
try {
LOGGER.info("Loaded model: " + modelName);
final var uri = String.format("%1$s/%2$s/%2$s.json", REPO_BASE_URL, modelName);
return new ModelBuilder().fromURI(uri);
final var url = String.format("%1$s/%2$s/%2$s.json", REPO_BASE_URL, modelName);
return new JsonModelBuilder().fromURL(url);
} catch (final IOException | InterruptedException x) {
throw new RuntimeException(x);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import java.io.IOException;
import java.net.URISyntaxException;

import com.github.romualdrousseau.any2json.modeldata.JsonModelBuilder;

public class ModelDB {

public static Model createConnection(final String modelName) {
try {
return new ModelBuilder()
return new JsonModelBuilder()
.fromResource(new ModelDB().getClass(), String.format("/data/%s.json", modelName))
.build();
} catch (final URISyntaxException | IOException x) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,24 @@
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import org.tensorflow.SavedModelBundle;
import org.tensorflow.SessionFunction;
import org.tensorflow.Signature;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.types.TFloat32;

import com.github.romualdrousseau.any2json.Header;
import com.github.romualdrousseau.any2json.HeaderTag;
import com.github.romualdrousseau.any2json.Model;
import com.github.romualdrousseau.any2json.Table;
import com.github.romualdrousseau.any2json.TagClassifier;
import com.github.romualdrousseau.any2json.util.Disk;
import com.github.romualdrousseau.any2json.util.TempFile;
import com.github.romualdrousseau.shuju.types.Tensor;
import com.github.romualdrousseau.shuju.commons.PythonManager;
import com.github.romualdrousseau.shuju.json.JSON;
import com.github.romualdrousseau.shuju.types.Tensor;
import com.github.romualdrousseau.shuju.preprocessing.Text;
import com.github.romualdrousseau.shuju.preprocessing.hasher.VocabularyHasher;
import com.github.romualdrousseau.shuju.preprocessing.tokenizer.NgramTokenizer;
Expand All @@ -43,7 +46,6 @@ public class NetTagClassifier extends SimpleTagClassifier implements Trainable {
private final List<String> vocabulary;
private final int ngrams;
private final int wordMinSize;
private final List<String> lexicon;
private final Text.ITokenizer tokenizer;
private final Text.IHasher hasher;
private final boolean isModelTemp;
Expand All @@ -54,14 +56,14 @@ public class NetTagClassifier extends SimpleTagClassifier implements Trainable {

public NetTagClassifier(final List<String> vocabulary, final int ngrams, final int wordMinSize,
final List<String> lexicon, final Optional<Path> modelPath) {
this.setLexicon(lexicon);

this.vocabulary = vocabulary;
this.ngrams = ngrams;
this.wordMinSize = wordMinSize;
this.lexicon = lexicon;

this.tokenizer = (ngrams == 0)
? new ShingleTokenizer(this.lexicon, this.wordMinSize)
? new ShingleTokenizer(this.getLexicon(), this.wordMinSize)
: new NgramTokenizer(ngrams);
this.hasher = new VocabularyHasher(this.vocabulary);
this.isModelTemp = modelPath.filter(x -> x.toFile().exists()).isEmpty();
Expand All @@ -71,12 +73,12 @@ public NetTagClassifier(final List<String> vocabulary, final int ngrams, final i
public NetTagClassifier(final Model model, final TagClassifier.TagStyle tagStyle) {
this(
model.getData().getList("vocabulary"),
model.getData().getInt("ngrams"),
model.getData().getInt("wordMinSize"),
model.getData().<Integer>get("ngrams").orElse(0),
model.getData().<Integer>get("wordMinSize").orElse(2),
model.getData().getList("lexicon"),
Optional.ofNullable(model.getAttributes().get("modelPath")).map(Path::of));
Optional.ofNullable(model.getModelAttributes().get("modelPath")).map(Path::of));

assert this.isModelTemp && model.getData().getString("model") != null : "model element must exist";
assert this.isModelTemp && model.getData().<String>get("model").isPresent() : "model element must exist";

this.setModel(model);
this.setTagStyle(tagStyle);
Expand All @@ -90,17 +92,21 @@ public void close() throws Exception {
@Override
public void updateModelData() {
this.getModel().getData().setList("vocabulary", this.vocabulary);
this.getModel().getData().setInt("ngrams", this.ngrams);
this.getModel().getData().setInt("wordMinSize", this.wordMinSize);
this.getModel().getData().setList("lexicon", this.lexicon);
this.getModel().getData().set("ngrams", this.ngrams);
this.getModel().getData().set("wordMinSize", this.wordMinSize);
this.getModel().getData().setList("lexicon", this.getLexicon());
if (!this.isModelTemp && this.modelPath.isPresent()) {
this.getModel().getAttributes().put("modelPath", this.modelPath.get().toString());
this.getModel().getData().setString("model", this.serializeModelML(this.modelPath.get()));
this.getModel().getModelAttributes().put("modelPath", this.modelPath.get().toString());
this.getModel().getData().set("model", this.serializeModelML(this.modelPath.get()));
}
}

@Override
public String predict(final String name, final List<String> entities, final List<String> context) {
public String predict(final Table table, final Header header) {
final var name = header.getName();
final var entities = StreamSupport.stream(header.entities().spliterator(), false).toList();
final var context = StreamSupport.stream(table.getHeaderNames().spliterator(), false).toList();

if (!this.loadModelML()) {
return HeaderTag.None.getValue();
}
Expand All @@ -121,6 +127,32 @@ public String predict(final String name, final List<String> entities, final List
return this.getModel().getTagList().get((int) result.argmax(1).item(0));
}

@Override
public List<Integer> getInputVector(final String name, final List<String> entities,
final List<String> context) {
final var part1 = Text.to_categorical(entities, this.getModel().getEntityList());
final var part2 = Text.one_hot(name, this.getModel().getFilterList(), this.tokenizer, this.hasher);
final var part3 = context.stream()
.filter(x -> !x.equals(name))
.flatMap(x -> Text.one_hot(x, this.getModel().getFilterList(), this.tokenizer, this.hasher).stream())
.distinct().sorted().toList();
return Stream.of(
Text.pad_sequence(part1, IN_ENTITY_SIZE).subList(0, IN_ENTITY_SIZE),
Text.pad_sequence(part2, IN_NAME_SIZE).subList(0, IN_NAME_SIZE),
Text.pad_sequence(part3, IN_CONTEXT_SIZE).subList(0, IN_CONTEXT_SIZE))
.flatMap(Collection::stream)
.toList();
}

@Override
public List<Integer> getOutputVector(final String label) {
return Text.pad_sequence(Text.to_categorical(label, this.getModel().getTagList()), OUT_TAG_SIZE);
}

public List<String> getVocabulary() {
return this.vocabulary;
}

public Process fit(final List<TrainingEntry> trainingSet, final List<TrainingEntry> validationSet)
throws IOException, InterruptedException, URISyntaxException {
this.closeModelML();
Expand Down Expand Up @@ -148,38 +180,10 @@ public Process fit(final List<TrainingEntry> trainingSet, final List<TrainingEnt
"-m " + this.modelPath.get());
}

public List<String> getVocabulary() {
return this.vocabulary;
}

public List<String> getLexicon() {
return this.lexicon;
}

public List<Integer> getInputVector(final String name, final List<String> entities,
final List<String> context) {
final var part1 = Text.to_categorical(entities, this.getModel().getEntityList());
final var part2 = Text.one_hot(name, this.getModel().getFilters(), this.tokenizer, this.hasher);
final var part3 = context.stream()
.filter(x -> !x.equals(name))
.flatMap(x -> Text.one_hot(x, this.getModel().getFilters(), this.tokenizer, this.hasher).stream())
.distinct().sorted().toList();
return Stream.of(
Text.pad_sequence(part1, IN_ENTITY_SIZE).subList(0, IN_ENTITY_SIZE),
Text.pad_sequence(part2, IN_NAME_SIZE).subList(0, IN_NAME_SIZE),
Text.pad_sequence(part3, IN_CONTEXT_SIZE).subList(0, IN_CONTEXT_SIZE))
.flatMap(Collection::stream)
.toList();
}

public List<Integer> getOutputVector(final String label) {
return Text.pad_sequence(Text.to_categorical(label, this.getModel().getTagList()), OUT_TAG_SIZE);
}

private boolean loadModelML() {
try {
if (this.modelPath.isEmpty()) {
final var modelString = this.getModel().getData().getString("model");
final var modelString = this.getModel().getData().<String>get("model").get();
this.modelPath = Optional.of(this.unserializeModelML(modelString));
}
if (this.tagClassifierModel == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import java.io.IOException;
import java.net.URISyntaxException;

import com.github.romualdrousseau.any2json.modeldata.JsonModelBuilder;

public class ModelDB {

public static Model createConnection(final String modelName) {
try {
return new ModelBuilder()
return new JsonModelBuilder()
.fromResource(new ModelDB().getClass(), String.format("/data/%s.json", modelName))
.build();
} catch (final URISyntaxException | IOException x) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import java.io.IOException;
import java.net.URISyntaxException;

import com.github.romualdrousseau.any2json.modeldata.JsonModelBuilder;

public class ModelDB {

public static Model createConnection(final String modelName) {
try {
return new ModelBuilder()
return new JsonModelBuilder()
.fromResource(new ModelDB().getClass(), String.format("/data/%s.json", modelName))
.build();
} catch (final URISyntaxException | IOException x) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,19 @@ public JsonModelBuilder fromURL(final String url) throws IOException, Interrupte
return this.fromModelData(new JsonModelData(JSON.objectOf(response.body())));
}

public List<String> getEntityList() {
return this.entities;
}

public JsonModelBuilder setEntityList(final List<String> entities) {
this.entities = entities;
return this;
}

public Map<String, String> getPatternMap() {
return this.patterns;
}

public JsonModelBuilder setPatternMap(final Map<String, String> patterns) {
this.patterns = patterns;
return this;
Expand Down

0 comments on commit 14a3019

Please sign in to comment.