Adding new models
The model
module contains all of the architectures implemented as part of this package. We offer GAN and VAE based architectures with a number of adjustments to achieve privacy and other augmented functionalities. The module handles the training and generation of synthetic data using these architectures, per a user's choice of model(s) and configuration.
It is likely that as the literature matures, more effective architectures will present themselves as promising for application to the type of tabular data NHSSynth
is designed for. Below we discuss how to add new models to the package.
Model design
The models in this package are built entirely in PyTorch and use Opacus for differential privacy.
We have built the VAE and (Tabular)GAN implementations in this package to serve as the foundations for a number of other architectures. As such, we try to maintain a somewhat modular design to building up more complex differentially private (or otherwise augmented) architectures. Each model inherits from either the GAN
or VAE
class (in files of the same name) which in turn inherit from a generic Model
class found in the common
folder. This folder contains components of models which are not to be instantiated themselves, e.g. a mixin class for differential privacy, the MLP underlying the GAN
and so on.
The Model
class from which all of the models derive handles all of the general attributes. Roughly, these are the specifics of the dataset the instance of the model is relative to, the device that training is to be carried out upon, and other training parameters such as the total number of epochs to execute.
We define these things at the model level, as when using differential privacy or other privacy accountant methods, we must know ahead of time the data and length of training exposure in order to calculate the levels of noise required to reach a certain privacy guarantee and so on.
Implementing a new model
In order to add a new architecture then, it is important to first investigate the modular parts already implemented to ensure that what you want to build is not already possible through the composition of these existing parts. Then you must ensure that your architecture either inherits from the GAN
or VAE
, or Model
if you wish to implement a different type of generative model.
In all of these cases, the interface expects for the implementation to have the following methods:
get_args
: a class method that lists the architecture specific arguments that the model requires. This is used to facilitate default arguments in the python API whilst still allowing for arguments in the CLI to be propagated and recorded automatically in the experiment output. This should be a list of variable names equal to the concatenation of all of the non-Model
parent classes (e.g.DPVAE
hasDP
andVAE
args) plus any architecture specific arguments in the__init__
method of the model in question.get_metrics
: another class method that behaves similarly to the above, should return a list of valid metrics to track during training for this modeltrain
: a method handling the training loop for the model. This should takenum_epochs
,patience
anddisplayed_metrics
as arguments and return a tuple containing the number of epochs that were executed plus a bundle of training metrics (the values over time returned byget_metrics
on the class). In the execution of this method, the utility methods defined inModel
should be called in order,_start_training
at the beginning, then_record_metrics
at each training step of the data loader, and finally_finish_training
to clean up progress bars and so on.displayed_metrics
determines which metrics are actively displayed during training.generate
: a method to call on the trained model which generatesN
samples of data, and calls the model's associatedMetaTransformer
to return a valid pandas DataFrame of synthetic data ready to output.
Adding a new model to the CLI
Once you have implemented your new model, you must add it to the CLI. To do this, we must first export the model's class into the MODELS
constant in the __init__
file in the models
subfolder. We can then add a new function and option in module_arguments.py
to list the arguments and their types unique to this type of architecture.
Note that you should not duplicate arguments that are already defined in the Model
class or foundational model architectures such as the GAN
if you are implementing an extension to it. If you have setup get_args
correctly all of this will be propagated automatically.