1
+
import itertools
2
+
from typing import Sequence, Mapping
3
+
from comfy.graph import DynamicPrompt
4
+
5
+
import nodes
6
+
7
+
from comfy.graph_utils import is_link
8
+
9
+
class CacheKeySet:
10
+
def __init__(self, dynprompt, node_ids, is_changed_cache):
11
+
self.keys = {}
12
+
self.subcache_keys = {}
13
+
14
+
def add_keys(self, node_ids):
15
+
raise NotImplementedError()
16
+
17
+
def all_node_ids(self):
18
+
return set(self.keys.keys())
19
+
20
+
def get_used_keys(self):
21
+
return self.keys.values()
22
+
23
+
def get_used_subcache_keys(self):
24
+
return self.subcache_keys.values()
25
+
26
+
def get_data_key(self, node_id):
27
+
return self.keys.get(node_id, None)
28
+
29
+
def get_subcache_key(self, node_id):
30
+
return self.subcache_keys.get(node_id, None)
31
+
32
+
class Unhashable:
33
+
def __init__(self):
34
+
self.value = float("NaN")
35
+
36
+
def to_hashable(obj):
37
+
# So that we don't infinitely recurse since frozenset and tuples
38
+
# are Sequences.
39
+
if isinstance(obj, (int, float, str, bool, type(None))):
40
+
return obj
41
+
elif isinstance(obj, Mapping):
42
+
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
43
+
elif isinstance(obj, Sequence):
44
+
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
45
+
else:
46
+
# TODO - Support other objects like tensors?
47
+
return Unhashable()
48
+
49
+
class CacheKeySetID(CacheKeySet):
50
+
def __init__(self, dynprompt, node_ids, is_changed_cache):
51
+
super().__init__(dynprompt, node_ids, is_changed_cache)
52
+
self.dynprompt = dynprompt
53
+
self.add_keys(node_ids)
54
+
55
+
def add_keys(self, node_ids):
56
+
for node_id in node_ids:
57
+
if node_id in self.keys:
58
+
continue
59
+
node = self.dynprompt.get_node(node_id)
60
+
self.keys[node_id] = (node_id, node["class_type"])
61
+
self.subcache_keys[node_id] = (node_id, node["class_type"])
62
+
63
+
class CacheKeySetInputSignature(CacheKeySet):
64
+
def __init__(self, dynprompt, node_ids, is_changed_cache):
65
+
super().__init__(dynprompt, node_ids, is_changed_cache)
66
+
self.dynprompt = dynprompt
67
+
self.is_changed_cache = is_changed_cache
68
+
self.add_keys(node_ids)
69
+
70
+
def include_node_id_in_input(self) -> bool:
71
+
return False
72
+
73
+
def add_keys(self, node_ids):
74
+
for node_id in node_ids:
75
+
if node_id in self.keys:
76
+
continue
77
+
node = self.dynprompt.get_node(node_id)
78
+
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
79
+
self.subcache_keys[node_id] = (node_id, node["class_type"])
80
+
81
+
def get_node_signature(self, dynprompt, node_id):
82
+
signature = []
83
+
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
84
+
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
85
+
for ancestor_id in ancestors:
86
+
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
87
+
return to_hashable(signature)
88
+
89
+
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
90
+
node = dynprompt.get_node(node_id)
91
+
class_type = node["class_type"]
92
+
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
93
+
signature = [class_type, self.is_changed_cache.get(node_id)]
94
+
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
95
+
signature.append(node_id)
96
+
inputs = node["inputs"]
97
+
for key in sorted(inputs.keys()):
98
+
if is_link(inputs[key]):
99
+
(ancestor_id, ancestor_socket) = inputs[key]
100
+
ancestor_index = ancestor_order_mapping[ancestor_id]
101
+
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
102
+
else:
103
+
signature.append((key, inputs[key]))
104
+
return signature
105
+
106
+
# This function returns a list of all ancestors of the given node. The order of the list is
107
+
# deterministic based on which specific inputs the ancestor is connected by.
108
+
def get_ordered_ancestry(self, dynprompt, node_id):
109
+
ancestors = []
110
+
order_mapping = {}
111
+
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
112
+
return ancestors, order_mapping
113
+
114
+
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
115
+
inputs = dynprompt.get_node(node_id)["inputs"]
116
+
input_keys = sorted(inputs.keys())
117
+
for key in input_keys:
118
+
if is_link(inputs[key]):
119
+
ancestor_id = inputs[key][0]
120
+
if ancestor_id not in order_mapping:
121
+
ancestors.append(ancestor_id)
122
+
order_mapping[ancestor_id] = len(ancestors) - 1
123
+
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
124
+
125
+
class BasicCache:
126
+
def __init__(self, key_class):
127
+
self.key_class = key_class
128
+
self.initialized = False
129
+
self.dynprompt: DynamicPrompt
130
+
self.cache_key_set: CacheKeySet
131
+
self.cache = {}
132
+
self.subcaches = {}
133
+
134
+
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
135
+
self.dynprompt = dynprompt
136
+
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
137
+
self.is_changed_cache = is_changed_cache
138
+
self.initialized = True
139
+
140
+
def all_node_ids(self):
141
+
assert self.initialized
142
+
node_ids = self.cache_key_set.all_node_ids()
143
+
for subcache in self.subcaches.values():
144
+
node_ids = node_ids.union(subcache.all_node_ids())
145
+
return node_ids
146
+
147
+
def _clean_cache(self):
148
+
preserve_keys = set(self.cache_key_set.get_used_keys())
149
+
to_remove = []
150
+
for key in self.cache:
151
+
if key not in preserve_keys:
152
+
to_remove.append(key)
153
+
for key in to_remove:
154
+
del self.cache[key]
155
+
156
+
def _clean_subcaches(self):
157
+
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
158
+
159
+
to_remove = []
160
+
for key in self.subcaches:
161
+
if key not in preserve_subcaches:
162
+
to_remove.append(key)
163
+
for key in to_remove:
164
+
del self.subcaches[key]
165
+
166
+
def clean_unused(self):
167
+
assert self.initialized
168
+
self._clean_cache()
169
+
self._clean_subcaches()
170
+
171
+
def _set_immediate(self, node_id, value):
172
+
assert self.initialized
173
+
cache_key = self.cache_key_set.get_data_key(node_id)
174
+
self.cache[cache_key] = value
175
+
176
+
def _get_immediate(self, node_id):
177
+
if not self.initialized:
178
+
return None
179
+
cache_key = self.cache_key_set.get_data_key(node_id)
180
+
if cache_key in self.cache:
181
+
return self.cache[cache_key]
182
+
else:
183
+
return None
184
+
185
+
def _ensure_subcache(self, node_id, children_ids):
186
+
subcache_key = self.cache_key_set.get_subcache_key(node_id)
187
+
subcache = self.subcaches.get(subcache_key, None)
188
+
if subcache is None:
189
+
subcache = BasicCache(self.key_class)
190
+
self.subcaches[subcache_key] = subcache
191
+
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
192
+
return subcache
193
+
194
+
def _get_subcache(self, node_id):
195
+
assert self.initialized
196
+
subcache_key = self.cache_key_set.get_subcache_key(node_id)
197
+
if subcache_key in self.subcaches:
198
+
return self.subcaches[subcache_key]
199
+
else:
200
+
return None
201
+
202
+
def recursive_debug_dump(self):
203
+
result = []
204
+
for key in self.cache:
205
+
result.append({"key": key, "value": self.cache[key]})
206
+
for key in self.subcaches:
207
+
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
208
+
return result
209
+
210
+
class HierarchicalCache(BasicCache):
211
+
def __init__(self, key_class):
212
+
super().__init__(key_class)
213
+
214
+
def _get_cache_for(self, node_id):
215
+
assert self.dynprompt is not None
216
+
parent_id = self.dynprompt.get_parent_node_id(node_id)
217
+
if parent_id is None:
218
+
return self
219
+
220
+
hierarchy = []
221
+
while parent_id is not None:
222
+
hierarchy.append(parent_id)
223
+
parent_id = self.dynprompt.get_parent_node_id(parent_id)
224
+
225
+
cache = self
226
+
for parent_id in reversed(hierarchy):
227
+
cache = cache._get_subcache(parent_id)
228
+
if cache is None:
229
+
return None
230
+
return cache
231
+
232
+
def get(self, node_id):
233
+
cache = self._get_cache_for(node_id)
234
+
if cache is None:
235
+
return None
236
+
return cache._get_immediate(node_id)
237
+
238
+
def set(self, node_id, value):
239
+
cache = self._get_cache_for(node_id)
240
+
assert cache is not None
241
+
cache._set_immediate(node_id, value)
242
+
243
+
def ensure_subcache_for(self, node_id, children_ids):
244
+
cache = self._get_cache_for(node_id)
245
+
assert cache is not None
246
+
return cache._ensure_subcache(node_id, children_ids)
247
+
248
+
class LRUCache(BasicCache):
249
+
def __init__(self, key_class, max_size=100):
250
+
super().__init__(key_class)
251
+
self.max_size = max_size
252
+
self.min_generation = 0
253
+
self.generation = 0
254
+
self.used_generation = {}
255
+
self.children = {}
256
+
257
+
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
258
+
super().set_prompt(dynprompt, node_ids, is_changed_cache)
259
+
self.generation += 1
260
+
for node_id in node_ids:
261
+
self._mark_used(node_id)
262
+
263
+
def clean_unused(self):
264
+
while len(self.cache) > self.max_size and self.min_generation < self.generation:
265
+
self.min_generation += 1
266
+
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
267
+
for key in to_remove:
268
+
del self.cache[key]
269
+
del self.used_generation[key]
270
+
if key in self.children:
271
+
del self.children[key]
272
+
self._clean_subcaches()
273
+
274
+
def get(self, node_id):
275
+
self._mark_used(node_id)
276
+
return self._get_immediate(node_id)
277
+
278
+
def _mark_used(self, node_id):
279
+
cache_key = self.cache_key_set.get_data_key(node_id)
280
+
if cache_key is not None:
281
+
self.used_generation[cache_key] = self.generation
282
+
283
+
def set(self, node_id, value):
284
+
self._mark_used(node_id)
285
+
return self._set_immediate(node_id, value)
286
+
287
+
def ensure_subcache_for(self, node_id, children_ids):
288
+
# Just uses subcaches for tracking 'live' nodes
289
+
super()._ensure_subcache(node_id, children_ids)
290
+
291
+
self.cache_key_set.add_keys(children_ids)
292
+
self._mark_used(node_id)
293
+
cache_key = self.cache_key_set.get_data_key(node_id)
294
+
self.children[cache_key] = []
295
+
for child_id in children_ids:
296
+
self._mark_used(child_id)
297
+
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
298
+
return self
299
+
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4