vault-k8s

Vault

  • Canonical Telco
Channel Revision Published Runs on
latest/edge 89 31 Jan 2024
Ubuntu 22.04 Ubuntu 20.04
latest/edge 9 27 Jan 2023
Ubuntu 22.04 Ubuntu 20.04
1.16/stable 323 20 Jan 2025
Ubuntu 22.04
1.16/candidate 323 20 Jan 2025
Ubuntu 22.04
1.16/beta 323 20 Jan 2025
Ubuntu 22.04
1.16/edge 326 20 Jan 2025
Ubuntu 22.04
1.15/stable 248 24 Jul 2024
Ubuntu 22.04
1.15/candidate 248 24 Jul 2024
Ubuntu 22.04
1.15/beta 248 24 Jul 2024
Ubuntu 22.04
1.15/edge 248 10 Jul 2024
Ubuntu 22.04
juju deploy vault-k8s --channel 1.16/stable
Show information

Platform:

charms.vault_k8s.v0.vault_client

#!/usr/bin/env python3
# Copyright 2023 Canonical Ltd.
# Licensed under the Apache2.0. See LICENSE file in charm source for details.

"""Library for interacting with a Vault cluster.

This library shares operations that interact with Vault through its API. It is
intended to be used by charms that need to manage a Vault cluster.
"""

import logging
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum
from io import IOBase
from typing import List, MutableMapping, Protocol

import hvac
import requests
from hvac.exceptions import Forbidden, InternalServerError, InvalidPath, InvalidRequest, VaultError
from requests.exceptions import ConnectionError, RequestException

# The unique Charmhub library identifier, never change it
LIBID = "674754a3268d4507b749ec34214706fd"

# Increment this major API version when introducing breaking changes
LIBAPI = 0

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 25


RAFT_STATE_ENDPOINT = "v1/sys/storage/raft/autopilot/state"


class LogAdapter(logging.LoggerAdapter):
    """Adapter for the logger to prepend a prefix to all log lines."""

    prefix = "vault_client"

    def process(self, msg: str, kwargs: MutableMapping) -> tuple[str, MutableMapping]:
        """Decides the format for the prepended text."""
        return f"[{self.prefix}] {msg}", kwargs


logger = LogAdapter(logging.getLogger(__name__), {})


@dataclass
class Token:
    """Class that represents token authentication for vault.

    This method is the most basic and always available method to access vault.
    """

    token: str

    def login(self, client: hvac.Client):
        """Authenticate a vault client with a token."""
        client.token = self.token


@dataclass
class AppRole:
    """Class that represents approle authentication for vault.

    This method is primarily used to authenticate automation programs for vault.
    """

    role_id: str
    secret_id: str

    def login(self, client: hvac.Client):
        """Authenticate a vault client with approle details."""
        client.auth.approle.login(role_id=self.role_id, secret_id=self.secret_id, use_token=True)


class AuthMethod(Protocol):
    """Classes that implement a login method are auth methods used to log in to Vault."""

    @abstractmethod
    def login(self, client: hvac.Client) -> None:
        """Log in using the given method."""
        raise NotImplementedError


@dataclass
class Certificate:
    """Class that represents a certificate generated by the PKI secrets engine."""

    certificate: str
    ca: str
    chain: List[str]


class AuditDeviceType(Enum):
    """Class that represents the devices that vault supports as device types for audit."""

    FILE = "file"
    SYSLOG = "syslog"
    SOCKET = "socket"


class SecretsBackend(Enum):
    """Class that represents the supported secrets backends by Vault."""

    KV_V2 = "kv-v2"
    PKI = "pki"
    TRANSIT = "transit"


class VaultClientError(Exception):
    """Base class for exceptions raised by the Vault client."""


class VaultClient:
    """Class to interact with Vault through its API."""

    def __init__(self, url: str, ca_cert_path: str | None):
        self._client = hvac.Client(url=url, verify=ca_cert_path if ca_cert_path else False)

    def authenticate(self, auth_details: AuthMethod) -> bool:
        """Find and use the token related with the given auth method.

        Returns:
            bool: True if the authentication was successful and the token was accepted by vault.
        """
        try:
            auth_details.login(self._client)
            self._client.auth.token.lookup_self()
        except (VaultError, ConnectionError, Forbidden) as e:
            logger.warning("Failed login to Vault: %s", e)
            return False
        return True

    @property
    def token(self) -> str:
        """Return the token used to authenticate with Vault."""
        return self._client.token

    def is_api_available(self) -> bool:
        """Return whether Vault is available."""
        try:
            self._client.sys.read_health_status(standby_ok=True)
            return True
        except (VaultError, RequestException) as e:
            logger.error("Error while checking Vault health status: %s", e)
            return False

    def is_initialized(self) -> bool:
        """Return whether Vault is initialized."""
        return self._client.sys.is_initialized()

    def is_sealed(self) -> bool:
        """Return whether Vault is sealed."""
        try:
            return self._client.sys.is_sealed()
        except VaultError as e:
            # This seems to happen if the seal status is checked immediately
            # after initializing the vault when autounseal is enabled.
            # There is a short period of time where the vault is initialized,
            # but hasn't finished setting up the autounseal configuration yet,
            # and the server will return an internal server error during this
            # period.
            #
            # hvac.exceptions.InternalServerError:
            # core: barrier reports initialized but no seal configuration found
            logger.error("Error while checking Vault seal status: %s", e)
            raise VaultClientError(e) from e

    def read(self, path: str) -> dict:
        """Read the data at the given path."""
        try:
            data = self._client.read(path)
        except VaultError as e:
            logger.error("Error while writing data to %s: %s", path, e)
            return {}
        if data is None:
            return {}
        if isinstance(data, requests.Response):
            data = data.json()
        return data.get("data", {})

    def write(self, path: str, data: dict) -> bool:
        """Write the data at the given path."""
        try:
            response = self._client.write_data(path, data=data)
        except VaultError as e:
            logger.error("Error while writing data to %s: %s", path, e)
            return False
        logger.info("Wrote data to %s: %s", path, response)
        return True

    def list(self, path: str) -> List[str]:
        """List the keys at the given path."""
        try:
            data = self._client.list(path)
        except VaultError as e:
            logger.error("Error while listing keys at %s: %s", path, e)
            return []
        if data is None:
            return []
        if isinstance(data, requests.Response):
            data = data.json()
        try:
            return data["data"]["keys"]
        except KeyError:
            return []

    def needs_migration(self) -> bool:
        """Return true if the vault needs to be migrated, false otherwise."""
        return self._client.seal_status["migration"]  # type: ignore -- bad type hint in stubs

    def get_seal_type(self) -> str:
        """Return the seal type of the Vault."""
        return self._client.seal_status["type"]  # type: ignore -- bad type hint in stubs

    def is_seal_type_transit(self) -> bool:
        """Return whether Vault is sealed by the transit backend."""
        return "transit" == self.get_seal_type()

    def is_active(self) -> bool:
        """Return whether the Vault node is active or not.

        Returns:
            True if initialized, unsealed and active, False otherwise.
        """
        try:
            health_status = self._client.sys.read_health_status()
            return health_status.status_code == 200
        except (VaultError, RequestException) as e:
            logger.error("Error while checking Vault health status: %s", e)
            return False

    def is_active_or_standby(self) -> bool:
        """Return the health status of Vault.

        Returns:
            True if initialized, unsealed and active or standby, False otherwise.
        """
        try:
            health_status = self._client.sys.read_health_status()
            return health_status.status_code == 200 or health_status.status_code == 429
        except (VaultError, RequestException) as e:
            logger.error("Error while checking Vault health status: %s", e)
            return False

    def enable_audit_device(self, device_type: AuditDeviceType, path: str) -> None:
        """Enable a new audit device at the supplied path if it isn't already enabled.

        Args:
            device_type: One of three available device types
            path: The path that will receive audit logs
        """
        try:
            self._client.sys.enable_audit_device(
                device_type=device_type.value,
                options={"file_path": path},
            )
            logger.info("Enabled audit device `%s` for path `%s`", device_type.value, path)
        except InvalidRequest as e:
            if not e.json or not isinstance(e.json, dict):
                raise VaultClientError(e) from e
            errors = e.json.get("errors", [])
            if len(errors) == 1 and errors[0].startswith("path already in use"):
                logger.info("Audit device already enabled.")
            else:
                raise VaultClientError(e) from e
        except VaultError as e:
            raise VaultClientError(e) from e

    def enable_approle_auth_method(self) -> None:
        """Enable approle auth method if it isn't already enabled."""
        try:
            self._client.sys.enable_auth_method("approle")
            logger.info("Enabled approle auth method.")
        except InvalidRequest as e:
            if not e.json or not isinstance(e.json, dict):
                raise VaultClientError(e) from e
            errors = e.json.get("errors", [])
            if len(errors) == 1 and errors[0].startswith("path is already in use"):
                logger.info("Approle already enabled.")
            else:
                raise VaultClientError(e) from e
        except VaultError as e:
            raise VaultClientError(e) from e

    def create_or_update_policy_from_file(
        self, name: str, path: str, **formatting_args: str
    ) -> None:
        """Create/update a policy within vault, using the file contents as the policy.

        Args:
            name: Name of the policy to create
            path: The path of the file where the policy is defined, ending with .hcl
            **formatting_args: Additional arguments to format the policy
        """
        # TODO: Remove this method when it is no longer needed. Prefer create_or_update_policy.
        with open(path, "r") as f:
            policy = f.read()
        try:
            self._client.sys.create_or_update_policy(
                name=name,
                policy=policy if not formatting_args else policy.format(**formatting_args),
            )
        except VaultError as e:
            raise VaultClientError(e) from e
        logger.debug("Created or updated charm policy: %s", name)

    def create_or_update_policy(self, name: str, content: str) -> None:
        """Create/update a policy within vault.

        Args:
            name: Name of the policy to create
            content: The policy content
        """
        try:
            self._client.sys.create_or_update_policy(name=name, policy=content)
        except VaultError as e:
            raise VaultClientError(e) from e
        logger.debug("Created or updated charm policy: %s", name)

    def create_or_update_approle(
        self,
        name: str,
        token_ttl: str | None = None,
        token_max_ttl: str | None = None,
        policies: List[str] | None = None,
        cidrs: List[str] | None = None,
        token_period: str | None = None,
    ) -> str:
        """Create/update a role within vault associating the supplied policies.

        Args:
            name: Name of the role to be created or updated
            policies: The attached list of policy names this approle will have access to
            token_ttl: Incremental lifetime for generated tokens, provided as a duration string such as "5m"
            token_max_ttl: Maximum lifetime for generated tokens, provided as a duration string such as "5m"
            token_period: The period within which the token must be renewed. See Vault documentation for more information.
            cidrs: The list of IP networks that are allowed to authenticate
        """
        self._client.auth.approle.create_or_update_approle(
            name,
            bind_secret_id="true",
            token_ttl=token_ttl,
            token_max_ttl=token_max_ttl,
            token_policies=policies,
            token_bound_cidrs=cidrs,
            token_period=token_period,
        )
        response = self._client.auth.approle.read_role_id(name)
        return response["data"]["role_id"]

    def generate_role_secret_id(self, name: str, cidrs: List[str] | None = None) -> str:
        """Generate a new secret tied to an AppRole."""
        response = self._client.auth.approle.generate_secret_id(name, cidr_list=cidrs)
        return response["data"]["secret_id"]

    def read_role_secret(self, name: str, id: str) -> dict:
        """Get definition of a secret tied to an AppRole."""
        response = self._client.auth.approle.read_secret_id(name, id)
        return response["data"]

    def enable_secrets_engine(self, backend_type: SecretsBackend, path: str) -> None:
        """Enable given secret engine on the given path."""
        try:
            self._client.sys.enable_secrets_engine(
                backend_type=backend_type.value,
                description=f"Charm created '{backend_type.value}' backend",
                path=path,
            )
            logger.info("Enabled %s backend", backend_type.value)
        except InvalidRequest as e:
            # TODO: Fix the type stubs for hvac to properly identify the json attribute
            if not e.json or not isinstance(e.json, dict):
                raise VaultClientError(e) from e
            errors = e.json.get("errors", [])
            if len(errors) == 1 and errors[0].startswith("path is already in use"):
                logger.info("%s backend already enabled", backend_type.value)
            else:
                raise VaultClientError(e) from e

    def disable_secrets_engine(self, path: str) -> None:
        """Disable the secret engine at the given path."""
        try:
            self._client.sys.disable_secrets_engine(path)
            logger.info("Disabled secret engine at %s", path)
        except InvalidPath:
            logger.info("Secret engine at `%s` is already disabled", path)

    def get_intermediate_ca(self, mount: str) -> str:
        """Get the intermediate CA for the PKI backend."""
        return self._client.secrets.pki.read_ca_certificate(mount_point=mount)

    def import_ca_certificate_and_key(self, mount: str, certificate: str, private_key: str):
        """Import the CA certificate and private key for the PKI backend."""
        pem_bundle = generate_pem_bundle(certificate=certificate, private_key=private_key)
        self._client.secrets.pki.submit_ca_information(
            pem_bundle=pem_bundle,
            mount_point=mount,
        )

    def sign_pki_certificate_signing_request(
        self,
        mount: str,
        role: str,
        csr: str,
        common_name: str,
        ttl: str,
    ) -> Certificate | None:
        """Sign a certificate signing request for the PKI backend.

        Args:
            mount: The PKI mount point.
            role: The role to use for signing the certificate.
            csr: The certificate signing request.
            common_name: The common name for the certificate.
            ttl: The relative validity for the certificate.
                Should be a string in the format of a number with a unit such as
                "120m", "10h" or "90d".

        Returns:
            Certificate: The signed certificate object
        """
        try:
            response = self._client.secrets.pki.sign_certificate(
                csr=csr,
                mount_point=mount,
                common_name=common_name,
                name=role,
                extra_params={"ttl": ttl},
            )
            logger.info("Signed a PKI certificate for %s", common_name)
            return Certificate(
                certificate=response["data"]["certificate"],
                ca=response["data"]["issuing_ca"],
                chain=response["data"]["ca_chain"],
            )
        except InvalidRequest as e:
            logger.warning("Error while signing PKI certificate: %s", e)
            return None

    def create_or_update_pki_charm_role(
        self, role: str, allowed_domains: str, max_ttl: str, mount: str
    ) -> None:
        """Create a role for the PKI backend or update it if it already exists.

        Args:
            role: The name of the role to create or update.
            allowed_domains: The list of allowed domains for the role.
            max_ttl: The maximum TTL for the role.
                It is also used by Vault as a maximum validity for the certificates issued by this role.
                Should be a string in the format of a number with a unit such as
                "120m", "10h" or "90d".
            mount: The mount point of the PKI backend for which the role will be created.
        """
        self._client.secrets.pki.create_or_update_role(
            name=role,
            mount_point=mount,
            extra_params={
                "allowed_domains": allowed_domains,
                "allow_subdomains": True,
                "max_ttl": max_ttl,
            },
        )
        logger.info(
            "Created or updated PKI role `%s` with `allowed_domains=%s` and `max_ttl=%s`",
            role,
            allowed_domains,
            max_ttl,
        )

    def is_pki_role_created(self, role: str, mount: str) -> bool:
        """Check if the role is created for the PKI backend."""
        try:
            existing_roles = self._client.secrets.pki.list_roles(mount_point=mount)
            return role in existing_roles["data"]["keys"]
        except InvalidPath:
            return False

    def create_snapshot(self) -> requests.Response:
        """Create a snapshot of the Vault data."""
        return self._client.sys.take_raft_snapshot()

    def restore_snapshot(self, snapshot: IOBase) -> None:
        """Restore a snapshot of the Vault data.

        Uses force_restore_raft_snapshot to restore the snapshot
        even if the unseal key used at backup time is different from the current one.
        """
        response = self._client.sys.force_restore_raft_snapshot(snapshot)
        if not 200 <= response.status_code < 300:
            logger.warning("Error while restoring snapshot: %s", response.text)
            raise VaultClientError(f"Error while restoring snapshot: {response.text}")

    def get_raft_cluster_state(self) -> dict:
        """Get raft cluster state."""
        response = self._client.adapter.get(RAFT_STATE_ENDPOINT)
        return response["data"]

    def is_raft_cluster_healthy(self) -> bool:
        """Check if raft cluster is healthy."""
        return self.get_raft_cluster_state()["healthy"]

    def remove_raft_node(self, id: str) -> None:
        """Remove raft peer."""
        try:
            self._client.sys.remove_raft_node(server_id=id)
        except (InternalServerError, ConnectionError) as e:
            logger.warning("Error while removing raft node: %s", e)
            return
        logger.info("Removed raft node %s", id)

    def is_node_in_raft_peers(self, id: str) -> bool:
        """Check if node is in raft peers."""
        try:
            raft_config = self._client.sys.read_raft_config()
        except (InternalServerError, ConnectionError) as e:
            logger.warning("Error while reading raft config: %s", e)
            return False
        for peer in raft_config["data"]["config"]["servers"]:
            if peer["node_id"] == id:
                return True
        return False

    def get_num_raft_peers(self) -> int:
        """Return the number of raft peers."""
        try:
            raft_config = self._client.sys.read_raft_config()
        except (InternalServerError, ConnectionError) as e:
            logger.warning("Error while reading raft config: %s", e)
            return 0
        return len(raft_config["data"]["config"]["servers"])

    def is_common_name_allowed_in_pki_role(self, role: str, mount: str, common_name: str) -> bool:
        """Return whether the provided common name is in the list of domains allowed by the specified PKI role."""
        try:
            return common_name in self._client.secrets.pki.read_role(
                name=role, mount_point=mount
            ).get("data", {}).get("allowed_domains", [])
        except InvalidPath:
            logger.warning("Role does not exist on the specified path.")
            return False

    def get_role_max_ttl(self, role: str, mount: str) -> int | None:
        """Get the max ttl for the specified PKI role in seconds."""
        try:
            return (
                self._client.secrets.pki.read_role(name=role, mount_point=mount)
                .get("data", {})
                .get("max_ttl")
            )
        except InvalidPath:
            logger.warning("Role does not exist on the specified path.")
            return None

    def list_pki_issuers(self, mount: str) -> List[str]:
        """Get the list of issuers for the PKI backend.

        Args:
            mount: The mount point of the PKI backend.

        Returns:
            The list of issuers (i.e. ["issuer1", "issuer2"]).
        """
        try:
            return self._client.secrets.pki.list_issuers(mount_point=mount)["data"]["keys"]
        except (InvalidPath, KeyError) as e:
            logger.error("No issuers found on the specified path: %s", e)
            raise VaultClientError("No issuers found on the specified path.")

    def create_transit_key(self, mount_point: str, key_name: str) -> None:
        """Create a new key in the transit backend."""
        response = self._client.secrets.transit.create_key(mount_point=mount_point, name=key_name)
        logger.debug("Created a new transit key. response=%s", response)

    def delete_role(self, name: str) -> None:
        """Delete the approle with the given name."""
        return self._client.auth.approle.delete_role(name)

    def delete_policy(self, name: str) -> None:
        """Delete the policy with the given name."""
        return self._client.sys.delete_policy(name)


def generate_pem_bundle(certificate: str, private_key: str) -> str:
    """Generate a PEM bundle from a certificate and private key."""
    return f"{certificate}\n{private_key}"