热门IT资讯网

学习日志---线性回归实现

发表于:2024-11-24 作者:热门IT资讯网编辑
编辑最后更新 2024年11月24日,由对偏导数的计算可以得到w的计算公式:如下假定输入数据存放在矩阵x中,而回归系数存放在向量w中。那么对于给定的数据,预测结果将会通过给出。对于x和y,如何找到w?常用的方法是找到平方误差最小的w。平方

由对偏导数的计算可以得到w的计算公式:如下

假定输入数据存放在矩阵x中,而回归系数存放在向量w中。那么对于给定的数据,预测结果将会通过给出。对于x和y,如何找到w?常用的方法是找到平方误差最小的w。

平方误差可以写做:


用矩阵表示还可以写做。对w求导,解得w如下:

采用的数据是在UCI上下载的回归汽车msg性能的数据集;

由于下载的数据格式不标准,因此这里自己写了一段java代码将数据集的格式进行了重新的规整,代码如下:

import java.io.BufferedReader;import java.io.BufferedWriter;import java.io.File;import java.io.FileInputStream;import java.io.FileOutputStream;import java.io.InputStreamReader;import java.io.OutputStreamWriter;public class MyMaze {        public static void main(String[] args) throws Exception {        FileInputStream fileInputStream = new FileInputStream(new File("E:\\DataRegression.txt"));        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(fileInputStream));        File file = new File("E:\\result.txt");        FileOutputStream fileOutputStream = new FileOutputStream(file);        BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(fileOutputStream));        String line;        String newline = null;        while((line = bufferedReader.readLine())!=null)        {            if(line == null)            {                break;            }            int length = line.length();            for(int i = 0; i
输出的文件是每个变量之间都有两个空格的数据集,其中第一项是因变量,也就是msg。

下面是采用python方法对数据集进行线性回归:

import numpy as npimport matplotlib.pyplot as pltnumFeat = len(open('result.txt').readline().split('  '))dataMat = []; labelMat = []fr = open('result.txt')//这里对每行的数据进行分割,提取每行的数据for line in fr.readlines():    lineArr=[]    curline = line.split('  ')    for i in range(1,numFeat):        lineArr.append(float(curline[i]))    dataMat.append(lineArr)    labelMat.append(float(curline[0]))//将序列转为矩阵xMat = np.mat(dataMat)yMat = np.mat(labelMat).TxTx = xMat.T*xMat/判断行列式的值是否为0if np.linalg.det(xTx) == 0.0:    print "wrong"//利用公式求参数ws = xTx.I*(xMat.T*yMat)//利用matplotLib画图,制定在fig中fig = plt.figure()ax = fig.add_subplot(111)xCopy = xMat.copy()xCopy.sort(0)yHat = xCopy*ws//这里是找x矩阵中某一项与yHat的关系,如这里是第二项ax.plot(xCopy[:,1],yHat)//展示图像plt.show()//这里是求出相关系数的函数,越接近1越好yHat = xMat*wsprint yHat.T.shapeprint yMat.shapeprint np.corrcoef(yHat.T, yMat.T)


附件:http://down.51cto.com/data/2366089
0