mirror of http://git.sairate.top/sairate/doc.git
247 lines
9.4 KiB
Python
247 lines
9.4 KiB
Python
|
import logging
|
||
|
import socket
|
||
|
from time import perf_counter
|
||
|
from types import TracebackType
|
||
|
from typing import Any, List, Optional, Tuple, Type, TYPE_CHECKING
|
||
|
|
||
|
import paramiko
|
||
|
from cerulean.credential import (Credential, PasswordCredential,
|
||
|
PubKeyCredential)
|
||
|
from cerulean.terminal import Terminal
|
||
|
from cerulean.util import BaseExceptionType
|
||
|
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class SshTerminal(Terminal):
|
||
|
"""A terminal that runs commands over SSH.
|
||
|
|
||
|
This terminal connects to a host using SSH, then lets you run \
|
||
|
commands there.
|
||
|
|
||
|
Arguments:
|
||
|
host: The hostname to connect to.
|
||
|
port: The port to connect on.
|
||
|
credential: The credential to authenticate with.
|
||
|
|
||
|
"""
|
||
|
def __init__(self, host: str, port: int, credential: Credential) -> None:
|
||
|
self.__host = host
|
||
|
self.__port = port
|
||
|
self.__credential = credential
|
||
|
|
||
|
self.__transport = self.__ensure_connection(None)
|
||
|
self.__transport2 = None # type: Optional[paramiko.Transport]
|
||
|
|
||
|
def __enter__(self) -> 'SshTerminal':
|
||
|
"""Enter context manager."""
|
||
|
return self
|
||
|
|
||
|
def __exit__(self, exc_type: Optional[BaseExceptionType],
|
||
|
exc_value: Optional[BaseException],
|
||
|
traceback: Optional[TracebackType]) -> None:
|
||
|
"""Exit context manager."""
|
||
|
self.close()
|
||
|
|
||
|
def close(self) -> None:
|
||
|
"""Close the terminal.
|
||
|
|
||
|
This closes any connections and frees resources associated \
|
||
|
with the terminal.
|
||
|
"""
|
||
|
self.__transport.close()
|
||
|
logger.debug('Disconnected from SSH server')
|
||
|
|
||
|
def __eq__(self, other: Any) -> bool:
|
||
|
"""Returns True iff this terminal equals other."""
|
||
|
if not isinstance(other, Terminal):
|
||
|
return NotImplemented
|
||
|
if isinstance(other, SshTerminal):
|
||
|
return self.__host == other.__host and self.__port == other.__port
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
def _get_sftp_client(self) -> paramiko.SFTPClient:
|
||
|
"""Get an SFTP client using this terminal.
|
||
|
|
||
|
This function is used by SftpFileSystem to get an SFTP client \
|
||
|
using this Terminal's connection. This is a private function, \
|
||
|
but SftpFileSystem is a friend class.
|
||
|
|
||
|
Returns:
|
||
|
An SFTP client object using this terminal's connection.
|
||
|
|
||
|
"""
|
||
|
self.__transport = self.__ensure_connection(self.__transport)
|
||
|
client = paramiko.SFTPClient.from_transport(self.__transport)
|
||
|
if client is None:
|
||
|
raise RuntimeError('Could not open a channel for SFTP')
|
||
|
return client
|
||
|
|
||
|
def _get_downstream_sftp_client(self) -> paramiko.SFTPClient:
|
||
|
"""Gets a second SFTP client using this terminal.
|
||
|
|
||
|
This is a work-around for an issue in paramiko that keeps us \
|
||
|
from copying data upstream and downstream simultaneously \
|
||
|
through a single connection with reasonable performance. \
|
||
|
We solve it by opening a second connection for the downstream \
|
||
|
part.
|
||
|
|
||
|
Returns:
|
||
|
An SFTP client object using a second connection.
|
||
|
|
||
|
"""
|
||
|
self.__transport2 = self.__ensure_connection(self.__transport2)
|
||
|
client = paramiko.SFTPClient.from_transport(self.__transport2)
|
||
|
if client is None:
|
||
|
raise RuntimeError('Could not open a channel for SFTP')
|
||
|
return client
|
||
|
|
||
|
def run(self,
|
||
|
timeout: float,
|
||
|
command: str,
|
||
|
args: List[str],
|
||
|
stdin_data: str = None,
|
||
|
workdir: str = None) -> Tuple[Optional[int], str, str]:
|
||
|
|
||
|
if workdir:
|
||
|
cmd_str = 'cd {}; {} {}'.format(workdir, command, ' '.join(args))
|
||
|
else:
|
||
|
cmd_str = '{} {}'.format(command, ' '.join(args))
|
||
|
|
||
|
logger.debug('Executing %s', cmd_str)
|
||
|
last_exception = None # type: Optional[BaseException]
|
||
|
start_time = perf_counter()
|
||
|
while perf_counter() < start_time + timeout:
|
||
|
self.__transport = self.__ensure_connection(self.__transport)
|
||
|
try:
|
||
|
session = self.__transport.open_session()
|
||
|
logger.debug('Opened session')
|
||
|
session.exec_command(command=cmd_str)
|
||
|
logger.debug('exec_command done')
|
||
|
if stdin_data is not None:
|
||
|
session.sendall(bytes(stdin_data, 'utf-8'))
|
||
|
session.shutdown_write()
|
||
|
logger.debug('stdin sent')
|
||
|
|
||
|
got_all_stdout, stdout_text = self.__get_data_from_channel(
|
||
|
session, 'stdout', timeout)
|
||
|
got_all_stderr, stderr_text = self.__get_data_from_channel(
|
||
|
session, 'stderr', timeout)
|
||
|
logger.debug(
|
||
|
'got output %s %s %s %s', got_all_stdout, stdout_text,
|
||
|
got_all_stderr, stderr_text)
|
||
|
if not got_all_stdout or not got_all_stderr:
|
||
|
logger.debug('Command did not finish within timeout')
|
||
|
session.close()
|
||
|
return None, stdout_text, stderr_text
|
||
|
|
||
|
session.settimeout(2.0)
|
||
|
exit_status = session.recv_exit_status()
|
||
|
logger.debug('received exit status %s', exit_status)
|
||
|
session.close()
|
||
|
|
||
|
if exit_status == -1:
|
||
|
raise EOFError('Execution failed, connection'
|
||
|
' or server issue?')
|
||
|
|
||
|
logger.debug('Command executed successfully')
|
||
|
return exit_status, stdout_text, stderr_text
|
||
|
except paramiko.SSHException as e:
|
||
|
last_exception = e
|
||
|
except EOFError as e:
|
||
|
last_exception = e
|
||
|
except ConnectionError as e:
|
||
|
last_exception = e
|
||
|
except OSError as e:
|
||
|
if 'Socket' in str(e):
|
||
|
self.__ensure_connection(self.__transport, True)
|
||
|
last_exception = e
|
||
|
|
||
|
raise ConnectionError(str(last_exception))
|
||
|
|
||
|
def __get_data_from_channel(self, channel: paramiko.Channel,
|
||
|
stream_name: str,
|
||
|
timeout: float) -> Tuple[bool, str]:
|
||
|
"""Reads text from standard output or standard error."""
|
||
|
if stream_name == 'stdout':
|
||
|
receive = paramiko.Channel.recv
|
||
|
else:
|
||
|
receive = paramiko.Channel.recv_stderr
|
||
|
|
||
|
channel.settimeout(timeout)
|
||
|
|
||
|
data = bytearray()
|
||
|
try:
|
||
|
new_data = receive(channel, 1024 * 1024)
|
||
|
while len(new_data) > 0:
|
||
|
data.extend(new_data)
|
||
|
new_data = receive(channel, 1024 * 1024)
|
||
|
except socket.timeout:
|
||
|
return False, data.decode('utf-8')
|
||
|
|
||
|
return True, data.decode('utf-8')
|
||
|
|
||
|
def __get_key_from_file(self, filename: str,
|
||
|
passphrase: Optional[str]) -> paramiko.pkey.PKey:
|
||
|
key = None
|
||
|
messages = ''
|
||
|
try:
|
||
|
key = paramiko.ed25519key.Ed25519Key.from_private_key_file(
|
||
|
filename=filename, password=passphrase)
|
||
|
except paramiko.ssh_exception.SSHException as e:
|
||
|
key = None
|
||
|
messages += '{}; '.format(e)
|
||
|
|
||
|
if key is None:
|
||
|
try:
|
||
|
key = paramiko.ecdsakey.ECDSAKey.from_private_key_file(
|
||
|
filename=filename, password=passphrase)
|
||
|
except paramiko.ssh_exception.SSHException as e:
|
||
|
key = None
|
||
|
messages += '{}; '.format(e)
|
||
|
|
||
|
if key is None:
|
||
|
try:
|
||
|
key = paramiko.rsakey.RSAKey.from_private_key_file(
|
||
|
filename=filename, password=passphrase)
|
||
|
except paramiko.ssh_exception.SSHException as e:
|
||
|
key = None
|
||
|
messages += '{}; '.format(e)
|
||
|
|
||
|
if key is None:
|
||
|
logger.debug('Invalid key: %s', messages)
|
||
|
raise RuntimeError(
|
||
|
'Invalid key specified, could not open as RSA, ECDSA or'
|
||
|
' Ed25519 key'
|
||
|
)
|
||
|
|
||
|
return key
|
||
|
|
||
|
def __ensure_connection(self, transport: Optional[paramiko.Transport],
|
||
|
force: bool = False) -> paramiko.Transport:
|
||
|
if transport is None or not transport.is_active() or force:
|
||
|
if transport is not None:
|
||
|
transport.close()
|
||
|
transport = paramiko.Transport((self.__host, self.__port))
|
||
|
logger.info(
|
||
|
'Connecting to %s on port %s', self.__host, self.__port)
|
||
|
try:
|
||
|
if isinstance(self.__credential, PasswordCredential):
|
||
|
logger.debug('Authenticating using a password')
|
||
|
transport.connect(
|
||
|
username=self.__credential.username,
|
||
|
password=self.__credential.password)
|
||
|
elif isinstance(self.__credential, PubKeyCredential):
|
||
|
logger.debug('Authenticating using a public key')
|
||
|
key = self.__get_key_from_file(self.__credential.public_key,
|
||
|
self.__credential.passphrase)
|
||
|
transport.connect(username=self.__credential.username, pkey=key)
|
||
|
else:
|
||
|
raise RuntimeError('Unknown kind of credential')
|
||
|
logger.info('Connection (re)established')
|
||
|
except paramiko.SSHException:
|
||
|
raise ConnectionError('Cerulean was disconnected and could not reconnect')
|
||
|
return transport
|