深度學(xué)習(xí)從宏觀層面理解并不復(fù)雜,但過(guò)程很繁瑣,超參數(shù)非常多,一不留神就弄錯(cuò)了,而且有些參數(shù)調(diào)優(yōu)還挺敏感的。
tensorflow另一個(gè)問(wèn)題是作為底層深度學(xué)習(xí)庫(kù),API非常細(xì),配置一個(gè)網(wǎng)絡(luò)需要大量模塊化的代碼,其實(shí)是可以高效封裝和復(fù)用的。
本文實(shí)現(xiàn)一個(gè)簡(jiǎn)單的not and門,來(lái)測(cè)試代碼的封裝。直接上最終代碼看看,很簡(jiǎn)潔,這就是一個(gè)“輸入層”:input_data,一個(gè)“全連接層”:fully_connected
import unittest
import tensorflow as tf
from aipack.tfgraph import layers,trainer
X = [[0., 0.], [0., 1.], [1., 0.], [1., 1.]]
Y_nand = [[1.,0.], [1.,0.], [1.,0.], [0.,1.]]
Y_or = [[0.], [1.], [1.], [1.]]
g = layers.input_data(shape=[None, 2])
# not and operator definition
g_nand = layers.fully_connected(g, 32)
g_nand = layers.fully_connected(g_nand, 32)
g_nand = layers.fully_connected(g_nand, 2, activation=tf.nn.softmax)
# Or operator definition
#g_or = layers.fully_connected(g, 32)
#g_or = layers.fully_connected(g_or, 32)
#g_or = layers.fully_connected(g_or, 1,activation=tf.nn.sigmoid)
model = trainer.Trainer(g_nand,learning_rate=0.1)
model.fit(X,Y_nand,validation_set=(X,Y_nand),n_epoch=400,batch_size=4)
print(model.predict([[1.,0.]]))
print(model.predict([[1.,1.]]))
看下input_data層的封裝:
import tensorflow as tf
from tensorflow.contrib import rnn
import numpy as np
import math
def input_data(shape=None,dtype=tf.float32,name='InputData'):
with tf.name_scope(name):
placeholder = tf.placeholder(dtype=dtype,shape=shape,name='X')
tf.add_to_collection('input_x',placeholder)
return placeholder
再看下全連接層,看到?jīng)]有,其實(shí)就是WX+b
def fully_connected(incoming,n_units,activation=None,name='FullyConnected'):
input_shape = incoming.get_shape().as_list()
#如果超過(guò)2維,除去第一維,把后面的維“相乘”,就是flatten,“壓平了”
n_inputs = int(np.prod(input_shape[1:]))
with tf.name_scope(name):
weights = tf.Variable(
tf.truncated_normal([n_inputs, n_units],
stddev=1.0 / math.sqrt(float(n_inputs))),
name='weights')
biases = tf.Variable(tf.zeros([n_units]),
name='biases')
logits = tf.matmul(incoming, weights) + biases#x*W+b
if activation is not None:
logits = activation(logits)
return logits
這里有幾個(gè)trick,learning_rate很敏感,設(shè)置為2或0.1,或者0.01差別很大。
后面可以把rnn/lstm,gru還有conv都做相應(yīng)的封裝,這樣使用tensorflow就很高效。
關(guān)于作者:魏佳斌,互聯(lián)網(wǎng)產(chǎn)品/技術(shù)總監(jiān),北京大學(xué)光華管理學(xué)院(MBA),特許金融分析師(CFA),資深產(chǎn)品經(jīng)理/碼農(nóng)。偏愛(ài)python,深度關(guān)注互聯(lián)網(wǎng)趨勢(shì),人工智能,AI金融量化。致力于使用最前沿的認(rèn)知技術(shù)去理解這個(gè)復(fù)雜的世界。
聯(lián)系客服