From 14a301985fb24a38b9259488ad7db117aac133f6 Mon Sep 17 00:00:00 2001 From: Romuald Rousseau Date: Tue, 27 Aug 2024 13:59:59 +0800 Subject: [PATCH] chore: Bring back code --- .../romualdrousseau/any2json/ModelDB.java | 4 +- .../romualdrousseau/any2json/ModelDB.java | 4 +- .../any2json/examples/Common.java | 12 +-- .../romualdrousseau/any2json/ModelDB.java | 4 +- .../any2json/classifier/NetTagClassifier.java | 90 ++++++++++--------- .../romualdrousseau/any2json/ModelDB.java | 4 +- .../romualdrousseau/any2json/ModelDB.java | 4 +- .../any2json/modeldata/JsonModelBuilder.java | 8 ++ 8 files changed, 76 insertions(+), 54 deletions(-) diff --git a/any2json-csv/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java b/any2json-csv/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java index 9c360f24..adccd327 100644 --- a/any2json-csv/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java +++ b/any2json-csv/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java @@ -3,11 +3,13 @@ import java.io.IOException; import java.net.URISyntaxException; +import com.github.romualdrousseau.any2json.modeldata.JsonModelBuilder; + public class ModelDB { public static Model createConnection(final String modelName) { try { - return new ModelBuilder() + return new JsonModelBuilder() .fromResource(new ModelDB().getClass(), String.format("/data/%s.json", modelName)) .build(); } catch (final URISyntaxException | IOException x) { diff --git a/any2json-dbf/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java b/any2json-dbf/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java index 9c360f24..adccd327 100644 --- a/any2json-dbf/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java +++ b/any2json-dbf/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java @@ -3,11 +3,13 @@ import java.io.IOException; import java.net.URISyntaxException; +import com.github.romualdrousseau.any2json.modeldata.JsonModelBuilder; + public class ModelDB { public static Model createConnection(final String modelName) { try { - return new ModelBuilder() + return new JsonModelBuilder() .fromResource(new ModelDB().getClass(), String.format("/data/%s.json", modelName)) .build(); } catch (final URISyntaxException | IOException x) { diff --git a/any2json-examples/src/main/java/com/github/romualdrousseau/any2json/examples/Common.java b/any2json-examples/src/main/java/com/github/romualdrousseau/any2json/examples/Common.java index 2e998a85..4129247b 100644 --- a/any2json-examples/src/main/java/com/github/romualdrousseau/any2json/examples/Common.java +++ b/any2json-examples/src/main/java/com/github/romualdrousseau/any2json/examples/Common.java @@ -9,7 +9,7 @@ import org.slf4j.LoggerFactory; import com.github.romualdrousseau.any2json.Header; -import com.github.romualdrousseau.any2json.ModelBuilder; +import com.github.romualdrousseau.any2json.modeldata.JsonModelBuilder; import com.github.romualdrousseau.any2json.Row; import com.github.romualdrousseau.any2json.Sheet; import com.github.romualdrousseau.any2json.event.BitmapGeneratedEvent; @@ -21,15 +21,15 @@ public class Common { private static final Logger LOGGER = LoggerFactory.getLogger(Common.class); private static final String REPO_BASE_URL = "https://raw.githubusercontent.com/RomualdRousseau/Any2Json-Models/main"; - public static ModelBuilder loadModelBuilder(final String modelName, final Class clazz) { - return new ModelBuilder().fromPath(Common.getResourcePath(String.format("/models/%s.json", modelName), clazz)); + public static JsonModelBuilder loadModelBuilder(final String modelName, final Class clazz) { + return new JsonModelBuilder().fromPath(Common.getResourcePath(String.format("/models/%s.json", modelName), clazz)); } - public static ModelBuilder loadModelBuilderFromGitHub(final String modelName) { + public static JsonModelBuilder loadModelBuilderFromGitHub(final String modelName) { try { LOGGER.info("Loaded model: " + modelName); - final var uri = String.format("%1$s/%2$s/%2$s.json", REPO_BASE_URL, modelName); - return new ModelBuilder().fromURI(uri); + final var url = String.format("%1$s/%2$s/%2$s.json", REPO_BASE_URL, modelName); + return new JsonModelBuilder().fromURL(url); } catch (final IOException | InterruptedException x) { throw new RuntimeException(x); } diff --git a/any2json-excel/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java b/any2json-excel/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java index 9c360f24..adccd327 100644 --- a/any2json-excel/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java +++ b/any2json-excel/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java @@ -3,11 +3,13 @@ import java.io.IOException; import java.net.URISyntaxException; +import com.github.romualdrousseau.any2json.modeldata.JsonModelBuilder; + public class ModelDB { public static Model createConnection(final String modelName) { try { - return new ModelBuilder() + return new JsonModelBuilder() .fromResource(new ModelDB().getClass(), String.format("/data/%s.json", modelName)) .build(); } catch (final URISyntaxException | IOException x) { diff --git a/any2json-net-classifier/src/main/java/com/github/romualdrousseau/any2json/classifier/NetTagClassifier.java b/any2json-net-classifier/src/main/java/com/github/romualdrousseau/any2json/classifier/NetTagClassifier.java index 188ffeaf..529b0d96 100644 --- a/any2json-net-classifier/src/main/java/com/github/romualdrousseau/any2json/classifier/NetTagClassifier.java +++ b/any2json-net-classifier/src/main/java/com/github/romualdrousseau/any2json/classifier/NetTagClassifier.java @@ -13,6 +13,7 @@ import java.util.Map; import java.util.Optional; import java.util.stream.Stream; +import java.util.stream.StreamSupport; import org.tensorflow.SavedModelBundle; import org.tensorflow.SessionFunction; @@ -20,14 +21,16 @@ import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.types.TFloat32; +import com.github.romualdrousseau.any2json.Header; import com.github.romualdrousseau.any2json.HeaderTag; import com.github.romualdrousseau.any2json.Model; +import com.github.romualdrousseau.any2json.Table; import com.github.romualdrousseau.any2json.TagClassifier; import com.github.romualdrousseau.any2json.util.Disk; import com.github.romualdrousseau.any2json.util.TempFile; +import com.github.romualdrousseau.shuju.types.Tensor; import com.github.romualdrousseau.shuju.commons.PythonManager; import com.github.romualdrousseau.shuju.json.JSON; -import com.github.romualdrousseau.shuju.types.Tensor; import com.github.romualdrousseau.shuju.preprocessing.Text; import com.github.romualdrousseau.shuju.preprocessing.hasher.VocabularyHasher; import com.github.romualdrousseau.shuju.preprocessing.tokenizer.NgramTokenizer; @@ -43,7 +46,6 @@ public class NetTagClassifier extends SimpleTagClassifier implements Trainable { private final List vocabulary; private final int ngrams; private final int wordMinSize; - private final List lexicon; private final Text.ITokenizer tokenizer; private final Text.IHasher hasher; private final boolean isModelTemp; @@ -54,14 +56,14 @@ public class NetTagClassifier extends SimpleTagClassifier implements Trainable { public NetTagClassifier(final List vocabulary, final int ngrams, final int wordMinSize, final List lexicon, final Optional modelPath) { + this.setLexicon(lexicon); this.vocabulary = vocabulary; this.ngrams = ngrams; this.wordMinSize = wordMinSize; - this.lexicon = lexicon; this.tokenizer = (ngrams == 0) - ? new ShingleTokenizer(this.lexicon, this.wordMinSize) + ? new ShingleTokenizer(this.getLexicon(), this.wordMinSize) : new NgramTokenizer(ngrams); this.hasher = new VocabularyHasher(this.vocabulary); this.isModelTemp = modelPath.filter(x -> x.toFile().exists()).isEmpty(); @@ -71,12 +73,12 @@ public NetTagClassifier(final List vocabulary, final int ngrams, final i public NetTagClassifier(final Model model, final TagClassifier.TagStyle tagStyle) { this( model.getData().getList("vocabulary"), - model.getData().getInt("ngrams"), - model.getData().getInt("wordMinSize"), + model.getData().get("ngrams").orElse(0), + model.getData().get("wordMinSize").orElse(2), model.getData().getList("lexicon"), - Optional.ofNullable(model.getAttributes().get("modelPath")).map(Path::of)); + Optional.ofNullable(model.getModelAttributes().get("modelPath")).map(Path::of)); - assert this.isModelTemp && model.getData().getString("model") != null : "model element must exist"; + assert this.isModelTemp && model.getData().get("model").isPresent() : "model element must exist"; this.setModel(model); this.setTagStyle(tagStyle); @@ -90,17 +92,21 @@ public void close() throws Exception { @Override public void updateModelData() { this.getModel().getData().setList("vocabulary", this.vocabulary); - this.getModel().getData().setInt("ngrams", this.ngrams); - this.getModel().getData().setInt("wordMinSize", this.wordMinSize); - this.getModel().getData().setList("lexicon", this.lexicon); + this.getModel().getData().set("ngrams", this.ngrams); + this.getModel().getData().set("wordMinSize", this.wordMinSize); + this.getModel().getData().setList("lexicon", this.getLexicon()); if (!this.isModelTemp && this.modelPath.isPresent()) { - this.getModel().getAttributes().put("modelPath", this.modelPath.get().toString()); - this.getModel().getData().setString("model", this.serializeModelML(this.modelPath.get())); + this.getModel().getModelAttributes().put("modelPath", this.modelPath.get().toString()); + this.getModel().getData().set("model", this.serializeModelML(this.modelPath.get())); } } @Override - public String predict(final String name, final List entities, final List context) { + public String predict(final Table table, final Header header) { + final var name = header.getName(); + final var entities = StreamSupport.stream(header.entities().spliterator(), false).toList(); + final var context = StreamSupport.stream(table.getHeaderNames().spliterator(), false).toList(); + if (!this.loadModelML()) { return HeaderTag.None.getValue(); } @@ -121,6 +127,32 @@ public String predict(final String name, final List entities, final List return this.getModel().getTagList().get((int) result.argmax(1).item(0)); } + @Override + public List getInputVector(final String name, final List entities, + final List context) { + final var part1 = Text.to_categorical(entities, this.getModel().getEntityList()); + final var part2 = Text.one_hot(name, this.getModel().getFilterList(), this.tokenizer, this.hasher); + final var part3 = context.stream() + .filter(x -> !x.equals(name)) + .flatMap(x -> Text.one_hot(x, this.getModel().getFilterList(), this.tokenizer, this.hasher).stream()) + .distinct().sorted().toList(); + return Stream.of( + Text.pad_sequence(part1, IN_ENTITY_SIZE).subList(0, IN_ENTITY_SIZE), + Text.pad_sequence(part2, IN_NAME_SIZE).subList(0, IN_NAME_SIZE), + Text.pad_sequence(part3, IN_CONTEXT_SIZE).subList(0, IN_CONTEXT_SIZE)) + .flatMap(Collection::stream) + .toList(); + } + + @Override + public List getOutputVector(final String label) { + return Text.pad_sequence(Text.to_categorical(label, this.getModel().getTagList()), OUT_TAG_SIZE); + } + + public List getVocabulary() { + return this.vocabulary; + } + public Process fit(final List trainingSet, final List validationSet) throws IOException, InterruptedException, URISyntaxException { this.closeModelML(); @@ -148,38 +180,10 @@ public Process fit(final List trainingSet, final List getVocabulary() { - return this.vocabulary; - } - - public List getLexicon() { - return this.lexicon; - } - - public List getInputVector(final String name, final List entities, - final List context) { - final var part1 = Text.to_categorical(entities, this.getModel().getEntityList()); - final var part2 = Text.one_hot(name, this.getModel().getFilters(), this.tokenizer, this.hasher); - final var part3 = context.stream() - .filter(x -> !x.equals(name)) - .flatMap(x -> Text.one_hot(x, this.getModel().getFilters(), this.tokenizer, this.hasher).stream()) - .distinct().sorted().toList(); - return Stream.of( - Text.pad_sequence(part1, IN_ENTITY_SIZE).subList(0, IN_ENTITY_SIZE), - Text.pad_sequence(part2, IN_NAME_SIZE).subList(0, IN_NAME_SIZE), - Text.pad_sequence(part3, IN_CONTEXT_SIZE).subList(0, IN_CONTEXT_SIZE)) - .flatMap(Collection::stream) - .toList(); - } - - public List getOutputVector(final String label) { - return Text.pad_sequence(Text.to_categorical(label, this.getModel().getTagList()), OUT_TAG_SIZE); - } - private boolean loadModelML() { try { if (this.modelPath.isEmpty()) { - final var modelString = this.getModel().getData().getString("model"); + final var modelString = this.getModel().getData().get("model").get(); this.modelPath = Optional.of(this.unserializeModelML(modelString)); } if (this.tagClassifierModel == null) { diff --git a/any2json-parquet/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java b/any2json-parquet/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java index 9c360f24..adccd327 100644 --- a/any2json-parquet/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java +++ b/any2json-parquet/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java @@ -3,11 +3,13 @@ import java.io.IOException; import java.net.URISyntaxException; +import com.github.romualdrousseau.any2json.modeldata.JsonModelBuilder; + public class ModelDB { public static Model createConnection(final String modelName) { try { - return new ModelBuilder() + return new JsonModelBuilder() .fromResource(new ModelDB().getClass(), String.format("/data/%s.json", modelName)) .build(); } catch (final URISyntaxException | IOException x) { diff --git a/any2json-pdf/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java b/any2json-pdf/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java index 9c360f24..adccd327 100644 --- a/any2json-pdf/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java +++ b/any2json-pdf/src/test/java/com/github/romualdrousseau/any2json/ModelDB.java @@ -3,11 +3,13 @@ import java.io.IOException; import java.net.URISyntaxException; +import com.github.romualdrousseau.any2json.modeldata.JsonModelBuilder; + public class ModelDB { public static Model createConnection(final String modelName) { try { - return new ModelBuilder() + return new JsonModelBuilder() .fromResource(new ModelDB().getClass(), String.format("/data/%s.json", modelName)) .build(); } catch (final URISyntaxException | IOException x) { diff --git a/any2json/src/main/java/com/github/romualdrousseau/any2json/modeldata/JsonModelBuilder.java b/any2json/src/main/java/com/github/romualdrousseau/any2json/modeldata/JsonModelBuilder.java index d81d54d5..677f824e 100644 --- a/any2json/src/main/java/com/github/romualdrousseau/any2json/modeldata/JsonModelBuilder.java +++ b/any2json/src/main/java/com/github/romualdrousseau/any2json/modeldata/JsonModelBuilder.java @@ -78,11 +78,19 @@ public JsonModelBuilder fromURL(final String url) throws IOException, Interrupte return this.fromModelData(new JsonModelData(JSON.objectOf(response.body()))); } + public List getEntityList() { + return this.entities; + } + public JsonModelBuilder setEntityList(final List entities) { this.entities = entities; return this; } + public Map getPatternMap() { + return this.patterns; + } + public JsonModelBuilder setPatternMap(final Map patterns) { this.patterns = patterns; return this;