A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/comfyanonymous/ComfyUI/commit/5cfe38f41c7091b0fd954877d9d7427a8b438b1a below:

Execution Model Inversion (#2666) · comfyanonymous/ComfyUI@5cfe38f · GitHub

1 +

import itertools

2 +

from typing import Sequence, Mapping

3 +

from comfy.graph import DynamicPrompt

4 + 5 +

import nodes

6 + 7 +

from comfy.graph_utils import is_link

8 + 9 +

class CacheKeySet:

10 +

def __init__(self, dynprompt, node_ids, is_changed_cache):

11 +

self.keys = {}

12 +

self.subcache_keys = {}

13 + 14 +

def add_keys(self, node_ids):

15 +

raise NotImplementedError()

16 + 17 +

def all_node_ids(self):

18 +

return set(self.keys.keys())

19 + 20 +

def get_used_keys(self):

21 +

return self.keys.values()

22 + 23 +

def get_used_subcache_keys(self):

24 +

return self.subcache_keys.values()

25 + 26 +

def get_data_key(self, node_id):

27 +

return self.keys.get(node_id, None)

28 + 29 +

def get_subcache_key(self, node_id):

30 +

return self.subcache_keys.get(node_id, None)

31 + 32 +

class Unhashable:

33 +

def __init__(self):

34 +

self.value = float("NaN")

35 + 36 +

def to_hashable(obj):

37 +

# So that we don't infinitely recurse since frozenset and tuples

38 +

# are Sequences.

39 +

if isinstance(obj, (int, float, str, bool, type(None))):

40 +

return obj

41 +

elif isinstance(obj, Mapping):

42 +

return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])

43 +

elif isinstance(obj, Sequence):

44 +

return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))

45 +

else:

46 +

# TODO - Support other objects like tensors?

47 +

return Unhashable()

48 + 49 +

class CacheKeySetID(CacheKeySet):

50 +

def __init__(self, dynprompt, node_ids, is_changed_cache):

51 +

super().__init__(dynprompt, node_ids, is_changed_cache)

52 +

self.dynprompt = dynprompt

53 +

self.add_keys(node_ids)

54 + 55 +

def add_keys(self, node_ids):

56 +

for node_id in node_ids:

57 +

if node_id in self.keys:

58 +

continue

59 +

node = self.dynprompt.get_node(node_id)

60 +

self.keys[node_id] = (node_id, node["class_type"])

61 +

self.subcache_keys[node_id] = (node_id, node["class_type"])

62 + 63 +

class CacheKeySetInputSignature(CacheKeySet):

64 +

def __init__(self, dynprompt, node_ids, is_changed_cache):

65 +

super().__init__(dynprompt, node_ids, is_changed_cache)

66 +

self.dynprompt = dynprompt

67 +

self.is_changed_cache = is_changed_cache

68 +

self.add_keys(node_ids)

69 + 70 +

def include_node_id_in_input(self) -> bool:

71 +

return False

72 + 73 +

def add_keys(self, node_ids):

74 +

for node_id in node_ids:

75 +

if node_id in self.keys:

76 +

continue

77 +

node = self.dynprompt.get_node(node_id)

78 +

self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)

79 +

self.subcache_keys[node_id] = (node_id, node["class_type"])

80 + 81 +

def get_node_signature(self, dynprompt, node_id):

82 +

signature = []

83 +

ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)

84 +

signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))

85 +

for ancestor_id in ancestors:

86 +

signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))

87 +

return to_hashable(signature)

88 + 89 +

def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):

90 +

node = dynprompt.get_node(node_id)

91 +

class_type = node["class_type"]

92 +

class_def = nodes.NODE_CLASS_MAPPINGS[class_type]

93 +

signature = [class_type, self.is_changed_cache.get(node_id)]

94 +

if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):

95 +

signature.append(node_id)

96 +

inputs = node["inputs"]

97 +

for key in sorted(inputs.keys()):

98 +

if is_link(inputs[key]):

99 +

(ancestor_id, ancestor_socket) = inputs[key]

100 +

ancestor_index = ancestor_order_mapping[ancestor_id]

101 +

signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))

102 +

else:

103 +

signature.append((key, inputs[key]))

104 +

return signature

105 + 106 +

# This function returns a list of all ancestors of the given node. The order of the list is

107 +

# deterministic based on which specific inputs the ancestor is connected by.

108 +

def get_ordered_ancestry(self, dynprompt, node_id):

109 +

ancestors = []

110 +

order_mapping = {}

111 +

self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)

112 +

return ancestors, order_mapping

113 + 114 +

def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):

115 +

inputs = dynprompt.get_node(node_id)["inputs"]

116 +

input_keys = sorted(inputs.keys())

117 +

for key in input_keys:

118 +

if is_link(inputs[key]):

119 +

ancestor_id = inputs[key][0]

120 +

if ancestor_id not in order_mapping:

121 +

ancestors.append(ancestor_id)

122 +

order_mapping[ancestor_id] = len(ancestors) - 1

123 +

self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)

124 + 125 +

class BasicCache:

126 +

def __init__(self, key_class):

127 +

self.key_class = key_class

128 +

self.initialized = False

129 +

self.dynprompt: DynamicPrompt

130 +

self.cache_key_set: CacheKeySet

131 +

self.cache = {}

132 +

self.subcaches = {}

133 + 134 +

def set_prompt(self, dynprompt, node_ids, is_changed_cache):

135 +

self.dynprompt = dynprompt

136 +

self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)

137 +

self.is_changed_cache = is_changed_cache

138 +

self.initialized = True

139 + 140 +

def all_node_ids(self):

141 +

assert self.initialized

142 +

node_ids = self.cache_key_set.all_node_ids()

143 +

for subcache in self.subcaches.values():

144 +

node_ids = node_ids.union(subcache.all_node_ids())

145 +

return node_ids

146 + 147 +

def _clean_cache(self):

148 +

preserve_keys = set(self.cache_key_set.get_used_keys())

149 +

to_remove = []

150 +

for key in self.cache:

151 +

if key not in preserve_keys:

152 +

to_remove.append(key)

153 +

for key in to_remove:

154 +

del self.cache[key]

155 + 156 +

def _clean_subcaches(self):

157 +

preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())

158 + 159 +

to_remove = []

160 +

for key in self.subcaches:

161 +

if key not in preserve_subcaches:

162 +

to_remove.append(key)

163 +

for key in to_remove:

164 +

del self.subcaches[key]

165 + 166 +

def clean_unused(self):

167 +

assert self.initialized

168 +

self._clean_cache()

169 +

self._clean_subcaches()

170 + 171 +

def _set_immediate(self, node_id, value):

172 +

assert self.initialized

173 +

cache_key = self.cache_key_set.get_data_key(node_id)

174 +

self.cache[cache_key] = value

175 + 176 +

def _get_immediate(self, node_id):

177 +

if not self.initialized:

178 +

return None

179 +

cache_key = self.cache_key_set.get_data_key(node_id)

180 +

if cache_key in self.cache:

181 +

return self.cache[cache_key]

182 +

else:

183 +

return None

184 + 185 +

def _ensure_subcache(self, node_id, children_ids):

186 +

subcache_key = self.cache_key_set.get_subcache_key(node_id)

187 +

subcache = self.subcaches.get(subcache_key, None)

188 +

if subcache is None:

189 +

subcache = BasicCache(self.key_class)

190 +

self.subcaches[subcache_key] = subcache

191 +

subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)

192 +

return subcache

193 + 194 +

def _get_subcache(self, node_id):

195 +

assert self.initialized

196 +

subcache_key = self.cache_key_set.get_subcache_key(node_id)

197 +

if subcache_key in self.subcaches:

198 +

return self.subcaches[subcache_key]

199 +

else:

200 +

return None

201 + 202 +

def recursive_debug_dump(self):

203 +

result = []

204 +

for key in self.cache:

205 +

result.append({"key": key, "value": self.cache[key]})

206 +

for key in self.subcaches:

207 +

result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})

208 +

return result

209 + 210 +

class HierarchicalCache(BasicCache):

211 +

def __init__(self, key_class):

212 +

super().__init__(key_class)

213 + 214 +

def _get_cache_for(self, node_id):

215 +

assert self.dynprompt is not None

216 +

parent_id = self.dynprompt.get_parent_node_id(node_id)

217 +

if parent_id is None:

218 +

return self

219 + 220 +

hierarchy = []

221 +

while parent_id is not None:

222 +

hierarchy.append(parent_id)

223 +

parent_id = self.dynprompt.get_parent_node_id(parent_id)

224 + 225 +

cache = self

226 +

for parent_id in reversed(hierarchy):

227 +

cache = cache._get_subcache(parent_id)

228 +

if cache is None:

229 +

return None

230 +

return cache

231 + 232 +

def get(self, node_id):

233 +

cache = self._get_cache_for(node_id)

234 +

if cache is None:

235 +

return None

236 +

return cache._get_immediate(node_id)

237 + 238 +

def set(self, node_id, value):

239 +

cache = self._get_cache_for(node_id)

240 +

assert cache is not None

241 +

cache._set_immediate(node_id, value)

242 + 243 +

def ensure_subcache_for(self, node_id, children_ids):

244 +

cache = self._get_cache_for(node_id)

245 +

assert cache is not None

246 +

return cache._ensure_subcache(node_id, children_ids)

247 + 248 +

class LRUCache(BasicCache):

249 +

def __init__(self, key_class, max_size=100):

250 +

super().__init__(key_class)

251 +

self.max_size = max_size

252 +

self.min_generation = 0

253 +

self.generation = 0

254 +

self.used_generation = {}

255 +

self.children = {}

256 + 257 +

def set_prompt(self, dynprompt, node_ids, is_changed_cache):

258 +

super().set_prompt(dynprompt, node_ids, is_changed_cache)

259 +

self.generation += 1

260 +

for node_id in node_ids:

261 +

self._mark_used(node_id)

262 + 263 +

def clean_unused(self):

264 +

while len(self.cache) > self.max_size and self.min_generation < self.generation:

265 +

self.min_generation += 1

266 +

to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]

267 +

for key in to_remove:

268 +

del self.cache[key]

269 +

del self.used_generation[key]

270 +

if key in self.children:

271 +

del self.children[key]

272 +

self._clean_subcaches()

273 + 274 +

def get(self, node_id):

275 +

self._mark_used(node_id)

276 +

return self._get_immediate(node_id)

277 + 278 +

def _mark_used(self, node_id):

279 +

cache_key = self.cache_key_set.get_data_key(node_id)

280 +

if cache_key is not None:

281 +

self.used_generation[cache_key] = self.generation

282 + 283 +

def set(self, node_id, value):

284 +

self._mark_used(node_id)

285 +

return self._set_immediate(node_id, value)

286 + 287 +

def ensure_subcache_for(self, node_id, children_ids):

288 +

# Just uses subcaches for tracking 'live' nodes

289 +

super()._ensure_subcache(node_id, children_ids)

290 + 291 +

self.cache_key_set.add_keys(children_ids)

292 +

self._mark_used(node_id)

293 +

cache_key = self.cache_key_set.get_data_key(node_id)

294 +

self.children[cache_key] = []

295 +

for child_id in children_ids:

296 +

self._mark_used(child_id)

297 +

self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))

298 +

return self

299 +

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4