@@ -8364,6 +8364,59 @@ def get_output_names_callback(name, arr):
8364
8364
check_name(us_sym, ['data', 'pooling_data', 'pooling_output'])
8365
8365
del os.environ['MXNET_SUBGRAPH_BACKEND']
8366
8366
8367
+
@with_seed()
8368
+
def test_monitor_with_variable_input_shape():
8369
+
output = {}
8370
+
8371
+
def get_output_min_callback(name, arr):
8372
+
name = py_str(name)
8373
+
handle = ctypes.cast(arr, NDArrayHandle)
8374
+
arr = NDArray(handle, writable=False)
8375
+
min_val = mx.ndarray.min(arr).asscalar()
8376
+
if name in output:
8377
+
output[name] = min(output[name], min_val)
8378
+
else:
8379
+
output[name] = min_val
8380
+
8381
+
def check_result(output, names):
8382
+
assert len(output) > 0
8383
+
for k, v in output.items():
8384
+
assert k in names
8385
+
assert v is not None
8386
+
8387
+
is_windows = sys.platform.startswith('win')
8388
+
if (is_windows):
8389
+
# Windows doesn't support set environment variable on the fly, so disable it for now
8390
+
pass
8391
+
else:
8392
+
# Disable subgraph in case subgraph will replace symbol
8393
+
os.environ['MXNET_SUBGRAPH_BACKEND'] = "NONE"
8394
+
8395
+
batch_size = 1
8396
+
op_name = 'conv'
8397
+
dshape = (batch_size, 3, 10, 10)
8398
+
data = mx.sym.Variable('data', shape=dshape)
8399
+
sym = mx.sym.Convolution(data, kernel=(1, 1), num_filter=1, name=op_name)
8400
+
8401
+
mod = mx.module.Module(symbol=sym, label_names=None)
8402
+
mod.bind(for_training=False, data_shapes=[('data', dshape)])
8403
+
mod.init_params()
8404
+
mod._exec_group.execs[0].set_monitor_callback(get_output_min_callback, monitor_all=True)
8405
+
8406
+
new_dshape = dshape[:-1] + (dshape[-1] + 4,)
8407
+
new_data = mx.nd.random.uniform(shape=new_dshape)
8408
+
new_data = mx.io.NDArrayIter(data=new_data, batch_size=batch_size)
8409
+
new_data = DummyIter(new_data)
8410
+
8411
+
for batch in new_data:
8412
+
mod.forward(data_batch=batch, is_train=False)
8413
+
mx.nd.waitall()
8414
+
break
8415
+
8416
+
name_list = ['data', 'conv_data', 'conv_weight', 'conv_bias', 'conv_output']
8417
+
check_result(output, name_list)
8418
+
del os.environ['MXNET_SUBGRAPH_BACKEND']
8419
+
8367
8420
@with_seed()
8368
8421
@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/13915")
8369
8422
def test_activation():
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4