This commit is contained in:
parent
94db86402c
commit
06a96f70ab
|
|
@ -272,8 +272,7 @@ class SiLU(nn.Module):
|
|||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
self.weight = self.weight.float()
|
||||
self.bias = self.bias.float()
|
||||
self.to(torch.float32)
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue