-
Notifications
You must be signed in to change notification settings - Fork 22
TensorLog with TensorFlow: Getting Started
If you're used to writing learners in Tensorflow you probably want to use Tensorlog's ability to compile logic into Tensorflow operations. Here's an example of how to do that. The source for this example is in Tensorlog/datasets/grid/demo.py if you want to see it in full.
As an example we'll use a simple task from the Tensorlog paper: navigation on a grid. The database defines a 16-by-16 grid, with each cell named according to its x,y position: for instance, the string 3,5
is the cell in row 3 and column 5. Each cell is connected to its 8 closest neighbors and also to itself by the edge
relation. We put a small initial weight on these edges---here, a uniform weight of 1/5---so the triples look something like this:
edge 3,5 2,4 0.200000 edge 3,5 2,5 0.200000 edge 3,5 2,6 0.200000 ... edge 3,5 4,6 0.200000
The format for this database file, named grid.cfacts
, is tab-separated, one triple per line, with the last column being optional and indicating the weight for that fact.
The program defines transitive closure on this graph:
path(X,Y) <= edge(X,Y) path(X,Y) <= edge(X,Z) & path(Z,Y)
The learning task is to reweight the edges so that when you issue a query of the form path(cell,Y)
, where cell
is a string like "3,5" naming a grid cell, the top-ranked Y is the closest corner of the grid---e.g., for 3,5
the top-ranked Y should be cell 1,1
. So examples in the training and test datasets look something like this:
... path 3,5 1,1 path 3,6 1,1 path 3,9 1,16 path 3,10 1,16 ...
Most of what you need for using Tensorlog within Tensorflow is encapsulated in the simple
subpackage. So start out by importing
import tensorflow as tf from tensorlog import simple
Now you will want to load in your database and your rule set. Rules are more or less part of the model, and in Tensorflow programs, you usually define the model in Python, so it's most natural to use the simple.Builder
class to construct rules. You might do something like this:
# generate the rules - for transitive closure b = simple.Builder() path,edge = b.predicates("path,edge") X,Y,Z = b.variables("X,Y,Z") b.rules += path(X,Y) <= edge(X,Y) b.rules += path(X,Y) <= edge(X,Z) & path(Z,Y)
Next you construct a simple.Compiler
object:
tlog = simple.Compiler(db="grid.cfacts",prog=b.rules)
The tlog
Compiler object has instance variables called db
and prog
which point to compiled versions of your database and program, respectively. Here we're going to configure these by saying we want to learn the edge weights, and also by limiting the depth of recursion to 16.
# configure the database so that edge weights are a parameter tlog.prog.db.markAsParameter('edge',2) # configure the program so that maximum recursive depth is 16 tlog.prog.maxDepth = 16
Now, to get the loss function associated with this learning process, we do the following:
mode = "path/io" unregularized_loss = tlog.loss(mode)
The unregularized_loss
variable now points to into the Tensorflow computation graph. If you wanted to, you could have specified the inputs by saying unregularized_loss = tlog.loss(mode, inputs=[x])
where x
is some other Tensorlog computation. If you don't specify this, the input will be a new Placeholder
whose string name can be found with with tlog.input_placeholder_name(mode)
.
To explain the argument mode
, notice that in general there could be many types of predicates defined in your program, and many types of queries. This loss function is specific for queries of the form "path(cell,Y)" where the first argument is an input and the second is an output (hence the "/io").
Some minimal code for learning is below:
optimizer = tf.train.AdagradOptimizer(1.0) train_step = optimizer.minimize(unregularized_loss) session = tf.Session() session.run(tf.global_variables_initializer()) trainData = tlog.load_small_dataset(trainFile) # run the optimizer for 20 epochs (tx,ty) = trainData[mode] train_fd = {tlog.input_placeholder_name(mode):tx, tlog.target_output_placeholder_name(mode):ty} for i in range(20): session.run(train_step, feed_dict=train_fd)
When Tensorlog compiles the loss function, it also compiles an "inference" function, which shares structure with the loss function. You can use this to monitor performance, use the learned classifier, etc. Here's a more complex learner as an example, which also prints train loss and accuracy as you learn.
mode = 'path/io' predicted_y = tlog.inference(mode) # accuracy metric, for testing actual_y = tlog.target_output_placeholder(mode) correct_predictions = tf.equal(tf.argmax(actual_y,1), tf.argmax(predicted_y,1)) accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32)) # learn to optimize the loss... unregularized_loss = tlog.loss(mode) optimizer = tf.train.AdagradOptimizer(1.0) train_step = optimizer.minimize(unregularized_loss) # set up the session session = tf.Session() session.run(tf.global_variables_initializer()) # load the training and test data trainData = tlog.load_small_dataset(trainFile) testData = tlog.load_small_dataset(testFile) # run the optimizer for 20 epochs, reporting progress as we go (tx,ty) = trainData[mode] train_fd = {tlog.input_placeholder_name(mode):tx, tlog.target_output_placeholder_name(mode):ty} for i in range(20): session.run(train_step, feed_dict=train_fd) print 'epoch',i+1,'train loss and accuracy',session.run([unregularized_loss,accuracy], feed_dict=train_fd) # output the test set performance (ux,uy) = testData[mode] test_fd = {tlog.input_placeholder_name(mode):ux, tlog.target_output_placeholder_name(mode):uy} print 'test acc', session.run(accuracy, feed_dict=test_fd)
The output should look like:
[misc warnings/status messages from t, np, etc] ... epoch 1 train loss and accuracy [399.21759, 0.03773585] epoch 2 train loss and accuracy [1361.6881, 0.03773585] epoch 3 train loss and accuracy [1024.2174, 0.10062893] ... epoch 18 train loss and accuracy [20.040112, 0.98742139] epoch 19 train loss and accuracy [14.722947, 0.99371076] epoch 20 train loss and accuracy [11.698922, 0.99371076] test acc 1.
A similar walk-through for another problem can be found here: