Pytorch自定义损失函数

本文以自己实现交叉熵损失函数为例(不彻底的实现),来简单说说在pytorch中自定义损失函数。
首先我们看一下pytorch中的交叉熵损失函数是什么样的,这里我们简单放张图,网上还是有不少说这个的,可自行百度。
Framework
简单说一下,以文本多分类为例,x是batch_size*class_num的Tensor,class是这个batch的真实标签。
我们通常使用criterion = nn.CrossEntropyLoss()这行代码来申明我们使用的损失函数,那么,如果我们自己实现该如何写呢?

  • step1:自定义一个类,继承自nn.Module,比如叫myloss,接着在该类中实现即可。
  • step2:将criterion = nn.CrossEntropyLoss()改成criterion=myloss()
    看上去很简单,但实际上有坑。
    我先放上我根据上图第一个公式中的第一个等式来写的损失函数(实际是个类)。
    Framework
    为了比较,我把官方的loss1=F.cross_entropy(pred_res,true_res)打印出来,发现在第一次计算值和官方值一致,但是他并没有反向传播更新参数,体现在实验准确率很低且不变。问题出在当我们继承nn.Module,在forward中实现loss定义时,所有的数学操作都必须使用tensor提供的数学操作。于是,我们的问题在于使用的是math.log,应该改成torch.log。除此之外,还需要注意的就是返回值得是0维的tensor(可以理解为tensor类型的标量)。
    最后,如果你发现那你需要的数学运算tensor没有提供,那么你可以继承于nn.autograd.function,然后自己写forward和backward函数,这部分我也没弄过,就不说了。
相关文章
评论
分享
Please check the comment setting in config.yml of hexo-theme-Annie!