Shortcuts

Source code for ride.feature_extraction

from operator import attrgetter
from pathlib import Path

import numpy
import torch

from ride.core import Configs, RideMixin
from ride.logging import get_log_dir
from ride.metrics import MetricDict
from ride.utils.io import bump_version
from ride.utils.logging import getLogger
from ride.utils.utils import rgetattr

[docs]logger = getLogger(__name__)
[docs]class FeatureExtractable(RideMixin): """Adds feature extraction capabilities to model"""
[docs] hparams: ...
@staticmethod
[docs] def configs() -> Configs: c = Configs() c.add( name="extract_features_after_layer", default="", type=str, description=( "Layer name after which to extract features. " "Nested layers may be selected using dot-notation, " "e.g. `block.subblock.layer1`" ), ) return c
[docs] def validate_attributes(self): for hparam in FeatureExtractable.configs().names: attrgetter(f"hparams.{hparam}")(self)
[docs] def on_init_end(self, hparams, *args, **kwargs): if not self.hparams.extract_features_after_layer: return available_layers = [k for k, _ in self.named_modules() if k != ""] assert self.hparams.extract_features_after_layer in available_layers, ( f"Invalid `extract_features_after_layer` ({self.hparams.extract_features_after_layer}). " f"Available layers are: {available_layers}" ) layer = rgetattr(self, self.hparams.extract_features_after_layer) self.extracted_features = [] def store_features(sself, input, output): nonlocal self for o in output: self.extracted_features.append(o.detach().cpu().numpy()) layer.register_forward_hook(store_features)
[docs] def metrics_epoch( self, preds: torch.Tensor, targets: torch.Tensor, prefix: str = None, clear_extracted_features=True, *args, **kwargs, ) -> MetricDict: if not hasattr(self, "extracted_features"): return {} # Save save_path = bump_version( Path(get_log_dir(self)) / "features" / (prefix or "") / f"{self.hparams.extract_features_after_layer.replace('.','_')}.npy" ) save_path.parent.mkdir(parents=True, exist_ok=True) logger.info(f"💾 Saving extracted features to {str(save_path)}") numpy.save(save_path, self.extracted_features) if clear_extracted_features and prefix != "test": self.extracted_features = [] return {}

© Copyright Copyright (c) 2020-2023, Lukas Hedegaard. Revision cce91b78.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.