You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
231 lines
6.0 KiB
231 lines
6.0 KiB
1 year ago
|
import math
|
||
|
import textwrap
|
||
|
import sys
|
||
|
import pytest
|
||
|
import threading
|
||
|
import traceback
|
||
|
import time
|
||
|
|
||
|
import numpy as np
|
||
|
from numpy.testing import IS_PYPY
|
||
|
from . import util
|
||
|
|
||
|
|
||
|
class TestF77Callback(util.F2PyTest):
|
||
|
sources = [util.getpath("tests", "src", "callback", "foo.f")]
|
||
|
|
||
|
@pytest.mark.parametrize("name", "t,t2".split(","))
|
||
|
def test_all(self, name):
|
||
|
self.check_function(name)
|
||
|
|
||
|
@pytest.mark.xfail(IS_PYPY,
|
||
|
reason="PyPy cannot modify tp_doc after PyType_Ready")
|
||
|
def test_docstring(self):
|
||
|
expected = textwrap.dedent("""\
|
||
|
a = t(fun,[fun_extra_args])
|
||
|
|
||
|
Wrapper for ``t``.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
fun : call-back function
|
||
|
|
||
|
Other Parameters
|
||
|
----------------
|
||
|
fun_extra_args : input tuple, optional
|
||
|
Default: ()
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
a : int
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
Call-back functions::
|
||
|
|
||
|
def fun(): return a
|
||
|
Return objects:
|
||
|
a : int
|
||
|
""")
|
||
|
assert self.module.t.__doc__ == expected
|
||
|
|
||
|
def check_function(self, name):
|
||
|
t = getattr(self.module, name)
|
||
|
r = t(lambda: 4)
|
||
|
assert r == 4
|
||
|
r = t(lambda a: 5, fun_extra_args=(6, ))
|
||
|
assert r == 5
|
||
|
r = t(lambda a: a, fun_extra_args=(6, ))
|
||
|
assert r == 6
|
||
|
r = t(lambda a: 5 + a, fun_extra_args=(7, ))
|
||
|
assert r == 12
|
||
|
r = t(lambda a: math.degrees(a), fun_extra_args=(math.pi, ))
|
||
|
assert r == 180
|
||
|
r = t(math.degrees, fun_extra_args=(math.pi, ))
|
||
|
assert r == 180
|
||
|
|
||
|
r = t(self.module.func, fun_extra_args=(6, ))
|
||
|
assert r == 17
|
||
|
r = t(self.module.func0)
|
||
|
assert r == 11
|
||
|
r = t(self.module.func0._cpointer)
|
||
|
assert r == 11
|
||
|
|
||
|
class A:
|
||
|
def __call__(self):
|
||
|
return 7
|
||
|
|
||
|
def mth(self):
|
||
|
return 9
|
||
|
|
||
|
a = A()
|
||
|
r = t(a)
|
||
|
assert r == 7
|
||
|
r = t(a.mth)
|
||
|
assert r == 9
|
||
|
|
||
|
@pytest.mark.skipif(sys.platform == 'win32',
|
||
|
reason='Fails with MinGW64 Gfortran (Issue #9673)')
|
||
|
def test_string_callback(self):
|
||
|
def callback(code):
|
||
|
if code == "r":
|
||
|
return 0
|
||
|
else:
|
||
|
return 1
|
||
|
|
||
|
f = getattr(self.module, "string_callback")
|
||
|
r = f(callback)
|
||
|
assert r == 0
|
||
|
|
||
|
@pytest.mark.skipif(sys.platform == 'win32',
|
||
|
reason='Fails with MinGW64 Gfortran (Issue #9673)')
|
||
|
def test_string_callback_array(self):
|
||
|
# See gh-10027
|
||
|
cu1 = np.zeros((1, ), "S8")
|
||
|
cu2 = np.zeros((1, 8), "c")
|
||
|
cu3 = np.array([""], "S8")
|
||
|
|
||
|
def callback(cu, lencu):
|
||
|
if cu.shape != (lencu,):
|
||
|
return 1
|
||
|
if cu.dtype != "S8":
|
||
|
return 2
|
||
|
if not np.all(cu == b""):
|
||
|
return 3
|
||
|
return 0
|
||
|
|
||
|
f = getattr(self.module, "string_callback_array")
|
||
|
for cu in [cu1, cu2, cu3]:
|
||
|
res = f(callback, cu, cu.size)
|
||
|
assert res == 0
|
||
|
|
||
|
def test_threadsafety(self):
|
||
|
# Segfaults if the callback handling is not threadsafe
|
||
|
|
||
|
errors = []
|
||
|
|
||
|
def cb():
|
||
|
# Sleep here to make it more likely for another thread
|
||
|
# to call their callback at the same time.
|
||
|
time.sleep(1e-3)
|
||
|
|
||
|
# Check reentrancy
|
||
|
r = self.module.t(lambda: 123)
|
||
|
assert r == 123
|
||
|
|
||
|
return 42
|
||
|
|
||
|
def runner(name):
|
||
|
try:
|
||
|
for j in range(50):
|
||
|
r = self.module.t(cb)
|
||
|
assert r == 42
|
||
|
self.check_function(name)
|
||
|
except Exception:
|
||
|
errors.append(traceback.format_exc())
|
||
|
|
||
|
threads = [
|
||
|
threading.Thread(target=runner, args=(arg, ))
|
||
|
for arg in ("t", "t2") for n in range(20)
|
||
|
]
|
||
|
|
||
|
for t in threads:
|
||
|
t.start()
|
||
|
|
||
|
for t in threads:
|
||
|
t.join()
|
||
|
|
||
|
errors = "\n\n".join(errors)
|
||
|
if errors:
|
||
|
raise AssertionError(errors)
|
||
|
|
||
|
def test_hidden_callback(self):
|
||
|
try:
|
||
|
self.module.hidden_callback(2)
|
||
|
except Exception as msg:
|
||
|
assert str(msg).startswith("Callback global_f not defined")
|
||
|
|
||
|
try:
|
||
|
self.module.hidden_callback2(2)
|
||
|
except Exception as msg:
|
||
|
assert str(msg).startswith("cb: Callback global_f not defined")
|
||
|
|
||
|
self.module.global_f = lambda x: x + 1
|
||
|
r = self.module.hidden_callback(2)
|
||
|
assert r == 3
|
||
|
|
||
|
self.module.global_f = lambda x: x + 2
|
||
|
r = self.module.hidden_callback(2)
|
||
|
assert r == 4
|
||
|
|
||
|
del self.module.global_f
|
||
|
try:
|
||
|
self.module.hidden_callback(2)
|
||
|
except Exception as msg:
|
||
|
assert str(msg).startswith("Callback global_f not defined")
|
||
|
|
||
|
self.module.global_f = lambda x=0: x + 3
|
||
|
r = self.module.hidden_callback(2)
|
||
|
assert r == 5
|
||
|
|
||
|
# reproducer of gh18341
|
||
|
r = self.module.hidden_callback2(2)
|
||
|
assert r == 3
|
||
|
|
||
|
|
||
|
class TestF77CallbackPythonTLS(TestF77Callback):
|
||
|
"""
|
||
|
Callback tests using Python thread-local storage instead of
|
||
|
compiler-provided
|
||
|
"""
|
||
|
|
||
|
options = ["-DF2PY_USE_PYTHON_TLS"]
|
||
|
|
||
|
|
||
|
class TestF90Callback(util.F2PyTest):
|
||
|
sources = [util.getpath("tests", "src", "callback", "gh17797.f90")]
|
||
|
|
||
|
def test_gh17797(self):
|
||
|
def incr(x):
|
||
|
return x + 123
|
||
|
|
||
|
y = np.array([1, 2, 3], dtype=np.int64)
|
||
|
r = self.module.gh17797(incr, y)
|
||
|
assert r == 123 + 1 + 2 + 3
|
||
|
|
||
|
|
||
|
class TestGH18335(util.F2PyTest):
|
||
|
"""The reproduction of the reported issue requires specific input that
|
||
|
extensions may break the issue conditions, so the reproducer is
|
||
|
implemented as a separate test class. Do not extend this test with
|
||
|
other tests!
|
||
|
"""
|
||
|
sources = [util.getpath("tests", "src", "callback", "gh18335.f90")]
|
||
|
|
||
|
def test_gh18335(self):
|
||
|
def foo(x):
|
||
|
x[0] += 1
|
||
|
|
||
|
r = self.module.gh18335(foo)
|
||
|
assert r == 123 + 1
|