File: //usr/local/CyberCP/lib/python3.10/site-packages/asyncssh/kex_rsa.py
# Copyright (c) 2018-2024 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
# http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
# GNU General Public License, Version 2.0, or any later versions of
# that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
# Ron Frederick - initial implementation, API, and documentation
"""RSA key exchange handler"""
from hashlib import sha1, sha256
from typing import TYPE_CHECKING, Optional, cast
from .kex import Kex, register_kex_alg
from .misc import HashType, KeyExchangeFailed, ProtocolError
from .misc import get_symbol_names, randrange
from .packet import MPInt, String, SSHPacket
from .public_key import KeyImportError, SSHKey
from .public_key import decode_ssh_public_key, generate_private_key
from .rsa import RSAKey
if TYPE_CHECKING:
# pylint: disable=cyclic-import
from .connection import SSHConnection, SSHClientConnection
from .connection import SSHServerConnection
# SSH KEXRSA message values
MSG_KEXRSA_PUBKEY = 30
MSG_KEXRSA_SECRET = 31
MSG_KEXRSA_DONE = 32
class _KexRSA(Kex):
"""Handler for RSA key exchange"""
_handler_names = get_symbol_names(globals(), 'MSG_KEXRSA')
def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType,
key_size: int, hash_size: int):
super().__init__(alg, conn, hash_alg)
self._key_size = key_size
self._k_limit = 1 << (key_size - 2*hash_size - 49)
self._host_key_data = b''
self._trans_key: Optional[SSHKey] = None
self._trans_key_data = b''
self._k = 0
self._encrypted_k = b''
async def start(self) -> None:
"""Start RSA key exchange"""
if self._conn.is_server():
server_conn = cast('SSHServerConnection', self._conn)
host_key = server_conn.get_server_host_key()
assert host_key is not None
self._host_key_data = host_key.public_data
self._trans_key = generate_private_key(
'ssh-rsa', key_size=self._key_size)
self._trans_key_data = self._trans_key.public_data
self.send_packet(MSG_KEXRSA_PUBKEY, String(self._host_key_data),
String(self._trans_key_data))
def _compute_hash(self) -> bytes:
"""Compute a hash of key information associated with the connection"""
hash_obj = self._hash_alg()
hash_obj.update(self._conn.get_hash_prefix())
hash_obj.update(String(self._host_key_data))
hash_obj.update(String(self._trans_key_data))
hash_obj.update(String(self._encrypted_k))
hash_obj.update(MPInt(self._k))
return hash_obj.digest()
def _process_pubkey(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a KEXRSA pubkey message"""
if self._conn.is_server():
raise ProtocolError('Unexpected KEXRSA pubkey msg')
self._host_key_data = packet.get_string()
self._trans_key_data = packet.get_string()
packet.check_end()
try:
pubkey = decode_ssh_public_key(self._trans_key_data)
except KeyImportError:
raise ProtocolError('Invalid KEXRSA pubkey msg') from None
trans_key = cast(RSAKey, pubkey)
self._k = randrange(self._k_limit)
self._encrypted_k = \
cast(bytes, trans_key.encrypt(MPInt(self._k), self.algorithm))
self.send_packet(MSG_KEXRSA_SECRET, String(self._encrypted_k))
def _process_secret(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a KEXRSA secret message"""
if self._conn.is_client():
raise ProtocolError('Unexpected KEXRSA secret msg')
self._encrypted_k = packet.get_string()
packet.check_end()
trans_key = cast(RSAKey, self._trans_key)
decrypted_k = trans_key.decrypt(self._encrypted_k, self.algorithm)
if not decrypted_k:
raise KeyExchangeFailed('Key exchange decryption failed')
packet = SSHPacket(decrypted_k)
self._k = packet.get_mpint()
packet.check_end()
server_conn = cast('SSHServerConnection', self._conn)
host_key = server_conn.get_server_host_key()
assert host_key is not None
h = self._compute_hash()
sig = host_key.sign(h)
self.send_packet(MSG_KEXRSA_DONE, String(sig))
self._conn.send_newkeys(MPInt(self._k), h)
def _process_done(self, _pkttype: int, _pktid: int,
packet: SSHPacket) -> None:
"""Process a KEXRSA done message"""
if self._conn.is_server():
raise ProtocolError('Unexpected KEXRSA done msg')
sig = packet.get_string()
packet.check_end()
client_conn = cast('SSHClientConnection', self._conn)
host_key = client_conn.validate_server_host_key(self._host_key_data)
h = self._compute_hash()
if not host_key.verify(h, sig):
raise KeyExchangeFailed('Key exchange hash mismatch')
self._conn.send_newkeys(MPInt(self._k), h)
_packet_handlers = {
MSG_KEXRSA_PUBKEY: _process_pubkey,
MSG_KEXRSA_SECRET: _process_secret,
MSG_KEXRSA_DONE: _process_done
}
for _name, _hash_alg, _key_size, _hash_size, _default in (
(b'rsa2048-sha256', sha256, 2048, 256, True),
(b'rsa1024-sha1', sha1, 1024, 160, False)):
register_kex_alg(_name, _KexRSA, _hash_alg,
(_key_size, _hash_size), _default)