Make step index detection much more robust (#9392)

This commit is contained in:
Jedrzej Kosinski
2025-08-17 15:54:07 -07:00
committed by GitHub
parent d4e353a94e
commit 7f3b9b16c6

View File

@@ -164,8 +164,11 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
indexes = torch.where(model_options["transformer_options"]["sample_sigmas"] == timestep[0])
self._step = int(indexes[0])
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
self._step = int(matches[0].item())
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model