diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 9b777f217..8a0c5f435 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1,8 +1,8 @@ import datetime import json +import logging import math import os -import logging import numpy as np import safetensors @@ -10,13 +10,15 @@ import torch from PIL import Image, ImageDraw, ImageFont from PIL.PngImagePlugin import PngInfo -import comfy -import comfy_extras +import comfy.samplers +import comfy.utils +import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers from comfy.cli_args import args from comfy.comfy_types.node_typing import IO + class TrainSampler(comfy.samplers.Sampler): def __init__(self, loss_fn, optimizer, loss_callback=None):