Bases: InputConversion[TabularDataset, Table]
The input conversion for a neural network, defines the input parameters for the neural network.
Source code in src/safeds/ml/nn/converters/_input_converter_table.py
| class InputConversionTable(InputConversion[TabularDataset, Table]):
"""The input conversion for a neural network, defines the input parameters for the neural network."""
def __init__(self) -> None:
self._target_name = ""
self._feature_names: list[str] = []
self._first = True
@property
def _data_size(self) -> int:
return len(self._feature_names)
def _data_conversion_fit(self, input_data: TabularDataset, batch_size: int, num_of_classes: int = 1) -> DataLoader:
return input_data._into_dataloader_with_classes(
batch_size,
num_of_classes,
)
def _data_conversion_predict(self, input_data: Table, batch_size: int) -> DataLoader:
return input_data._into_dataloader(batch_size)
def _data_conversion_output(self, input_data: Table, output_data: Tensor) -> TabularDataset:
return input_data.add_columns([Column(self._target_name, output_data.tolist())]).to_tabular_dataset(
self._target_name,
)
def _is_fit_data_valid(self, input_data: TabularDataset) -> bool:
if self._first:
self._feature_names = input_data.features.column_names
self._target_name = input_data.target.name
self._first = False
columns_with_missing_values = []
columns_with_non_numerical_data = []
for col in input_data.features.add_columns([input_data.target]).to_columns():
if col.missing_value_count() > 0:
columns_with_missing_values.append(col.name)
if not col.type.is_numeric:
columns_with_non_numerical_data.append(col.name)
reason = ""
if len(columns_with_missing_values) > 0:
reason += f"The following Columns contain missing values: {columns_with_missing_values}\n"
if len(columns_with_non_numerical_data) > 0:
reason += f"The following Columns contain non-numerical data: {columns_with_non_numerical_data}"
if reason != "":
raise InvalidFitDataError(reason)
return (sorted(input_data.features.column_names)).__eq__(sorted(self._feature_names))
def _is_predict_data_valid(self, input_data: Table) -> bool:
return (sorted(input_data.column_names)).__eq__(sorted(self._feature_names))
|