Skip to content
This repository has been archived by the owner on Apr 2, 2024. It is now read-only.

Commit

Permalink
Add a flag to specify the intents file, fixes #70
Browse files Browse the repository at this point in the history
  • Loading branch information
hugolgst committed Mar 25, 2020
1 parent f4fbf88 commit 5a82b39
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 24 deletions.
8 changes: 4 additions & 4 deletions analysis/intents.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ type Document struct {
}

// SerializeIntents returns a list of intents retrieved from `res/intents.json`
func SerializeIntents() []Intent {
func SerializeIntents(intentsPath string) []Intent {
var intents []Intent

err := json.Unmarshal(util.ReadFile("res/intents.json"), &intents)
err := json.Unmarshal(util.ReadFile(intentsPath), &intents)
if err != nil {
panic(err)
}
Expand All @@ -54,9 +54,9 @@ func SerializeModulesIntents() []Intent {

// Organize intents with an array of all words, an array with a representative word of each tag
// and an array of Documents which contains a word list associated with a tag
func Organize() (words, classes []string, documents []Document) {
func Organize(intentsPath string) (words, classes []string, documents []Document) {
// Append the modules intents to the intents from res/intents.json
intents := append(SerializeIntents(), SerializeModulesIntents()...)
intents := append(SerializeIntents(intentsPath), SerializeModulesIntents()...)

for _, intent := range intents {
for _, pattern := range intent.Patterns {
Expand Down
14 changes: 7 additions & 7 deletions analysis/sentence.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ func NewSentence(content string) (sentence Sentence) {
}

// PredictTag classifies the sentence with the model
func (sentence Sentence) PredictTag(neuralNetwork network.Network) string {
words, classes, _ := Organize()
func (sentence Sentence) PredictTag(neuralNetwork network.Network, intentsPath string) string {
words, classes, _ := Organize(intentsPath)

// Predict with the model
predict := neuralNetwork.Predict(sentence.WordsBag(words))
Expand All @@ -61,13 +61,13 @@ func (sentence Sentence) PredictTag(neuralNetwork network.Network) string {

// RandomizeResponse takes the entry message, the response tag and the token and returns a random
// message from res/intents.json where the triggers are applied
func RandomizeResponse(entry string, tag string, token string) (string, string) {
func RandomizeResponse(intentsPath, entry, tag, token string) (string, string) {
if tag == DontUnderstand {
return DontUnderstand, util.GetMessage(tag)
}

// Append the modules intents to the intents from res/intents.json
intents := append(SerializeIntents(), SerializeModulesIntents()...)
intents := append(SerializeIntents(intentsPath), SerializeModulesIntents()...)

for _, intent := range intents {
if intent.Tag != tag {
Expand Down Expand Up @@ -97,16 +97,16 @@ func RandomizeResponse(entry string, tag string, token string) (string, string)
}

// Calculate send the sentence content to the neural network and returns a response with the matching tag
func (sentence Sentence) Calculate(cache gocache.Cache, neuralNetwork network.Network, token string) (string, string) {
func (sentence Sentence) Calculate(cache gocache.Cache, neuralNetwork network.Network, intentsPath, token string) (string, string) {
tag, found := cache.Get(sentence.Content)

// Predict tag with the neural network if the sentence isn't in the cache
if !found {
tag = sentence.PredictTag(neuralNetwork)
tag = sentence.PredictTag(neuralNetwork, intentsPath)
cache.Set(sentence.Content, tag, gocache.DefaultExpiration)
}

return RandomizeResponse(sentence.Content, tag.(string), token)
return RandomizeResponse(intentsPath, sentence.Content, tag.(string), token)
}

// LogResults print in the console the sentence and its tags sorted by prediction
Expand Down
21 changes: 16 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
package main

import (
"flag"
"fmt"
"os"

"github.com/gookit/color"

"github.com/olivia-ai/olivia/network"

"github.com/olivia-ai/olivia/server"
"github.com/olivia-ai/olivia/training"
)

var (
// Initialize the neural network by training it
neuralNetwork = training.CreateNeuralNetwork()
)
var neuralNetwork network.Network

func main() {
intentsPath := flag.String("intents", "res/intents.json", "The path for intents file.")
flag.Parse()

magenta := color.FgMagenta.Render
fmt.Printf("Using %s as intents file.\n", magenta(*intentsPath))

neuralNetwork = training.CreateNeuralNetwork(*intentsPath)

port := "8080"
// Get port from environment variables if there is
if os.Getenv("PORT") != "" {
port = os.Getenv("PORT")
}

// Serves the server
server.Serve(neuralNetwork, port)
server.Serve(neuralNetwork, port, *intentsPath)
}
2 changes: 1 addition & 1 deletion res/training.json

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions server/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ import (
var (
// Create the neural network variable to use it everywhere
neuralNetwork network.Network
// Initiatizes the cache with a 5 minute lifetime
// Initializes the cache with a 5 minute lifetime
cache = gocache.New(5*time.Minute, 5*time.Minute)
// Set the intents file path
intentsPath string
)

// Serve serves the server in the given port
func Serve(_neuralNetwork network.Network, port string) {
func Serve(_neuralNetwork network.Network, port, _intentsPath string) {
// Set the current global network as a global variable
neuralNetwork = _neuralNetwork
intentsPath = _intentsPath

// Initializes the router
router := mux.NewRouter()
Expand Down
2 changes: 1 addition & 1 deletion server/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func Reply(request RequestMessage) []byte {
} else {
responseTag, responseSentence = analysis.NewSentence(
request.Content,
).Calculate(*cache, neuralNetwork, request.Token)
).Calculate(*cache, neuralNetwork, intentsPath, request.Token)
}

// Marshall the response in json
Expand Down
8 changes: 4 additions & 4 deletions training/training.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
)

// TrainData returns the inputs and outputs for the neural network
func TrainData() (inputs, outputs [][]float64) {
words, classes, documents := analysis.Organize()
func TrainData(intentsPath string) (inputs, outputs [][]float64) {
words, classes, documents := analysis.Organize(intentsPath)

for _, document := range documents {
outputRow := make([]float64, len(classes))
Expand All @@ -30,14 +30,14 @@ func TrainData() (inputs, outputs [][]float64) {

// CreateNeuralNetwork returns a new neural network which is loaded from res/training.json or
// trained from TrainData() inputs and targets.
func CreateNeuralNetwork() (neuralNetwork network.Network) {
func CreateNeuralNetwork(intentsPath string) (neuralNetwork network.Network) {
// Decide if the network is created by the save or is a new one
saveFile := "res/training.json"

_, err := os.Open(saveFile)
// Train the model if there is no training file
if err != nil {
inputs, outputs := TrainData()
inputs, outputs := TrainData(intentsPath)

neuralNetwork = network.CreateNetwork(0.1, inputs, outputs, 50)
neuralNetwork.Train(1000)
Expand Down

0 comments on commit 5a82b39

Please sign in to comment.