博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
测试1
阅读量:5290 次
发布时间:2019-06-14

本文共 3536 字,大约阅读时间需要 11 分钟。

 

# encoding: UTF-8import tensorflow as tfimport numpy as npfrom tensorflow.examples.tutorials.mnist import input_data as mnist_dataimport tensorflow as tffrom tensorflow.python.platform import gfileimport osprint("Tensorflow version " + tf.__version__)print(tf.__path__)# tf.set_random_seed(0)# # 输入mnist数据# mnist = mnist_data.read_data_sets("data", one_hot=True)# #输入数据# x = tf.placeholder("float", [None, 784])# y_ = tf.placeholder("float", [None,10])# #权值输入# W = tf.Variable(tf.zeros([784,10]))# b = tf.Variable(tf.zeros([10]))# #神经网络输出# y = tf.nn.softmax(tf.matmul(x,W) + b)# #设置交叉熵# cross_entropy = -tf.reduce_sum(y_*tf.log(y))# #设置训练模型# learningRate = 0.005# train_step = tf.train.GradientDescentOptimizer(learningRate).minimize(cross_entropy)# init = tf.initialize_all_variables()# sess = tf.Session()# sess.run(init)# itnum = 1000;# batch_size = 100;# for i in range(itnum):#     if i % 100 == 0:#         print("the index " + str(i + 1) + " train")#     batch_xs, batch_ys = mnist.train.next_batch(batch_size)#     sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})# correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))# accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))# print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})def train():    height = 28    width = 28    inchannel = 1    outchannel = 2    #conv0 (64, 112, 112) kernel (3, 3) stride (1, 1) pad (1, 1)    wkernel = 3    stride = 1    pad = 1    dilate  = 1    w = np.arange(wkernel * wkernel * inchannel * outchannel).reshape((outchannel,inchannel,wkernel,wkernel))    b = np.array([0])    data = np.arange(height * width * inchannel).reshape((1,inchannel,height,width))    print('input:',data)    print('weight:',w)    data = data.transpose(0,3,2,1)    w = w.transpose(3,2,1,0)    # print('input:',data)    # print('inputshape:',data.shape)    # print('weight:',w)    # print('weight:',w.shape)    input = tf.Variable(data, dtype=np.float32, name="input")    #input_reshape = tf.reshape(input, [1,inchannel,height,width])    filter = tf.Variable(w, dtype=np.float32,name="weight")    conv = tf.nn.conv2d(input, filter, strides=[1, stride, stride, 1], padding='SAME', name = "conv")    init = tf.global_variables_initializer()    with tf.Session() as sess:        sess.run(init)        #print("input: \n", sess.run(input))        #input_reshape = sess.run(input).transpose(0,3,2,1)        #print("input_reshape: \n", input_reshape)        #print("filter: \n", sess.run(filter))        #filter_reshape = sess.run(filter).transpose(3,2,1,0)        #print("filter_reshape: \n", filter_reshape)        #print("conv ", sess.run(conv))        conv_reshape = sess.run(conv).transpose(0,3,2,1)        print("conv_reshape: \n", conv_reshape)        # tf_prelu_reshape = sess.run(tf_prelu).transpose(0,3,2,1)        # print("tf_prelu_reshape: \n", tf_prelu_reshape)        # tf_bn_reshape = sess.run(tf_bn).transpose(0,3,2,1)        # print("tf_bn_reshape: \n", tf_bn_reshape)        export_dir = "log"        saver = tf.train.Saver()        step = 200        import os        checkpoint_file = os.path.join(export_dir, 'model.ckpt')        saver.save(sess, checkpoint_file, global_step=step)        graph = tf.get_default_graph()        checkpoint_file = os.path.join(export_dir, 'model.ckpt-200.meta')        _ = tf.train.import_meta_graph(checkpoint_file)        summary_write = tf.summary.FileWriter(export_dir , graph)if __name__ == '__main__':    train()

 

转载于:https://www.cnblogs.com/adong7639/p/9227033.html

你可能感兴趣的文章
P2822 组合数问题
查看>>
P1385 密令
查看>>
U65320 Oak的预算方案
查看>>
P1064 金明的预算方案
查看>>
P2902 [USACO08MAR]珍珠配对Pearl Pairing
查看>>
P1852 [国家集训队]跳跳棋
查看>>
P1318 积水面积
查看>>
P1572 计算分数
查看>>
P1056 排座椅
查看>>
P5535 【XR-3】小道消息
查看>>
P1033 自由落体
查看>>
P2706 巧克力
查看>>
#575. 列车调度
查看>>
P1144 最短路计数
查看>>
P1209 [USACO1.3]修理牛棚 Barn Repair
查看>>
P2904 [USACO08MAR]跨河River Crossing
查看>>
P4379 [USACO18OPEN]Lemonade Line
查看>>
P1109 学生分组
查看>>
P3650 [USACO1.3]滑雪课程设计Ski Course Design
查看>>
P2789 直线交点数
查看>>