Label Smoothing

Posted by Packy on January 6, 2020

label Smoothing

def cal_loss(pred, gold, smoothing=True):	
	''' Calculate cross entropy loss, apply label smoothing if needed. '''
	gold = gold.contiguous().view(-1)
	if smoothing:
	eps = 0.2
	n_class = pred.size(1)
	one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
	one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
	log_prb = F.log_softmax(pred, dim=1)
	loss = -(one_hot * log_prb).sum(dim=1).mean()
	else:
	loss = F.cross_entropy(pred, gold, reduction='mean')
	return loss

让硬的one hot label 变为 软label。