【pytorch基础教程7】多维特征input(学不会来打我啊)-凯发app官方网站

凯发app官方网站-凯发k8官网下载客户端中心 | | 凯发app官方网站-凯发k8官网下载客户端中心
  • 博客访问: 3600095
  • 博文数量: 365
  • 博客积分: 0
  • 博客等级: 民兵
  • 技术积分: 2522
  • 用 户 组: 普通用户
  • 注册时间: 2019-10-28 13:40
文章分类

(365)

  • (365)
文章存档

(8)

(130)

(155)

(50)

(22)

我的朋友
相关博文
  • ·
  • ·
  • ·
  • ·
  • ·
  • ·
  • ·
  • ·
  • ·
  • ·

分类: python/ruby

2021-10-19 17:27:16

# -*- coding: utf-8 -*-

"""

created on mon oct 18 10:18:24 2021

@author: 86493

"""

import torch

import torch.nn as nn

import numpy as np

import matplotlib.pyplot as plt

# 这里的type不用double,特斯拉gpudouble

xy = np.loadtxt('diabetes.csv',

                delimiter = ' ',

                dtype = np.float32)

# 最后一列不要

x_data = torch.from_numpy(xy[: , : -1])

# [-1]则拿出来的是一个矩阵,去了中括号则拿出向量

y_data = torch.from_numpy(xy[:, [-1]])

losslst = []

class model(nn.module):

    def __init__(self):

        super(model, self).__init__()

        self.linear1 = nn.linear(9, 6)

        self.linear2 = nn.linear(6, 4)

        self.linear3 = nn.linear(4, 1)         

        # 外汇跟单gendan5.com上次logistic是调用nn.functionalsigmoid

        self.sigmoid = nn.sigmoid()

        # 这个也是继承module,没有参数,比上次写法不容易出错

    def forward(self, x):

        x = self.sigmoid(self.linear1(x))

        x = self.sigmoid(self.linear2(x))

        x = self.sigmoid(self.linear3(x))

        return x

model = model()

# 使用交叉熵作损失函数

criterion = nn.bceloss(size_average = false)

optimizer = torch.optim.sgd(model.parameters(),

                            lr = 0.01)

# 训练,下面没有用mini-batch,后面讲dataloader再说

for epoch in range(10):

    y_predict = model(x_data)

    loss = criterion(y_predict, y_data)

    # 打印loss对象会自动调用__str__

    print(epoch, loss.item())

    losslst.append(loss.item())

    # 梯度清零后反向传播

    optimizer.zero_grad()

    loss.backward()

    # 更新权重

    optimizer.step()

# 画图

plt.plot(range(10), losslst)

plt.ylabel('loss')

plt.xlabel('epoch')

plt.show()

阅读(5330) | 评论(0) | 转发(0) |
0

上一篇:

下一篇:

给主人留下些什么吧!~~
")); function link(t){ var href= $(t).attr('href'); href ="?url=" encodeuricomponent(location.href); $(t).attr('href',href); //setcookie("returnouturl", location.href, 60, "/"); }
网站地图