wound_segmentation.checkpoints module
Utilities for downloading and managing model checkpoints.
This module provides a function to automatically download pretrained model weights from Google Drive if they are not found locally. The downloaded archive is extracted into the default model directory, and the ZIP file is deleted afterward.
Logging is used to provide feedback during the download and extraction processes. This module is typically used in the entry point script (main.py) to ensure model weights are available before inference.
Functions
- download_weights(gdrive_id, model_name)
This function downloads zipped weights from Google Drive, extracts them to a target directory, and logs the progress. It raises a RuntimeError if any step fails.
Notes
The Google Drive ID and default model name are configured in 'constants.py'.
- wound_segmentation.checkpoints.download_weights(gdrive_id: str = '1Rldcue5dVF2XTp1kx3G5hBvG3xApwSaS', model_name: str = 'segmentation_model.weights.h5') None
Download and save model weights from Google Drive.
This function downloads a zip file containing model weights using the specified Google Drive file ID, extracts the contents to the default model directory, and deletes the zip.
- Parameters:
gdrive_id (str) -- Google Drive file ID of the zipped weights.
model_name (str) -- Filename of the model weights (.weights.h5).
- Raises:
RuntimeError -- If download or extraction fails.
Examples
>>> download_weights() >>> model = load_segmentation_model("model_weights/segmentation_model.weights.h5")