Source code for ergo.platforms.foretold

from dataclasses import dataclass
from typing import List, Union

import jax.numpy as np
import numpy as onp
import pandas as pd
import requests
import seaborn

from ergo.distributions.base import uniform

[docs]class Foretold: """Interface to Foretold""" def __init__(self, token=None): """token (string): Specify an authorization token (supports Bot tokens from Foretold)""" self.token = token self.api_url = ""
[docs] def get_question(self, id): """Retrieve a single question by its id""" question = ForetoldQuestion(id, self) question.refresh_question() return question
[docs] def get_questions(self, ids): """ Retrieve many questions by their ids ids (List[string]): List of foretold question ids (should be less than 500 per request) Returns: List of questions corresponding to the ids, or None for questions that weren't found.""" measurables = self._query_measurables(ids) return [ ForetoldQuestion(measurable["id"], self, measurable) if measurable else None for measurable in measurables ]
def _post(self, json_data): """Send a json post request to the foretold API, with proper authorization""" headers = {} if self.token is not None: headers["Authorization"] = f"Bearer {self.token}" response =, json=json_data, headers=headers) response.raise_for_status() return response.json() def _query_measurable(self, id): """Retrieve data from api about single question by its id""" response = self._post( { "variables": {"measurableId": id}, "query": """query ($measurableId: String!) { measurable(id:$measurableId) { id channelId previousAggregate { value { floatCdf { xs ys } } } } }""", } ) return response["data"]["measurable"] def _query_measurables(self, ids): """Retrieve data from api about many question by a list of ids""" if len(ids) > 500: # If we want to implement this later, # we can properly use the pageInfo in the request raise NotImplementedError( "We haven't implemented support for more than 500 ids per request" ) response = self._post( { "variables": {"measurableIds": ids}, "query": """query ($measurableIds: [String!]) { measurables(measurableIds: $measurableIds, first: 500) { total pageInfo { hasPreviousPage hasNextPage startCursor endCursor __typename } edges { node { id channelId previousAggregate { value { floatCdf { xs ys } } } } } } }""", } ) if "errors" in response: raise ValueError( "Error retrieving foretold measurables. You may not have authorization " "to load one or more measurables, or one of the measureable ids may be incorrect" ) if response["data"]["measurables"]["pageInfo"]["hasNextPage"]: raise NotImplementedError( "We haven't implemented support for more than 500 ids per request" ) measurables_dict = {} for edge in response["data"]["measurables"]["edges"]: measureable = edge["node"] measurables_dict[measureable["id"]] = measureable return [measurables_dict.get(id, None) for id in ids] def create_measurement( self, measureable_id: str, cdf: "ForetoldCdf" ) -> requests.Response: if self.token is None: raise Exception("A token is required to submit a prediction") if len(cdf) > 1000: raise Exception("Maximum CDF length of 1000 exceeded") headers = {"Authorization": f"Bearer {self.token}"} query = _measurement_query(measureable_id, cdf) response =, json={"query": query}, headers=headers) return response
[docs]class ForetoldQuestion: """"Information about foretold question, including aggregated distribution""" def __init__(self, id, foretold, data=None): """ Should not be called directly, instead use Foretold.get_question id: measurableId, the second id in the URL for a foretold question foretold: Foretold api data: Data retrieved from the foretold api """ = id self.foretold = foretold self.floatCdf = None self.channelId = None if data is not None: self._update_from_data(data) def _update_from_data(self, data): """Update based on a dictionary of data from Foretold""" try: self.channelId = data["channelId"] except (KeyError, TypeError): raise ValueError("Foretold data missing or invalid") # If floatCdf is not available, we can just keep it as None try: self.floatCdf = data["previousAggregate"]["value"]["floatCdf"] except (KeyError, TypeError): self.floatCdf = None def refresh_question(self): # previousAggregate is the most recent aggregated distribution try: measurable = self.foretold._query_measurable( self._update_from_data(measurable) except ValueError: raise ValueError(f"Error loading distribution {} from Foretold") @property def url(self): return f"{self.channelId}/m/{}" @property def community_prediction_available(self): return self.floatCdf is not None def get_float_cdf_or_error(self): if not self.community_prediction_available: raise ValueError("No community prediction available") return self.floatCdf
[docs] def quantile(self, q): """Quantile of distribution""" floatCdf = self.get_float_cdf_or_error() return onp.interp(q, floatCdf["ys"], floatCdf["xs"])
[docs] def sample_community(self): """Sample from CDF""" y = uniform() return np.array(self.quantile(y))
[docs] def plotCdf(self): """Plot the CDF""" floatCdf = self.get_float_cdf_or_error() seaborn.lineplot(floatCdf["xs"], floatCdf["ys"])
[docs] def submit_from_samples( self, samples: Union[np.ndarray, pd.Series], length: int = 20 ) -> requests.Response: """Submit a prediction to Foretold based on the given samples :param samples: Samples on which to base the submission :param length: The length of the CDF derived from the samples """ cdf = ForetoldCdf.from_samples(samples, length) return self.foretold.create_measurement(, cdf)
@dataclass class ForetoldCdf: xs: List[float] ys: List[float] @staticmethod def from_samples( samples: Union[np.ndarray, pd.Series], length: int ) -> "ForetoldCdf": """Build a Foretold CDF representation from an array of samples See the following for details: :param samples: Samples from which to build the CDF :param length: The length of returned CDF """ if length < 2: raise ValueError("`length` must be at least 2") hist, bin_edges = onp.histogram(samples, bins=length - 1, density=True) # type: ignore bin_width = bin_edges[1] - bin_edges[0] # Foretold expects `0 <= ys <= 1`, so we clip to that . This # is defensive -- at the time of implementation it isn't known # how the API handles violations of this. ys = np.clip(np.hstack([onp.array([0.0]), onp.cumsum(hist) * bin_width]), 0, 1) # type: ignore return ForetoldCdf(bin_edges.tolist(), ys.tolist()) # type: ignore def __len__(self): return len(self.xs) def _measurement_query(measureable_id: str, cdf: ForetoldCdf) -> str: return f"""mutation {{ measurementCreate( input: {{ value: {{ floatCdf: {{ xs: {cdf.xs}, ys: {cdf.ys} }} }} competitorType: COMPETITIVE measurableId: "{measureable_id}" }} ) {{ id }} }} """