This commit is contained in:
parent
847c4f0c4d
commit
95bbc7825d
22
entry.py
22
entry.py
|
|
@ -41,11 +41,11 @@ def get_batch(keys, value_dict, N, device="cuda"):
|
|||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
batch_uc["original_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1) / 2
|
||||
)
|
||||
# batch_uc["original_size_as_tuple"] = (
|
||||
# torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
||||
# .to(device)
|
||||
# .repeat(*N, 1) / 2
|
||||
# )
|
||||
elif key == "crop_coords_top_left":
|
||||
batch["crop_coords_top_left"] = (
|
||||
torch.tensor(
|
||||
|
|
@ -70,11 +70,11 @@ def get_batch(keys, value_dict, N, device="cuda"):
|
|||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
batch_uc["target_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1) / 2.0
|
||||
)
|
||||
# batch_uc["target_size_as_tuple"] = (
|
||||
# torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
||||
# .to(device)
|
||||
# .repeat(*N, 1) / 2.0
|
||||
# )
|
||||
else:
|
||||
batch[key] = value_dict[key]
|
||||
|
||||
|
|
@ -100,7 +100,7 @@ sampler = EulerAncestralSampler(
|
|||
verbose=True,
|
||||
)
|
||||
|
||||
torch.manual_seed(123)
|
||||
torch.manual_seed(12345)
|
||||
|
||||
config_path = './sd_xl_base.yaml'
|
||||
config = OmegaConf.load(config_path)
|
||||
|
|
|
|||
Loading…
Reference in New Issue