model
This module provides utility functions for manipulating PyTorch models, such as replacing layers or loading state_dicts.
adjust_prefix_and_load_state_dict(model, ckpt_path, ckpt_to_model_prefix=None, layers_to_ignore=None)
This function loads a state dictionary from a checkpoint file into a model using torch.load(strict=False)
.
It supports remapping layer names between the checkpoint and model through the ckpt_to_model_prefix
parameter.
This is useful when loading weights from a model that was trained as part of a larger architecture, where the layer names may not match the standalone version of the model.
Before using ckpt_to_model_prefix
, it's recommended to:
1. Check the layer names in both the checkpoint and target model
2. Map the mismatched prefixes accordingly
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
The model to load the state_dict into. |
required |
ckpt_path
|
str
|
The path to the checkpoint file. |
required |
ckpt_to_model_prefix
|
dict[str, str] | None
|
Mapping of checkpoint prefixes to model prefixes. |
None
|
layers_to_ignore
|
List[str] | None
|
Layers to ignore when loading the state_dict. |
None
|
Returns:
Name | Type | Description |
---|---|---|
Module |
Module
|
The model with the loaded state_dict. |
Raises:
Type | Description |
---|---|
ValueError
|
If there is no overlap between the checkpoint's and model's state_dict. |
Source code in lighter/utils/model.py
remove_n_last_layers_sequentially(model, num_layers=1)
Removes a specified number of layers from the end of a model and returns it as a Sequential model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module()
|
The model to modify. |
required |
num_layers
|
The number of layers to remove from the end. |
1
|
Returns:
Name | Type | Description |
---|---|---|
Sequential |
Sequential
|
The modified model as a Sequential container. |
Source code in lighter/utils/model.py
replace_layer_with(model, layer_name, new_layer)
Replaces a specified layer in a PyTorch model with a new layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
The model to modify. |
required |
layer_name
|
str
|
The name of the layer to replace, using dot notation if necessary (e.g. "layer10.fc.weights"). |
required |
new_layer
|
Module
|
The new layer to insert. |
required |
Returns:
Name | Type | Description |
---|---|---|
Module |
Module
|
The modified model with the new layer. |
Source code in lighter/utils/model.py
replace_layer_with_identity(model, layer_name)
Replaces a specified layer in a PyTorch model with an Identity layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
The model to modify. |
required |
layer_name
|
str
|
The name of the layer to replace with an Identity layer, using dot notation if necessary (e.g. "layer10.fc.weights"). |
required |
Returns:
Name | Type | Description |
---|---|---|
Module |
Module
|
The modified model with the Identity layer. |