Update upscale model code to latest Chainner model code.

Don't add SRFormer because the code license is incompatible with the GPL.

Remove MAT because it's unused and the license is incompatible with GPL.
This commit is contained in:
comfyanonymous
2023-09-02 22:25:12 -04:00
parent 62efc78a4b
commit 766c7b3815
12 changed files with 2084 additions and 2511 deletions

View File

@@ -1,13 +1,14 @@
import logging as logger
from .architecture.DAT import DAT
from .architecture.face.codeformer import CodeFormer
from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean
from .architecture.face.restoreformer_arch import RestoreFormer
from .architecture.HAT import HAT
from .architecture.LaMa import LaMa
from .architecture.MAT import MAT
from .architecture.OmniSR.OmniSR import OmniSR
from .architecture.RRDB import RRDBNet as ESRGAN
from .architecture.SCUNet import SCUNet
from .architecture.SPSR import SPSRNet as SPSR
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
@@ -33,7 +34,6 @@ def load_state_dict(state_dict) -> PyTorchModel:
state_dict = state_dict["params"]
state_dict_keys = list(state_dict.keys())
# SRVGGNet Real-ESRGAN (v2)
if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
model = RealESRGANv2(state_dict)
@@ -46,12 +46,14 @@ def load_state_dict(state_dict) -> PyTorchModel:
and "initial.cnn.depthwise.weight" in state_dict["model"].keys()
):
model = SwiftSRGAN(state_dict)
# HAT -- be sure it is above swinir
elif "layers.0.residual_group.blocks.0.conv_block.cab.0.weight" in state_dict_keys:
model = HAT(state_dict)
# SwinIR
# SwinIR, Swin2SR, HAT
elif "layers.0.residual_group.blocks.0.norm1.weight" in state_dict_keys:
if "patch_embed.proj.weight" in state_dict_keys:
if (
"layers.0.residual_group.blocks.0.conv_block.cab.0.weight"
in state_dict_keys
):
model = HAT(state_dict)
elif "patch_embed.proj.weight" in state_dict_keys:
model = Swin2SR(state_dict)
else:
model = SwinIR(state_dict)
@@ -78,12 +80,15 @@ def load_state_dict(state_dict) -> PyTorchModel:
or "generator.model.1.bn_l.running_mean" in state_dict_keys
):
model = LaMa(state_dict)
# MAT
elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys:
model = MAT(state_dict)
# Omni-SR
elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys:
model = OmniSR(state_dict)
# SCUNet
elif "m_head.0.weight" in state_dict_keys and "m_tail.0.weight" in state_dict_keys:
model = SCUNet(state_dict)
# DAT
elif "layers.0.blocks.2.attn.attn_mask_0" in state_dict_keys:
model = DAT(state_dict)
# Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
else:
try: