diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 71a652ffa..ae5d2c563 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -189,7 +189,7 @@ class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],), + "sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps", "cosmos_rflow"],), "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), }} @@ -202,6 +202,7 @@ class ModelSamplingContinuousEDM: def patch(self, model, sampling, sigma_max, sigma_min): m = model.clone() + sampling_base = comfy.model_sampling.ModelSamplingContinuousEDM latent_format = None sigma_data = 1.0 if sampling == "eps": @@ -215,8 +216,11 @@ class ModelSamplingContinuousEDM: sampling_type = comfy.model_sampling.EDM sigma_data = 0.5 latent_format = comfy.latent_formats.SDXL_Playground_2_5() + elif sampling == "cosmos_rflow": + sampling_type = comfy.model_sampling.COSMOS_RFLOW + sampling_base = comfy.model_sampling.ModelSamplingCosmosRFlow - class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): + class ModelSamplingAdvanced(sampling_base, sampling_type): pass model_sampling = ModelSamplingAdvanced(model.model.model_config)