"""Unit tests of posydon/grids/scrubbing.py
"""
__authors__ = [
"Matthias Kruckow <Matthias.Kruckow@unige.ch>"
]
# import the module which will be tested
import posydon.grids.scrubbing as totest
# aliases
np = totest.np
# import other needed code for the tests, which is not already imported in the
# module you like to test
from pytest import fixture, raises, warns
from inspect import isroutine
#from posydon.grids.psygrid import PSyGrid
from posydon.utils.posydonwarning import InappropriateValueWarning
# define test classes collecting several test functions
[docs]
class TestElements:
# check for objects, which should be an element of the tested module
[docs]
def test_dir(self):
elements = {'LG_MTRANSFER_RATE_THRESHOLD', 'Pwarn',\
'RL_RELATIVE_OVERFLOW_THRESHOLD',\
'THRESHOLD_CENTRAL_ABUNDANCE',\
'THRESHOLD_CENTRAL_ABUNDANCE_LOOSE_C', '__authors__',\
'__builtins__', '__cached__', '__doc__', '__file__',\
'__loader__', '__name__', '__package__', '__spec__',\
'keep_after_RLO', 'keep_till_central_abundance_He_C',\
'np', 'scrub'}
totest_elements = set(dir(totest))
missing_in_test = elements - totest_elements
assert len(missing_in_test) == 0, "There are missing objects in "\
+f"{totest.__name__}: "\
+f"{missing_in_test}. Please "\
+"check, whether they have been "\
+"removed on purpose and update "\
+"this unit test."
new_in_test = totest_elements - elements
assert len(new_in_test) == 0, "There are new objects in "\
+f"{totest.__name__}: {new_in_test}. "\
+"Please check, whether they have been "\
+"added on purpose and update this "\
+"unit test."
[docs]
def test_instance_scrub(self):
assert isroutine(totest.scrub)
[docs]
def test_instance_keep_after_RLO(self):
assert isroutine(totest.keep_after_RLO)
[docs]
def test_instance_keep_till_central_abundance_He_C(self):
assert isroutine(totest.keep_till_central_abundance_He_C)
[docs]
class TestFunctions:
[docs]
@fixture
def tables(self):
return np.array([1.0, 0.9, 0.8, 0.7, 0.6], dtype=([('mass', '<f8')]))
[docs]
@fixture
def models(self):
return np.array(range(5))
[docs]
@fixture
def ages(self, models):
return 10**models
[docs]
@fixture
def star_history(self):
# a temporary star history for testing
return np.array([(1.0, 0.2, 0.0), (1.0e+2, 0.9, 0.0),\
(1.0e+3, 0.2, 0.8)],\
dtype=[('star_age', '<f8'), ('center_he4', '<f8'),\
('center_c12', '<f8')])
[docs]
@fixture
def star_history2(self):
# a temporary star history for testing
return np.array([(1.0, 1.0, 0.0), (1.0e+2, 0.0, 0.4),\
(1.0e+3, 0.0, 0.0)],\
dtype=[('star_age', '<f8'), ('center_he4', '<f8'),\
('center_c12', '<f8')])
[docs]
@fixture
def binary_history(self):
# a temporary binary history for testing
return np.array([(totest.RL_RELATIVE_OVERFLOW_THRESHOLD-1.0,\
totest.RL_RELATIVE_OVERFLOW_THRESHOLD-1.0,\
totest.LG_MTRANSFER_RATE_THRESHOLD-1.0, 1.0),\
(totest.RL_RELATIVE_OVERFLOW_THRESHOLD,\
totest.RL_RELATIVE_OVERFLOW_THRESHOLD-1.0,\
totest.LG_MTRANSFER_RATE_THRESHOLD, 1.0e+2),\
(totest.RL_RELATIVE_OVERFLOW_THRESHOLD-1.0,\
totest.RL_RELATIVE_OVERFLOW_THRESHOLD,\
totest.LG_MTRANSFER_RATE_THRESHOLD, 1.0e+3)],\
dtype=[('rl_relative_overflow_1', '<f8'),\
('rl_relative_overflow_2', '<f8'),\
('lg_mtransfer_rate', '<f8'), ('age', '<f8')])
# test functions
[docs]
def test_scrub(self, tables, models, ages):
# missing argument
with raises(TypeError, match="missing 3 required positional "\
+"arguments: 'tables', 'models', and "\
+"'ages'"):
totest.scrub()
# bad input
with raises(TypeError, match="'NoneType' object is not iterable"):
totest.scrub(None, None, None)
# examples: nothing
assert totest.scrub([None], [None], [None]) == [None]
assert totest.scrub([None], [models], [ages]) == [None]
assert totest.scrub([tables], [None], [ages]) == [None]
assert totest.scrub([tables], [models], [None]) == [None]
assert np.array_equal(totest.scrub([np.array([])], [np.array([])],\
[np.array([])]), [np.array([])])
# examples: no scrubbing for two tables and a None type object
for (t, r) in zip(totest.scrub([tables, tables, None],\
[models, models, None],\
[ages, ages, None]),\
[tables, tables, None]):
if (isinstance(t, np.ndarray) and isinstance(r, np.ndarray)):
assert np.array_equal(t, r)
else:
assert t == r
# examples: scrub element 1 on age
ages[2] = ages[1]
assert np.array_equal(totest.scrub([tables], [models], [ages])[0],\
tables[np.array([True, False, True, True, True])])
# examples: additionally scrub element 3 on model
models[4] = models[3]
assert np.array_equal(totest.scrub([tables], [models], [ages])[0],\
tables[np.array([True, False, True, False, True])])
[docs]
def test_keep_after_RLO(self, star_history, binary_history):
# missing argument
with raises(TypeError, match="missing 3 required positional "\
+"arguments: 'bh', 'h1', and 'h2'"):
totest.keep_after_RLO()
# bad input
with raises(ValueError, match="No `lg_mtransfer_rate` in binary "\
+"history."):
totest.keep_after_RLO(binary_history[['rl_relative_overflow_1']],\
None, None)
with raises(ValueError, match="No `rl_relative_overflow` of any star "\
+"in binary history."):
totest.keep_after_RLO(binary_history[['lg_mtransfer_rate']], None,\
None)
# examples: nothing
for h1 in [None, np.array([])]:
for h2 in [None, np.array([])]:
assert totest.keep_after_RLO(None, h1, h2) == (None, h1, h2)
# examples: cut out element 0 and history of star 1; it should be
# noted, that the function although modifies the late part of input
# objects
bh, h1, h2 = totest.keep_after_RLO(binary_history[[\
'rl_relative_overflow_1', 'lg_mtransfer_rate', 'age']], star_history,\
None)
assert np.array_equal(binary_history[['rl_relative_overflow_1',\
'lg_mtransfer_rate',\
'age']][1:], bh)
assert np.array_equal(star_history[1:], h1)
assert h2 is None
# examples: cut out element 0 and history of star 2; test numerical
# correction, incl. fail
binary_history['rl_relative_overflow_1'][0] =\
totest.RL_RELATIVE_OVERFLOW_THRESHOLD
with raises(Exception, match="Numerical precision fix failed."):
binary_history['age'] = np.array([0.3, 1.0, 1.0])
totest.keep_after_RLO(binary_history[['rl_relative_overflow_1',\
'lg_mtransfer_rate',\
'age']], None, star_history)
binary_history['age'] = np.array([0.3, 0.99999999999999994, 1.0])
bh, h1, h2 = totest.keep_after_RLO(binary_history[[\
'rl_relative_overflow_1', 'lg_mtransfer_rate', 'age']], None,\
star_history)
assert np.array_equal(binary_history[['rl_relative_overflow_1',\
'lg_mtransfer_rate', 'age']], bh)
assert h1 is None
assert np.array_equal(star_history, h2)
# examples: no RLO
for i in range(len(binary_history)):
binary_history['rl_relative_overflow_1'][i] =\
totest.RL_RELATIVE_OVERFLOW_THRESHOLD - 1.0
binary_history['rl_relative_overflow_2'][i] =\
totest.RL_RELATIVE_OVERFLOW_THRESHOLD - 1.0
binary_history['lg_mtransfer_rate'][i] =\
totest.LG_MTRANSFER_RATE_THRESHOLD - 1.0
assert totest.keep_after_RLO(binary_history, None, None) == None
[docs]
def test_keep_till_central_abundance_He_C(self, star_history,\
star_history2, binary_history):
# missing argument
with raises(TypeError, match="missing 3 required positional "\
+"arguments: 'bh', 'h1', and 'h2'"):
totest.keep_till_central_abundance_He_C()
# bad input
with raises(AttributeError, match="'list' object has no attribute "\
+"'dtype'"):
totest.keep_till_central_abundance_He_C([], [], [])
# examples: nothing
for args in [(None, star_history, star_history),\
(binary_history, None, star_history),\
(binary_history, star_history, None),\
(binary_history[['lg_mtransfer_rate']],\
star_history[['center_he4']], star_history)]:
bh, h1, h2, tf = totest.keep_till_central_abundance_He_C(*args)
if args[0] is None:
assert bh is None
else:
assert np.array_equal(bh, args[0])
if args[1] is None:
assert h1 is None
else:
assert np.array_equal(h1, args[1])
if args[2] is None:
assert h2 is None
else:
assert np.array_equal(h2, args[2])
assert tf == ""
for cols in [['star_age'], ['star_age', 'center_he4', 'center_c12']]:
bh, h1, h2, tf = totest.keep_till_central_abundance_He_C(\
binary_history[['age']], star_history[cols],\
star_history[cols])
assert np.array_equal(binary_history[['age']], bh)
assert np.array_equal(star_history[cols], h1)
assert np.array_equal(star_history[cols], h2)
assert tf == ""
# examples: one star is depleted
bh, h1, h2, tf = totest.keep_till_central_abundance_He_C(\
binary_history[['age']], star_history2,\
star_history)
assert np.array_equal(binary_history[['age']][:2], bh)
assert np.array_equal(star_history2[:2], h1)
assert np.array_equal(star_history[:2], h2)
assert tf == "Primary got stopped before central carbon depletion"
bh, h1, h2, tf = totest.keep_till_central_abundance_He_C(\
binary_history[['age']], star_history,\
star_history2)
assert np.array_equal(binary_history[['age']][:2], bh)
assert np.array_equal(star_history[:2], h1)
assert np.array_equal(star_history2[:2], h2)
assert tf == "Secondary got stopped before central carbon depletion"
# examples: both stars are depleted
bh, h1, h2, tf = totest.keep_till_central_abundance_He_C(\
binary_history[['age']], star_history2,\
star_history2)
assert np.array_equal(binary_history[['age']][:2], bh)
assert np.array_equal(star_history2[:2], h1)
assert np.array_equal(star_history2[:2], h2)
assert tf == "Primary got stopped before central carbon depletion"
star_history['center_he4'][-1] = 0.0
star_history['center_c12'][-1] = 0.0
cols = ['center_he4', 'center_c12']
bh, h1, h2, tf = totest.keep_till_central_abundance_He_C(\
binary_history[['age']], star_history[cols],\
star_history2[cols], XCstop=0.5)
assert np.array_equal(binary_history[['age']][:2], bh)
assert np.array_equal(star_history[cols][:2], h1)
assert np.array_equal(star_history2[cols][:2], h2)
assert tf == "Secondary got stopped before central carbon depletion"
# examples: stars have history of length 0
bh, h1, h2, tf = totest.keep_till_central_abundance_He_C(\
binary_history[['age']], star_history[0:0],\
star_history[0:0])
assert np.array_equal(binary_history[['age']], bh)
assert np.array_equal(star_history[0:0], h1)
assert np.array_equal(star_history[0:0], h2)
assert tf == ""