|
From Machine Learning Engineering in Action by Ben Wilson This article discusses artifact management and using passive retraining to improve model accuracy over time, as well as other helpful tips. |
Let’s imagine that we’re working at the fire risk department of the forest service. In our efforts to effectively dispatch personnel and equipment to high-risk areas in the park system, we’ve arrived at a solution that works remarkably well. Our features are generally locked-in and performing well. We’ve evaluated the performance of the predictions and are seeing genuine value from the model.
Throughout this process of getting the features into a good state, we’ve been iterating through the improvement cycle as shown in figure 1.
Figure 1 Improvements to a deployed model on the road to production steady-state operation.
As the cycle in figure 1 shows, we’ve been iteratively releasing new versions of the model, testing against a baseline deployment, collecting feedback, and working to improve the predictions. At some point, however, we’re going to be going into model sustaining mode. We’ve worked as hard as we can to improve the features going into the model and have found that the return on investment (ROI) of continuing to add new elements of data to the project is simply not worth it. We’re now in the position of scheduled passive retraining of our model based on new data coming in over time.
When we’re at this ‘steady-state’ point, the last thing that we want to do is to have one of the DS team members spend an afternoon manually retraining a model, manually comparing its results to the current production-deployed model with ad-hoc analysis and deciding on whether the model should be updated.
Oh, come on. No one does this manually.
From my own history as a DS, I didn’t start using passive retraining for the first 6 years of solving problems. It wasn’t due to a lack of need, nor was it due to a lack of tooling. It was pure and simple ignorance. I had no idea how big of a problem drift could be (I learned that the hard way several times over by having solution devolve into irrelevance due to my neglect), nor did I understand or appreciate the importance of attribution calculations.
Over years of repeatedly screwing up my solutions, I found techniques that others had written about through researching solutions to my self-imposed woes of inadequately engineered projects. I came to embrace the ideas that led me to DS work to begin with: automating annoying and repetitive tasks. By removing the manual activity of monitoring the health of my projects (drift tracking in an ad-hoc manner), I found that I had solved two primary problems that were plaguing me.
Firstly, I freed up my time. Doing ad-hoc analyses on prediction results and feature stability takes a lot of time. To add to that, it’s incredibly boring work. The second big problem was in accuracy. Manual evaluation of model performance is repetitive and error prone. Missing details through a manual analysis can mean deploying a model version that is worse than the currently deployed one, introducing issues that are far more significant than a slightly poorer prediction performance.
I’ve learned my lesson about automating retraining (typically opting for passive retraining systems rather than the far more complex active ones if I can get away with it). Like everything else I’ve learned in my career, I’ve learned it by screwing it up. Hopefully you can avoid the same fate.
All of this can be automated with passive retraining. Figure 2 shows this concept of a scheduled retraining event.
Figure 2 A passive retraining system logical diagram
With the automation of scheduled retraining in place as shown in figure 2, the primary concern with this system is in knowing what is running in production. For instance, what happens if a problem is uncovered in production once a new version is released? What can we do to recover from a concept drift that has dramatically affected a retraining event? How do we rollback the model to the previous version without having to rebuild a model?
We can allay these concerns by using a model registry.
MLflow’s model registry
In this situation that we find ourselves in, with scheduled updates to a model happening autonomously, it is important for us to know the state of production deployment. Not only do we need to know what it is right now, but if questions arise about performance of a passive retraining system in the past, we need to have a means of investigating the historical provenance of the model.
Figure 3 shows a comparison of what a request for explanation on a historical issue would look like between using and not using a registry for provenance tracking.
Figure 3 Passive retraining schedule with a historic issue found far in the future.
As is shown in figure 3, the process for attempting to recreate a past run is fraught with peril and has a high risk of being unable to reproduce the issue that the business has found in historical predictions. With no registry to record the artifacts that were utilized in production, manual work must be done to recreate the conditions that the model that had been running was built within. This can be incredible challenging (if not impossible) in most companies due to the fact that changes to the underlying data that was used to train the model may have occurred, rendering it impossible to recreate that state.
The preferred approach, as shown in figure 3, is to utilize a model registry service. MLflow, for instance, offers exactly this functionality within its APIs, allowing us to log details of each retraining run to the tracking server, handle production promotion if the scheduled retraining job performs better on holdout data, and archive the older model for future reference. If we had used this framework, the process of testing conditions of a model that had at one point run in production would be as simple as recalling the artifact from the registry entry, load it into a notebook environment, and generate the explainable correlation reports with tools such as a shap
.
Is a registry really that important?
Well, in two words: “it depends”.
I remember with a distinct spine-chilling horror one of my first “major real-deal no-kidding really-serious” ML implementations that I built. It wasn’t by any means my first “production” release of a solution, but it was the first one that actually had some serious attention being paid to it. It helped to run a rather significant part of the business, and as such, was closely scrutinized by a lot of people. Rightfully so, if I may add.
My “deployment” (if it could be called that) involved a passive-like retraining system that stored the last known good hyperparameters of the previous day’s tuning run, using those values as a starting point to begin automated tuning. After optimizing to all of the new feature training data that was available, it chose the best performing model, ran a prediction on the new data, and overwrote a serving table with the predictions.
It wasn’t until a full 3 months into the project’s production run that the first serious question came up regarding why the model was predicting in an unexpected way with certain customers. The business couldn’t figure out why it was doing that, so they approached me and asked for me to investigate.
Having no record of the model (it wasn’t even saved off anywhere) and realizing that the training data was changing consistently over time as the features updated rendered my ability to explain historical performance of the model completely impossible.
The business was less than pleased with this answer. Although the model didn’t get shut off (it probably should have, though), it made me realize the importance of storing and cataloguing models for the precise reason of being able to explain “why” the solution behaves the way it does, even if that explanation is months past the point in which it was being used.
Interfacing with the Model Registry
To get a feel for how this code would look to support an integration with the model registry service of MLflow, let’s adapt our use case to support this passive retraining functionality. To start with, we’ll need to create an adjudication system that checks the current production model’s performance against the scheduled retraining results. After having that comparison built, we can interface with the registry service to replace the current production model with the newer model (if it’s better), or to stay with the current production model based on its performance against the same holdout data that the new model was tested against.
Let’s look at an example of how to interface with the MLflow model registry to support automated passive retraining that retains provenance of the model’s state over time. Listing 1 establishes the first portion of what we need to build to have a historical status table of each of the scheduled retraining events.
Listing 1 Registry state row generation and logging
@dataclass class Registry: #A model_name: str production_version: int updated: bool training_time: str class RegistryStructure: #B def __init__(self, data): self.data = data def generate_row(self): spark_df = spark.createDataFrame(pd.DataFrame([vars(self.data)])) #C return (spark_df.withColumn("training_time", F.to_timestamp(F.col("training_time"))) .withColumn("production_version", F.col("production_version").cast("long"))) class RegistryLogging: def __init__(self, database, table, delta_location, model_name, production_version, updated): self.database = database self.table = table self.delta_location = delta_location self.entry_data = Registry(model_name, production_version, updated, self._get_time()) #D @classmethod def _get_time(self): return datetime.today().strftime('%Y-%m-%d %H:%M:%S') def _check_exists(self): #E return spark._jsparkSession.catalog().tableExists(self.database, self.table) def write_entry(self): #F log_row = RegistryStructure(self.entry_data).generate_row() log_row.write.format("delta").mode("append").save(self.delta_location) if not self._check_exists(): spark.sql(f"""CREATE TABLE IF NOT EXISTS {self.database}.{self.table} USING DELTA LOCATION '{self.delta_location}';""")
#A A data class to wrap the data we’re going to be logging.
#B Class for converting the registration data to a Spark Dataframe to write a row to a Delta Table for provenance.
#C Accessing the members of the data class in a short-hand fashion to cast to a Pandas Dataframe and then a Spark DataFrame (leveraging implicit type inferences).
#D Building the Spark DataFrame Row at class initialization.
#E Method for determining if the Delta table has been created yet or not.
#F Writing the log data to Delta in append mode and creating the table reference in the Hive metastore if it doesn’t already exist.
The code shown in listing 1 helps to set the stage for provenance of the model training history. Since we’re looking to automate the retraining on a schedule, it’s far easier to have a tracking table to refer to what the history of changes have been in a centralized location. If we have multiple builds of this model, as well as other projects that are registered, we can have a single snapshot view of the state of production passive retraining without needing to do anything more than writing a simple query. Listing 2 illustrates what a query of this table would look like. With multiple models logged to a transaction history table like this, an addition of a df.filter(F.col(“model_name” == “”)
allows for rapid access to the historical log for a single model.
Listing 2 Querying the registry state table
from pyspark.sql import functions as F REGISTRY_TABLE = "mleng_demo.registry_status" display(spark.table(REGISTRY_TABLE).orderBy(F.col(“training_time”)) #A
#A Since we’ve registered the table in our row-input stage earlier, we can refer to it directly by . reference. We can then order the commits chronologically.
Executing the code in listing 2 results in what is shown in figure 4.
Figure 4 Querying the registry state transaction table
In addition to this log, there is also a graphical user interface (GUI) for the model registry within MLflow. Figure 5 shows a screen capture of the GUI that matches to the registry table from listing 2.
Figure 5 The MLflow Model Registry GUI for our experiments
Now that we have the historical tracking functionality set up, we can write the interface to MLflow’s registry server to support passive retraining. Listing 3 shows the implementation for leveraging the tracking server’s entries, the registry service for querying current production metadata, and an automated state transition of the retrained model for supplanting the current production model if it performs better.
Listing 3 Passive retraining model registration logic
class ModelRegistration: def __init__(self, experiment_name, experiment_title, model_name, metric, direction): self.experiment_name = experiment_name self.experiment_title = experiment_title self.model_name = model_name self.metric = metric self.direction = direction self.client = MlflowClient() self.experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id def _get_best_run_info(self, key): #A run_data = mlflow.search_runs(self.experiment_id, order_by=[f"metrics.{self.metric} {self.direction}"]) return run_data.head(1)[key].values[0] def _get_registered_status(self): return self.client.get_registered_model(name=self.experiment_title) def _get_current_prod(self): #B return ([x.run_id for x in self._get_registered_status().latest_versions if x.current_stage == "Production"][0]) def _get_prod_version(self): return int([x.version for x in self._get_registered_status().latest_versions if x.current_stage == "Production"][0]) def _get_metric(self, run_id): return mlflow.get_run(run_id).data.metrics.get(self.metric) def _find_best(self): #C try: current_prod_id = self._get_current_prod() prod_metric = self._get_metric(current_prod_id) except mlflow.exceptions.RestException: current_prod_id = -1 prod_metric = 1e7 best_id = self._get_best_run_info('run_id') best_metric = self._get_metric(best_id) if self.direction == "ASC": if prod_metric < best_metric: return current_prod_id else: return best_id else: if prod_metric > best_metric: return current_prod_id else: return best_id def _generate_artifact_path(self, run_id): return f"runs:/{run_id}/{self.model_name}" def register_best(self, registration_message, logging_location, log_db, log_table): #D best_id = self._find_best() try: current_prod = self._get_current_prod() current_prod_version = self._get_prod_version() except mlflow.exceptions.RestException: current_prod = -1 current_prod_version = -1 updated = current_prod != best_id if updated: register_new = mlflow.register_model(self._generate_artifact_path(best_id), self.experiment_title) self.client.update_registered_model(name=register_new.name, description="Forest Fire Prediction for the National Park") self.client.update_model_version(name=register_new.name, version=register_new.version, description=registration_message) self.client.transition_model_version_stage(name=register_new.name, version=register_new.version, stage="Production") if current_prod_version > 0: self.client.transition_model_version_stage(name=register_new.name, version=current_prod_version, stage="Archived") RegistryLogging(log_db, log_table, logging_location, self.experiment_title, int(register_new.version), updated).write_entry() return "upgraded prod" else: RegistryLogging(log_db, log_table, logging_location, self.experiment_title, int(current_prod_version), updated).write_entry() return "no change" def get_model_as_udf(self): #E prod_id = self._get_current_prod() artifact_uri = self._generate_artifact_path(prod_id) return mlflow.pyfunc.spark_udf(spark, model_uri=artifact_uri)
#A Extracting all of the previous run data for the history of the production deployment and to return the run id that has the best performance against the validation data.
#B Query for the model that is current registered as ‘production deployed’ in the registry.
#C Method for determining if the current scheduled passive retraining run is performing better than production on its holdout data. It will return the run_id of the best logged run.
#D Utilize the MLflow Model Registry API to register the new model if it is better and de-register the current production model if it’s being replaced.
#E Acquiring the current production model for batch inference on a Spark DataFrame using a PythonUDF.
The code shown in listing 3 allows us to fully manage the passive retraining of this model implementation (see the companion GitHub repository for this book for the full code). By leveraging the MLflow Model Registry API, we can meet the needs of production scheduled predictions through having a one-line access to the model artifact. This simplifies the prediction batch scheduled job greatly, but also meets the needs of the investigation we began discussing in this section. With the business asking questions about why behavior in the past was occurring, this approach and general model provenance architecture allows us to rapidly determine which model was being used during the time in question. Having the ability to retrieve the model with such ease, we can manually test the feature data against that model, run simulations with the use of tools like shap, and rapidly answer their questions without having to struggle with recreating a potentially impossible state.
Building features like these into your ML infrastructure may not seem particularly important when you’re focused on meeting a production release deadline, but the minor amount of work involved in creating them can save dozens (if not hundreds) of hours of work later when the inevitable questions arise about model performance during a time that is far removed from the current state of production. After all, models can, should, and will change over time as we battle against the constant states of drift that are affecting their performance. Trying to recreate something that you’ve already spent time building is a complete waste of time and energy.
In the same vein of using a model registry to keep track of the model artifacts, the features being used to train models and predict with the use of models can be catalogued for efficiency’s sake as well. This concept is materialized through the form of feature stores.
That’s cool and all, but what about active retraining?
The only primary difference to the approach employed in a passive retraining to the needs of an active retraining architecture is in the mechanism of initiating retraining.
Passive, scheduled by CRON, is a ‘best hope’ strategy that attempts to find an improved model fit by incorporating new training data in the effort to counteract drift. Active, on the other hand, monitors the state of predictions and features to determine algorithmically when it makes sense to trigger a retraining.
An active system can be beneficial if drift is happening at unpredictable rates (i.e., a model has been performing well for weeks, falls apart in the span of a few days, gets retrained, and performs well for only a few days before needing retraining) as it is by design purposed to respond to unpredictable performance degradation. To create this responsive feedback loop to trigger a retraining event, monitoring of the prediction quality needs to occur. A system for ingesting the predictions, merging the highly variable nature of ground-truth results that arrive at a later point (in some cases seconds, other times weeks later), and effectively setting statistically significant thresholds on aggregated result states over time needs to be built to generate a retraining signal.
These systems are highly dependent on the nature of the problem being solved by the ML solution, and as such, vary in their design and implementation so much that even a generic example architecture is irrelevant for presentation here.
For instance, if you’re trying to determine the success of a model’s ability to predict the weather in the next hour in a location, you can get feedback within an hour. You could build a system that merges the hour-lagged real weather against the predictions, feeding the actual model accuracy into a windowed aggregation of accuracy rate over the last 48 hours. Should the aggregated rate of success in weather forecasting drop below a defined threshold of 70%, a retraining of the model can be initiated autonomously. This newly trained model can be compared against the current production model through validating both models through a standard (new) holdout validation data set and the new model can be used either entirely using a blue/green deployment strategy or dynamically allocated using a multi-armed bandit algorithm that routes traffic based on relative performance improvement compared to the current production model.
Active training is complex, in a nutshell. It’s something that I recommend people investigate after they find that passive retraining simply “isn’t cutting it anymore”, rather than a complexity to employ just because it seems like it’s important. There are far more moving parts, services, and infrastructure to handle when autonomously handling retraining. The cloud services bill that you get when using one will reflect the increase in complexity as well (it’s expensive).
That’s all for this article. If you want to learn more, check out the book here.