optim¶
- class trojanzoo.optim.Optimizer(iteration=20, stop_threshold=None, loss_fn=None, **kwargs)[source]¶
An abstract input optimizer class that inherits
trojanzoo.utils.module.Process
.- Parameters:
iteration (int) – Optimization iteration. Defaults to
20
.stop_threshold (float | None) – Threshold used in early stop check. Defaults to
None
(no early stop).loss_fn (Callable) – Loss function (it’s usually
reduction='none'
).**kwargs – Keyword Arguments passed to
trojanzoo.utils.module.Process
.
- early_stop_check(*args, current_idx=None, adv_input=None, loss_values=None, loss_fn=None, stop_threshold=None, loss_kwargs={}, **kwargs)[source]¶
Early stop check using
stop_threshold
.- Parameters:
current_idx (torch.Tensor) – The indices of
adv_input
need to check (Other indices have early stopped).adv_input (torch.Tensor) – The entire batched adversairl input tensor with shape
(N, *)
.loss_values (torch.Tensor) – Batched loss tensor with shape
(N)
. IfNone
, useloss_fn
andadv_input
to calculate. Defaults toNone
.loss_fn (collections.abc.Callable | None) – Loss function (it’s usually
reduction='none'
). Defaults toself.loss_fn
.stop_threshold (float | None) – Threshold used in early stop check.
None
means usingself.stop_threshold
. Defaults toself.stop_threshold
.loss_kwargs (dict[str, torch.Tensor]) – Keyword arguments passed to
loss_fn
, which will also be selected according tocurrent_idx
.*args – Any positional argument (unused).
**kwargs – Any keyword argument (unused).
- Returns:
torch.Tensor – Batched
torch.BoolTensor
with shape(N)
.
- optimize(_input, *args, iteration=None, loss_fn=None, stop_threshold=None, output=None, **kwargs)[source]¶
Main optimize method.
- Parameters:
_input (torch.Tensor) – The batched input tensor to optimize.
iteration (int | None) – Optimization iteration. Defaults to
self.iteration
.loss_fn (Callable) – Loss function (it’s usually
reduction='none'
). Defaults toself.loss_fn
.stop_threshold (float | None) – Threshold used in early stop check.
None
means usingself.stop_threshold
. Defaults toself.stop_threshold
.output (int | Iterable[str]) – Output level integer or output items. If
int
, callget_output_int()
. Defaults toself.output
.
- Returns:
(torch.Tensor, torch.Tensor) – batched adversarial input tensor and batched optimization iterations (
-1
if not reachingself.threshold
).
- output_info(*args, mode='start', _iter=0, iteration=0, output=None, indent=None, **kwargs)[source]¶
Output information.
- Parameters:
mode (str) – The output mode (e.g.,
'start', 'end', 'middle', 'memory'
). Should be legal strings inget_output_int()
. Defaults to'start'
._iter (int) – Current iteration. Defaults to
0
.iteration (int) – Total iteration. Defaults to
0
.output (Iterable[str]) – Output items. Defaults to
self.output
.indent (int) – The space indent for the entire string. Defaults to
self.indent
.*args – Any positional argument (unused).
**kwargs – Any keyword argument (unused).
- preprocess_input(*args, adv_input=None, org_input=None, **kwargs)[source]¶
Optimize input tensor for 1 iteration.
- Parameters:
adv_input (torch.Tensor) – The entire batched adversairl input tensor with shape
(N, *)
.org_input (torch.Tensor) – The entire batched original input tensor with shape
(N, *)
.
- abstract update_input(current_idx, adv_input, org_input, *args, **kwargs)[source]¶
Optimize input tensor for 1 iteration.
- Parameters:
current_idx (torch.Tensor) – The indices of
adv_input
need to optimize (Other indices have early stopped).adv_input (torch.Tensor) – The entire batched adversairl input tensor with shape
(N, *)
.org_input (torch.Tensor) – The entire batched original input tensor with shape
(N, *)
.