Fix sample timesteps

This commit is contained in:
jerrister
2024-11-26 14:03:00 +08:00
parent c9b4d43a2b
commit 5aa4766033

View File

@@ -193,7 +193,7 @@ class CFM(nn.Module):
y0 = (1 - t_start) * y0 + t_start * test_cond
steps = int(steps * (1 - t_start))
t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
if sway_sampling_coef is not None:
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)