mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 13:35:05 +00:00
Update upscale model code to latest Chainner model code.
Don't add SRFormer because the code license is incompatible with the GPL. Remove MAT because it's unused and the license is incompatible with GPL.
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
import logging as logger
|
||||
|
||||
from .architecture.DAT import DAT
|
||||
from .architecture.face.codeformer import CodeFormer
|
||||
from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean
|
||||
from .architecture.face.restoreformer_arch import RestoreFormer
|
||||
from .architecture.HAT import HAT
|
||||
from .architecture.LaMa import LaMa
|
||||
from .architecture.MAT import MAT
|
||||
from .architecture.OmniSR.OmniSR import OmniSR
|
||||
from .architecture.RRDB import RRDBNet as ESRGAN
|
||||
from .architecture.SCUNet import SCUNet
|
||||
from .architecture.SPSR import SPSRNet as SPSR
|
||||
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
||||
from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
|
||||
@@ -33,7 +34,6 @@ def load_state_dict(state_dict) -> PyTorchModel:
|
||||
state_dict = state_dict["params"]
|
||||
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
# SRVGGNet Real-ESRGAN (v2)
|
||||
if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
|
||||
model = RealESRGANv2(state_dict)
|
||||
@@ -46,12 +46,14 @@ def load_state_dict(state_dict) -> PyTorchModel:
|
||||
and "initial.cnn.depthwise.weight" in state_dict["model"].keys()
|
||||
):
|
||||
model = SwiftSRGAN(state_dict)
|
||||
# HAT -- be sure it is above swinir
|
||||
elif "layers.0.residual_group.blocks.0.conv_block.cab.0.weight" in state_dict_keys:
|
||||
model = HAT(state_dict)
|
||||
# SwinIR
|
||||
# SwinIR, Swin2SR, HAT
|
||||
elif "layers.0.residual_group.blocks.0.norm1.weight" in state_dict_keys:
|
||||
if "patch_embed.proj.weight" in state_dict_keys:
|
||||
if (
|
||||
"layers.0.residual_group.blocks.0.conv_block.cab.0.weight"
|
||||
in state_dict_keys
|
||||
):
|
||||
model = HAT(state_dict)
|
||||
elif "patch_embed.proj.weight" in state_dict_keys:
|
||||
model = Swin2SR(state_dict)
|
||||
else:
|
||||
model = SwinIR(state_dict)
|
||||
@@ -78,12 +80,15 @@ def load_state_dict(state_dict) -> PyTorchModel:
|
||||
or "generator.model.1.bn_l.running_mean" in state_dict_keys
|
||||
):
|
||||
model = LaMa(state_dict)
|
||||
# MAT
|
||||
elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys:
|
||||
model = MAT(state_dict)
|
||||
# Omni-SR
|
||||
elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys:
|
||||
model = OmniSR(state_dict)
|
||||
# SCUNet
|
||||
elif "m_head.0.weight" in state_dict_keys and "m_tail.0.weight" in state_dict_keys:
|
||||
model = SCUNet(state_dict)
|
||||
# DAT
|
||||
elif "layers.0.blocks.2.attn.attn_mask_0" in state_dict_keys:
|
||||
model = DAT(state_dict)
|
||||
# Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
|
||||
else:
|
||||
try:
|
||||
|
Reference in New Issue
Block a user