"""States component."""
from typing import override
import numpy as np
from .base_spatial import SpatialComponent
from .base_component import Attribute
from .utils.decorators import apply_to_each_input
from .utils.plot_utils import show_slice_static, show_slice_interactive
STATE_ATTRIBUTES = ['PRESSURE', 'RS', 'SGAS', 'SOIL', 'SWAT']
[docs]
class States(SpatialComponent):
"""States component."""
_attributes_to_load: list[Attribute] = [
Attribute(attr, 'SOLUTION', attr, binary_file='UNRST', binary_section=attr, sequential=True) for attr
in STATE_ATTRIBUTES]
@property
def n_timesteps(self):
"""Effective number of timesteps."""
if not self.attributes:
return 0
return np.min([x.shape[0] for _, x in self.items()])
[docs]
@override
@apply_to_each_input
def apply(self, func, attr, *args, inplace=False, **kwargs):
"""Apply function to each timestamp of states attributes.
Parameters
----------
func : callable
A function to apply. Must accept data as its first argument.
attr : str, array-like
Attributes to get data from.
args : misc
Any additional positional arguments to ``func``.
inplace: bool
Modify сomponent inplace.
kwargs : misc
Any additional named arguments to ``func``.
Returns
-------
output : States
Transformed component.
"""
data = getattr(self, attr)
res = np.array([func(x, *args, **kwargs) for x in data])
if inplace:
setattr(self, attr, res)
return self
return res
[docs]
@apply_to_each_input
def to_spatial(self, attr, **kwargs):
"""Spatial order 'F' transformations."""
_ = kwargs
dimens = self.field.grid.dimens.values.ravel()
self.pad_na(attr=attr)
return self.reshape(attr=attr, newshape=(-1,) + tuple(dimens), order='F', inplace=True)
[docs]
@override
@apply_to_each_input
def ravel(self, attr):
"""Ravel order 'F' transformations."""
return self.reshape(attr=attr, newshape=(self.n_timesteps, -1), order='F', inplace=False)
[docs]
@apply_to_each_input
def pad_na(self, attr, fill_na=0., inplace=True):
"""Add dummy cells into the state vector in the positions of non-active cells if necessary.
Parameters
----------
attr: str, array-like
Attributes to be padded with non-active cells.
actnum: array-like of type bool
Vector representing a mask of active and non-active cells.
fill_na: float
Value to be used as filler.
inplace: bool
Modify сomponent inplace.
Returns
-------
output : component if inplace else padded attribute.
"""
data = getattr(self, attr)
if data is None:
return None
if np.prod(data.shape[1:]) == np.prod(self.field.grid.dimens.values):
return self if inplace else data
actnum = self.field.grid.actnum
if data.ndim > 2:
raise ValueError('Data should be raveled before padding.')
n_ts = data.shape[0]
actnum_ravel = actnum.ravel(order='F').astype(bool)
not_actnum_ravel = ~actnum_ravel
padded_data = np.empty(shape=(n_ts, actnum.size), dtype=float)
padded_data[..., actnum_ravel] = data
del data
padded_data[..., not_actnum_ravel] = fill_na
if inplace:
setattr(self, attr, padded_data)
return self
return padded_data
[docs]
@apply_to_each_input
def strip_na(self, attr):
"""Remove non-active cells from the state vector.
Parameters
----------
attr: str, array-like
Attributes to be stripped
actnum: array-like of type bool
Vector representing mask of active and non-active cells.
Returns
-------
output : stripped attribute.
Notes
-----
Outputs 1d array for each timestamp.
"""
data = self.ravel(attr)
actnum = self.field.grid.actnum
if data.shape[1] == np.sum(actnum):
return data
stripped_data = data[..., actnum.ravel(order='F')]
return stripped_data
def __getitem__(self, keys):
if isinstance(keys, str):
return super().__getitem__(keys)
out = self.__class__()
for attr, data in self.items():
data = data[keys].reshape((-1,) + data.shape[1:])
setattr(out, attr, data)
return out
[docs]
def show_slice(self, attr, t=None, i=None, j=None, k=None, figsize=None, **kwargs):
"""Visualize slices of 4D states arrays. If no slice is specified, spatial slices
will be shown with interactive slider widgets.
Parameters
----------
attr : str
Attribute to show.
t : int or None, optional
Timestamp to show.
i : int or None, optional
Slice along x-axis to show.
j : int or None, optional
Slice along y-axis to show.
k : int or None, optional
Slice along z-axis to show.
figsize : array-like, optional
Output plot size.
kwargs : dict, optional
Additional keyword arguments for plot.
"""
if np.all([t is None, i is None, j is None, k is None]):
show_slice_interactive(self, attr, figsize=figsize, **kwargs)
else:
show_slice_static(self, attr, t=t, i=i, j=j, k=k, figsize=figsize, **kwargs)
return self