critic_loss:
def critic_loss_fn(self, values, old_values, returns, mask):
value loss
values_clipped = torch.clamp(
values,
old_values - self.cliprange_value,
old_values + self.cliprange_value,
)
vf_loss1 = (values - returns) ** 2
vf_loss2 = (values_clipped - returns) ** 2
vf_loss = 0.5 * torch.sum(
torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
return vf_loss
Why does vf_loss take the maximum value? If that's the case, wouldn't the clamp be meaningless?