fim¶
- class trojanzoo.utils.fim.BaseKFAC(net, eps=0.1, sua=False, update_freq=1, alpha=1.0, constraint_norm=False, state_type=BaseState)[source]¶
Base K-FAC Preconditionner for Linear and Conv2d layers.
Compute the K-FAC of the second moment of the gradients. It works for Linear and Conv2d layers and silently skip other layers.
- Parameters:
net (torch.nn.Module) – Network to precondition.
eps (float) – Tikhonov regularization parameter for the inverses.
sua (bool) – Applies SUA approximation.
update_freq (int) – Perform inverses every update_freq updates.
alpha (float) – Running average parameter (if == 1, no r. ave.).
constraint_norm (bool) – Scale the gradients by the squared fisher norm.
- class trojanzoo.utils.fim.KFAC(*args, pi=False, **kwargs)[source]¶
K-FAC Preconditionner for
torch.nn.Linear
andtorch.nn.Conv2d
layers.Compute the K-FAC of the second moment of the gradients. It works for Linear and Conv2d layers and silently skip other layers.
- Parameters:
net (torch.nn.Module) – Network to precondition.
pi (bool) – Computes pi correction for Tikhonov regularization.
eps (float) – Tikhonov regularization parameter for the inverses.
sua (bool) – Applies SUA approximation.
update_freq (int) – Perform inverses every update_freq updates.
alpha (float) – Running average parameter (if == 1, no r. ave.).
constraint_norm (bool) – Scale the gradients by the squared fisher norm.
- class trojanzoo.utils.fim.EKFAC(net, *args, ra=False, **kwargs)[source]¶
EKFAC Preconditionner for
torch.nn.Linear
andtorch.nn.Conv2d
layers.Computes the EKFAC of the second moment of the gradients. It works for Linear and Conv2d layers and silently skip other layers.
- Parameters:
net (torch.nn.Module) – Network to precondition.
eps (float) – Tikhonov regularization parameter for the inverses.
sua (bool) – Applies SUA approximation.
ra (bool) – Computes stats using a running average of averaged gradients instead of using a intra minibatch estimate.
update_freq (int) – Perform inverses every update_freq updates.
alpha (float) – Running average parameter.
constraint_norm (bool) – Scale the gradients by the squared fisher norm.
- class trojanzoo.utils.fim.BaseState[source]¶
A basic storage class.
- Variables:
x (torch.Tensor) –
(N, in, xh, xw)
.gy (torch.Tensor) –
(N, out, yh, yw)
.num_locations (int) –
yh * yw
.
- class trojanzoo.utils.fim.KFACState[source]¶
A storage class for
KFAC
.- Variables:
xxt (torch.Tensor) –
(in [* kh * kw] + 1, in [* kh * kw] + 1)
.ggt (torch.Tensor) –
(out, out)
.ixxt (torch.Tensor) –
(in [* kh * kw] + 1, in [* kh * kw] + 1)
.iggt (torch.Tensor) –
(out, out)
.