Loading...
首页问答    

PyTorch---感知机 损失函数计算的数值一直没变化

Bingo~
Bingo~  发布于 2020-09-29 14:31:43 383

在训练过程中,损失函数的值一直没有变化,这是为啥啊?恳请大佬指点一下问题出在哪儿

pyTorch —-> 感知机

激活函数为 sign函数

import torch
import pandas
from torch import nn, optim
from torch.autograd import Variable

def get_data(path): # 将 csv 中的 每一条数据,放进一个列表中 —>每个列表为一个样本点 , y单独 放进一个列表中
data_x1 = pandas.read_csv(path)[‘x1’] # 特征 x1
data_x2 = pandas.read_csv(path)[‘x2’] # 特征 x2
target = pandas.read_csv(path)[‘y’] # 所属分类
target = list(target)
data_train = []
for i in range(len(target)):
if target[i] == 0:
target[i] = -1
for i in range(len(data_x1)):
temp_data = []
temp_data.append(data_x1[i])
temp_data.append(data_x2[i])
data_train.append(temp_data)
return data_train, target

def trans_list_variable(data_train, target): # list 转成 Tensor 再转成 Variable

# 列表转成Tensor
data_train = torch.Tensor(data_train)
target = [target]
target = torch.Tensor(target)
target =target.T
# Tensor 转成 Variable
data_train = Variable(data_train, requires_grad=True)
target = Variable(target, requires_grad=True)
return data_train, target

class LineRegression(nn.Module):
def init(self):
super(LineRegression, self).init()
self.linear_function = nn.Linear(2, 1)
def forward(self, data_train):
result = []
for i in data_train:
out = self.linear_function(i)
out = torch.sign(out)
result.append(out)

    # 将列表转换成 Variable
    result = [result]
    result = torch.Tensor(result)
    result = result.T
    result = Variable(result)

    return result

def train(data_train, target):
model = LineRegression()
loss_Model = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

epoch = 40000
for i in range(epoch):  # 训练40000次

    result = model(data_train)            
    loss_function = loss_Model(result, target)
    optimizer.zero_grad()
    loss_function.backward()
    optimizer.step()

    if (i+1) % 200 == 0:
        out = loss_function.item()
        print(out)

if name == ‘main‘:
path = r’C:\Users\sel\Desktop\ML_data.csv’
data_train, target = get_data(path) # 获得数据
data_train, target = trans_list_variable(data_train, target)
train(data_train, target)

添加附件:文件小于20M
上传
易百纳技术社区
确定要删除此文章、专栏、评论吗?
确定
取消
易百纳技术社区