Linear Models#


In this section, we’ll show how to use Jai to train a linear model.

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

california_housing = fetch_california_housing(as_frame=True)
df = california_housing['frame']

# target is true median value of house per block group
target =
X_train, X_test, y_train, y_test = train_test_split(df, target)

Linear model#

Here it is how to train a Linear Model using the LinearModel Module.

We use scikit-learn’s models in the back end. You can use most of the parameters as describe in the documentation.

Tasks: - regression : LinearRegression - sgd_regression : SGDRegressor - classification : LogisticRegression - sgd_classification: SGDClassifier

from jai import LinearModel
model = LinearModel("california_linear",  "regression")
report =, y_train, overwrite=True)
# After training, the model is ready to be consumed
0 1.308
1 0.885
2 2.193
3 3.427
4 3.042
... ...
5155 3.442
5156 0.665
5157 1.487
5158 2.750
5159 2.329

5160 rows × 1 columns

# You can improve the model using one new sample
model.learn(X_test.iloc[[0]], y_test.iloc[[0]])
{'before': {'MAE': 2.1094237467877974e-14,
  'MSE': 4.44966854351227e-28,
  'MAPE': 1.6127092865350133e-14},
 'after': {'MAE': 2.1094237467877974e-14,
  'MSE': 4.44966854351227e-28,
  'MAPE': 1.6127092865350133e-14},
 'change': True}
# Or you can improve the model using multiple new samples
model.learn(X_test.iloc[1:4], y_test.iloc[1:4])
{'before': {'MAE': 1.021405182655144e-14,
  'MSE': 1.3995707226796117e-28,
  'MAPE': 8.507385228034694e-15,
  'R2_Score': 1.0},
 'after': {'MAE': 1.021405182655144e-14,
  'MSE': 1.3995707226796117e-28,
  'MAPE': 8.507385228034694e-15,
  'R2_Score': 1.0},
 'change': True}
0 1.308
1 0.885
2 2.193
3 3.427
4 3.042
... ...
5155 3.442
5156 0.665
5157 1.487
5158 2.750
5159 2.329

5160 rows × 1 columns

Pretrained Bases application#

On Jai, you can use previously trained collections and reuse them for new tasks.

So if we had trained a model before, now we can build a linear model using the vectors from that database.

from jai import Trainer

# Let's create a collection using only the features of the dataset.
trainer = Trainer("pretrained_california")


Recognized fit arguments:
- db_type: SelfSupervised
query =, overwrite=True)
Insert Data: 100%|██████████| 1/1 [00:00<00:00,  1.73it/s]

Recognized fit arguments:
- db_type: SelfSupervised
JAI is working: 100%|██████████|20/20 [02:22]

Setup Report:

Best model at epoch: 63 val_loss: 0.71

Now we’ll build a linear model using that collection.

We’ll first make a mapping to train the new Linear Model.

In this case, we’ll only use one feature, id_california, which is the id value from the previous collection. Each id will correspond to its vector stored in Jai, and those vectors are used to train the linear model.

import pandas as pd
df_train = pd.DataFrame(X_train.index, index=X_train.index, columns=["id_california"])
model = LinearModel("california_pretrained",  "regression")
report =, y_train, pretrained_bases=[{"id_name":"id_california", "db_parent": "pretrained_california"}], overwrite=True)
# You can add the new data to the previous collection
Insert Data: 100%|██████████| 1/1 [00:00<00:00,  3.75it/s]
JAI is working: 100%|██████████|20/20 [00:02]
({0: {'Task': 'Adding new data for tabular setup',
   'Status': 'Completed',
   'Description': 'Insertion completed.',
   'Interrupted': False}},
 {'Task': 'Adding new data to database',
  'Status': 'Running',
  'Description': 'Task started to run now',
  'Interrupted': False})
# And map the ids to make the prediction
df_test = pd.DataFrame(X_test.index, index=X_test.index, columns=["id_california"])
0 1.386100
1 0.785691
2 2.185110
3 3.764664
4 3.319453
... ...
5155 4.088492
5156 0.703170
5157 1.311731
5158 2.087266
5159 2.143795

5160 rows × 1 columns

# Or, if you don't want to modify the previous collection, you can consume the model using the original raw data
0 1.386100
1 0.785691
2 2.185110
3 3.764664
4 3.319453
... ...
5155 4.088492
5156 0.703170
5157 1.311731
5158 2.087266
5159 2.143795

5160 rows × 1 columns