+19
-10
lines changedFilter options
+19
-10
lines changed Original file line number Diff line number Diff line change
@@ -88,13 +88,10 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
88
88
try:
89
89
hash(static_arg)
90
90
except TypeError:
91
-
logging.warning(
92
-
"Static argument (index %s) of type %s for function %s is "
93
-
"non-hashable. As this can lead to unexpected cache-misses, it "
94
-
"will raise an error in a near future.", i, type(static_arg),
95
-
f.__name__)
96
-
# e.g. ndarrays, DeviceArrays
97
-
fixed_args[i] = WrapHashably(static_arg) # type: ignore
91
+
raise ValueError(
92
+
"Non-hashable static arguments are not supported, as this can lead "
93
+
f"to unexpected cache-misses. Static argument (index {i}) of type "
94
+
f"{type(static_arg)} for function {f.__name__} is non-hashable.")
98
95
else:
99
96
fixed_args[i] = Hashable(static_arg) # type: ignore
100
97
Original file line number Diff line number Diff line change
@@ -415,6 +415,18 @@ def test_jit_reference_dropping(self):
415
415
del g # no more references to x
416
416
assert x() is None # x is gone
417
417
418
+
def test_jit_raises_on_first_invocation_on_non_hashable_static_argnum(self):
419
+
if self.jit != jax.api._python_jit:
420
+
raise unittest.SkipTest("this test only applies to _python_jit")
421
+
f = lambda x, y: x + 3
422
+
jitted_f = self.jit(f, static_argnums=(1,))
423
+
424
+
msg = ("Non-hashable static arguments are not supported, as this can lead "
425
+
"to unexpected cache-misses. Static argument (index 1) of type "
426
+
"<class 'numpy.ndarray'> for function <lambda> is non-hashable.")
427
+
with self.assertRaisesRegex(ValueError, re.escape(msg)):
428
+
jitted_f(1, np.asarray(1))
429
+
418
430
def test_cpp_jit_raises_on_non_hashable_static_argnum(self):
419
431
if version < (0, 1, 58):
420
432
raise unittest.SkipTest("Disabled because it depends on some future "
@@ -428,9 +440,9 @@ def test_cpp_jit_raises_on_non_hashable_static_argnum(self):
428
440
429
441
jitted_f(1, 1)
430
442
431
-
msg = (
432
-
"""Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'numpy.ndarray'>, 1. The error was:
433
-
TypeError: unhashable type: 'numpy.ndarray'""")
443
+
msg = ("Non-hashable static arguments are not supported. An error occured "
444
+
"while trying to hash an object of type <class 'numpy.ndarray'>, 1. "
445
+
"The error was:\nTypeError: unhashable type: 'numpy.ndarray'")
434
446
435
447
with self.assertRaisesRegex(ValueError, re.escape(msg)):
436
448
jitted_f(1, np.asarray(1))
You can’t perform that action at this time.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4