A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/jax-ml/jax/commit/cb48f42 below:

Raise an error on non-hashable static arguments for jax.jit and xla_c… · jax-ml/jax@cb48f42 · GitHub

File tree Expand file treeCollapse file tree 2 files changed

+19

-10

lines changed

Filter options

Expand file treeCollapse file tree 2 files changed

+19

-10

lines changed Original file line number Diff line number Diff line change

@@ -88,13 +88,10 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],

88 88

try:

89 89

hash(static_arg)

90 90

except TypeError:

91 -

logging.warning(

92 -

"Static argument (index %s) of type %s for function %s is "

93 -

"non-hashable. As this can lead to unexpected cache-misses, it "

94 -

"will raise an error in a near future.", i, type(static_arg),

95 -

f.__name__)

96 -

# e.g. ndarrays, DeviceArrays

97 -

fixed_args[i] = WrapHashably(static_arg) # type: ignore

91 +

raise ValueError(

92 +

"Non-hashable static arguments are not supported, as this can lead "

93 +

f"to unexpected cache-misses. Static argument (index {i}) of type "

94 +

f"{type(static_arg)} for function {f.__name__} is non-hashable.")

98 95

else:

99 96

fixed_args[i] = Hashable(static_arg) # type: ignore

100 97 Original file line number Diff line number Diff line change

@@ -415,6 +415,18 @@ def test_jit_reference_dropping(self):

415 415

del g # no more references to x

416 416

assert x() is None # x is gone

417 417 418 +

def test_jit_raises_on_first_invocation_on_non_hashable_static_argnum(self):

419 +

if self.jit != jax.api._python_jit:

420 +

raise unittest.SkipTest("this test only applies to _python_jit")

421 +

f = lambda x, y: x + 3

422 +

jitted_f = self.jit(f, static_argnums=(1,))

423 + 424 +

msg = ("Non-hashable static arguments are not supported, as this can lead "

425 +

"to unexpected cache-misses. Static argument (index 1) of type "

426 +

"<class 'numpy.ndarray'> for function <lambda> is non-hashable.")

427 +

with self.assertRaisesRegex(ValueError, re.escape(msg)):

428 +

jitted_f(1, np.asarray(1))

429 + 418 430

def test_cpp_jit_raises_on_non_hashable_static_argnum(self):

419 431

if version < (0, 1, 58):

420 432

raise unittest.SkipTest("Disabled because it depends on some future "

@@ -428,9 +440,9 @@ def test_cpp_jit_raises_on_non_hashable_static_argnum(self):

428 440 429 441

jitted_f(1, 1)

430 442 431 -

msg = (

432 -

"""Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'numpy.ndarray'>, 1. The error was:

433 -

TypeError: unhashable type: 'numpy.ndarray'""")

443 +

msg = ("Non-hashable static arguments are not supported. An error occured "

444 +

"while trying to hash an object of type <class 'numpy.ndarray'>, 1. "

445 +

"The error was:\nTypeError: unhashable type: 'numpy.ndarray'")

434 446 435 447

with self.assertRaisesRegex(ValueError, re.escape(msg)):

436 448

jitted_f(1, np.asarray(1))

You can’t perform that action at this time.


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4