损失函数是什么
在PyTorch中,损失函数(Loss Function)是用于衡量模型在训练数据集上的预测误差的函数。它可用于指导模型参数的更新,使模型在训练数据上的效果越来越好。损失函数通常与网络的最后一层连用,通过自动微分,计算每个参数对损失的影响,指导参数优化方向。
PyTorch中常用的损失函数有:
MSELoss - 平均平方误差损失,用于回归问题
CrossEntropyLoss - 交叉熵损失,用于分类问题,将softmax激活和负对数似然损失合并在一起计算。
NLLLoss - 负对数似然损失,用于多分类问题。
BCELoss - 二进制交叉熵损失,用于二分类问题。
L1Loss - L1范数损失,使模型对异常点更鲁棒。用于线性回归,逻辑回归。
SmoothL1Loss - 平滑L1损失,综合平方误差损失和L1损失的优点。用于目标检测。
什么是回归问题,分类问题,多分类问题,二分类问题
回归问题: 回归问题的目标是预测连续型的数值目标变量,预测结果为一个连续的值。如房价预测、销量预测等。比如之前写的股价预测,其中的损失函数就是MSELoss
分类问题: 分类问题的目标是预测离散的类别标签,一般输出是一个类别。如图像分类(猫or狗)、垃圾邮件分类。
多分类问题: 多分类问题的目标是对样本进行多类别分类,预测结果是多个类别中的一个。分类类别数目大于2。如手写数字识别(0~9类别)。
二分类问题: 二分类问题的目标是对样本进行二分类,只有两个类别,如是与否分类、垃圾邮件检测(垃圾or非垃圾)。
使用损失函数
在PyTorch中使用损失函数非常简单,可以如下实例化一个损失函数对象:
loss_fn = nn.MSELoss()
loss = loss_fn(prediction, target)
选择合适的损失函数对模型性能和训练速度有很大影响。实践中可以测试不同的损失函数,选择验证指标效果最好的。