-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pather.py
82 lines (60 loc) · 2.39 KB
/
er.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def add_layer(inputs , in_size,out_size, n_layer , activation_function = None):
layer_name = 'layer%s' % n_layer
with tf.name_scope(layer_name):
with tf.name_scope('Weights'):
Weights = tf.Variable(tf.random_normal([in_size,out_size]),name='W')
tf.summary.histogram(layer_name+'/weights', Weights)
with tf.name_scope('biases'):
biases = tf.Variable(tf.zeros([1,out_size])+0.1)
tf.summary.histogram(layer_name+'/biases', biases)
with tf.name_scope('Wx_plus_b'):
Wx_plus_b = tf.matmul(inputs,Weights) + biases
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
tf.summary.histogram(layer_name+'/outputs', outputs)
return outputs
x_data = np.linspace(-1,1,300)[:,np.newaxis]
noise = np.random.normal(0,0.05,x_data.shape)
y_data = np.square(x_data) - 0.5 + noise
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data,y_data,s = 10,color = 'b')
plt.ion()
plt.show()
with tf.name_scope('inputs'):
xs = tf.placeholder(tf.float32,[None , 1],name = 'x_input') #None 表示无论给多少个粒子都可以 输出为1
ys = tf.placeholder(tf.float32,[None , 1],name = 'y_input')
从输入到隐层
l1 = add_layer(xs,1,10,n_layer=1,activation_function = tf.nn.tanh)
prediction = add_layer(l1 , 10,1,n_layer=2,activation_function = None)
with tf.name_scope('loss'):
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
reduction_indices=[1]))
##tf.summary.scaler('loss',loss)
tf.summary.scalar('loss',loss)
with tf.name_scope('train'):
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
merged = tf.summary.merge_all()
writer = tf.summary.FileWriter('logs/', sess.graph)
sess.run(init)
for i in range(2001):
sess.run(train_step , feed_dict = {xs:x_data,ys:y_data})
if i%50==0:
result = sess.run(merged , feed_dict = {xs:x_data,ys:y_data})
writer.add_summary(result,i)
print(sess.run(loss, feed_dict = {xs:x_data,ys:y_data}))
try:
ax.lines.remove(lines[0])
except Exception:
pass
prediction_value = sess.run(prediction,feed_dict = {xs:x_data})
lines = plt.plot(x_data,prediction_value,'',lw = 5)
#plt.show()
plt.pause(0.1)