博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
gluon实现softmax分类FashionMNIST
阅读量:5748 次
发布时间:2019-06-18

本文共 1940 字,大约阅读时间需要 6 分钟。

from mxnet import gluon,initfrom mxnet.gluon import loss as gloss,nnfrom mxnet.gluon import data as gdatafrom mxnet import autograd,ndimport gluonbook as gbimport sys# 读取数据mnist_train = gdata.vision.FashionMNIST(train=True)mnist_test = gdata.vision.FashionMNIST(train=False)batch_size = 256transformer = gdata.vision.transforms.ToTensor()if sys.platform.startswith('win'):    num_workers = 0else:    num_workers = 4# 小批量数据迭代器train_iter = gdata.DataLoader(mnist_train.transform_first(transformer),batch_size=batch_size,shuffle=True,num_workers=num_workers)test_iter = gdata.DataLoader(mnist_test.transform_first(transformer),batch_size=batch_size,shuffle=False,num_workers=num_workers)# 模型参数初始化net = nn.Sequential()net.add(nn.Dense(10))net.initialize(init.Normal(sigma=0.01))# 损失函数loss = gloss.SoftmaxCrossEntropyLoss()# 优化算法trainer = gluon.Trainer(net.collect_params(),'sgd',{
'learning_rate':0.1})def accuracy(y_hat, y): return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar()def evaluate_accuracy(data_iter, net): acc = 0 for X, y in data_iter: acc += accuracy(net(X), y) return acc / len(data_iter)num_epochs = 5def train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,trainer=None): for epoch in range(num_epochs): train_l_sum = 0 train_acc_sum = 0 for X,y in train_iter: with autograd.record(): y_hat = net(X) l = loss(y_hat,y) l.backward() if trainer is None: gb.sgd(params,lr,batch_size) else: trainer.step(batch_size) train_l_sum += l.mean().asscalar() test_acc = evaluate_accuracy(test_iter,net) print('epoch %d,loss %.4f,test acc %.3f'%(epoch+1,train_l_sum / len(train_iter),test_acc))train(net,train_iter,test_iter,loss,num_epochs,batch_size,None,None,trainer)

 

转载于:https://www.cnblogs.com/TreeDream/p/10033155.html

你可能感兴趣的文章
时间助理 时之助
查看>>
英国征召前黑客组建“网络兵团”
查看>>
PHP 命令行模式实战之cli+mysql 模拟队列批量发送邮件(在Linux环境下PHP 异步执行脚本发送事件通知消息实际案例)...
查看>>
pyjamas build AJAX apps in Python (like Google did for Java)
查看>>
centos5.9使用RPM包搭建lamp平台
查看>>
Javascript String类的属性及方法
查看>>
[LeetCode] Merge Intervals
查看>>
Struts2 学习小结
查看>>
在 Linux 系统中安装Load Generator ,并在windows 调用
查看>>
chm文件打开,有目录无内容
查看>>
whereis、find、which、locate的区别
查看>>
一点不懂到小白的linux系统运维经历分享
查看>>
桌面支持--打不开网页上的pdf附件解决办法(ie-tools-compatibility)
查看>>
nagios监控windows 改了NSclient++默认端口 注意事项
查看>>
干货 | JAVA代码引起的NATIVE野指针问题(上)
查看>>
POI getDataFormat() 格式对照
查看>>
好的产品原型具有哪些特点?
查看>>
实现java导出文件弹出下载框让用户选择路径
查看>>
刨根问底--技术--jsoup登陆网站
查看>>
OSChina 五一劳动节乱弹 ——女孩子晚上不要出门,发生了这样的事情
查看>>