This commit is contained in:
lvmin 2023-08-09 15:11:01 -07:00
parent 94db86402c
commit 06a96f70ab
1 changed files with 1 additions and 2 deletions

View File

@ -272,8 +272,7 @@ class SiLU(nn.Module):
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
self.weight = self.weight.float()
self.bias = self.bias.float()
self.to(torch.float32)
return super().forward(x.float()).type(x.dtype)