博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
深度学习原理与框架-Tensorflow基本操作-实现线性拟合
阅读量:6824 次
发布时间:2019-06-26

本文共 1539 字,大约阅读时间需要 5 分钟。

代码:使用tensorflow进行数据点的线性拟合操作

第一步:使用np.random.normal生成正态分布的数据

第二步:将数据分为X_data 和 y_data

第三步:对参数W和b, 使用tf.Variable()进行初始化,对于参数W,使用tf.random_normal([1], -1.0, 1.0)构造初始值,对于参数b,使用tf.zeros([1]) 构造初始值

第四步:使用W * X_data + b 构造出预测值y_pred 

第五步:使用均分误差来表示loss损失值,即tf.reduce_mean(tf.square(y_data - y_pred))

第六步:使用opt = tf.train.GradientDescentOptimizer(0.5).minimize(loss) 梯度下架来降低损失值

第七步:循环,使用sess.run(opt) 执行梯度降低损失值的操作,并打印w,b和loss

第八步:进行作图操作,画出散点图和拟合好的曲线图

import numpy as npimport matplotlib.pyplot as pltimport tensorflow as tf# 第一步:使用np.random.normal创建数据,即y = 0.1*x + 0.3data = []num_data = 1000for i in range(num_data):    x_data = np.random.normal(0.0, 0.55)    y_data = 0.1 * x_data + 0.3 + np.random.normal(0.0, 0.03)    data.append([x_data, y_data])# 第二步:将数据进行分配,分成特征和标签X_data = [v[0] for v in data]y_data = [v[1] for v in data]# 第三步:使用tf.Variable进行参数的初始化操作W = tf.Variable(tf.random_normal([1], -1.0, 1.0), name='W')b = tf.Variable(tf.zeros([1]))# 第四步:使用X_data * W + b 计算损失值y_pred = X_data * W + b# 第五步:使用均分误差来作为损失值loss = tf.reduce_mean(tf.square(y_data - y_pred))# 第六步:使用梯度下降来降低损失值opt = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)# 参数初始化操作sess = tf.Session()init = tf.global_variables_initializer()sess.run(init)for i in range(20):    # 第七步:循环,执行梯度下降操作,打印w,b和loss    sess.run(opt)    print('W:%g b:%g loss:%g'%(sess.run(W), sess.run(b), sess.run(loss)))# 第八步: 画图操作plt.scatter(X_data, y_data, c='r')plt.plot(X_data, sess.run(W) * X_data + sess.run(b))plt.show()

 

转载于:https://www.cnblogs.com/my-love-is-python/p/10520532.html

你可能感兴趣的文章
jquery基础
查看>>
C# 集合已修改;可能无法执行枚举操作
查看>>
FSM Code Generator
查看>>
JDBC学习笔记——事务、存储过程以及批量处理
查看>>
JVM内存结构
查看>>
Java 锁
查看>>
7、索引在什么情况下遵循最左前缀的规则?
查看>>
c#中委托与事件
查看>>
mysql数据库备份之主从同步配置
查看>>
angularJs(1)指令篇
查看>>
自定义Xadmin
查看>>
jsp页面表单的遍历要怎么写
查看>>
循环引用,看我就对了
查看>>
软件工程——第一周作业
查看>>
ubuntu14.04安装vmware workstation
查看>>
ArcGIS API for Silverlight部署本地地图服务
查看>>
小知识点
查看>>
python mongodb MapReduce
查看>>
python-数据类型
查看>>
Google MapReduce/GFS/BigTable三大技术的论文中译版
查看>>