An example implementation of Uber's Generative Teaching Network (GTN) with Keras (tensorflow)
-
GPU:
docker build -f gpuDockerfile -t kerasgtn .
-
CPU:
docker build -f cpuDockerfile -t kerasgtn .
docker run --gpus all -u $(id -u):$(id -g) -it --rm -v $PWD:/tf kerasgtn:latest bash
Remove --gpus all
if using CPU Dockerfile
docker run --gpus all -u $(id -u):$(id -g) -it --rm -v -p 8888:8888 $PWD:/tf kerasgtn:latest
Remove --gpus all
if using CPU Dockerfile
from kerasgtn.gtn import GTN
class MyGTN(GTN):
def __init__(self, **kwargs):
super(MyGTN, self).__init__(**kwargs)
def get_generator(self, input_layer):
<implement>
def get_learner(self, real_input, teacher):
<implement>
gtn = MyGTN(input_shape=input_shape, n_classes=n_classes)
gtn.train(...)