博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
L2范数惩罚项,高维线性回归
阅读量:5926 次
发布时间:2019-06-19

本文共 2298 字,大约阅读时间需要 7 分钟。

%matplotlib inlineimport mxnetfrom mxnet import nd,autogradfrom mxnet import gluon,initfrom mxnet.gluon import data as gdata,loss as gloss,nnimport gluonbook as gbn_train, n_test, num_inputs = 20,100,200true_w = nd.ones((num_inputs, 1)) * 0.01true_b = 0.05features = nd.random.normal(shape=(n_train+n_test, num_inputs))labels = nd.dot(features,true_w) + true_blabels += nd.random.normal(scale=0.01, shape=labels.shape)train_feature = features[:n_train,:]test_feature = features[n_train:,:]train_labels = labels[:n_train]test_labels = labels[n_train:]#print(features,train_feature,test_feature)# 初始化模型参数def init_params():    w = nd.random.normal(scale=1, shape=(num_inputs, 1))    b = nd.zeros(shape=(1,))    w.attach_grad()    b.attach_grad()    return [w,b]# 定义,训练,测试batch_size = 1num_epochs = 100lr = 0.03train_iter = gdata.DataLoader(gdata.ArrayDataset(train_feature,train_labels),batch_size=batch_size,shuffle=True)# 定义网络def linreg(X, w, b):    return nd.dot(X,w) + b# 损失函数def squared_loss(y_hat, y):    """Squared loss."""    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2# L2 范数惩罚def l2_penalty(w):    return (w**2).sum() / 2def sgd(params, lr, batch_size):    for param in params:        param[:] = param - lr * param.grad / batch_sizedef fit_and_plot(lambd):    w, b = init_params()    train_ls, test_ls = [], []    for _ in range(num_epochs):        for X, y in train_iter:            with autograd.record():                # 添加了 L2 范数惩罚项。                l = squared_loss(linreg(X, w, b), y) + lambd * l2_penalty(w)            l.backward()            sgd([w, b], lr, batch_size)        train_ls.append(squared_loss(linreg(train_feature, w, b),                             train_labels).mean().asscalar())        test_ls.append(squared_loss(linreg(test_feature, w, b),                            test_labels).mean().asscalar())    gb.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',                range(1, num_epochs + 1), test_ls, ['train', 'test'])    print('L2 norm of w:', w.norm().asscalar())
fit_and_plot(0) fit_and_plot(3)

训练集太少,容易出现过拟合,即训练集loss远小于测试集loss,解决方案,权重衰减——(L2范数正则化)

例如线性回归:

loss(w1,w2,b) = 1/n * sum(x1w1 + x2w2 + b - y)^2 /2 ,平方损失函数。

权重参数 w = [w1,w2],

新损失函数 loss(w1,w2,b) += lambd / 2n *||w||^2

迭代方程:

转载于:https://www.cnblogs.com/TreeDream/p/10027139.html

你可能感兴趣的文章
比较好的Dapper封装的仓储实现类 来源:https://www.cnblogs.com/liuchang/articles/4220671.html...
查看>>
Myeclipse优化配置
查看>>
C#源代码生成器
查看>>
2015年终总结
查看>>
MyBatis学习总结(17)——Mybatis分页插件PageHelper
查看>>
HTML5:理解head
查看>>
一维条形码***技术(Badbarcode)
查看>>
认清几种视频接口标准---无私奉献版
查看>>
Vim的配置
查看>>
bigpipe merge对F5做批量配置
查看>>
为什么这个SQL Server DBA学习PowerShell--SQL任务
查看>>
boost pool内存池库使用简要介绍
查看>>
Ansible 一步一步从入门到精通(一)
查看>>
Linux内核驱动GPIO的使用
查看>>
zabbix2.0安装与配置
查看>>
oracle用户名密码过期引起的网站后台无法登录
查看>>
TinyUI组件开发示例
查看>>
crond定时任务详细分析
查看>>
MySQL查询,按拼音首字母排序
查看>>
VmWare5.5主机Citrix桌面实施方案(二)
查看>>