Source code for prescient_sdk.client

import datetime
import logging
import urllib.parse
from pathlib import Path

import boto3
import google.auth.transport.requests
import google.oauth2.credentials
import msal
import requests
from google_auth_oauthlib.flow import InstalledAppFlow

from prescient_sdk.config import Settings

logger = logging.getLogger("prescient_sdk")


[docs] class PrescientClient: """ Client for interacting with the Prescient API. This client is used to authenticate, and for obtaining bucket credentials using the authenticated token. The client also provides helpers such as provding the STAC URL for the Prescient API, and authentication headers for making requests to the STAC API. Token expiration is handled automatically for long running instances of the client (in a notebook for example). Configuration Options: 1. Construct the `prescient_sdk.config.Settings` object directly 2. Specify the path to an environment file containing configuration values 3. Do neither, and allow the client to build the `Settings` object using default methods (env variables, config.env file location in the working directory) Note that you cannot specify the env_file location AND provide a `Settings` object. Args: env_file (str | Path, optional): Path to a configuration file. Defaults to None. settings (Settings, optional): Configuration settings for the client. Defaults to None. Raises: ValueError: If both an environment file and a settings object are provided. ValueError: If the provided configuration file is not found. """ def __init__( self, env_file: str | Path | None = None, settings: Settings | None = None, ): if env_file and settings: raise ValueError( "Cannot provide both an environment file and a settings object" ) if env_file: env_file = Path(env_file) if env_file.exists(): logger.info(f"Loading configuration variables from {env_file}") else: raise ValueError(f"Configuration file not found: {env_file}") # default configuration values are set in the Settings class (prescient_sdk.config.py) if settings is None: if env_file: settings = Settings(_env_file=env_file) # type: ignore else: # if no env file is present, we use default settings # which can be sourced from a config.env file in the working # directory, or env variables settings = Settings() # type: ignore self.settings: Settings = settings self._expiration_duration = 1 * 60 * 60 # Fixed to 1hr # initialize empty credentials self._auth_credentials: dict = {} self._bucket_credentials: dict = {} self._upload_bucket_credentials: dict = {} @property def stac_catalog_url(self) -> str: """ Get the STAC URL. Returns: str: The STAC URL. """ return urllib.parse.urljoin(self.settings.prescient_endpoint_url, "stac") def _fetch_microsoft_credentials(self) -> dict: """Acquire or refresh credentials using Microsoft MSAL. Returns: dict: Raw MSAL token response containing ``id_token`` and ``refresh_token``. """ authority_url = urllib.parse.urljoin( self.settings.prescient_auth_url, self.settings.prescient_tenant_id ) app = msal.PublicClientApplication( client_id=self.settings.prescient_client_id, authority=authority_url ) if ( not self._auth_credentials or "refresh_token" not in self._auth_credentials.keys() ): return app.acquire_token_interactive( scopes=["https://graph.microsoft.com/.default"] ) else: return app.acquire_token_by_refresh_token( refresh_token=self._auth_credentials["refresh_token"], scopes=["https://graph.microsoft.com/.default"], ) def _fetch_google_credentials(self) -> dict: """Acquire or refresh credentials using Google OAuth2. Uses ``google-auth-oauthlib`` for the interactive browser flow and ``google-auth`` for silent token refresh. The returned dict is normalized to include the same ``id_token`` and ``refresh_token`` keys used by the Microsoft flow so that all downstream code is unaffected. Returns: dict: Normalized credential dict containing ``id_token``, ``refresh_token``, and ``access_token``. """ token_uri = urllib.parse.urljoin( self.settings.prescient_auth_url, "/o/oauth2/token" ) scopes = ["openid", "https://www.googleapis.com/auth/userinfo.email"] if not self._auth_credentials or "refresh_token" not in self._auth_credentials: flow = InstalledAppFlow.from_client_config( client_config={ "installed": { "client_id": self.settings.prescient_client_id, "client_secret": self.settings.prescient_google_client_secret, "auth_uri": urllib.parse.urljoin( self.settings.prescient_auth_url, "/o/oauth2/auth" ), "token_uri": token_uri, } }, scopes=scopes, ) credentials = flow.run_local_server( port=self.settings.prescient_google_redirect_port ) else: credentials = google.oauth2.credentials.Credentials( token=None, refresh_token=self._auth_credentials["refresh_token"], token_uri=token_uri, client_id=self.settings.prescient_client_id, client_secret=self.settings.prescient_google_client_secret, ) credentials.refresh(google.auth.transport.requests.Request()) return { "id_token": credentials.id_token, "refresh_token": credentials.refresh_token, "access_token": credentials.token, } @property def auth_credentials(self) -> dict: """ Get the authorization credentials for the client. Returns: dict: Token response containing at minimum:: { "id_token": "string", "refresh_token": "string", "access_token": "string", } Raises: ValueError: If a valid id_token cannot be obtained. """ if not self.credentials_expired: return self._auth_credentials time_zero = datetime.datetime.now(datetime.timezone.utc) if self.settings.prescient_auth_provider == "google": self._auth_credentials = self._fetch_google_credentials() else: self._auth_credentials = self._fetch_microsoft_credentials() token: str = self._auth_credentials.get("id_token", "") if token == "": raise ValueError(f"Failed to obtain Auth token: {self._auth_credentials}") self._auth_credentials["expiration"] = time_zero + datetime.timedelta( seconds=self._expiration_duration ) return self._auth_credentials @property def headers(self) -> dict: """ Get headers for a request, including the auth header with a bearer token. Returns: dict: The headers. """ return { "Content-Type": "application/json", "Accept": "application/json", "Authorization": f"Bearer {self.auth_credentials['id_token']}", } def _get_bucket_credentials(self, role: str): access_token = self.auth_credentials.get("id_token") sts_client = boto3.client("sts", region_name=self.settings.prescient_aws_region) # assume arn string, otherwise last 10 characters of role string try: role_name_stub = role.split("/")[1] except IndexError: role_name_stub = role[-10:] role_session_name = f"prescient-s3-access-{role_name_stub}" # exchange token with aws temp creds response: dict = sts_client.assume_role_with_web_identity( DurationSeconds=self._expiration_duration, RoleArn=role, RoleSessionName=role_session_name, WebIdentityToken=access_token, ) credentials = response.get("Credentials", {}) if not credentials: raise ValueError(f"Failed to obtain creds: {response}") # convert datetime to UTC for later comparison credentials["Expiration"] = credentials["Expiration"].astimezone( datetime.timezone.utc ) return credentials @property def bucket_credentials(self): """Get bucket credentials using an auth access token Returns: dict: bucket temporary credentials:: { "AccessKeyId": "string", "SecretAccessKey": "string", "SessionToken": "string", "Expiration": datetime(2015, 1, 1) } Raises: ValueError: If the credentials response is empty """ if self._bucket_credentials and not self.credentials_expired: return self._bucket_credentials if self.settings.prescient_aws_role: self._bucket_credentials = self._fetch_sts_credentials() else: self._bucket_credentials = self._fetch_fileproxy_credentials() expiration = self._bucket_credentials["Expiration"] if isinstance(expiration, str): expiration = datetime.datetime.fromisoformat( expiration.replace("Z", "+00:00") ) self._bucket_credentials["Expiration"] = expiration.astimezone( datetime.timezone.utc ) return self._bucket_credentials def _fetch_sts_credentials(self) -> dict: """Exchange the auth id_token for AWS credentials via STS.""" access_token = self.auth_credentials.get("id_token") sts_client = boto3.client("sts", region_name=self.settings.prescient_aws_region) response: dict = sts_client.assume_role_with_web_identity( DurationSeconds=self._expiration_duration, RoleArn=self.settings.prescient_aws_role, RoleSessionName="prescient-s3-access", WebIdentityToken=access_token, ) creds = response.get("Credentials") if not creds: raise ValueError(f"Failed to obtain creds: {response}") return creds def _fetch_fileproxy_credentials(self) -> dict: """Fetch temporary bucket credentials from the Prescient fileproxy endpoint. The endpoint returns snake_case keys; map them to the PascalCase shape used by the rest of the client (matching the boto3 STS response). """ url = urllib.parse.urljoin( self.settings.prescient_endpoint_url, "fileproxy/credentials" ) response = requests.get(url, headers=self.headers) response.raise_for_status() payload = response.json() return { "AccessKeyId": payload["access_key_id"], "SecretAccessKey": payload["secret_access_key"], "SessionToken": payload["session_token"], "Expiration": payload["expiration"], } @property def upload_bucket_credentials(self): """Get upload bucket credentials using an auth access token Returns: dict: bucket temporary credentials:: { "AccessKeyId": "string", "SecretAccessKey": "string", "SessionToken": "string", "Expiration": datetime(2015, 1, 1) } Raises: ValueError: If the credentials response is empty """ if self._upload_bucket_credentials and not self.credentials_expired: return self._upload_bucket_credentials if not self.settings.prescient_upload_role: raise ValueError( "prescient_upload_role is not configured; set PRESCIENT_UPLOAD_ROLE " "to use the upload bucket." ) self._upload_bucket_credentials = self._get_bucket_credentials( role=self.settings.prescient_upload_role ) return self._upload_bucket_credentials @property def session(self) -> boto3.Session: """ Get an AWS session for authenticating to the bucket Returns: Session: boto3 Session object """ return boto3.Session( aws_access_key_id=self.bucket_credentials["AccessKeyId"], aws_secret_access_key=self.bucket_credentials["SecretAccessKey"], aws_session_token=self.bucket_credentials["SessionToken"], ) @property def upload_session(self) -> boto3.Session: """ Get an AWS session for authenticating to the upload bucket Returns: Session: boto3 Session object """ return boto3.Session( aws_access_key_id=self.upload_bucket_credentials["AccessKeyId"], aws_secret_access_key=self.upload_bucket_credentials["SecretAccessKey"], aws_session_token=self.upload_bucket_credentials["SessionToken"], region_name=self.settings.prescient_aws_region, ) @property def credentials_expired(self) -> bool: """Checks to see if the client credentials have expired. Note: if auth credentials have expired, all credentials are considered expired as they all depend on auth credentials. Returns: bool: True - credentials are expired, False - credentials have NOT expired. """ if "expiration" in self._auth_credentials and ( datetime.datetime.now(datetime.timezone.utc) < self._auth_credentials["expiration"] ): return False else: return True
[docs] def refresh_credentials(self, force=False): """ Will refresh all the client credentials. param force: If True will force the creds to be refreshed. Returns: None """ if force: self._auth_credentials.pop("expiration") _ = self.bucket_credentials # Will call self.auth_credentials if self.settings.prescient_upload_role: _ = self.upload_bucket_credentials