Source code for tdg.analysis.bootstrap

import torch
from tdg.h5 import H5able
from tdg.performance import Timer

from functools import lru_cache as cached

import logging
logger = logging.getLogger(__name__)

[docs]class Bootstrap(H5able): r''' The bootstrap is a resampling technique for estimating uncertainties. For samples with weights :math:`w` the expectation value of an observable is .. math:: \left\langle O \right\rangle = \frac{\left\langle O w \right\rangle}{\left\langle w \right\rangle} and an accurate bootstrap estimate of the left-hand side requires tracking the correlations between the numerator and denominator. Moreover, quoting correlated uncertainties requires resampling different observables in the same way. Parameters ---------- ensemble: :class:`~.GrandCanonical` or :class:`~.Binned` The ensemble to resample. draws: int The number of times to resample. Any observable that :code:`ensemble` supports can be called from the :class:`~.Bootstrap`. :class:`~.Bootstrap` uses :code:`__getattr__` trickery under the hood to intercept calls and perform the weighted average transparently. Each observable returns a :code:`torch.tensor` of the same dimension as the ensemble's observable. However, rather than configurations first, :code:`draws` are first. Each draw is a weighted average over the resampled weight, as shown above, and is therefore an estimator for the expectation value. These are guaranteed (by the `central limit theorem`_) to be normally distributed. To get an uncertainty estimate one need only take the :code:`.mean()` for a central value and :code:`.std()` for the uncertainty on the mean. .. _central limit theorem: https://en.wikipedia.org/wiki/Central_limit_theorem ''' _observables = set() # The _observables are populated by the @observable and @derived decorators. _intermediates = set() # The _intermediates are populated by the @intermediate and @derived_intermediate decorator. def __init__(self, ensemble, draws=100): self.Ensemble = ensemble r'''The ensemble from which to resample.''' self.Action = ensemble.Action r'''The action underlying the ensemble.''' self.draws = draws r'''The number of resamplings.''' cfgs = ensemble.configurations.shape[0] self.indices = torch.randint(0, cfgs, (cfgs, draws)) r'''The random draws themselves; configurations × draws.''' def __len__(self): return self.draws
[docs] def measure(self, *observables): r''' Compute each :ref:`@observable <observables>` and @derived quantity in `observables`; log an error for any :ref:`unregistered observable <custom observables>` or derived quantity. Parameters ---------- observables: strings Names of observables or derived quantities. Returns ------- :class:`~.Bootstrap`; itself, now with some observables and derived quantities bootstrapped. .. note:: If no `observables` are passed, bootstraps **every** registered `@observable` and `@derived` quantity. ''' if not observables: observables = self._observables with Timer(logger.info, f'Bootstrap of {len(self)} draws', per=len(self)): for observable in observables: if observable not in self._observables | self._intermediates: logger.error(f'No registered observable "{observable}"') continue try: getattr(self, observable) except AttributeError as error: logger.error(str(error)) return self
def _resample(self, obs): # Each observable should be multiplied by its respective weight. # Each draw should be divided by its average weight. w = self.Ensemble.weights[self.indices] # This index ordering is needed to broadcast the weights division correctly. # See https://github.com/evanberkowitz/two-dimensional-gasses/issues/55 # We return the bootstrap axis to the front to provide an analogous interface for Bootstrap and GrandCanonical quantities. return torch.einsum('...d->d...', torch.einsum('cd,cd...->c...d', w, obs[self.indices]).mean(axis=0) / w.mean(axis=0)) def __getattr__(self, name): if name not in self._observables | self._intermediates | self.Ensemble._intermediates : raise AttributeError(name) with Timer(logger.info, f'Bootstrapping {name}', per=len(self)): try: forward = self.Ensemble.__getattribute__(name) except: forward = self.Ensemble.__getattr__(name) ## I believe this code block is a remnant from callable observables. ## Callable observables were eliminated in #58 https://github.com/evanberkowitz/two-dimensional-gasses/pull/58 ## to keep the interface much simpler. ## ## Normally I'd just delete this block of code. But since I'm not POSITIVE this __getattr__ isn't doing something ## else clever I'll just comment it out for now and remove it permanently in the future. ## ## TODO: remove this if its absence hasn't caused a problem. ## If it is removed, we can also remove the importing of lru_cache as cached above. ## Marked for deletion on 21 July 2023. # # if callable(forward): # @cached # def curry(*args, **kwargs): # return self._resample(forward(*args, **kwargs)) # return curry # ## In fact, it may even be that the try/except above can go too; that was also about intercepting calls. self.__dict__[name] = self._resample(forward) return self.__dict__[name]