mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-10-25 07:54:30 +00:00
92 lines
2.4 KiB
Python
92 lines
2.4 KiB
Python
import torch
|
|
|
|
class NestedTensor:
|
|
def __init__(self, tensors):
|
|
self.tensors = list(tensors)
|
|
self.is_nested = True
|
|
|
|
def _copy(self):
|
|
return NestedTensor(self.tensors)
|
|
|
|
def apply_operation(self, other, operation):
|
|
o = self._copy()
|
|
if isinstance(other, NestedTensor):
|
|
for i, t in enumerate(o.tensors):
|
|
o.tensors[i] = operation(t, other.tensors[i])
|
|
else:
|
|
for i, t in enumerate(o.tensors):
|
|
o.tensors[i] = operation(t, other)
|
|
return o
|
|
|
|
def __add__(self, b):
|
|
return self.apply_operation(b, lambda x, y: x + y)
|
|
|
|
def __sub__(self, b):
|
|
return self.apply_operation(b, lambda x, y: x - y)
|
|
|
|
def __mul__(self, b):
|
|
return self.apply_operation(b, lambda x, y: x * y)
|
|
|
|
# def __itruediv__(self, b):
|
|
# return self.apply_operation(b, lambda x, y: x / y)
|
|
|
|
def __truediv__(self, b):
|
|
return self.apply_operation(b, lambda x, y: x / y)
|
|
|
|
def __getitem__(self, *args, **kwargs):
|
|
return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
|
|
|
|
def unbind(self):
|
|
return self.tensors
|
|
|
|
def to(self, *args, **kwargs):
|
|
o = self._copy()
|
|
for i, t in enumerate(o.tensors):
|
|
o.tensors[i] = t.to(*args, **kwargs)
|
|
return o
|
|
|
|
def new_ones(self, *args, **kwargs):
|
|
return self.tensors[0].new_ones(*args, **kwargs)
|
|
|
|
def float(self):
|
|
return self.to(dtype=torch.float)
|
|
|
|
def chunk(self, *args, **kwargs):
|
|
return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
|
|
|
|
def size(self):
|
|
return self.tensors[0].size()
|
|
|
|
@property
|
|
def shape(self):
|
|
return self.tensors[0].shape
|
|
|
|
@property
|
|
def ndim(self):
|
|
dims = 0
|
|
for t in self.tensors:
|
|
dims = max(t.ndim, dims)
|
|
return dims
|
|
|
|
@property
|
|
def device(self):
|
|
return self.tensors[0].device
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self.tensors[0].dtype
|
|
|
|
@property
|
|
def layout(self):
|
|
return self.tensors[0].layout
|
|
|
|
|
|
def cat_nested(tensors, *args, **kwargs):
|
|
cated_tensors = []
|
|
for i in range(len(tensors[0].tensors)):
|
|
tens = []
|
|
for j in range(len(tensors)):
|
|
tens.append(tensors[j].tensors[i])
|
|
cated_tensors.append(torch.cat(tens, *args, **kwargs))
|
|
return NestedTensor(cated_tensors)
|