Add cosmos_rflow option to ModelSamplingContinuousEDM node. (#8523)

This is for the cosmos predict2 model.
This commit is contained in:
comfyanonymous 2025-06-13 14:47:52 -07:00 committed by GitHub
parent c69af655aa
commit 5bf69bde35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -189,7 +189,7 @@ class ModelSamplingContinuousEDM:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],), "sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps", "cosmos_rflow"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}} }}
@ -202,6 +202,7 @@ class ModelSamplingContinuousEDM:
def patch(self, model, sampling, sigma_max, sigma_min): def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone() m = model.clone()
sampling_base = comfy.model_sampling.ModelSamplingContinuousEDM
latent_format = None latent_format = None
sigma_data = 1.0 sigma_data = 1.0
if sampling == "eps": if sampling == "eps":
@ -215,8 +216,11 @@ class ModelSamplingContinuousEDM:
sampling_type = comfy.model_sampling.EDM sampling_type = comfy.model_sampling.EDM
sigma_data = 0.5 sigma_data = 0.5
latent_format = comfy.latent_formats.SDXL_Playground_2_5() latent_format = comfy.latent_formats.SDXL_Playground_2_5()
elif sampling == "cosmos_rflow":
sampling_type = comfy.model_sampling.COSMOS_RFLOW
sampling_base = comfy.model_sampling.ModelSamplingCosmosRFlow
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): class ModelSamplingAdvanced(sampling_base, sampling_type):
pass pass
model_sampling = ModelSamplingAdvanced(model.model.model_config) model_sampling = ModelSamplingAdvanced(model.model.model_config)