WeightAveraging¶
- class lightning.pytorch.callbacks.WeightAveraging(device=None, use_buffers=True, **kwargs)[source]¶
Bases:
CallbackA callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average (EMA) after each training step.
Arguments given to the constructor will be passed to the
AveragedModelconstructor. If nodeviceis specified, the device of the original model will be used. Contrary toAveragedModel,use_buffersis set toTrueby default. That is, by default the callback will compute running averages for both the parameters and the buffers of the model. Settinguse_bufferstoFalsewill cause only the model parameters to be averaged, leaving updating the batch normalization statistics to the user (usingtorch.optim.swa_utils.update_bn()).You can provide a custom averaging function with the
avg_fnormulti_avg_fnparameter. See theAveragedModelclass for details. If no averaging function is provided, the default is to compute the equally-weighted average of the weights (SWA).You can customize when the average model is updated by overriding the
should_update()method. The callback calls it with eitherstep_idxorepoch_idxand the method returns a boolean indicating whether to update after the given step or epoch. The default is to update after every step.During validation and after the training finishes, the current model parameters will be replaced with the averaged values.
See also the documentation on the weight averaging callbacks provided by Lightning.
Note
To ensure that the
AveragedModelwill contain all layers,setup()will callconfigure_model()before instantiating theAveragedModel. However, that hook is not called in a strategy aware context, sharded models do not work with weight averaging, and a warning will be issued.Example:
from lightning.pytorch.callbacks import WeightAveraging from torch.optim.swa_utils import get_ema_avg_fn class EMAWeightAveraging(WeightAveraging): def __init__(self): super().__init__(avg_fn=get_ema_avg_fn()) def should_update(self, step_idx=None, epoch_idx=None): # Start after 100 steps. return (step_idx is not None) and (step_idx >= 100) trainer = Trainer(callbacks=EMAWeightAveraging(), max_epochs=10) trainer.fit(model, dataloader)
- Parameters:
device¶ (
Union[device,str,int,None]) – By default, theAveragedModelwill be stored on the same device as the original model. If thedeviceargument is provided, theAveragedModelwill be stored on this device instead. If you run out of GPU memory, you might want to use"cpu".use_buffers¶ (
bool) – IfFalse, the buffers of the model will not be averaged.kwargs¶ (
Any) – Additional keyword arguments to be passed to theAveragedModelconstructor, such asavg_fnormulti_avg_fn.
- load_state_dict(state_dict)[source]¶
Called when loading a checkpoint.
Reloads the callback state given a
state_dict.
- on_load_checkpoint(trainer, pl_module, checkpoint)[source]¶
Called when loading a model checkpoint.
Loads the current model and the
AveragedModelparameters from the checkpoint.
- on_save_checkpoint(trainer, pl_module, checkpoint)[source]¶
Called when saving a checkpoint.
Moves the current model state to the key
current_model_state, and places the average model state instate_dictinstead. Any other state variables of theAveragedModelwill be saved inaveraging_state.
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]¶
Called when a training batch ends.
Updates the
AveragedModelparameters, if requested byself.should_update().- Parameters:
- Return type:
- on_train_end(trainer, pl_module)[source]¶
Called when training ends.
Transfers parameters from the
AveragedModelto the current model.- Parameters:
pl_module¶ (
LightningModule) – The currentLightningModuleinstance.
- Return type:
- on_train_epoch_end(trainer, pl_module)[source]¶
Called when a training epoch ends.
Updates the
AveragedModelparameters, if requested byself.should_update().- Parameters:
pl_module¶ (
LightningModule) – The currentLightningModuleinstance.
- Return type:
- on_validation_epoch_end(trainer, pl_module)[source]¶
Called when a validation epoch ends.
Recovers the current model parameters from the
AveragedModel.- Parameters:
pl_module¶ (
LightningModule) – The currentLightningModuleinstance.
- Return type:
- on_validation_epoch_start(trainer, pl_module)[source]¶
Called when a validation epoch begins.
Transfers parameter values from the
AveragedModelto the current model.- Parameters:
pl_module¶ (
LightningModule) – The currentLightningModuleinstance.
- Return type:
- setup(trainer, pl_module, stage)[source]¶
Called when fit, validate, test, predict, or tune begins.
Creates an
AveragedModelwhen fit begins.- Parameters:
pl_module¶ (
LightningModule) – The currentLightningModuleinstance.
- Return type:
- should_update(step_idx=None, epoch_idx=None)[source]¶
Called after every optimizer step and after every training epoch to check whether the average model should be updated.
One of the arguments is set to the zero-based index of the last training step or epoch. The default implementation returns
Truewhen anystep_idxis provided. The user can customize when the average model gets updated by overriding this method.