def transform(self, y):
Transform y into a compatible type (tf.data.Dataset).
self.check_data_type(y)
if isinstance(y, tf.data.Dataset):
return y
if isinstance(y, np.ndarray):
if len(y.shape) == 1:
y = y.reshape(-1, 1)
return tf.data.Dataset.from_tensor_slices(y)
if isinstance(y, pd.DataFrame):
return tf.data.Dataset.from_tensor_slices(y.values)
if isinstance(y, pd.Series):
return tf.data.Dataset.from_tensor_slices(y.values.reshape(-1, 1))
def postprocess(self, y):
After Change
def transform(self, y):
Transform y into a compatible type (tf.data.Dataset).
self._check(y)
dataset = self._convert_to_dataset(y)return dataset
def _convert_to_dataset(self, y):
if isinstance(y, tf.data.Dataset):
return y