diff options
Diffstat (limited to 'dnn/torch/osce/adv_train_model.py')
-rw-r--r-- | dnn/torch/osce/adv_train_model.py | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/dnn/torch/osce/adv_train_model.py b/dnn/torch/osce/adv_train_model.py index 9cd32000..dcfb65f1 100644 --- a/dnn/torch/osce/adv_train_model.py +++ b/dnn/torch/osce/adv_train_model.py @@ -408,6 +408,10 @@ for ep in range(1, epochs + 1): optimizer.step() + # sparsification + if hasattr(model, 'sparsifier'): + model.sparsifier() + running_model_grad_norm += get_grad_norm(model).detach().cpu().item() running_adv_loss += gen_loss.detach().cpu().item() running_disc_loss += disc_loss.detach().cpu().item() |