You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
2.6 KiB
97 lines
2.6 KiB
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import inspect
|
|
import os
|
|
import re
|
|
from typing import TYPE_CHECKING
|
|
import warnings
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Generator
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def rewrite_exception(old_name: str, new_name: str) -> Generator[None, None, None]:
|
|
"""
|
|
Rewrite the message of an exception.
|
|
"""
|
|
try:
|
|
yield
|
|
except Exception as err:
|
|
if not err.args:
|
|
raise
|
|
msg = str(err.args[0])
|
|
msg = msg.replace(old_name, new_name)
|
|
args: tuple[str, ...] = (msg,)
|
|
if len(err.args) > 1:
|
|
args = args + err.args[1:]
|
|
err.args = args
|
|
raise
|
|
|
|
|
|
def find_stack_level() -> int:
|
|
"""
|
|
Find the first place in the stack that is not inside pandas
|
|
(tests notwithstanding).
|
|
"""
|
|
|
|
import pandas as pd
|
|
|
|
pkg_dir = os.path.dirname(pd.__file__)
|
|
test_dir = os.path.join(pkg_dir, "tests")
|
|
|
|
# https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
|
|
frame = inspect.currentframe()
|
|
n = 0
|
|
while frame:
|
|
fname = inspect.getfile(frame)
|
|
if fname.startswith(pkg_dir) and not fname.startswith(test_dir):
|
|
frame = frame.f_back
|
|
n += 1
|
|
else:
|
|
break
|
|
return n
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def rewrite_warning(
|
|
target_message: str,
|
|
target_category: type[Warning],
|
|
new_message: str,
|
|
new_category: type[Warning] | None = None,
|
|
) -> Generator[None, None, None]:
|
|
"""
|
|
Rewrite the message of a warning.
|
|
|
|
Parameters
|
|
----------
|
|
target_message : str
|
|
Warning message to match.
|
|
target_category : Warning
|
|
Warning type to match.
|
|
new_message : str
|
|
New warning message to emit.
|
|
new_category : Warning or None, default None
|
|
New warning type to emit. When None, will be the same as target_category.
|
|
"""
|
|
if new_category is None:
|
|
new_category = target_category
|
|
with warnings.catch_warnings(record=True) as record:
|
|
yield
|
|
if len(record) > 0:
|
|
match = re.compile(target_message)
|
|
for warning in record:
|
|
if warning.category is target_category and re.search(
|
|
match, str(warning.message)
|
|
):
|
|
category = new_category
|
|
message: Warning | str = new_message
|
|
else:
|
|
category, message = warning.category, warning.message
|
|
warnings.warn_explicit(
|
|
message=message,
|
|
category=category,
|
|
filename=warning.filename,
|
|
lineno=warning.lineno,
|
|
)
|
|
|