Source code for georead._load_utils

"""Utils for data loading."""

from collections.abc import Callable, Iterator, Sequence, Generator
from contextlib import ExitStack
import copy
import itertools
import logging
import pathlib
import shlex
from types import TracebackType
from typing import Literal, Protocol, Self, TextIO, cast, runtime_checkable
import uuid
import re
import warnings
import chardet

import numpy as np
import numpy.typing as npt
import pandas as pd

from ._data_directory import (
    DATA_DIRECTORY,
    INT_NAN,
    SECTIONS,
    ArraySpecification,
    ArrayWithUnits,
    DTypeString,
    DataType,
    DataTypes,
    IntType,
    KeywordSpecification,
    NoDataSpecification,
    NpArray,
    ObjectSpecification,
    ParametersSpecification,
    RecordValueType,
    RecordsSpecification,
    SpecificationType,
    StatementSpecification,
    StringSpecification,
    TableSpecification,
    ValueType,
    get_dynamic_keyword_specification,
)


DEFAULT_ENCODINGS = ['utf-8', 'cp1251']


class PReadBuf(Protocol):
    """Protocol for readable buffer."""

    def __iter__(self) -> Iterator[str]:
        """
        Get sting iterator.

        Returns
        -------
        Iterator[str]
            Iterator.

        """
        ...

    def __next__(self) -> str:
        """
        Next string.

        Returns
        -------
        str
            string.

        """
        ...

    def prev(self) -> Self:
        """
        Move cursor to the previous string.

        Returns
        -------
        Self
            self.

        """
        ...


@runtime_checkable
class PBufPosition(Protocol):
    @property
    def line_number(self) -> int:
        """Current line number."""
        ...

    @property
    def current_file(self) -> pathlib.Path:
        """Current file."""
        ...


def _load_string(
    keyword_spec: SpecificationType | None, buf: PReadBuf
) -> str | pd.Timestamp:
    start_line = buf.line_number if isinstance(buf, PBufPosition) else None
    line = next(buf)

    if "'" in line:
        split = re.split(r"'(.*)'", line)
    elif '"' in line:
        split = re.split(r'"(.*)"', line)
    else:
        split = line.split('/')
        if len(split) > 1:
            split = ['', split[0], *['/' + s for s in split[1:]]]
        else:
            split = ['', split[0], '']

    terminated = False
    for i, s in enumerate(split):
        if i % 2 == 0:
            if '/' in s:
                terminated = True
                break
    if terminated:
        val = ''.join(split[:i] + [split[i].split('/')[0]])  # pyright: ignore[reportPossiblyUnboundVariable]
    else:
        val = line
        line = next(buf)
        if not line.startswith('/'):
            warnings.warn(
                _format_warning_messsage(
                    'Data was not properly terminated.', start_line, buf
                )
            )
            _ = buf.prev()
    val = re.sub(r'"(.*?)"', r'\1', val)
    val = re.sub(r'\'(.*?)\'', r'\1', val)
    val = val.strip()
    if keyword_spec is None:
        return val
    if not isinstance(keyword_spec, StringSpecification):
        raise ValueError(
            '`keyword_spec` should be None or of type StringSpecification.'
        )
    if keyword_spec.date:
        return _parse_date(val)
    return val


def _load_object_list(
    keyword_spec: SpecificationType | None, buf: PReadBuf
) -> list[str] | list[pd.Timestamp]:
    if not isinstance(keyword_spec, ObjectSpecification | None):
        raise ValueError('`keyword_spec` should be of type keyword_spec or None.')
    if keyword_spec is not None:
        terminated = keyword_spec.terminated
        is_dates = keyword_spec.date
    else:
        terminated = False
        is_dates = False
    res: list[str] | list[pd.Timestamp] = []
    while True:
        line = _get_expected_line(buf)
        split = line.split('/')
        val = split[0].strip(' \t\n\'""')
        if terminated and len(split) == 1:
            raise ValueError(f'Line "{line}" is not teminated with "/"')
        if val:
            if is_dates:
                for v in res:
                    if not isinstance(v, pd.Timestamp):
                        raise ValueError('`res` should be of type List[pd.Timestamp].')
                res = cast(list[pd.Timestamp], res)
                res.append(_parse_date(val))
            else:
                for v in res:
                    if not isinstance(v, str):
                        raise ValueError('`res` should be of type List[str].')
                res = cast(list[str], res)
                res.append(val)

        else:
            if len(split) == 1:
                raise ValueError('Object specification expected.')
        if len(split) > 1 and not terminated:
            break
        if line.startswith('/'):
            break
    return res


def _parse_date(s: str) -> pd.Timestamp:
    return pd.to_datetime(s)


def _load_table(
    keyword_spec: SpecificationType, buf: PReadBuf, data: DataType | None = None
) -> Sequence[pd.DataFrame] | Sequence[tuple[pd.DataFrame, pd.DataFrame]]:
    def _parse_val(val: str | int | float, t: Literal['int', 'float', 'text']):
        if t == 'int':
            if val == 'nan':
                return INT_NAN
            return int(val)
        if t == 'float':
            return float(val)
        return val

    def _empty_val(t: Literal['int', 'float', 'text']):
        if t == 'int':
            return INT_NAN
        if t == 'float':
            return np.nan
        else:
            raise ValueError('`t` should be `int` or `float`.')

    if not isinstance(keyword_spec, TableSpecification):
        raise ValueError('`keyword_spec` should be of type TableSpecification.')

    domain = keyword_spec.domain
    if callable(domain):
        domain = domain
    if domain is None:
        depth = 1
    else:
        depth = len(domain)
    if isinstance(keyword_spec.number, IntType):
        n = keyword_spec.number
    else:
        n = keyword_spec.number(data)
    table_data = _read_table_data(buf, depth, n)
    tables: list[pd.DataFrame] | list[tuple[pd.DataFrame, pd.DataFrame]] = []
    for region_table_data in table_data:
        header = None
        if keyword_spec.header is not None:
            n_header = len(keyword_spec.header.columns)
            header_data = region_table_data[:n_header]
            header_data = [
                _parse_val(v, t)
                for v, t in zip(header_data, keyword_spec.header.dtypes)
            ]
            header = pd.DataFrame(
                [header_data], columns=list(keyword_spec.header.columns)
            )
            region_table_data = region_table_data[n_header:]
        if callable(keyword_spec.columns):
            if header is None:
                raise ValueError('`header` should be not None.')
            columns = keyword_spec.columns(header)
        else:
            columns = keyword_spec.columns
        n_attrs = len(columns)
        dtypes = keyword_spec.dtypes
        if isinstance(dtypes, str):
            dtypes = [keyword_spec.dtypes] * n_attrs
        dtypes = cast(Sequence[DTypeString], dtypes)
        if depth == 2:
            table_parts: list[list[float | int | str]] = []
            for d in region_table_data:
                n_rows = (len(d) - 1) / (n_attrs - 1)
                if not n_rows.is_integer():
                    raise ValueError(
                        'Number of element is not aligned with the number of attributes.'
                    )
                data_tmp: list[list[float | int | str]] = []
                for i in range(int(n_rows)):
                    data_tmp.append(
                        [_parse_val(d[0], dtypes[0])]
                        + [
                            _parse_val(v, t)
                            for v, t in zip(
                                d[i * (n_attrs - 1) + 1 : (i + 1) * (n_attrs - 1) + 1],
                                dtypes[1:],
                            )
                        ]
                    )
                table_parts += data_tmp
            table = pd.DataFrame(table_parts, columns=list(columns))
        else:
            if len(region_table_data) < n_attrs:
                tmp = [_empty_val(t) for t in dtypes[len(region_table_data) :]]
                region_table_data += tmp
            if len(region_table_data) % n_attrs > 0:
                raise ValueError(
                    'Number of values in table is not consistent wit number of columns.'
                )
            n_rows = int(len(region_table_data) / n_attrs)
            data_tmp = [
                list(
                    map(
                        _parse_val,
                        region_table_data[i * n_attrs : (i + 1) * n_attrs],
                        dtypes,
                    )
                )
                for i in range(n_rows)
            ]
            table = pd.DataFrame(data_tmp, columns=list(columns))
        if 'int' in dtypes:
            int_columns = [col for col, t in zip(columns, dtypes) if t == 'int']
            for col in int_columns:
                table[col] = table[col].fillna(INT_NAN)
                if (np.mod(table[col], 1) > 0).any():
                    raise ValueError('Noninteger value in integer column.')
                table[col] = table[col].astype(int)
        if keyword_spec.domain is not None:
            domain_attrs = [columns[i] for i in keyword_spec.domain]
            table = table.set_index(domain_attrs)
        if header is not None:
            for t in tables:
                if not isinstance(t, tuple):
                    raise ValueError(
                        '`tables` should be of type list[tuple[pd.DataFrame, pd.DataFrame]].'
                    )
            tables = cast(list[tuple[pd.DataFrame, pd.DataFrame]], tables)
            tables.append((table, header))
        else:
            for t in tables:
                if not isinstance(t, pd.DataFrame):
                    raise ValueError('`tables` should be of type list[pd.DataFrame].')
            tables = cast(
                list[pd.DataFrame], tables
            )  # # pyright: ignore[reportUnnecessaryCast]

            tables.append(table)
    return tables


def _load_single_statement(keyword_spec: SpecificationType, buffer: PReadBuf):
    if not isinstance(keyword_spec, StatementSpecification):
        raise ValueError('`keyword_spec` should be of type `StatementSpecification`.')

    columns = keyword_spec.columns
    column_types = keyword_spec.dtypes
    shift = 0
    full: list[str | None] = [None] * len(columns)
    while True:
        line = _get_expected_line(buffer)
        split = line.split('/')
        line = split[0].strip()
        vals = shlex.split(line)
        full, shift = parse_vals(full, shift, vals)
        if len(split) > 1:
            break
    df = pd.DataFrame(dict(zip(columns, full)), index=[0])
    if 'text' in column_types:
        text_columns = [col for col, dt in zip(columns, column_types) if dt == 'text']
        df[text_columns] = df[text_columns].map(
            cast(
                Callable[
                    [
                        str,
                    ],
                    str,
                ],
                lambda x: x.strip('\'"') if x is not None else x,  # pyright: ignore[reportUnknownMemberType, reportUnknownLambdaType]
            )
        )
    if 'float' in column_types:
        float_columns = [col for col, dt in zip(columns, column_types) if dt == 'float']
        df[float_columns] = df[float_columns].astype(float)
    if 'int' in column_types:
        int_columns = [col for col, dt in zip(columns, column_types) if dt == 'int']
        df[int_columns] = df[int_columns].astype(float).fillna(INT_NAN).astype(int)
    return df


def _load_records(keyword_spec: SpecificationType, buffer: PReadBuf) -> RecordValueType:
    def _load_record(
        spec: StatementSpecification | ArraySpecification, buffer: PReadBuf
    ):
        if isinstance(spec, StatementSpecification):
            return _load_single_statement(spec, buffer)
        return _load_array(spec, buffer)

    if not isinstance(keyword_spec, RecordsSpecification):
        raise ValueError('`keyword_spec` shold be of type `RecordsSpecification`.')

    def _spec_generator(
        res: Sequence[
            pd.DataFrame
            | npt.NDArray[np.floating | np.integer | np.bool_]
            | tuple[pd.DataFrame, pd.DataFrame]
        ],
    ) -> Generator[StatementSpecification | ArraySpecification, None, None]:
        while True:
            try:
                get_new_spec = keyword_spec.get_next_specification

                if get_new_spec is None:
                    raise ValueError(
                        '`keyword_spec.get_next_specification` should not be None'
                    )
                spec = get_new_spec(res)
                yield spec
            except ValueError:
                break

    res: RecordValueType = []
    if keyword_spec.dynamic:
        spec_iterable = _spec_generator(res)
    else:
        spec_iterable = keyword_spec.specifications
        if spec_iterable is None:
            raise ValueError(
                '`keyword_spec.specifications` should not be None if not `keyword_spec.dynamic`.'
            )
    for spec in spec_iterable:
        res.append(_load_record(spec, buffer))
    return res


def _read_table_data(buffer: PReadBuf, depth: int, n: IntType) -> list[list[str]]:
    """
    Read numerical data for table.

    Parameters
    ----------
    buffer : PReadBuf
        String buffer to read.
    depth : int
        Depth of the table nesting (2 for multiindex table, 1 for normal table).
    n : IntType
        Number of tables to read.

    Returns
    -------
    List[np.ndarray] or List[List[np.ndarray]]
        List of numpy arrays (1 array for each region), if `depth==1`.
        List of lists of numpy array (1 array for each subtable, list of arrays
        for each region), if depth==2

    Raises
    ------
    ValueError
        If table block is not properly closed

    """
    data: list[list[str]] | list[list[list[str]]] = []
    if depth not in (1, 2):
        raise ValueError('`depth` should be 1 or 2.')
    for _ in range(depth):
        data = list(data)
    ind = [0] * depth
    group_end = True
    expr = re.compile(r'(\d*)\*(([^\s]*))')

    def _repl(match: re.Match[str]):
        num = match.groups()[0]
        val = match.groups()[1]
        if len(val) == 0:
            val = 'nan'
        num = int(num) if num else 1
        return ' '.join([val] * num)

    for line in buffer:
        line = line.strip()
        split = line.split('/')
        line = split[0]
        if len(line) > 0:
            cur_item = data
            line = expr.sub(_repl, line)
            for i in reversed(ind):
                if len(cur_item) == i:
                    if len(cur_item) > 0:
                        if isinstance(cur_item[0], str):
                            raise ValueError(
                                '`cur_item` should be of type `list[list[str]]'
                            )
                    cur_item = cast(list[list[str]] | list[list[list[str]]], cur_item)
                    cur_item.append([])

                cur_item = cur_item[i]
            values = line.split()
            if isinstance(cur_item, str):
                raise ValueError('`cur_item` should be of type `list[list[str]]')
            if len(cur_item) > 0:
                if isinstance(cur_item[0], str):
                    raise ValueError('`cur_item` should be of type `list[list[str]]')
            cur_item = cast(list[list[str]], cur_item)
            cur_item.append(values)
            group_end = False
        if len(split) > 1:
            if group_end:
                try:
                    ind[1] += 1
                    if len(data) == n:
                        break
                except IndexError:
                    _ = buffer.prev()
                    raise ValueError('Unexpected closing slash.')
                ind[0] = 0
            else:
                ind[0] += 1
                if len(data) == n and depth == 1:
                    break
            group_end = True

    if depth == 1:
        tmp_iter = [data]
    else:
        tmp_iter = data
    tmp_iter = cast(list[list[list[str]]], tmp_iter)
    for d in tmp_iter:
        for i, vals in enumerate(d):
            d[i] = list(itertools.chain(*vals))
    assert len(data) == n
    return data


def _load_array(keyword_spec: SpecificationType, buf: PReadBuf) -> NpArray:
    if not isinstance(keyword_spec, ArraySpecification):
        raise ValueError(keyword_spec, ArraySpecification)
    data = read_array(buf, dtype=keyword_spec.dtype)
    return data


def _load_array_with_units(keyword_spec: SpecificationType, buf: PReadBuf):
    if not isinstance(keyword_spec, ArraySpecification):
        raise ValueError('`keyword_spec` should be of type `ArraySpecification`.')
    line = next(buf)
    units = line.split()[0]
    _ = buf.prev()
    array = read_array(buf, dtype=keyword_spec.dtype, skip_first_word=True)
    return ArrayWithUnits(units, array)


def read_array(
    buffer: PReadBuf,
    dtype: type | None = None,
    compressed: bool = True,
    skip_first_word: bool = False,
) -> NpArray:
    """
    Read array data from a string buffer before first occurrence of '/' symbol.

    Parameters
    ----------
    buffer : buffer
        String buffer to read.
    dtype : dtype or None, default None.
        Defines dtype of an output array. If not specified, float array is returned.
    compressed : bool, default True.
        If True, A*B will be interpreted as B repeated A times.
    skip_first_word: bool, default False
        Should first word be skipped.

    Returns
    -------
    arr : ndarray
        Parsed array.

    """
    arr: list[NpArray] = []
    last_line = False
    if dtype is None:
        dtype = float
    for i, line in enumerate(buffer):
        if '/' in line:
            last_line = True
            line = line.split('/')[0]
        if i == 0 and skip_first_word:
            line = ' '.join(line.split()[1:])
        if compressed:
            x = decompress_array(line, dtype=dtype)
        else:
            x = np.fromstring(line.strip(), dtype=dtype, sep=' ')  # pyright: ignore[reportUnknownVariableType]
        if x.size:
            arr.append(x)  # pyright: ignore[reportUnknownArgumentType]
        if last_line:
            break
    return cast(NpArray, np.hstack(arr))


def decompress_array(s: str, dtype: type | None) -> NpArray:
    """Extract compressed numerical array from ASCII string. Interprets A*B as B repeated A times."""
    if dtype is None:
        dtype = float
    nums: list[float | int | bool] = []
    for x in s.split():
        try:
            val = [dtype(float(x))]
        except ValueError:
            k, val = x.split('*')
            val = [dtype(float(val))] * int(k)
        nums.extend(val)
    return np.array(nums)


def _load_parameters(
    keyword_spec: SpecificationType, buf: PReadBuf
) -> dict[str, str | None]:
    if not isinstance(keyword_spec, ParametersSpecification):
        raise ValueError('`keyword_spec` should be of type ParametersSpecification.')
    if keyword_spec.tabulated:
        return _load_parameters_tabulated(keyword_spec, buf)
    res: dict[str, str | None] = {}
    for line in buf:
        split = line.split('/')
        words = split[0].split()
        words = [w.strip('\'"') for w in words]
        for word in words:
            if '=' in word:
                key, val = word.split('=')
                res[key] = val
            else:
                res[word] = None
        if len(split) > 1:
            break
    return res


def _load_named_table(
    keyword_spec: SpecificationType, buf: PReadBuf
) -> tuple[str, pd.DataFrame]:
    if not isinstance(keyword_spec, TableSpecification):
        raise ValueError('`keyword_spec` should be of type `TableSpecification`.')
    _ = buf.prev()
    line = next(buf)
    parts = line.split()
    if len(parts) < 2:
        raise ValueError('No `name` in keyword string.')
    if len(parts) > 2:
        raise ValueError('Multiple `name`s in keyword string')
    name = parts[1].strip('\'"')
    tables = _load_table(keyword_spec, buf)
    if len(tables) != 1:
        raise ValueError('There should be exactlly one table.')
    table = tables[0]
    if not isinstance(table, pd.DataFrame):
        raise ValueError('`Table should be of type `pandas.DataFrame`.')

    return (name, table)


def _load_parameters_tabulated(_, buf: PReadBuf) -> dict[str, str | None]:
    res: dict[str, str | None] = {}
    for line in buf:
        split = line.split('/')
        if len(split) > 1 and split[0] == '':
            break
        words = split[0].split()
        if len(words) != 2:
            raise ValueError('There should be exactly two words on each line.')
        res[words[0]] = words[1]
        if len(split) > 1:
            break
    return res


def parse_vals(full: list[str | None], shift: int, vals: list[str]):
    """Parse values (unpack asterisk terms)."""
    full = copy.deepcopy(full)
    i = -1
    for i, v in enumerate(vals):
        if '*' in v:
            v = v.strip('\'"')
            if v == '*':
                continue
            try:
                n = int(v.split('*')[0])
                shift += n - 1
                if v.endswith('*'):
                    continue
                full[i + shift - n + 1 : i + shift + 1] = [v.split('*')[1]] * n
            except ValueError:
                full[i + shift] = v
        else:
            full[i + shift] = v
    return full, i + shift + 1


def _load_statement_list(keyword_spec: SpecificationType, buf: PReadBuf):
    """
    Parse Eclipse keyword data to dataframe.

    Parameters
    ----------
    keyword_spec : SpecificationType
        Keyword specification.
    buf : PReadBuf
        Buffer to read data from.

    Returns
    -------
    pd.Dataframe
        Loaded keyword dataframe.

    """
    statements: list[pd.DataFrame] = []
    while True:
        line = _get_expected_line(buf)
        if line.startswith('/'):
            break
        _ = buf.prev()
        statement = _load_single_statement(keyword_spec, buf)
        statements.append(statement)

    df = pd.concat(statements, ignore_index=True)
    return df


def _load_no_data(keyword_spec: SpecificationType | None, buf: PReadBuf):
    start_line = buf.line_number if isinstance(buf, PBufPosition) else None
    if keyword_spec is None:
        return None
    if not isinstance(keyword_spec, NoDataSpecification):
        raise ValueError(
            '`keyword_spec` should be of type `NoDataSpecification` or `None`.'
        )
    if not keyword_spec.terminated:
        return None
    line = next(buf)
    if not line.startswith('/'):
        raise ValueError(
            _format_warning_messsage(
                'Data is not properly terminated.', start_line, buf
            )
        )
    return None


LOADERS: dict[
    DataTypes | None,
    Callable[[SpecificationType, PReadBuf, DataType | None], ValueType],
] = {
    None: lambda keyword_spec, buf, _: _load_no_data(keyword_spec, buf),
    DataTypes.STRING: lambda keyword_spec, buf, _: _load_string(keyword_spec, buf),
    DataTypes.OBJECT_LIST: lambda keyword_spec, buf, _: _load_object_list(
        keyword_spec, buf
    ),
    DataTypes.TABLE_SET: lambda keyword_spec, buf, data: _load_table(
        keyword_spec, buf, data
    ),
    DataTypes.ARRAY: lambda keyword_spec, buf, _: _load_array(keyword_spec, buf),
    DataTypes.PARAMETERS: lambda keyword_spec, buf, _: _load_parameters(
        keyword_spec, buf
    ),
    DataTypes.SINGLE_STATEMENT: lambda keyword_spec, buf, _: _load_single_statement(
        keyword_spec, buf
    ),
    DataTypes.STATEMENT_LIST: lambda keyword_spec, buf, _: _load_statement_list(
        keyword_spec, buf
    ),
    DataTypes.RECORDS: lambda keyword_spec, buf, _: _load_records(keyword_spec, buf),
    DataTypes.NAMED_TABLE: lambda keyword_spec, buf, _: _load_named_table(
        keyword_spec, buf
    ),
    DataTypes.ARRAY_WITH_UNITS: lambda keyword_spec, buf, _: _load_array_with_units(
        keyword_spec, buf
    ),
}


def _format_warning_messsage(message: str, start_line: int | None, buf: PReadBuf):
    if not isinstance(buf, PBufPosition):
        return message
    if start_line is None:
        warnings.warn('`start_line` is None. Return original message.')
    return ' '.join((message, f'{buf.current_file} {start_line}:{buf.line_number}'))


class StringIteratorIO:
    """String iterator for text files."""

    def __init__(
        self,
        path: pathlib.Path,
        encoding: str | None = None,
        logger: logging.Logger | None = None,
        root_dir: pathlib.Path | None = None,
    ):
        """
        Initialize.

        Parameters
        ----------
        path : pathlib.Path
            Path to main model file.
        encoding : str | None, default None
            Encoding.
        logger : logging.Logger | None, default None
            Logger.
        root_dir : pathlib.Path | None, default None
            Root directory to evaluate relative path, if None directory of the `path`
            is used.

        """
        self._path: pathlib.Path = path
        self._root_dir: pathlib.Path = (
            root_dir if root_dir is not None else self._path.parent
        )
        if (encoding is not None) and encoding.startswith('auto'):
            encoding_tmp = encoding.split(':')
            if len(encoding_tmp) > 1:
                n_bytes = int(encoding_tmp[1])
            else:
                n_bytes = 5000

            with open(self._path, 'rb') as file:
                raw = file.read(n_bytes)
                self._encoding: str | None = chardet.detect(raw)['encoding']
        else:
            self._encoding = encoding
        self._line_number: int = 0
        self._f = None
        self._buffer: str = ''
        self._last_line: str | None = None
        self._include: StringIteratorIO | None = None
        self._on_last: bool = False
        if logger is None:
            logger = logging.getLogger(str(uuid.uuid4()))
            logger.addHandler(logging.NullHandler())
            logger.propagate = False
        self._logger: logging.Logger = logger
        self._proposed_encodings: list[str] = DEFAULT_ENCODINGS.copy()
        self._stack: ExitStack | None = None

    @property
    def line_number(self) -> int:
        """Number of lines read."""
        if self._include is not None:
            return self._include.line_number
        return self._line_number

    @property
    def current_file(self) -> pathlib.Path:
        """
        File where cursor is located.

        Returns
        -------
        pathlib.Path
            Path to file.

        """
        if self._include is not None:
            return self._include.current_file
        return self._path

    def __iter__(self):
        """Return iterator."""
        return self

    def __next__(self) -> str:
        """
        Return next string.

        Returns
        -------
        str
            Next string.

        """
        if self._include is not None:
            try:
                return next(self._include)
            except StopIteration:
                self._include = None
                self._logger.info(f'Continue reading {self.current_file}.')

        if self._on_last:
            self._on_last = False
            if self._last_line is None:
                raise ValueError('`_last_line` attribute should not be None.')
            return self._last_line
        try:
            if self._f is None:
                raise ValueError('`_f` should not be None.')
            line = next(self._f).split('--')[0].strip()
        except UnicodeDecodeError:
            return self._better_decoding()
        except StopIteration as e:
            self._logger.info(f'Finish reading {self.current_file}.')
            raise e
        self._line_number += 1
        if line:
            if line == 'INCLUDE':
                spec = DATA_DIRECTORY['INCLUDE']
                if spec is None:
                    raise ValueError(
                        'Specification for `INCLUDE` keyword should not be None.'
                    )
                path = LOADERS[DataTypes.STRING](spec.specification, self, None)
                if not isinstance(path, str):
                    raise ValueError('Path should be of type `str`.')
                self._include_file(path)
                return next(self)
            self._last_line = line
            return line
        return next(self)

    def _include_file(self, path: str | pathlib.Path):
        path = self._root_dir.joinpath(pathlib.Path(path))
        if self._stack is None:
            raise ValueError('`self._stack` should not be None.')
        with self._stack as stack:
            self._logger.info('INCLUDE keyword found.')
            self._include = stack.enter_context(
                StringIteratorIO(
                    path, self._encoding, logger=self._logger, root_dir=self._root_dir
                )
            )
            self._stack = stack.pop_all()

    def _better_decoding(self) -> str:
        """Last chance to read line with default encodings."""
        try:
            enc = self._proposed_encodings.pop()
        except IndexError as err:
            raise ValueError(
                'Failed to decode at line {}'.format(self._line_number + 1)
            ) from err
        if enc == self._encoding:
            return self._better_decoding()
        self._f: TextIO | None = open(self._path, 'r', encoding=enc)  # pylint: disable=consider-using-with
        self._encoding = enc
        for _ in range(self._line_number):
            _ = next(self._f)
        return next(self)

    def prev(self):
        """Set current position to previous line."""
        if self._include is not None:
            _ = self._include.prev()
            return self
        if self._on_last:
            raise ValueError('Maximum cache depth is reached.')
        self._on_last = True
        return self

    def __enter__(self):
        """
        Enter context associated with the buffer.

        Adds corresponding context to the `ExitStack`.
        """
        with ExitStack() as stack:
            self._logger.info(f'Start reading {self._path}.')
            self._f = stack.enter_context(
                open(self._path, 'r', encoding=self._encoding)
            )  # pylint: disable=consider-u
            self._stack = stack.pop_all()
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: TracebackType | None,
    ):
        """Exit the context. Closes all stack."""
        _ = exc_type, exc_val, exc_tb
        if self._stack is None:
            raise ValueError('`self._stack` should not be None.')
        self._stack.close()

    def read(self, n: int):
        """Read n characters."""
        while not self._buffer:
            try:
                self._buffer = next(self)
            except StopIteration:
                break
        result = self._buffer[:n]
        self._buffer = self._buffer[len(result) :]
        return result

    def skip_to(self, stop: Sequence[str] | str, *args):  # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
        """Skip strings until stop token."""
        if isinstance(stop, str):
            stop = [stop]
        stop_pattern = '|'.join([x + '$' for x in stop])
        for line in self:
            if re.match(stop_pattern, line.strip(), *args):  # pyright: ignore[reportUnknownArgumentType]
                return


def _get_expected_line(buf: PReadBuf):
    try:
        line = next(buf)
    except StopIteration:
        raise ValueError('Buffer has ended earlier then expected.')
    return line


[docs] def load( path: pathlib.Path, *, logger: logging.Logger | None = None, encoding: str | None = None, directory_extension: dict[str, KeywordSpecification | None] | None = None, ) -> DataType: """ Load model data. Parameters ---------- path : pathlib.Path Path to main model file. logger : logging.Logger | None, default None Logger. encoding : str | None, default None Encoding. directory_extension: dict[str, KeywordSpecification | None] | None Specifications for new keywords or altered specificatios for current keywords. Returns ------- DataType Model data. """ if directory_extension is None: directory_extension = {} res: DataType = {} sections = [sec.value for sec in SECTIONS] if logger is None: logger = logging.getLogger(str(uuid.uuid4())) logger.addHandler(logging.NullHandler()) logger.propagate = False filename = path.name logger.info(f'Start reading {filename}') cur_section = '' directory = DATA_DIRECTORY | directory_extension with StringIteratorIO(path, encoding=encoding, logger=logger) as lines: for line in lines: if not line: continue firstword = line.split(maxsplit=1)[0].upper() if firstword in sections: cur_section = firstword logger.info(f'Start {cur_section} section: line {lines.line_number}.') if cur_section not in res: res[cur_section] = [] continue if firstword in directory: keyword_spec = directory[firstword] if keyword_spec is None: keyword_spec = get_dynamic_keyword_specification(firstword, res) keyword_sections = [sec.value for sec in keyword_spec.sections] if cur_section not in keyword_sections: logger.warning( f'Keyword {firstword} in section {cur_section}' + f'is not supported (skipping): line {lines.line_number}' ) continue logger.info( f'Start reading keyword {firstword}: line {lines.line_number}.' ) with warnings.catch_warnings(record=True) as captured: warnings.simplefilter('always') data = LOADERS[keyword_spec.type]( keyword_spec.specification, lines, res ) for w in captured: warnings.warn( f'While reading keyword {firstword}: ' + str(w.message) ) if cur_section not in res: res[cur_section] = [] res[cur_section].append((firstword, data)) logger.info( f'Finish reading keyword {firstword}: line {lines.line_number}.' ) elif firstword.startswith('/'): logger.info(f'Unnecessary "/" (skipping): line: {lines.line_number}.') else: logger.warning( f'Keyword {firstword} in section {cur_section} ' + f'is not supported (skipping): line {lines.line_number}' ) return res