from collections import OrderedDict
from minerva.utils.typing import PathLike
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import wrapt
import re
[docs]
class LoadableModule:
# Interface for loadable modules. This is a dummy class that should be
# inherited by classes that can be loaded from a file.
# Allows type hinting for classes that can be loaded from a file.
pass
[docs]
class FromPretrained(wrapt.ObjectProxy, LoadableModule):
# Your docstrings and previous code remain unchanged...
def __init__(
self,
model: torch.nn.Module,
ckpt_path: Optional[PathLike] = None,
filter_keys: Optional[List[str]] = None,
keys_to_rename: Optional[Dict[str, str]] = None,
strict: bool = False,
ckpt_key: Optional[str] = "state_dict",
extractor: Optional[ModuleExtractor] = None,
error_on_missing_keys: bool = True,
ckpt_load_weights_only: bool = True,
):
"""Load a model from a checkpoint file and wrap it in a FromPretrained
object. The FromPretrained object acts as a proxy to the model, allowing
to call it as if it was the original model. All the attributes and
methods of the model are accessible through the FromPretrained object
directly.
Parameters
----------
model : torch.nn.Module
The model to be loaded (initialized randomly).
ckpt_path : Optional[PathLike], optional
The path to the checkpoint file from which the model will be loaded.
If None, the model will be loaded without any state_dict, that is,
nothing will be done to the model, it will remain as it is. By
default None
filter_keys : Optional[List[str]], optional
List of regular expressions to filter keys from the state_dict.
Only keys that match any of the regular expressions will be kept.
If None, all keys will be kept. By default None.
keys_to_rename : Optional[Dict[str, str]], optional
A dictionary with keys being regular expressions and values being
prefixes to be added to the keys that match the regular expressions.
If prefix is an empty string, the matched part of the key will be
removed. The keys that do not match any regular expression will
remain the same. If a key matches multiple regular expressions, the
first one will be used. Finally, if a empty string is used as key,
all keys will have the prefix added (this have priority over other
keys). By default None
strict : bool, optional
If True, the state_dict must match the keys of the model exactly.
If False, the state_dict can have extra keys that will be ignored.
By default False
ckpt_key : Optional[str], optional
The key in the checkpoint file where the state_dict is stored. If
None, the whole checkpoint will be used as state_dict. Else, the
value of the key will be used as state_dict. By default "state_dict".
extractor : Optional[ModuleExtractor], optional
Once model is loaded, the extractor will be called with the model
as argument. The extractor should return the desired submodel (for
instance, without some final layers). By default None
error_on_missing_keys : bool, optional
If True, raise an error if some keys are missing in the state_dict
when loading the model. If False, ignore missing keys.
By default True
ckpt_load_weights_only : bool, optional
If True, load only the weights from the checkpoint. If False, load
the whole checkpoint. By default True
"""
super().__init__(model)
self.__wrapped__ = model
if ckpt_path is not None:
# Load the state_dict from the checkpoint
ckpt = torch.load(
ckpt_path,
map_location="cpu",
weights_only=ckpt_load_weights_only,
)
# Get the state_dict from the checkpoint
if ckpt_key is not None:
state_dict = ckpt.get(ckpt_key, ckpt)
else:
state_dict = ckpt
# Filter keys if needed
if filter_keys is not None:
d = OrderedDict()
# Iterate over all keys in the state_dict
for k, v in state_dict.items():
# Check if key name matches any of the filter keys
# If it does, add it to the new state_dict (and break)
# Thus, if multiple filter keys match the same key, the
# first one will be used
# If no filter matches, the key will be ignored (not added)
for pattern in filter_keys:
if re.search(pattern, k):
d[k] = v
break
# Update the state_dict
state_dict = d
# Rename keys with prefix
if keys_to_rename is not None:
print(f"Performing key renaming with: {keys_to_rename}")
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_k = k
if "" in keys_to_rename:
new_k = f"{keys_to_rename['']}{k}"
else:
# Iterate over all keys to rename. If a key matches the
# regular expression, add the prefix to the key
# and break. Else keep the key as it is.
for old_key, new_prefix in keys_to_rename.items():
if re.match(old_key, k):
# If the new prefix is an empty string, all keys
# we will remove the matched part of the key
if new_prefix == "":
new_k = re.sub(old_key, new_prefix, k)
break
# If the new prefix is not an empty string, we
# will add the prefix to the key
else:
# new_k = f"{new_prefix}{k}"
new_k = re.sub(old_key, new_prefix, k)
break
else:
continue
print(
f"\tRenaming key: {k} -> {new_k} (changed: {k != new_k})"
)
new_state_dict[new_k] = v
state_dict = new_state_dict
# Load the modified state_dict
missing_keys, unexpected_keys = self.__wrapped__.load_state_dict(
state_dict, strict=strict
)
if error_on_missing_keys and missing_keys:
raise ValueError(f"Missing keys: {missing_keys}")
print(f"Model loaded from {ckpt_path}")
# Print message with missing and unexpected keys
if missing_keys:
print(
f"When loading model, the following keys are missing: {missing_keys}"
)
if unexpected_keys:
print(
f"When loading model, the following keys are unexpected: {unexpected_keys}. ",
end="",
)
if not strict:
print("Ignoring unexpected keys.")
else:
print()
else:
print("WARNING: Model loaded without state_dict.")
if extractor is not None:
print("Extracting submodel...")
self.__wrapped__ = extractor(self.__wrapped__)
# Aditional methods to make wrapped object callable and to allow pickling
[docs]
def __getattr__(self, name):
return getattr(self.__wrapped__, name)
[docs]
def __call__(self, *args, **kwargs):
return self.__wrapped__(*args, **kwargs)
[docs]
def __reduce_ex__(self, proto):
return self.__wrapped__.__reduce_ex__(proto)
[docs]
def __repr__(self):
return self.__wrapped__.__repr__()
[docs]
def __str__(self):
return self.__wrapped__.__str__()
[docs]
class FromModel(wrapt.ObjectProxy, LoadableModule):
"""This class loads a complete model (pickable) from a model file, extract
the desired submodel and wraps it in a FromModel object. The FromModel
object acts as a proxy to the submodel, allowing to call it as if it was the
original model. All the attributes and methods of the submodel are
accessible through the FromModel object directly.
"""
def __init__(
self,
model_path: PathLike,
extractor: Callable[[torch.nn.Module], torch.nn.Module] = None,
):
"""This class perform the following steps:
1. Load the whole model from the model file (pickable).
2. Extract the desired submodel using the extractor function.
Parameters
----------
model_path : PathLike
Path to the model file from which the model will be loaded.
extractor : Callable[[torch.nn.Module], torch.nn.Module], optional
The extractor function to be used to extract the desired submodel
from the loaded model. The default is None, that is, use the module
as it is loaded from the file.
"""
model = torch.load(model_path, map_location="cpu")
super().__init__(model)
self.__wrapped__ = model
if extractor is not None:
self.__wrapped__ = extractor(self.__wrapped__)
[docs]
def __getattr__(self, name):
return getattr(self.__wrapped__, name)
[docs]
def __call__(self, *args, **kwargs):
return self.__wrapped__(*args, **kwargs)