Source code for profile_photo.utils.aws.client_cache
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass, InitVar
from typing import ClassVar
from boto3 import Session, client
from botocore.client import BaseClient
from botocore.config import Config
from botocore.exceptions import UnknownServiceError
[docs]@dataclass
class ClientCache:
# boto3 service name
SERVICE_NAME: ClassVar[str] = None
# sub-classes can specify that client objects should be thread-safe
THREAD_SAFE: ClassVar[bool] = False
_clients: ClassVar[defaultdict] = defaultdict(dict)
# The specified region that is bound to a client. Default to 'us-east-1'.
region_name: str = 'us-east-1'
# Optional AWS profile name.
profile_name: str | None = None
# When enabled, automatically initializes the client for the region; this
# is useful when service requests will be made in multiple threads and it
# is desirable to reuse the same client between threads.
init_client: InitVar[bool] = False
# Maximum pool connections for multi-thread usage
# Ref: https://stackoverflow.com/a/68760777/10237506
max_pool_connections: int | None = None
def __post_init__(self, init_client: bool):
self.region_name = self.region_name.lower()
if init_client:
_ = self.client
@property
def client(self):
return self._get_client(self.region_name)
@client.setter
def client(self, value):
raise Exception('Member read-only')
def _get_client(self, region_name):
"""
Internal method to return a low-level SecretManager client for a given region name
"""
if region_name not in self._clients[self.SERVICE_NAME]:
self._clients[self.SERVICE_NAME][region_name] = self._create_client()
return self._clients[self.SERVICE_NAME][region_name]
def _create_client(self) -> BaseClient:
if not self.SERVICE_NAME:
raise ValueError(
'Sub-classes must provide a value for "SERVICE_NAME"')
try:
if self.THREAD_SAFE or self.profile_name:
client_func = Session(profile_name=self.profile_name).client
else:
client_func = client
client_kwargs = {}
if self.max_pool_connections:
client_kwargs['config'] = Config(
max_pool_connections=self.max_pool_connections)
return client_func(self.SERVICE_NAME, self.region_name,
**client_kwargs)
except UnknownServiceError:
raise NotImplementedError(
'Sub-classes must override this method')