Skip to content

Commit

Permalink
Create PredictiveAnalytics.java
Browse files Browse the repository at this point in the history
  • Loading branch information
KOSASIH authored Jul 25, 2024
1 parent ca3c607 commit f831f3f
Showing 1 changed file with 40 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package com.sidra;

import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class PredictiveAnalytics {
public static void main(String[] args) {
// Set up a neural network
NeuralNetConfiguration config = new NeuralNetConfiguration.Builder()
.seed(42)
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(0.01))
.list()
.layer(new DenseLayer.Builder()
.nIn(784)
.nOut(256)
.activation(Activation.RELU)
.build())
.layer(new DenseLayer.Builder()
.nIn(256)
.nOut(10)
.activation(Activation.SOFTMAX)
.build())
.pretrain(false)
.backprop(true)
.build();

MultiLayerNetwork model = new MultiLayerNetwork(config);
model.init();

// Train the model
DataSetIterator iterator = new DataSetIterator();
model.fit(iterator);
}
}

0 comments on commit f831f3f

Please sign in to comment.