This commit is contained in:
lvmin 2023-08-09 15:09:40 -07:00
parent c104f30137
commit 94db86402c
1 changed files with 2 additions and 2 deletions

View File

@ -272,8 +272,8 @@ class SiLU(nn.Module):
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
self.weight.to(torch.float32)
self.bias.to(torch.float32)
self.weight = self.weight.float()
self.bias = self.bias.float()
return super().forward(x.float()).type(x.dtype)