Source code for mmseg.models.losses.mse_loss
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp
from ..builder import LOSSES
class AverageMeter(object):
"""Computes and stores the average and current valvalue"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(outputs, targets):
batch_size = targets.size(0)
_, pred = outputs.topk(1, 1, True)
pred = pred.t()
correct = pred.eq(targets.view(1, -1))
n_correct_elems = correct.float().sum().item()
return n_correct_elems / batch_size
[docs]@LOSSES.register_module()
class MSELoss(nn.Module):
def __init__(self, use_sigmoid=False, loss_weight=None):
super(MSELoss, self).__init__()
self.criterion = nn.MSELoss()
self.loss_weight = loss_weight
[docs] def forward(self, out, target, ignore_index=-100):
target = target.float()
#h, w = target.size(1), target.size(2)
#out = F.interpolate(out, size=[h, w], mode='bilinear')
loss = self.criterion(out.squeeze(dim=1), target) * self.loss_weight
return loss
[docs]@LOSSES.register_module()
class NoNaNMSE(nn.Module):
def __init__(self, use_sigmoid=False, loss_weight=None):
super(NoNaNMSE, self).__init__()
self.loss_weight = loss_weight
[docs] def forward(self, output, target, ignore_index=-100):
diff = torch.squeeze(output) - target
not_nan = ~torch.isnan(diff)
loss = torch.mean(diff.masked_select(not_nan) ** 2) * self.loss_weight
return loss