Source code for geocode.field.states

"""States component."""
from typing import override
import numpy as np

from .base_spatial import SpatialComponent
from .base_component import Attribute
from .utils.decorators import apply_to_each_input
from .utils.plot_utils import show_slice_static, show_slice_interactive

STATE_ATTRIBUTES = ['PRESSURE', 'RS', 'SGAS', 'SOIL', 'SWAT']


[docs] class States(SpatialComponent): """States component.""" _attributes_to_load: list[Attribute] = [ Attribute(attr, 'SOLUTION', attr, binary_file='UNRST', binary_section=attr, sequential=True) for attr in STATE_ATTRIBUTES] @property def n_timesteps(self): """Effective number of timesteps.""" if not self.attributes: return 0 return np.min([x.shape[0] for _, x in self.items()])
[docs] @override @apply_to_each_input def apply(self, func, attr, *args, inplace=False, **kwargs): """Apply function to each timestamp of states attributes. Parameters ---------- func : callable A function to apply. Must accept data as its first argument. attr : str, array-like Attributes to get data from. args : misc Any additional positional arguments to ``func``. inplace: bool Modify сomponent inplace. kwargs : misc Any additional named arguments to ``func``. Returns ------- output : States Transformed component. """ data = getattr(self, attr) res = np.array([func(x, *args, **kwargs) for x in data]) if inplace: setattr(self, attr, res) return self return res
[docs] @apply_to_each_input def to_spatial(self, attr, **kwargs): """Spatial order 'F' transformations.""" _ = kwargs dimens = self.field.grid.dimens.values.ravel() self.pad_na(attr=attr) return self.reshape(attr=attr, newshape=(-1,) + tuple(dimens), order='F', inplace=True)
[docs] @override @apply_to_each_input def ravel(self, attr): """Ravel order 'F' transformations.""" return self.reshape(attr=attr, newshape=(self.n_timesteps, -1), order='F', inplace=False)
[docs] @apply_to_each_input def pad_na(self, attr, fill_na=0., inplace=True): """Add dummy cells into the state vector in the positions of non-active cells if necessary. Parameters ---------- attr: str, array-like Attributes to be padded with non-active cells. actnum: array-like of type bool Vector representing a mask of active and non-active cells. fill_na: float Value to be used as filler. inplace: bool Modify сomponent inplace. Returns ------- output : component if inplace else padded attribute. """ data = getattr(self, attr) if data is None: return None if np.prod(data.shape[1:]) == np.prod(self.field.grid.dimens.values): return self if inplace else data actnum = self.field.grid.actnum if data.ndim > 2: raise ValueError('Data should be raveled before padding.') n_ts = data.shape[0] actnum_ravel = actnum.ravel(order='F').astype(bool) not_actnum_ravel = ~actnum_ravel padded_data = np.empty(shape=(n_ts, actnum.size), dtype=float) padded_data[..., actnum_ravel] = data del data padded_data[..., not_actnum_ravel] = fill_na if inplace: setattr(self, attr, padded_data) return self return padded_data
[docs] @apply_to_each_input def strip_na(self, attr): """Remove non-active cells from the state vector. Parameters ---------- attr: str, array-like Attributes to be stripped actnum: array-like of type bool Vector representing mask of active and non-active cells. Returns ------- output : stripped attribute. Notes ----- Outputs 1d array for each timestamp. """ data = self.ravel(attr) actnum = self.field.grid.actnum if data.shape[1] == np.sum(actnum): return data stripped_data = data[..., actnum.ravel(order='F')] return stripped_data
def __getitem__(self, keys): if isinstance(keys, str): return super().__getitem__(keys) out = self.__class__() for attr, data in self.items(): data = data[keys].reshape((-1,) + data.shape[1:]) setattr(out, attr, data) return out
[docs] def show_slice(self, attr, t=None, i=None, j=None, k=None, figsize=None, **kwargs): """Visualize slices of 4D states arrays. If no slice is specified, spatial slices will be shown with interactive slider widgets. Parameters ---------- attr : str Attribute to show. t : int or None, optional Timestamp to show. i : int or None, optional Slice along x-axis to show. j : int or None, optional Slice along y-axis to show. k : int or None, optional Slice along z-axis to show. figsize : array-like, optional Output plot size. kwargs : dict, optional Additional keyword arguments for plot. """ if np.all([t is None, i is None, j is None, k is None]): show_slice_interactive(self, attr, figsize=figsize, **kwargs) else: show_slice_static(self, attr, t=t, i=i, j=j, k=k, figsize=figsize, **kwargs) return self