Doing some experimentation

This commit is contained in:
Jedrzej Kosinski
2025-09-02 22:19:12 -07:00
parent a40c5ae341
commit 295b49c165
3 changed files with 509 additions and 1 deletions

View File

@@ -25,6 +25,7 @@ def prepare_noise_wrapper(executor, *args, **kwargs):
start_index = relevant_sigmas[0][0]
end_index = relevant_sigmas[-1][0]
sb_holder.predict_ratios = torch.linspace(sb_holder.start_predict_ratio, sb_holder.end_predict_ratio, end_index - start_index + 1)
sb_holder.predict_start_index = start_index
return executor(*args, **kwargs)
finally:
@@ -95,7 +96,8 @@ def model_forward_wrapper(executor, *args, **kwargs):
# when 1: Select DiT blocks(4)
if sb_holder.step_modulus == 1:
predict_ratio = sb_holder.predict_ratios[sb_holder.active_steps-1]
predict_index = max(0, sb_holder.step_count - sb_holder.predict_start_index)
predict_ratio = sb_holder.predict_ratios[predict_index]
logging.info(f"Sortblock: for step {sb_holder.step_count}, selecting blocks for recomputation and prediction, predict_ratio: {predict_ratio}")
reuse_ratio = 1.0 - predict_ratio
for block_type, blocks in sb_holder.blocks_per_type.items():
@@ -322,6 +324,7 @@ class SortblockHolder:
self.do_sortblock = False
self.active_steps = 0
self.predict_ratios = []
self.predict_start_index = 0
# cache values
self.all_blocks = []
@@ -371,6 +374,7 @@ class SortblockHolder:
self.active_steps = 0
self.predict_ratios = []
self.do_sortblock = False
self.predict_start_index = 0
return self
def clone(self):