ocl.conditioning
Implementation of conditioning approaches for slots.
RandomConditioning
Random conditioning with potentially learnt mean and stddev.
Source code in ocl/conditioning.py
LearntConditioning
Conditioning with a learnt set of slot initializations, similar to DETR.
Source code in ocl/conditioning.py
__init__
Initialize LearntConditioning.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object_dim |
int
|
Dimensionality of the conditioning vector to generate. |
required |
n_slots |
int
|
Number of conditioning vectors to generate. |
required |
slot_init |
Optional[Callable[[torch.Tensor], None]]
|
Callable used to initialize individual slots. |
None
|
Source code in ocl/conditioning.py
forward
Generate conditioining vectors for batch_size
instances.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch_size |
int
|
Number of instances to create conditioning vectors for. |
required |
Returns:
Type | Description |
---|---|
ocl.typing.ConditioningOutput
|
The conditioning vectors. |
Source code in ocl/conditioning.py
RandomConditioningWithQMCSampling
Bases: RandomConditioning
Random gaussian conditioning using Quasi-Monte Carlo (QMC) samples.
Source code in ocl/conditioning.py
__init__
Initialize RandomConditioningWithQMCSampling.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object_dim |
int
|
Dimensionality of the conditioning vector to generate. |
required |
n_slots |
int
|
Number of conditioning vectors to generate. |
required |
learn_mean |
bool
|
Learn the mean vector of sampling distribution. |
True
|
learn_std |
bool
|
Learn the std vector for sampling distribution. |
True
|
mean_init |
Callable[[torch.Tensor], None]
|
Callable to initialize mean vector. |
torch.nn.init.xavier_uniform_
|
logsigma_init |
Callable[[torch.Tensor], None]
|
Callable to initialize logsigma. |
torch.nn.init.xavier_uniform_
|
Source code in ocl/conditioning.py
forward
Generate conditioning vectors for batch_size
instances.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch_size |
int
|
Number of instances to create conditioning vectors for. |
required |
Returns:
Type | Description |
---|---|
ocl.typing.ConditioningOutput
|
The conditioning vectors. |
Source code in ocl/conditioning.py
SlotwiseLearntConditioning
Random conditioning with learnt mean and stddev for each slot.
Removes permutation equivariance compared to the original slot attention conditioning.
Source code in ocl/conditioning.py
__init__
Initialize SlotwiseLearntConditioning.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object_dim |
int
|
Dimensionality of the conditioning vector to generate. |
required |
n_slots |
int
|
Number of conditioning vectors to generate. |
required |
mean_init |
Callable[[torch.Tensor], None]
|
Callable to initialize mean vector. |
torch.nn.init.normal_
|
logsigma_init |
Callable[[torch.Tensor], None]
|
Callable to initialize logsigma. |
torch.nn.init.xavier_uniform_
|
Source code in ocl/conditioning.py
forward
Generate conditioning vectors for batch_size
instances.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch_size |
int
|
Number of instances to create conditioning vectors for. |
required |
Returns:
Type | Description |
---|---|
ocl.typing.ConditioningOutput
|
The conditioning vectors. |