Multi-Class Classification
How to log your model schema for multiclass classification models
Multi-class classification cases differ based on your model's score and label availability. The case determines the performance metrics available.
Variant | Expected Fields | Performance Metrics | Allowed Metric Families | Arize Representation |
---|---|---|---|---|
Case 1: Supports Only Classification Metrics | prediction label, actual label | Accuracy, Recall, Precision, FPR, FNR, F1, Sensitivity, Specificity | Classification | 1 inference can be represented regardless of prediction label cardinality |
Case 2: Support Classification & AUC Metrics | prediction label, actual label, prediction score, actual score | AUC, PR-AUC, Log Loss, Accuracy, Recall, Precision, FPR, FNR, F1, Sensitivity, Specificity | Classification, AUC/LogLoss | requires 1 inference per prediction label cardinality |
Python Pandas
Python Single Record
Data Connector
Example Row
Class | tier | zip_code | prediction_label | actual_label | timestamp |
---|---|---|---|---|---|
1 | 'gold' | 12345 | "economy" | "business" | 897029940 |
2 | 'platinum' | 542321 | "business" | "business" | 897029940 |
3 | 'silver' | 12312 | "first" | "business" | 897029940 |
schema=Schema(
prediction_id_column_name='prediction_id',
prediction_label_column_name='prediction_label',
actual_label_column_name='actual_label',
feature_column_names=['tier']
tag_column_names=['zip_code']
)
response = arize_client.log(
dataframe=sample_df,
schema=schema,
model_id='sample-model-1',
model_version='1.0',
model_type=ModelTypes.SCORE_CATEGORICAL,
environment=Environments.PRODUCTION
)
For more details on Python Batch API Reference, visit here:
# Predicting likelihood of Economy, Business, or First Class
"""
example_record = {
"prediction_scores":{
"economy_class":0.81,
"business_class":0.42,
"first_class":0.35
},
"prediction": "economy_class",
"actual": "business_class"
}
"""
# Logging only the predicted label and the actual label
response = arize_client.log(
model_id='sample-model-1',
model_version='1.0',
environment=Environments.PRODUCTION,
model_type=ModelTypes.SCORE_CATEGORICAL,
prediction_id= "1",
prediction_label= "economy_class",
actual_label= "business_class",
)
For more information on Python Single Record Logging API Reference, visit here:
Learn how to upload files via various Data Connectors:
Python Pandas
Python Single Record
Data Connector
Example Row
Class | tier | prediction_label_for_economy | actual_label | prediction_score_for_economy | actual_score | timestamp |
---|---|---|---|---|---|---|
1 | 'gold' | "economy" | "business" | 0.81 | 0 | 897029940 |
Class | tier | prediction_label_for_economy | actual_label | prediction_score_for_economy | actual_score | timestamp |
---|---|---|---|---|---|---|
1 | 'gold' | "economy" | "business" | 0.81 | 0 | 897029940 |
Class | tier | prediction_label_for_first | actual_label | prediction_score_for_first | actual_score | timestamp |
---|---|---|---|---|---|---|
3 | 'silver' | "first" | "business" | 0.35 | 0 | 897029940 |
Code Example
#Logging probability of Economy Class
schema=Schema(
prediction_id_column_name='prediction_id',
prediction_label_column_name='prediction_label_for_economy',
prediction_score_column_name='prediction_score_for_economy',
actual_label_column_name='actual_label'
)
response = arize_client.log(
dataframe=sample_df,
schema=schema,
model_id='sample-model-1',
model_version='1.0',
model_type=ModelTypes.MULTICLASS_CLASSIFICATION,
metrics_validation=[Metrics.CLASSIFICATION, Metrics.AUC_LOG_LOSS],
environment=Environments.PRODUCTION
)
#Logging probability of Business Class
schema=Schema(
prediction_id_column_name='prediction_id',
prediction_label_column_name='prediction_label_for_business',
prediction_score_column_name='prediction_score_for_business',
actual_label_column_name='actual_label'
)
response = arize_client.log(
dataframe=sample_df,
schema=schema,
model_id='sample-model-1',
model_version='1.0',
model_type=ModelTypes.MULTICLASS_CLASSIFICATION,
metrics_validation=[Metrics.CLASSIFICATION, Metrics.AUC_LOG_LOSS],
environment=Environments.PRODUCTION
)
#Logging probability of First Class
schema=Schema(
prediction_id_column_name='prediction_id',
prediction_label_column_name='prediction_label_for_first',
prediction_score_column_name='prediction_score_for_first',
actual_label_column_name='actual_label'
)
response = arize_client.log(
dataframe=sample_df,
schema=schema,
model_id='sample-model-1',
model_version='1.0',
model_type=ModelTypes.MULTICLASS_CLASSIFICATION,
metrics_validation=[Metrics.CLASSIFICATION, Metrics.AUC_LOG_LOSS],
environment=Environments.PRODUCTION
)
Arize expects the DataFrame's index to be sorted and begin at 0. If you perform operations that might affect the index prior to logging data, reset the index as follows:
dataframe = dataframe.reset_index(drop=True)
For more details on Python Batch API Reference, visit here:
# Predicting Economy Class, Business Class, First Class
"""
example_record = {
"prediction_scores": {
"economy_class":0.81,
"business_class":0.42,
"first_class":0.35
},
"predicted_class": "economy_class",
"actual": "business_class"
}
"""
# Prediction #1 - Logging probability of Economy Class
response = arize_client.log(
model_id='sample-model-1',
model_version='1.0',
model_type=ModelTypes.SCORE_CATEGORICAL,
environment=Environments.PRODUCTION,
prediction_id="1-economy",
prediction_label= "economy",
prediction_score=0.81,
actual_label= "business_class"
actual_score=0
)
# Prediction #2 - Logging probability of Business Class
response = arize_client.log(
model_id='sample-model-1',
model_version='1.0',
model_type=ModelTypes.SCORE_CATEGORICAL,
environment=Environments.PRODUCTION,
prediction_id="1-business",
prediction_label="business",
prediction_score=0.42
actual_label="business",
actual_score=1
)
# Prediction #3 - Logging probability of First Class
response = arize_client.log(
model_id='sample-model-1',
model_version='1.0',
model_type=ModelTypes.SCORE_CATEGORICAL,
environment=Environments.PRODUCTION,
prediction_id="1-first",
prediction_label="first",
prediction_score=0.35,
actual_label="business",
actual_score=0
)
For more information on Python Single Record Logging API Reference, visit here:
Learn how to upload files via various Data Connectors:
Prediction Label: The classification label of this event (Cardinality > 2)
Actual Label: The ground truth label (Cardinality > 2)
Prediction Score: The likelihood of that prediction class (1 per cardinality of prediction label)
Actual Score: The ground truth score of that class (1 per cardinality of prediction label)
Last modified 2mo ago