doc/.venv/Lib/site-packages/cerulean/ssh_terminal.py

247 lines
9.4 KiB
Python
Raw Normal View History

import logging
import socket
from time import perf_counter
from types import TracebackType
from typing import Any, List, Optional, Tuple, Type, TYPE_CHECKING
import paramiko
from cerulean.credential import (Credential, PasswordCredential,
PubKeyCredential)
from cerulean.terminal import Terminal
from cerulean.util import BaseExceptionType
logger = logging.getLogger(__name__)
class SshTerminal(Terminal):
"""A terminal that runs commands over SSH.
This terminal connects to a host using SSH, then lets you run \
commands there.
Arguments:
host: The hostname to connect to.
port: The port to connect on.
credential: The credential to authenticate with.
"""
def __init__(self, host: str, port: int, credential: Credential) -> None:
self.__host = host
self.__port = port
self.__credential = credential
self.__transport = self.__ensure_connection(None)
self.__transport2 = None # type: Optional[paramiko.Transport]
def __enter__(self) -> 'SshTerminal':
"""Enter context manager."""
return self
def __exit__(self, exc_type: Optional[BaseExceptionType],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType]) -> None:
"""Exit context manager."""
self.close()
def close(self) -> None:
"""Close the terminal.
This closes any connections and frees resources associated \
with the terminal.
"""
self.__transport.close()
logger.debug('Disconnected from SSH server')
def __eq__(self, other: Any) -> bool:
"""Returns True iff this terminal equals other."""
if not isinstance(other, Terminal):
return NotImplemented
if isinstance(other, SshTerminal):
return self.__host == other.__host and self.__port == other.__port
else:
return False
def _get_sftp_client(self) -> paramiko.SFTPClient:
"""Get an SFTP client using this terminal.
This function is used by SftpFileSystem to get an SFTP client \
using this Terminal's connection. This is a private function, \
but SftpFileSystem is a friend class.
Returns:
An SFTP client object using this terminal's connection.
"""
self.__transport = self.__ensure_connection(self.__transport)
client = paramiko.SFTPClient.from_transport(self.__transport)
if client is None:
raise RuntimeError('Could not open a channel for SFTP')
return client
def _get_downstream_sftp_client(self) -> paramiko.SFTPClient:
"""Gets a second SFTP client using this terminal.
This is a work-around for an issue in paramiko that keeps us \
from copying data upstream and downstream simultaneously \
through a single connection with reasonable performance. \
We solve it by opening a second connection for the downstream \
part.
Returns:
An SFTP client object using a second connection.
"""
self.__transport2 = self.__ensure_connection(self.__transport2)
client = paramiko.SFTPClient.from_transport(self.__transport2)
if client is None:
raise RuntimeError('Could not open a channel for SFTP')
return client
def run(self,
timeout: float,
command: str,
args: List[str],
stdin_data: str = None,
workdir: str = None) -> Tuple[Optional[int], str, str]:
if workdir:
cmd_str = 'cd {}; {} {}'.format(workdir, command, ' '.join(args))
else:
cmd_str = '{} {}'.format(command, ' '.join(args))
logger.debug('Executing %s', cmd_str)
last_exception = None # type: Optional[BaseException]
start_time = perf_counter()
while perf_counter() < start_time + timeout:
self.__transport = self.__ensure_connection(self.__transport)
try:
session = self.__transport.open_session()
logger.debug('Opened session')
session.exec_command(command=cmd_str)
logger.debug('exec_command done')
if stdin_data is not None:
session.sendall(bytes(stdin_data, 'utf-8'))
session.shutdown_write()
logger.debug('stdin sent')
got_all_stdout, stdout_text = self.__get_data_from_channel(
session, 'stdout', timeout)
got_all_stderr, stderr_text = self.__get_data_from_channel(
session, 'stderr', timeout)
logger.debug(
'got output %s %s %s %s', got_all_stdout, stdout_text,
got_all_stderr, stderr_text)
if not got_all_stdout or not got_all_stderr:
logger.debug('Command did not finish within timeout')
session.close()
return None, stdout_text, stderr_text
session.settimeout(2.0)
exit_status = session.recv_exit_status()
logger.debug('received exit status %s', exit_status)
session.close()
if exit_status == -1:
raise EOFError('Execution failed, connection'
' or server issue?')
logger.debug('Command executed successfully')
return exit_status, stdout_text, stderr_text
except paramiko.SSHException as e:
last_exception = e
except EOFError as e:
last_exception = e
except ConnectionError as e:
last_exception = e
except OSError as e:
if 'Socket' in str(e):
self.__ensure_connection(self.__transport, True)
last_exception = e
raise ConnectionError(str(last_exception))
def __get_data_from_channel(self, channel: paramiko.Channel,
stream_name: str,
timeout: float) -> Tuple[bool, str]:
"""Reads text from standard output or standard error."""
if stream_name == 'stdout':
receive = paramiko.Channel.recv
else:
receive = paramiko.Channel.recv_stderr
channel.settimeout(timeout)
data = bytearray()
try:
new_data = receive(channel, 1024 * 1024)
while len(new_data) > 0:
data.extend(new_data)
new_data = receive(channel, 1024 * 1024)
except socket.timeout:
return False, data.decode('utf-8')
return True, data.decode('utf-8')
def __get_key_from_file(self, filename: str,
passphrase: Optional[str]) -> paramiko.pkey.PKey:
key = None
messages = ''
try:
key = paramiko.ed25519key.Ed25519Key.from_private_key_file(
filename=filename, password=passphrase)
except paramiko.ssh_exception.SSHException as e:
key = None
messages += '{}; '.format(e)
if key is None:
try:
key = paramiko.ecdsakey.ECDSAKey.from_private_key_file(
filename=filename, password=passphrase)
except paramiko.ssh_exception.SSHException as e:
key = None
messages += '{}; '.format(e)
if key is None:
try:
key = paramiko.rsakey.RSAKey.from_private_key_file(
filename=filename, password=passphrase)
except paramiko.ssh_exception.SSHException as e:
key = None
messages += '{}; '.format(e)
if key is None:
logger.debug('Invalid key: %s', messages)
raise RuntimeError(
'Invalid key specified, could not open as RSA, ECDSA or'
' Ed25519 key'
)
return key
def __ensure_connection(self, transport: Optional[paramiko.Transport],
force: bool = False) -> paramiko.Transport:
if transport is None or not transport.is_active() or force:
if transport is not None:
transport.close()
transport = paramiko.Transport((self.__host, self.__port))
logger.info(
'Connecting to %s on port %s', self.__host, self.__port)
try:
if isinstance(self.__credential, PasswordCredential):
logger.debug('Authenticating using a password')
transport.connect(
username=self.__credential.username,
password=self.__credential.password)
elif isinstance(self.__credential, PubKeyCredential):
logger.debug('Authenticating using a public key')
key = self.__get_key_from_file(self.__credential.public_key,
self.__credential.passphrase)
transport.connect(username=self.__credential.username, pkey=key)
else:
raise RuntimeError('Unknown kind of credential')
logger.info('Connection (re)established')
except paramiko.SSHException:
raise ConnectionError('Cerulean was disconnected and could not reconnect')
return transport