From 54f7e327b30486f2ecda64d425becef4445ec773 Mon Sep 17 00:00:00 2001 From: Cayman Date: Fri, 15 Dec 2023 17:40:36 -0500 Subject: [PATCH] feat: refactor the codebase --- src/crypto.ts | 22 +- src/crypto/index.ts | 6 +- src/crypto/js.ts | 8 +- src/encoder.ts | 53 +---- src/handshake-xx.ts | 179 --------------- src/handshakes/abstract-handshake.ts | 181 ---------------- src/handshakes/xx.ts | 169 --------------- src/logger.ts | 39 ++-- src/noise.ts | 157 ++++++++------ src/nonce.ts | 9 +- src/performHandshake.ts | 90 ++++++++ src/protocol.ts | 313 +++++++++++++++++++++++++++ src/{crypto => }/streaming.ts | 27 +-- src/types.ts | 92 ++++---- src/utils.ts | 140 ++++-------- test/handshakes/xx.spec.ts | 160 -------------- test/noise.spec.ts | 60 ----- test/performHandshake.spec.ts | 138 ++++++++++++ test/protocol.spec.ts | 112 ++++++++++ test/xx-handshake.spec.ts | 143 ------------ 20 files changed, 897 insertions(+), 1201 deletions(-) delete mode 100644 src/handshake-xx.ts delete mode 100644 src/handshakes/abstract-handshake.ts delete mode 100644 src/handshakes/xx.ts create mode 100644 src/performHandshake.ts create mode 100644 src/protocol.ts rename src/{crypto => }/streaming.ts (66%) delete mode 100644 test/handshakes/xx.spec.ts create mode 100644 test/performHandshake.spec.ts create mode 100644 test/protocol.spec.ts delete mode 100644 test/xx-handshake.spec.ts diff --git a/src/crypto.ts b/src/crypto.ts index 2cf9df0..75c0f57 100644 --- a/src/crypto.ts +++ b/src/crypto.ts @@ -1,15 +1,27 @@ -import { type Uint8ArrayList } from 'uint8arraylist' -import type { bytes32, Hkdf, KeyPair } from './types.js' +import type { ICrypto, KeyPair } from './types.js' +import type { Uint8ArrayList } from 'uint8arraylist' +/** Underlying crypto implementation, meant to be overridable */ export interface ICryptoInterface { hashSHA256(data: Uint8Array | Uint8ArrayList): Uint8Array - getHKDF(ck: bytes32, ikm: Uint8Array): Hkdf + getHKDF(ck: Uint8Array, ikm: Uint8Array): [Uint8Array, Uint8Array, Uint8Array] generateX25519KeyPair(): KeyPair generateX25519KeyPairFromSeed(seed: Uint8Array): KeyPair generateX25519SharedKey(privateKey: Uint8Array | Uint8ArrayList, publicKey: Uint8Array | Uint8ArrayList): Uint8Array - chaCha20Poly1305Encrypt(plaintext: Uint8Array | Uint8ArrayList, nonce: Uint8Array, ad: Uint8Array, k: bytes32): Uint8ArrayList | Uint8Array - chaCha20Poly1305Decrypt(ciphertext: Uint8Array | Uint8ArrayList, nonce: Uint8Array, ad: Uint8Array, k: bytes32, dst?: Uint8Array): Uint8ArrayList | Uint8Array | null + chaCha20Poly1305Encrypt(plaintext: Uint8Array | Uint8ArrayList, nonce: Uint8Array, ad: Uint8Array, k: Uint8Array): Uint8ArrayList | Uint8Array + chaCha20Poly1305Decrypt(ciphertext: Uint8Array | Uint8ArrayList, nonce: Uint8Array, ad: Uint8Array, k: Uint8Array, dst?: Uint8Array): Uint8ArrayList | Uint8Array +} + +export function wrapCrypto (crypto: ICryptoInterface): ICrypto { + return { + generateKeypair: crypto.generateX25519KeyPair, + dh: (keypair, publicKey) => crypto.generateX25519SharedKey(keypair.privateKey, publicKey).subarray(0, 32), + encrypt: crypto.chaCha20Poly1305Encrypt, + decrypt: crypto.chaCha20Poly1305Decrypt, + hash: crypto.hashSHA256, + hkdf: crypto.getHKDF + } } diff --git a/src/crypto/index.ts b/src/crypto/index.ts index a428c1e..bd9d288 100644 --- a/src/crypto/index.ts +++ b/src/crypto/index.ts @@ -112,7 +112,11 @@ const asCrypto: Pick { @@ -24,51 +23,3 @@ export const uint16BEDecode: LengthDecoderFunction = (data: Uint8Array | Uint8Ar return data.getUint16(0) } uint16BEDecode.bytes = 2 - -export function encode0 (message: MessageBuffer): Uint8ArrayList { - return new Uint8ArrayList(message.ne, message.ciphertext) -} - -export function encode1 (message: MessageBuffer): Uint8ArrayList { - return new Uint8ArrayList(message.ne, message.ns, message.ciphertext) -} - -export function encode2 (message: MessageBuffer): Uint8ArrayList { - return new Uint8ArrayList(message.ns, message.ciphertext) -} - -export function decode0 (input: bytes): MessageBuffer { - if (input.length < 32) { - throw new Error('Cannot decode stage 0 MessageBuffer: length less than 32 bytes.') - } - - return { - ne: input.subarray(0, 32), - ciphertext: input.subarray(32, input.length), - ns: uint8ArrayAlloc(0) - } -} - -export function decode1 (input: bytes): MessageBuffer { - if (input.length < 80) { - throw new Error('Cannot decode stage 1 MessageBuffer: length less than 80 bytes.') - } - - return { - ne: input.subarray(0, 32), - ns: input.subarray(32, 80), - ciphertext: input.subarray(80, input.length) - } -} - -export function decode2 (input: bytes): MessageBuffer { - if (input.length < 48) { - throw new Error('Cannot decode stage 2 MessageBuffer: length less than 48 bytes.') - } - - return { - ne: uint8ArrayAlloc(0), - ns: input.subarray(0, 48), - ciphertext: input.subarray(48, input.length) - } -} diff --git a/src/handshake-xx.ts b/src/handshake-xx.ts deleted file mode 100644 index d43e724..0000000 --- a/src/handshake-xx.ts +++ /dev/null @@ -1,179 +0,0 @@ -import { alloc as uint8ArrayAlloc } from 'uint8arrays/alloc' -import { decode0, decode1, decode2, encode0, encode1, encode2 } from './encoder.js' -import { InvalidCryptoExchangeError, UnexpectedPeerError } from './errors.js' -import { XX } from './handshakes/xx.js' -import { - logLocalStaticKeys, - logLocalEphemeralKeys, - logRemoteEphemeralKey, - logRemoteStaticKey, - logCipherState -} from './logger.js' -import { - decodePayload, - getPeerIdFromPayload, - verifySignedPayload -} from './utils.js' -import type { ICryptoInterface } from './crypto.js' -import type { NoiseComponents } from './index.js' -import type { NoiseExtensions } from './proto/payload.js' -import type { bytes, bytes32, IHandshake, CipherState, NoiseSession, KeyPair } from './types.js' -import type { Logger, PeerId } from '@libp2p/interface' -import type { LengthPrefixedStream } from 'it-length-prefixed-stream' -import type { Uint8ArrayList } from 'uint8arraylist' - -export class XXHandshake implements IHandshake { - public isInitiator: boolean - public session: NoiseSession - public remotePeer!: PeerId - public remoteExtensions: NoiseExtensions = { webtransportCerthashes: [] } - - protected payload: bytes - protected connection: LengthPrefixedStream - protected xx: XX - protected staticKeypair: KeyPair - - private readonly prologue: bytes32 - private readonly log: Logger - - constructor ( - components: NoiseComponents, - isInitiator: boolean, - payload: bytes, - prologue: bytes32, - crypto: ICryptoInterface, - staticKeypair: KeyPair, - connection: LengthPrefixedStream, - remotePeer?: PeerId, - handshake?: XX - ) { - this.log = components.logger.forComponent('libp2p:noise:xxhandshake') - this.isInitiator = isInitiator - this.payload = payload - this.prologue = prologue - this.staticKeypair = staticKeypair - this.connection = connection - if (remotePeer) { - this.remotePeer = remotePeer - } - this.xx = handshake ?? new XX(components, crypto) - this.session = this.xx.initSession(this.isInitiator, this.prologue, this.staticKeypair) - } - - // stage 0 - public async propose (): Promise { - logLocalStaticKeys(this.session.hs.s, this.log) - if (this.isInitiator) { - this.log.trace('Stage 0 - Initiator starting to send first message.') - const messageBuffer = this.xx.sendMessage(this.session, uint8ArrayAlloc(0)) - await this.connection.write(encode0(messageBuffer)) - this.log.trace('Stage 0 - Initiator finished sending first message.') - logLocalEphemeralKeys(this.session.hs.e, this.log) - } else { - this.log.trace('Stage 0 - Responder waiting to receive first message...') - const receivedMessageBuffer = decode0((await this.connection.read()).subarray()) - const { valid } = this.xx.recvMessage(this.session, receivedMessageBuffer) - if (!valid) { - throw new InvalidCryptoExchangeError('xx handshake stage 0 validation fail') - } - this.log.trace('Stage 0 - Responder received first message.') - logRemoteEphemeralKey(this.session.hs.re, this.log) - } - } - - // stage 1 - public async exchange (): Promise { - if (this.isInitiator) { - this.log.trace('Stage 1 - Initiator waiting to receive first message from responder...') - const receivedMessageBuffer = decode1((await this.connection.read()).subarray()) - const { plaintext, valid } = this.xx.recvMessage(this.session, receivedMessageBuffer) - if (!valid) { - throw new InvalidCryptoExchangeError('xx handshake stage 1 validation fail') - } - this.log.trace('Stage 1 - Initiator received the message.') - logRemoteEphemeralKey(this.session.hs.re, this.log) - logRemoteStaticKey(this.session.hs.rs, this.log) - - this.log.trace("Initiator going to check remote's signature...") - try { - const decodedPayload = decodePayload(plaintext) - this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload) - await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer) - this.setRemoteNoiseExtension(decodedPayload.extensions) - } catch (e) { - const err = e as Error - throw new UnexpectedPeerError(`Error occurred while verifying signed payload: ${err.message}`) - } - this.log.trace('All good with the signature!') - } else { - this.log.trace('Stage 1 - Responder sending out first message with signed payload and static key.') - const messageBuffer = this.xx.sendMessage(this.session, this.payload) - await this.connection.write(encode1(messageBuffer)) - this.log.trace('Stage 1 - Responder sent the second handshake message with signed payload.') - logLocalEphemeralKeys(this.session.hs.e, this.log) - } - } - - // stage 2 - public async finish (): Promise { - if (this.isInitiator) { - this.log.trace('Stage 2 - Initiator sending third handshake message.') - const messageBuffer = this.xx.sendMessage(this.session, this.payload) - await this.connection.write(encode2(messageBuffer)) - this.log.trace('Stage 2 - Initiator sent message with signed payload.') - } else { - this.log.trace('Stage 2 - Responder waiting for third handshake message...') - const receivedMessageBuffer = decode2((await this.connection.read()).subarray()) - const { plaintext, valid } = this.xx.recvMessage(this.session, receivedMessageBuffer) - if (!valid) { - throw new InvalidCryptoExchangeError('xx handshake stage 2 validation fail') - } - this.log.trace('Stage 2 - Responder received the message, finished handshake.') - - try { - const decodedPayload = decodePayload(plaintext) - this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload) - await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer) - this.setRemoteNoiseExtension(decodedPayload.extensions) - } catch (e) { - const err = e as Error - throw new UnexpectedPeerError(`Error occurred while verifying signed payload: ${err.message}`) - } - } - logCipherState(this.session, this.log) - } - - public encrypt (plaintext: Uint8Array | Uint8ArrayList, session: NoiseSession): Uint8Array | Uint8ArrayList { - const cs = this.getCS(session) - - return this.xx.encryptWithAd(cs, uint8ArrayAlloc(0), plaintext) - } - - public decrypt (ciphertext: Uint8Array | Uint8ArrayList, session: NoiseSession, dst?: Uint8Array): { plaintext: Uint8Array | Uint8ArrayList, valid: boolean } { - const cs = this.getCS(session, false) - - return this.xx.decryptWithAd(cs, uint8ArrayAlloc(0), ciphertext, dst) - } - - public getRemoteStaticKey (): Uint8Array | Uint8ArrayList { - return this.session.hs.rs - } - - private getCS (session: NoiseSession, encryption = true): CipherState { - if (!session.cs1 || !session.cs2) { - throw new InvalidCryptoExchangeError('Handshake not completed properly, cipher state does not exist.') - } - - if (this.isInitiator) { - return encryption ? session.cs1 : session.cs2 - } else { - return encryption ? session.cs2 : session.cs1 - } - } - - protected setRemoteNoiseExtension (e: NoiseExtensions | null | undefined): void { - if (e) { - this.remoteExtensions = e - } - } -} diff --git a/src/handshakes/abstract-handshake.ts b/src/handshakes/abstract-handshake.ts deleted file mode 100644 index 4594b7b..0000000 --- a/src/handshakes/abstract-handshake.ts +++ /dev/null @@ -1,181 +0,0 @@ -import { Uint8ArrayList } from 'uint8arraylist' -import { fromString as uint8ArrayFromString } from 'uint8arrays' -import { alloc as uint8ArrayAlloc } from 'uint8arrays/alloc' -import { equals as uint8ArrayEquals } from 'uint8arrays/equals' -import { Nonce } from '../nonce.js' -import type { ICryptoInterface } from '../crypto.js' -import type { NoiseComponents } from '../index.js' -import type { bytes, bytes32, CipherState, MessageBuffer, SymmetricState } from '../types.js' -import type { Logger } from '@libp2p/interface' - -export interface DecryptedResult { - plaintext: Uint8ArrayList | Uint8Array - valid: boolean -} - -export interface SplitState { - cs1: CipherState - cs2: CipherState -} - -const EMPTY_KEY = uint8ArrayAlloc(32) - -export abstract class AbstractHandshake { - public crypto: ICryptoInterface - private readonly log: Logger - - constructor (components: NoiseComponents, crypto: ICryptoInterface) { - this.log = components.logger.forComponent('libp2p:noise:abstract-handshake') - this.crypto = crypto - } - - public encryptWithAd (cs: CipherState, ad: Uint8Array, plaintext: Uint8Array | Uint8ArrayList): Uint8Array | Uint8ArrayList { - const e = this.encrypt(cs.k, cs.n, ad, plaintext) - cs.n.increment() - - return e - } - - public decryptWithAd (cs: CipherState, ad: Uint8Array, ciphertext: Uint8Array | Uint8ArrayList, dst?: Uint8Array): DecryptedResult { - const { plaintext, valid } = this.decrypt(cs.k, cs.n, ad, ciphertext, dst) - if (valid) cs.n.increment() - - return { plaintext, valid } - } - - // Cipher state related - protected hasKey (cs: CipherState): boolean { - return !this.isEmptyKey(cs.k) - } - - protected isEmptyKey (k: bytes32): boolean { - return uint8ArrayEquals(EMPTY_KEY, k) - } - - protected encrypt (k: bytes32, n: Nonce, ad: Uint8Array, plaintext: Uint8Array | Uint8ArrayList): Uint8Array | Uint8ArrayList { - n.assertValue() - - return this.crypto.chaCha20Poly1305Encrypt(plaintext, n.getBytes(), ad, k) - } - - protected encryptAndHash (ss: SymmetricState, plaintext: bytes): Uint8Array | Uint8ArrayList { - let ciphertext - if (this.hasKey(ss.cs)) { - ciphertext = this.encryptWithAd(ss.cs, ss.h, plaintext) - } else { - ciphertext = plaintext - } - - this.mixHash(ss, ciphertext) - return ciphertext - } - - protected decrypt (k: bytes32, n: Nonce, ad: bytes, ciphertext: Uint8Array | Uint8ArrayList, dst?: Uint8Array): DecryptedResult { - n.assertValue() - - const encryptedMessage = this.crypto.chaCha20Poly1305Decrypt(ciphertext, n.getBytes(), ad, k, dst) - - if (encryptedMessage) { - return { - plaintext: encryptedMessage, - valid: true - } - } else { - return { - plaintext: uint8ArrayAlloc(0), - valid: false - } - } - } - - protected decryptAndHash (ss: SymmetricState, ciphertext: Uint8Array | Uint8ArrayList): DecryptedResult { - let plaintext: Uint8Array | Uint8ArrayList - let valid = true - if (this.hasKey(ss.cs)) { - ({ plaintext, valid } = this.decryptWithAd(ss.cs, ss.h, ciphertext)) - } else { - plaintext = ciphertext - } - - this.mixHash(ss, ciphertext) - return { plaintext, valid } - } - - protected dh (privateKey: bytes32, publicKey: Uint8Array | Uint8ArrayList): bytes32 { - try { - const derivedU8 = this.crypto.generateX25519SharedKey(privateKey, publicKey) - - if (derivedU8.length === 32) { - return derivedU8 - } - - return derivedU8.subarray(0, 32) - } catch (e) { - const err = e as Error - this.log.error('error deriving shared key', err) - return uint8ArrayAlloc(32) - } - } - - protected mixHash (ss: SymmetricState, data: Uint8Array | Uint8ArrayList): void { - ss.h = this.getHash(ss.h, data) - } - - protected getHash (a: Uint8Array, b: Uint8Array | Uint8ArrayList): Uint8Array { - const u = this.crypto.hashSHA256(new Uint8ArrayList(a, b)) - return u - } - - protected mixKey (ss: SymmetricState, ikm: bytes32): void { - const [ck, tempK] = this.crypto.getHKDF(ss.ck, ikm) - ss.cs = this.initializeKey(tempK) - ss.ck = ck - } - - protected initializeKey (k: bytes32): CipherState { - return { k, n: new Nonce() } - } - - // Symmetric state related - - protected initializeSymmetric (protocolName: string): SymmetricState { - const protocolNameBytes = uint8ArrayFromString(protocolName, 'utf-8') - const h = this.hashProtocolName(protocolNameBytes) - - const ck = h - const key = uint8ArrayAlloc(32) - const cs: CipherState = this.initializeKey(key) - - return { cs, ck, h } - } - - protected hashProtocolName (protocolName: Uint8Array): bytes32 { - if (protocolName.length <= 32) { - const h = uint8ArrayAlloc(32) - h.set(protocolName) - return h - } else { - return this.getHash(protocolName, uint8ArrayAlloc(0)) - } - } - - protected split (ss: SymmetricState): SplitState { - const [tempk1, tempk2] = this.crypto.getHKDF(ss.ck, uint8ArrayAlloc(0)) - const cs1 = this.initializeKey(tempk1) - const cs2 = this.initializeKey(tempk2) - - return { cs1, cs2 } - } - - protected writeMessageRegular (cs: CipherState, payload: bytes): MessageBuffer { - const ciphertext = this.encryptWithAd(cs, uint8ArrayAlloc(0), payload) - const ne = uint8ArrayAlloc(32) - const ns = uint8ArrayAlloc(0) - - return { ne, ns, ciphertext } - } - - protected readMessageRegular (cs: CipherState, message: MessageBuffer): DecryptedResult { - return this.decryptWithAd(cs, uint8ArrayAlloc(0), message.ciphertext) - } -} diff --git a/src/handshakes/xx.ts b/src/handshakes/xx.ts deleted file mode 100644 index ceb8e05..0000000 --- a/src/handshakes/xx.ts +++ /dev/null @@ -1,169 +0,0 @@ -import { alloc as uint8ArrayAlloc } from 'uint8arrays/alloc' -import { isValidPublicKey } from '../utils.js' -import { AbstractHandshake, type DecryptedResult } from './abstract-handshake.js' -import type { bytes32, bytes, CipherState, HandshakeState, KeyPair, MessageBuffer, NoiseSession } from '../types.js' -import type { Uint8ArrayList } from 'uint8arraylist' - -export class XX extends AbstractHandshake { - private initializeState (prologue: bytes32, s: KeyPair): HandshakeState { - const name = 'Noise_XX_25519_ChaChaPoly_SHA256' - const ss = this.initializeSymmetric(name) - this.mixHash(ss, prologue) - const psk = uint8ArrayAlloc(32) - const rs = uint8ArrayAlloc(32) // no static key yet - const re = uint8ArrayAlloc(32) - - return { ss, s, rs, psk, re } - } - - private writeMessageA (hs: HandshakeState, payload: bytes, e?: KeyPair): MessageBuffer { - const ns = uint8ArrayAlloc(0) - - if (e !== undefined) { - hs.e = e - } else { - hs.e = this.crypto.generateX25519KeyPair() - } - - const ne = hs.e.publicKey - - this.mixHash(hs.ss, ne) - const ciphertext = this.encryptAndHash(hs.ss, payload) - - return { ne, ns, ciphertext } - } - - private writeMessageB (hs: HandshakeState, payload: bytes): MessageBuffer { - hs.e = this.crypto.generateX25519KeyPair() - const ne = hs.e.publicKey - this.mixHash(hs.ss, ne) - - this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re)) - const spk = hs.s.publicKey - const ns = this.encryptAndHash(hs.ss, spk) - - this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re)) - const ciphertext = this.encryptAndHash(hs.ss, payload) - - return { ne, ns, ciphertext } - } - - private writeMessageC (hs: HandshakeState, payload: bytes): { messageBuffer: MessageBuffer, cs1: CipherState, cs2: CipherState, h: bytes } { - const spk = hs.s.publicKey - const ns = this.encryptAndHash(hs.ss, spk) - this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re)) - const ciphertext = this.encryptAndHash(hs.ss, payload) - const ne = uint8ArrayAlloc(32) - const messageBuffer: MessageBuffer = { ne, ns, ciphertext } - const { cs1, cs2 } = this.split(hs.ss) - - return { h: hs.ss.h, messageBuffer, cs1, cs2 } - } - - private readMessageA (hs: HandshakeState, message: MessageBuffer): DecryptedResult { - if (isValidPublicKey(message.ne)) { - hs.re = message.ne - } - - this.mixHash(hs.ss, hs.re) - return this.decryptAndHash(hs.ss, message.ciphertext) - } - - private readMessageB (hs: HandshakeState, message: MessageBuffer): DecryptedResult { - if (isValidPublicKey(message.ne)) { - hs.re = message.ne - } - - this.mixHash(hs.ss, hs.re) - if (!hs.e) { - throw new Error('Handshake state `e` param is missing.') - } - this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re)) - const { plaintext: ns, valid: valid1 } = this.decryptAndHash(hs.ss, message.ns) - if (valid1 && isValidPublicKey(ns)) { - hs.rs = ns - } - this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)) - const { plaintext, valid: valid2 } = this.decryptAndHash(hs.ss, message.ciphertext) - return { plaintext, valid: (valid1 && valid2) } - } - - private readMessageC (hs: HandshakeState, message: MessageBuffer): { h: bytes, plaintext: Uint8Array | Uint8ArrayList, valid: boolean, cs1: CipherState, cs2: CipherState } { - const { plaintext: ns, valid: valid1 } = this.decryptAndHash(hs.ss, message.ns) - if (valid1 && isValidPublicKey(ns)) { - hs.rs = ns - } - if (!hs.e) { - throw new Error('Handshake state `e` param is missing.') - } - this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)) - - const { plaintext, valid: valid2 } = this.decryptAndHash(hs.ss, message.ciphertext) - const { cs1, cs2 } = this.split(hs.ss) - - return { h: hs.ss.h, plaintext, valid: (valid1 && valid2), cs1, cs2 } - } - - public initSession (initiator: boolean, prologue: bytes32, s: KeyPair): NoiseSession { - const hs = this.initializeState(prologue, s) - - return { - hs, - i: initiator, - mc: 0 - } - } - - public sendMessage (session: NoiseSession, message: bytes, ephemeral?: KeyPair): MessageBuffer { - let messageBuffer: MessageBuffer - if (session.mc === 0) { - messageBuffer = this.writeMessageA(session.hs, message, ephemeral) - } else if (session.mc === 1) { - messageBuffer = this.writeMessageB(session.hs, message) - } else if (session.mc === 2) { - const { h, messageBuffer: resultingBuffer, cs1, cs2 } = this.writeMessageC(session.hs, message) - messageBuffer = resultingBuffer - session.h = h - session.cs1 = cs1 - session.cs2 = cs2 - } else if (session.mc > 2) { - if (session.i) { - if (!session.cs1) { - throw new Error('CS1 (cipher state) is not defined') - } - - messageBuffer = this.writeMessageRegular(session.cs1, message) - } else { - if (!session.cs2) { - throw new Error('CS2 (cipher state) is not defined') - } - - messageBuffer = this.writeMessageRegular(session.cs2, message) - } - } else { - throw new Error('Session invalid.') - } - - session.mc++ - return messageBuffer - } - - public recvMessage (session: NoiseSession, message: MessageBuffer): DecryptedResult { - let plaintext: Uint8Array | Uint8ArrayList = uint8ArrayAlloc(0) - let valid = false - if (session.mc === 0) { - ({ plaintext, valid } = this.readMessageA(session.hs, message)) - } else if (session.mc === 1) { - ({ plaintext, valid } = this.readMessageB(session.hs, message)) - } else if (session.mc === 2) { - const { h, plaintext: resultingPlaintext, valid: resultingValid, cs1, cs2 } = this.readMessageC(session.hs, message) - plaintext = resultingPlaintext - valid = resultingValid - session.h = h - session.cs1 = cs1 - session.cs2 = cs2 - } - session.mc++ - return { plaintext, valid } - } -} diff --git a/src/logger.ts b/src/logger.ts index 86b052c..6a0a006 100644 --- a/src/logger.ts +++ b/src/logger.ts @@ -1,16 +1,21 @@ import { toString as uint8ArrayToString } from 'uint8arrays/to-string' import { DUMP_SESSION_KEYS } from './constants.js' -import type { NoiseSession, KeyPair } from './types.js' +import type { CipherState } from './protocol.js' +import type { KeyPair } from './types.js' import type { Logger } from '@libp2p/interface' import type { Uint8ArrayList } from 'uint8arraylist' -export function logLocalStaticKeys (s: KeyPair, keyLogger: Logger): void { +export function logLocalStaticKeys (s: KeyPair | undefined, keyLogger: Logger): void { if (!keyLogger.enabled || !DUMP_SESSION_KEYS) { return } - keyLogger(`LOCAL_STATIC_PUBLIC_KEY ${uint8ArrayToString(s.publicKey, 'hex')}`) - keyLogger(`LOCAL_STATIC_PRIVATE_KEY ${uint8ArrayToString(s.privateKey, 'hex')}`) + if (s) { + keyLogger(`LOCAL_STATIC_PUBLIC_KEY ${uint8ArrayToString(s.publicKey, 'hex')}`) + keyLogger(`LOCAL_STATIC_PRIVATE_KEY ${uint8ArrayToString(s.privateKey, 'hex')}`) + } else { + keyLogger('Missing local static keys.') + } } export function logLocalEphemeralKeys (e: KeyPair | undefined, keyLogger: Logger): void { @@ -26,31 +31,35 @@ export function logLocalEphemeralKeys (e: KeyPair | undefined, keyLogger: Logger } } -export function logRemoteStaticKey (rs: Uint8Array | Uint8ArrayList, keyLogger: Logger): void { +export function logRemoteStaticKey (rs: Uint8Array | Uint8ArrayList | undefined, keyLogger: Logger): void { if (!keyLogger.enabled || !DUMP_SESSION_KEYS) { return } - keyLogger(`REMOTE_STATIC_PUBLIC_KEY ${uint8ArrayToString(rs.subarray(), 'hex')}`) + if (rs) { + keyLogger(`REMOTE_STATIC_PUBLIC_KEY ${uint8ArrayToString(rs.subarray(), 'hex')}`) + } else { + keyLogger('Missing remote static public key.') + } } -export function logRemoteEphemeralKey (re: Uint8Array | Uint8ArrayList, keyLogger: Logger): void { +export function logRemoteEphemeralKey (re: Uint8Array | Uint8ArrayList | undefined, keyLogger: Logger): void { if (!keyLogger.enabled || !DUMP_SESSION_KEYS) { return } - keyLogger(`REMOTE_EPHEMERAL_PUBLIC_KEY ${uint8ArrayToString(re.subarray(), 'hex')}`) + if (re) { + keyLogger(`REMOTE_EPHEMERAL_PUBLIC_KEY ${uint8ArrayToString(re.subarray(), 'hex')}`) + } else { + keyLogger('Missing remote ephemeral keys.') + } } -export function logCipherState (session: NoiseSession, keyLogger: Logger): void { +export function logCipherState (cs1: CipherState, cs2: CipherState, keyLogger: Logger): void { if (!keyLogger.enabled || !DUMP_SESSION_KEYS) { return } - if (session.cs1 && session.cs2) { - keyLogger(`CIPHER_STATE_1 ${session.cs1.n.getUint64()} ${uint8ArrayToString(session.cs1.k, 'hex')}`) - keyLogger(`CIPHER_STATE_2 ${session.cs2.n.getUint64()} ${uint8ArrayToString(session.cs2.k, 'hex')}`) - } else { - keyLogger('Missing cipher state.') - } + keyLogger(`CIPHER_STATE_1 ${cs1.n.getUint64()} ${cs1.k && uint8ArrayToString(cs1.k, 'hex')}`) + keyLogger(`CIPHER_STATE_2 ${cs2.n.getUint64()} ${cs2.k && uint8ArrayToString(cs2.k, 'hex')}`) } diff --git a/src/noise.ts b/src/noise.ts index 575e23b..c50542e 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -1,3 +1,6 @@ +import { unmarshalPrivateKey } from '@libp2p/crypto/keys' +import { type MultiaddrConnection, type SecuredConnection, type PeerId, CodeError, type PrivateKey } from '@libp2p/interface' +import { peerIdFromKeys } from '@libp2p/peer-id' import { decode } from 'it-length-prefixed' import { lpStream, type LengthPrefixedStream } from 'it-length-prefixed-stream' import { duplexPair } from 'it-pair/duplex' @@ -5,31 +8,22 @@ import { pipe } from 'it-pipe' import { alloc as uint8ArrayAlloc } from 'uint8arrays/alloc' import { NOISE_MSG_MAX_LENGTH_BYTES } from './constants.js' import { defaultCrypto } from './crypto/index.js' -import { decryptStream, encryptStream } from './crypto/streaming.js' +import { wrapCrypto, type ICryptoInterface } from './crypto.js' import { uint16BEDecode, uint16BEEncode } from './encoder.js' -import { XXHandshake } from './handshake-xx.js' import { type MetricsRegistry, registerMetrics } from './metrics.js' -import { getPayload } from './utils.js' -import type { ICryptoInterface } from './crypto.js' +import { performHandshakeInitiator, performHandshakeResponder } from './performHandshake.js' +import { decryptStream, encryptStream } from './streaming.js' import type { NoiseComponents } from './index.js' import type { NoiseExtensions } from './proto/payload.js' -import type { bytes, IHandshake, INoiseConnection, KeyPair } from './types.js' -import type { MultiaddrConnection, SecuredConnection, PeerId } from '@libp2p/interface' +import type { HandshakeResult, ICrypto, INoiseConnection, KeyPair } from './types.js' import type { Duplex } from 'it-stream-types' import type { Uint8ArrayList } from 'uint8arraylist' -interface HandshakeParams { - connection: LengthPrefixedStream - isInitiator: boolean - localPeer: PeerId - remotePeer?: PeerId -} - export interface NoiseInit { /** * x25519 private key, reuse for faster handshakes */ - staticNoiseKey?: bytes + staticNoiseKey?: Uint8Array extensions?: NoiseExtensions crypto?: ICryptoInterface prologueBytes?: Uint8Array @@ -37,10 +31,10 @@ export interface NoiseInit { export class Noise implements INoiseConnection { public protocol = '/noise' - public crypto: ICryptoInterface + public crypto: ICrypto private readonly prologue: Uint8Array - private readonly staticKeys: KeyPair + private readonly staticKey: KeyPair private readonly extensions?: NoiseExtensions private readonly metrics?: MetricsRegistry private readonly components: NoiseComponents @@ -50,15 +44,16 @@ export class Noise implements INoiseConnection { const { metrics } = components this.components = components - this.crypto = crypto ?? defaultCrypto + const _crypto = crypto ?? defaultCrypto + this.crypto = wrapCrypto(_crypto) this.extensions = extensions this.metrics = metrics ? registerMetrics(metrics) : undefined if (staticNoiseKey) { // accepts x25519 private key of length 32 - this.staticKeys = this.crypto.generateX25519KeyPairFromSeed(staticNoiseKey) + this.staticKey = _crypto.generateX25519KeyPairFromSeed(staticNoiseKey) } else { - this.staticKeys = this.crypto.generateX25519KeyPair() + this.staticKey = _crypto.generateX25519KeyPair() } this.prologue = prologueBytes ?? uint8ArrayAlloc(0) } @@ -79,12 +74,19 @@ export class Noise implements INoiseConnection { maxDataLength: NOISE_MSG_MAX_LENGTH_BYTES } ) - const handshake = await this.performHandshake({ - connection: wrappedConnection, - isInitiator: true, - localPeer, - remotePeer - }) + + if (!localPeer.privateKey) { + throw new CodeError('local peerId does not contain private key', 'ERR_NO_PRIVATE_KEY') + } + const privateKey = await unmarshalPrivateKey(localPeer.privateKey) + + const remoteIdentityKey = remotePeer?.publicKey + + const handshake = await this.performHandshakeInitiator( + wrappedConnection, + privateKey, + remoteIdentityKey + ) const conn = await this.createSecureConnection(wrappedConnection, handshake) connection.source = conn.source @@ -92,8 +94,8 @@ export class Noise implements INoiseConnection { return { conn: connection, - remoteExtensions: handshake.remoteExtensions, - remotePeer: handshake.remotePeer + remoteExtensions: handshake.payload.extensions, + remotePeer: await peerIdFromKeys(handshake.payload.identityKey) } } @@ -113,12 +115,19 @@ export class Noise implements INoiseConnection { maxDataLength: NOISE_MSG_MAX_LENGTH_BYTES } ) - const handshake = await this.performHandshake({ - connection: wrappedConnection, - isInitiator: false, - localPeer, - remotePeer - }) + + if (!localPeer.privateKey) { + throw new CodeError('local peerId does not contain private key', 'ERR_NO_PRIVATE_KEY') + } + const privateKey = await unmarshalPrivateKey(localPeer.privateKey) + + const remoteIdentityKey = remotePeer?.publicKey + + const handshake = await this.performHandshakeResponder( + wrappedConnection, + privateKey, + remoteIdentityKey + ) const conn = await this.createSecureConnection(wrappedConnection, handshake) connection.source = conn.source @@ -126,56 +135,74 @@ export class Noise implements INoiseConnection { return { conn: connection, - remotePeer: handshake.remotePeer, - remoteExtensions: handshake.remoteExtensions + remoteExtensions: handshake.payload.extensions, + remotePeer: await peerIdFromKeys(handshake.payload.identityKey) } } /** - * Perform XX handshake. + * Perform XX handshake as initiator. */ - private async performHandshake (params: HandshakeParams): Promise { - const payload = await getPayload(params.localPeer, this.staticKeys.publicKey, this.extensions) + private async performHandshakeInitiator ( + connection: LengthPrefixedStream, + // TODO: pass private key in noise constructor via Components + privateKey: PrivateKey, + remoteIdentityKey?: Uint8Array | Uint8ArrayList + ): Promise { + let result: HandshakeResult + try { + result = await performHandshakeInitiator({ + connection, + privateKey, + remoteIdentityKey, + log: this.components.logger.forComponent('libp2p:noise:xxhandshake'), + crypto: this.crypto, + prologue: this.prologue, + s: this.staticKey, + extensions: this.extensions + }) + this.metrics?.xxHandshakeSuccesses.increment() + } catch (e: unknown) { + this.metrics?.xxHandshakeErrors.increment() + throw e + } - // run XX handshake - return this.performXXHandshake(params, payload) + return result } - private async performXXHandshake ( - params: HandshakeParams, - payload: bytes - ): Promise { - const { isInitiator, remotePeer, connection } = params - const handshake = new XXHandshake( - this.components, - isInitiator, - payload, - this.prologue, - this.crypto, - this.staticKeys, - connection, - remotePeer - ) - + /** + * Perform XX handshake as responder. + */ + private async performHandshakeResponder ( + connection: LengthPrefixedStream, + // TODO: pass private key in noise constructor via Components + privateKey: PrivateKey, + remoteIdentityKey?: Uint8Array | Uint8ArrayList + ): Promise { + let result: HandshakeResult try { - await handshake.propose() - await handshake.exchange() - await handshake.finish() + result = await performHandshakeResponder({ + connection, + privateKey, + remoteIdentityKey, + log: this.components.logger.forComponent('libp2p:noise:xxhandshake'), + crypto: this.crypto, + prologue: this.prologue, + s: this.staticKey, + extensions: this.extensions + }) this.metrics?.xxHandshakeSuccesses.increment() } catch (e: unknown) { this.metrics?.xxHandshakeErrors.increment() - if (e instanceof Error) { - e.message = `Error occurred during XX handshake: ${e.message}` - throw e - } + throw e } - return handshake + return result } private async createSecureConnection ( connection: LengthPrefixedStream>>, - handshake: IHandshake + handshake: HandshakeResult ): Promise>> { // Create encryption box/unbox wrapper const [secure, user] = duplexPair() diff --git a/src/nonce.ts b/src/nonce.ts index 8180df9..d27b862 100644 --- a/src/nonce.ts +++ b/src/nonce.ts @@ -1,5 +1,4 @@ import { alloc as uint8ArrayAlloc } from 'uint8arrays/alloc' -import type { bytes, uint64 } from './types.js' export const MIN_NONCE = 0 // For performance reasons, the nonce is represented as a JS `number` @@ -17,8 +16,8 @@ const ERR_MAX_NONCE = 'Cipherstate has reached maximum n, a new handshake must b * Maintaining different representations help improve performance. */ export class Nonce { - private n: uint64 - private readonly bytes: bytes + private n: number + private readonly bytes: Uint8Array private readonly view: DataView constructor (n = MIN_NONCE) { @@ -34,11 +33,11 @@ export class Nonce { this.view.setUint32(4, this.n, true) } - getBytes (): bytes { + getBytes (): Uint8Array { return this.bytes } - getUint64 (): uint64 { + getUint64 (): number { return this.n } diff --git a/src/performHandshake.ts b/src/performHandshake.ts new file mode 100644 index 0000000..e0bfd53 --- /dev/null +++ b/src/performHandshake.ts @@ -0,0 +1,90 @@ +import { + logLocalStaticKeys, + logLocalEphemeralKeys, + logRemoteEphemeralKey, + logRemoteStaticKey, + logCipherState +} from './logger.js' +import { ZEROLEN, XXHandshakeState } from './protocol.js' +import { createHandshakePayload, decodeHandshakePayload } from './utils.js' +import type { HandshakeResult, HandshakeParams } from './types.js' + +export async function performHandshakeInitiator (init: HandshakeParams): Promise { + const { log, connection, crypto, privateKey, prologue, s, remoteIdentityKey, extensions } = init + + const payload = await createHandshakePayload(privateKey, s.publicKey, extensions) + const xx = new XXHandshakeState({ + crypto, + protocolName: 'Noise_XX_25519_ChaChaPoly_SHA256', + initiator: true, + prologue, + s + }) + + logLocalStaticKeys(xx.s, log) + log.trace('Stage 0 - Initiator starting to send first message.') + await connection.write(xx.writeMessageA(ZEROLEN)) + log.trace('Stage 0 - Initiator finished sending first message.') + logLocalEphemeralKeys(xx.e, log) + + log.trace('Stage 1 - Initiator waiting to receive first message from responder...') + const plaintext = xx.readMessageB(await connection.read()) + log.trace('Stage 1 - Initiator received the message.') + logRemoteEphemeralKey(xx.re, log) + logRemoteStaticKey(xx.rs, log) + + log.trace("Initiator going to check remote's signature...") + const receivedPayload = await decodeHandshakePayload(plaintext, xx.rs, remoteIdentityKey) + log.trace('All good with the signature!') + + log.trace('Stage 2 - Initiator sending third handshake message.') + await connection.write(xx.writeMessageC(payload)) + log.trace('Stage 2 - Initiator sent message with signed payload.') + + const [cs1, cs2] = xx.ss.split() + logCipherState(cs1, cs2, log) + + return { + payload: receivedPayload, + encrypt: (plaintext) => cs1.encryptWithAd(ZEROLEN, plaintext), + decrypt: (ciphertext, dst) => cs2.decryptWithAd(ZEROLEN, ciphertext, dst) + } +} + +export async function performHandshakeResponder (init: HandshakeParams): Promise { + const { log, connection, crypto, privateKey, prologue, s, remoteIdentityKey, extensions } = init + + const payload = await createHandshakePayload(privateKey, s.publicKey, extensions) + const xx = new XXHandshakeState({ + crypto, + protocolName: 'Noise_XX_25519_ChaChaPoly_SHA256', + initiator: false, + prologue, + s + }) + + logLocalStaticKeys(xx.s, log) + log.trace('Stage 0 - Responder waiting to receive first message.') + xx.readMessageA(await connection.read()) + log.trace('Stage 0 - Responder received first message.') + logRemoteEphemeralKey(xx.re, log) + + log.trace('Stage 1 - Responder sending out first message with signed payload and static key.') + await connection.write(xx.writeMessageB(payload)) + log.trace('Stage 1 - Responder sent the second handshake message with signed payload.') + logLocalEphemeralKeys(xx.e, log) + + log.trace('Stage 2 - Responder waiting for third handshake message...') + const plaintext = xx.readMessageC(await connection.read()) + log.trace('Stage 2 - Responder received the message, finished handshake.') + const receivedPayload = await decodeHandshakePayload(plaintext, xx.rs, remoteIdentityKey) + + const [cs1, cs2] = xx.ss.split() + logCipherState(cs1, cs2, log) + + return { + payload: receivedPayload, + encrypt: (plaintext) => cs2.encryptWithAd(ZEROLEN, plaintext), + decrypt: (ciphertext, dst) => cs1.decryptWithAd(ZEROLEN, ciphertext, dst) + } +} diff --git a/src/protocol.ts b/src/protocol.ts new file mode 100644 index 0000000..a71ac3b --- /dev/null +++ b/src/protocol.ts @@ -0,0 +1,313 @@ +import { Uint8ArrayList } from 'uint8arraylist' +import { fromString as uint8ArrayFromString } from 'uint8arrays' +import { alloc as uint8ArrayAlloc } from 'uint8arrays/alloc' +import { InvalidCryptoExchangeError } from './errors.js' +import { Nonce } from './nonce.js' +import type { ICipherState, ISymmetricState, IHandshakeState, KeyPair, ICrypto } from './types.js' + +// Code in this file is a direct translation of a subset of the noise protocol https://noiseprotocol.org/noise.html, +// agnostic to libp2p's usage of noise + +export const ZEROLEN = uint8ArrayAlloc(0) + +interface ICipherStateWithKey extends ICipherState { + k: Uint8Array +} + +export class CipherState implements ICipherState { + public k?: Uint8Array + public n: Nonce + private readonly crypto: ICrypto + + constructor (crypto: ICrypto, k: Uint8Array | undefined = undefined, n = 0) { + this.crypto = crypto + this.k = k + this.n = new Nonce(n) + } + + public hasKey (): this is ICipherStateWithKey { + return Boolean(this.k) + } + + public encryptWithAd (ad: Uint8Array, plaintext: Uint8Array | Uint8ArrayList): Uint8Array | Uint8ArrayList { + if (!this.hasKey()) { + return plaintext + } + + this.n.assertValue() + const e = this.crypto.encrypt(plaintext, this.n.getBytes(), ad, this.k) + this.n.increment() + + return e + } + + public decryptWithAd (ad: Uint8Array, ciphertext: Uint8Array | Uint8ArrayList, dst?: Uint8Array): Uint8Array | Uint8ArrayList { + if (!this.hasKey()) { + return ciphertext + } + + this.n.assertValue() + const plaintext = this.crypto.decrypt(ciphertext, this.n.getBytes(), ad, this.k, dst) + this.n.increment() + + return plaintext + } +} + +export class SymmetricState implements ISymmetricState { + public cs: CipherState + public ck: Uint8Array + public h: Uint8Array + private readonly crypto: ICrypto + + constructor (crypto: ICrypto, protocolName: string) { + this.crypto = crypto + + const protocolNameBytes = uint8ArrayFromString(protocolName, 'utf-8') + this.h = hashProtocolName(crypto, protocolNameBytes) + + this.ck = this.h + this.cs = new CipherState(crypto) + } + + public mixKey (ikm: Uint8Array): void { + const [ck, tempK] = this.crypto.hkdf(this.ck, ikm) + this.ck = ck + this.cs = new CipherState(this.crypto, tempK) + } + + public mixHash (data: Uint8Array | Uint8ArrayList): void { + this.h = this.crypto.hash(new Uint8ArrayList(this.h, data)) + } + + public encryptAndHash (plaintext: Uint8Array | Uint8ArrayList): Uint8Array | Uint8ArrayList { + const ciphertext = this.cs.encryptWithAd(this.h, plaintext) + this.mixHash(ciphertext) + return ciphertext + } + + public decryptAndHash (ciphertext: Uint8Array | Uint8ArrayList): Uint8Array | Uint8ArrayList { + const plaintext = this.cs.decryptWithAd(this.h, ciphertext) + this.mixHash(ciphertext) + return plaintext + } + + public split (): [CipherState, CipherState] { + const [tempK1, tempK2] = this.crypto.hkdf(this.ck, ZEROLEN) + return [new CipherState(this.crypto, tempK1), new CipherState(this.crypto, tempK2)] + } +} + +// const MESSAGE_PATTERNS = ['e', 's', 'ee', 'es', 'se', 'ss'] as const +// type MessagePattern = Array + +export interface HandshakeStateInit { + crypto: ICrypto + protocolName: string + initiator: boolean + prologue: Uint8Array + s?: KeyPair + e?: KeyPair + rs?: Uint8Array | Uint8ArrayList + re?: Uint8Array | Uint8ArrayList +} + +export abstract class AbstractHandshakeState implements IHandshakeState { + public ss: SymmetricState + public s?: KeyPair + public e?: KeyPair + public rs?: Uint8Array | Uint8ArrayList + public re?: Uint8Array | Uint8ArrayList + public initiator: boolean + protected readonly crypto: ICrypto + + constructor (init: HandshakeStateInit) { + const { crypto, protocolName, prologue, initiator, s, e, rs, re } = init + this.crypto = crypto + this.ss = new SymmetricState(crypto, protocolName) + this.ss.mixHash(prologue) + this.initiator = initiator + this.s = s + this.e = e + this.rs = rs + this.re = re + } + + protected writeE (): Uint8Array { + if (this.e) { + throw new Error('ephemeral keypair is already set') + } + const e = this.crypto.generateKeypair() + this.ss.mixHash(e.publicKey) + this.e = e + return e.publicKey + } + + protected writeS (): Uint8Array | Uint8ArrayList { + if (!this.s) { + throw new Error('static keypair is not set') + } + return this.ss.encryptAndHash(this.s.publicKey) + } + + protected writeEE (): void { + if (!this.e) { + throw new Error('ephemeral keypair is not set') + } + if (!this.re) { + throw new Error('remote ephemeral public key is not set') + } + this.ss.mixKey(this.crypto.dh(this.e, this.re)) + } + + protected writeES (): void { + if (this.initiator) { + if (!this.e) { + throw new Error('ephemeral keypair is not set') + } + if (!this.rs) { + throw new Error('remote static public key is not set') + } + this.ss.mixKey(this.crypto.dh(this.e, this.rs)) + } else { + if (!this.s) { + throw new Error('static keypair is not set') + } + if (!this.re) { + throw new Error('remote ephemeral public key is not set') + } + this.ss.mixKey(this.crypto.dh(this.s, this.re)) + } + } + + protected writeSE (): void { + if (this.initiator) { + if (!this.s) { + throw new Error('static keypair is not set') + } + if (!this.re) { + throw new Error('remote ephemeral public key is not set') + } + this.ss.mixKey(this.crypto.dh(this.s, this.re)) + } else { + if (!this.e) { + throw new Error('ephemeral keypair is not set') + } + if (!this.rs) { + throw new Error('remote static public key is not set') + } + this.ss.mixKey(this.crypto.dh(this.e, this.rs)) + } + } + + protected readE (message: Uint8ArrayList, offset = 0): void { + if (this.re) { + throw new Error('remote ephemeral public key is already set') + } + if (message.byteLength < offset + 32) { + throw new Error('message is not long enough') + } + this.re = message.sublist(offset, offset + 32) + this.ss.mixHash(this.re) + } + + protected readS (message: Uint8ArrayList, offset = 0): number { + if (this.rs) { + throw new Error('remote static public key is already set') + } + const cipherLength = 32 + (this.ss.cs.hasKey() ? 16 : 0) + if (message.byteLength < offset + cipherLength) { + throw new Error('message is not long enough') + } + const temp = message.sublist(offset, offset + cipherLength) + this.rs = this.ss.decryptAndHash(temp) + return cipherLength + } + + protected readEE (): void { + this.writeEE() + } + + protected readES (): void { + this.writeES() + } + + protected readSE (): void { + this.writeSE() + } +} + +/** + * A IHandshakeState that's optimized for the XX pattern + */ +export class XXHandshakeState extends AbstractHandshakeState { + // e + writeMessageA (payload: Uint8Array | Uint8ArrayList): Uint8Array | Uint8ArrayList { + return new Uint8ArrayList(this.writeE(), this.ss.encryptAndHash(payload)) + } + + // e, ee, s, es + writeMessageB (payload: Uint8Array | Uint8ArrayList): Uint8Array | Uint8ArrayList { + const e = this.writeE() + this.writeEE() + const encS = this.writeS() + this.writeES() + + return new Uint8ArrayList(e, encS, this.ss.encryptAndHash(payload)) + } + + // s, se + writeMessageC (payload: Uint8Array | Uint8ArrayList): Uint8Array | Uint8ArrayList { + const encS = this.writeS() + this.writeSE() + + return new Uint8ArrayList(encS, this.ss.encryptAndHash(payload)) + } + + // e + readMessageA (message: Uint8ArrayList): Uint8Array | Uint8ArrayList { + try { + this.readE(message) + + return this.ss.decryptAndHash(message.sublist(32)) + } catch (e) { + throw new InvalidCryptoExchangeError(`handshake stage 0 validation fail: ${(e as Error).message}`) + } + } + + // e, ee, s, es + readMessageB (message: Uint8ArrayList): Uint8Array | Uint8ArrayList { + try { + this.readE(message) + this.readEE() + const consumed = this.readS(message, 32) + this.readES() + + return this.ss.decryptAndHash(message.sublist(32 + consumed)) + } catch (e) { + throw new InvalidCryptoExchangeError(`handshake stage 1 validation fail: ${(e as Error).message}`) + } + } + + // s, se + readMessageC (message: Uint8ArrayList): Uint8Array | Uint8ArrayList { + try { + const consumed = this.readS(message) + this.readSE() + + return this.ss.decryptAndHash(message.sublist(consumed)) + } catch (e) { + throw new InvalidCryptoExchangeError(`handshake stage 2 validation fail: ${(e as Error).message}`) + } + } +} + +function hashProtocolName (crypto: ICrypto, protocolName: Uint8Array): Uint8Array { + if (protocolName.length <= 32) { + const h = uint8ArrayAlloc(32) + h.set(protocolName) + return h + } else { + return crypto.hash(protocolName) + } +} diff --git a/src/crypto/streaming.ts b/src/streaming.ts similarity index 66% rename from src/crypto/streaming.ts rename to src/streaming.ts index e1001bb..f7e2b7f 100644 --- a/src/crypto/streaming.ts +++ b/src/streaming.ts @@ -1,14 +1,14 @@ import { Uint8ArrayList } from 'uint8arraylist' -import { NOISE_MSG_MAX_LENGTH_BYTES, NOISE_MSG_MAX_LENGTH_BYTES_WITHOUT_TAG } from '../constants.js' -import { uint16BEEncode } from '../encoder.js' -import type { MetricsRegistry } from '../metrics.js' -import type { IHandshake } from '../types.js' +import { NOISE_MSG_MAX_LENGTH_BYTES, NOISE_MSG_MAX_LENGTH_BYTES_WITHOUT_TAG } from './constants.js' +import { uint16BEEncode } from './encoder.js' +import type { MetricsRegistry } from './metrics.js' +import type { HandshakeResult } from './types.js' import type { Transform } from 'it-stream-types' const CHACHA_TAG_LENGTH = 16 // Returns generator that encrypts payload from the user -export function encryptStream (handshake: IHandshake, metrics?: MetricsRegistry): Transform> { +export function encryptStream (handshake: HandshakeResult, metrics?: MetricsRegistry): Transform> { return async function * (source) { for await (const chunk of source) { for (let i = 0; i < chunk.length; i += NOISE_MSG_MAX_LENGTH_BYTES_WITHOUT_TAG) { @@ -20,9 +20,9 @@ export function encryptStream (handshake: IHandshake, metrics?: MetricsRegistry) let data: Uint8Array | Uint8ArrayList if (chunk instanceof Uint8Array) { - data = handshake.encrypt(chunk.subarray(i, end), handshake.session) + data = handshake.encrypt(chunk.subarray(i, end)) } else { - data = handshake.encrypt(chunk.sublist(i, end), handshake.session) + data = handshake.encrypt(chunk.sublist(i, end)) } metrics?.encryptedPackets.increment() @@ -34,7 +34,7 @@ export function encryptStream (handshake: IHandshake, metrics?: MetricsRegistry) } // Decrypt received payload to the user -export function decryptStream (handshake: IHandshake, metrics?: MetricsRegistry): Transform, AsyncGenerator> { +export function decryptStream (handshake: HandshakeResult, metrics?: MetricsRegistry): Transform, AsyncGenerator> { return async function * (source) { for await (const chunk of source) { for (let i = 0; i < chunk.length; i += NOISE_MSG_MAX_LENGTH_BYTES) { @@ -53,13 +53,14 @@ export function decryptStream (handshake: IHandshake, metrics?: MetricsRegistry) // this is ok because chacha20 reads bytes one by one and don't reread after that // it's also tested in https://github.com/ChainSafe/as-chacha20poly1305/pull/1/files#diff-25252846b58979dcaf4e41d47b3eadd7e4f335e7fb98da6c049b1f9cd011f381R48 const dst = chunk.subarray(i, end - CHACHA_TAG_LENGTH) - const { plaintext: decrypted, valid } = handshake.decrypt(encrypted, handshake.session, dst) - if (!valid) { + try { + const plaintext = handshake.decrypt(encrypted, dst) + metrics?.decryptedPackets.increment() + yield plaintext + } catch (e) { metrics?.decryptErrors.increment() - throw new Error('Failed to validate decrypted chunk') + throw e } - metrics?.decryptedPackets.increment() - yield decrypted } } } diff --git a/src/types.ts b/src/types.ts index 9b43cff..37feb73 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,37 +1,44 @@ import type { Nonce } from './nonce' -import type { NoiseExtensions } from './proto/payload' -import type { ConnectionEncrypter, PeerId } from '@libp2p/interface' +import type { NoiseExtensions, NoiseHandshakePayload } from './proto/payload' +import type { ConnectionEncrypter, Logger, PrivateKey } from '@libp2p/interface' +import type { LengthPrefixedStream } from 'it-length-prefixed-stream' import type { Uint8ArrayList } from 'uint8arraylist' -export type bytes = Uint8Array -export type bytes32 = Uint8Array -export type bytes16 = Uint8Array - -export type uint64 = number - -export interface IHandshake { - session: NoiseSession - remotePeer: PeerId - remoteExtensions: NoiseExtensions - encrypt(plaintext: Uint8Array | Uint8ArrayList, session: NoiseSession): Uint8Array | Uint8ArrayList - decrypt(ciphertext: Uint8Array | Uint8ArrayList, session: NoiseSession, dst?: Uint8Array): { plaintext: Uint8Array | Uint8ArrayList, valid: boolean } +/** Crypto functions defined by the noise protocol, abstracted from the underlying implementations */ +export interface ICrypto { + generateKeypair(): KeyPair + dh(keypair: KeyPair, publicKey: Uint8Array | Uint8ArrayList): Uint8Array + encrypt(plaintext: Uint8Array | Uint8ArrayList, nonce: Uint8Array, ad: Uint8Array, k: Uint8Array): Uint8ArrayList | Uint8Array + decrypt(ciphertext: Uint8Array | Uint8ArrayList, nonce: Uint8Array, ad: Uint8Array, k: Uint8Array, dst?: Uint8Array): Uint8ArrayList | Uint8Array + hash(data: Uint8Array | Uint8ArrayList): Uint8Array + hkdf(ck: Uint8Array, ikm: Uint8Array): [Uint8Array, Uint8Array, Uint8Array] } -export type Hkdf = [bytes, bytes, bytes] +export interface HandshakeParams { + log: Logger + connection: LengthPrefixedStream + crypto: ICrypto + privateKey: PrivateKey + prologue: Uint8Array + /** static keypair */ + s: KeyPair + remoteIdentityKey?: Uint8Array | Uint8ArrayList + extensions?: NoiseExtensions +} -export interface MessageBuffer { - ne: bytes32 - ns: Uint8Array | Uint8ArrayList - ciphertext: Uint8Array | Uint8ArrayList +export interface HandshakeResult { + payload: NoiseHandshakePayload + encrypt (plaintext: Uint8Array | Uint8ArrayList): Uint8Array | Uint8ArrayList + decrypt (ciphertext: Uint8Array | Uint8ArrayList, dst?: Uint8Array): Uint8Array | Uint8ArrayList } /** * A CipherState object contains k and n variables, which it uses to encrypt and decrypt ciphertexts. * During the handshake phase each party has a single CipherState, but during the transport phase each party has two CipherState objects: one for sending, and one for receiving. */ -export interface CipherState { +export interface ICipherState { /** A cipher key of 32 bytes (which may be empty). Empty is a special value which indicates k has not yet been initialized. */ - k: bytes32 + k?: Uint8Array /** * An 8-byte (64-bit) unsigned integer nonce. * @@ -45,54 +52,33 @@ export interface CipherState { * A SymmetricState object contains a CipherState plus ck and h variables. It is so-named because it encapsulates all the "symmetric crypto" used by Noise. * During the handshake phase each party has a single SymmetricState, which can be deleted once the handshake is finished. */ -export interface SymmetricState { - cs: CipherState +export interface ISymmetricState { + cs: ICipherState /** A chaining key of 32 bytes. */ - ck: bytes32 + ck: Uint8Array /** A hash output of 32 bytes. */ - h: bytes32 + h: Uint8Array } /** * A HandshakeState object contains a SymmetricState plus DH variables (s, e, rs, re) and a variable representing the handshake pattern. * During the handshake phase each party has a single HandshakeState, which can be deleted once the handshake is finished. */ -export interface HandshakeState { - ss: SymmetricState +export interface IHandshakeState { + ss: ISymmetricState /** The local static key pair */ - s: KeyPair + s?: KeyPair /** The local ephemeral key pair */ e?: KeyPair /** The remote party's static public key */ - rs: Uint8Array | Uint8ArrayList + rs?: Uint8Array | Uint8ArrayList /** The remote party's ephemeral public key */ - re: bytes32 -} - -export interface NoiseSession { - hs: HandshakeState - h?: bytes32 - cs1?: CipherState - cs2?: CipherState - mc: uint64 - i: boolean -} - -/** - * The Noise Protocol Framework caters for sending early data alongside handshake messages. We leverage this construct to transmit: - * - * 1. the libp2p identity key along with a signature, to authenticate each party to the other. - * 2. extensions used by the libp2p stack. - */ -export interface INoisePayload { - identityKey: bytes - identitySig: bytes - data: bytes + re?: Uint8Array | Uint8ArrayList } export interface KeyPair { - publicKey: bytes32 - privateKey: bytes32 + publicKey: Uint8Array + privateKey: Uint8Array } export interface INoiseConnection extends ConnectionEncrypter { } diff --git a/src/utils.ts b/src/utils.ts index 1a31112..4c75fb4 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,61 +1,58 @@ -import { unmarshalPublicKey, unmarshalPrivateKey } from '@libp2p/crypto/keys' -import { peerIdFromKeys } from '@libp2p/peer-id' -import { type Uint8ArrayList, isUint8ArrayList } from 'uint8arraylist' +import { unmarshalPublicKey } from '@libp2p/crypto/keys' +import { type Uint8ArrayList } from 'uint8arraylist' +import { equals, toString } from 'uint8arrays' import { concat as uint8ArrayConcat } from 'uint8arrays/concat' import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string' +import { UnexpectedPeerError } from './errors.js' import { type NoiseExtensions, NoiseHandshakePayload } from './proto/payload.js' -import type { bytes } from './types.js' -import type { PeerId } from '@libp2p/interface' +import type { PrivateKey } from '@libp2p/interface' -export async function getPayload ( - localPeer: PeerId, - staticPublicKey: bytes, +export async function createHandshakePayload ( + privateKey: PrivateKey, + staticPublicKey: Uint8Array | Uint8ArrayList, extensions?: NoiseExtensions -): Promise { - const signedPayload = await signPayload(localPeer, getHandshakePayload(staticPublicKey)) +): Promise { + const identitySig = await privateKey.sign(getSignaturePayload(staticPublicKey)) - if (localPeer.publicKey == null) { - throw new Error('PublicKey was missing from local PeerId') - } - - return createHandshakePayload( - localPeer.publicKey, - signedPayload, - extensions - ) -} - -export function createHandshakePayload ( - libp2pPublicKey: Uint8Array, - signedPayload: Uint8Array, - extensions?: NoiseExtensions -): bytes { return NoiseHandshakePayload.encode({ - identityKey: libp2pPublicKey, - identitySig: signedPayload, - extensions: extensions ?? { webtransportCerthashes: [] } - }).subarray() + identityKey: privateKey.public.bytes, + identitySig, + extensions + }) } -export async function signPayload (peerId: PeerId, payload: Uint8Array | Uint8ArrayList): Promise { - if (peerId.privateKey == null) { - throw new Error('PrivateKey was missing from PeerId') +export async function decodeHandshakePayload ( + payloadBytes: Uint8Array | Uint8ArrayList, + remoteStaticKey?: Uint8Array | Uint8ArrayList, + remoteIdentityKey?: Uint8Array | Uint8ArrayList +): Promise { + try { + const payload = NoiseHandshakePayload.decode(payloadBytes) + if (remoteIdentityKey) { + const remoteIdentityKeyBytes = remoteIdentityKey.subarray() + if (!equals(remoteIdentityKeyBytes, payload.identityKey)) { + throw new Error(`Payload identity key ${toString(payload.identityKey, 'hex')} does not match expected remote identity key ${toString(remoteIdentityKeyBytes, 'hex')}`) + } + } + + if (!remoteStaticKey) { + throw new Error('Remote static does not exist') + } + + const signaturePayload = getSignaturePayload(remoteStaticKey) + const publicKey = unmarshalPublicKey(payload.identityKey) + + if (!(await publicKey.verify(signaturePayload, payload.identitySig))) { + throw new Error('Invalid payload signature') + } + + return payload + } catch (e) { + throw new UnexpectedPeerError((e as Error).message) } - - const privateKey = await unmarshalPrivateKey(peerId.privateKey) - - return privateKey.sign(payload) -} - -export async function getPeerIdFromPayload (payload: NoiseHandshakePayload): Promise { - return peerIdFromKeys(payload.identityKey) -} - -export function decodePayload (payload: Uint8Array | Uint8ArrayList): NoiseHandshakePayload { - return NoiseHandshakePayload.decode(payload) } -export function getHandshakePayload (publicKey: Uint8Array | Uint8ArrayList): Uint8Array | Uint8ArrayList { +export function getSignaturePayload (publicKey: Uint8Array | Uint8ArrayList): Uint8Array | Uint8ArrayList { const prefix = uint8ArrayFromString('noise-libp2p-static-key:') if (publicKey instanceof Uint8Array) { @@ -66,54 +63,3 @@ export function getHandshakePayload (publicKey: Uint8Array | Uint8ArrayList): Ui return publicKey } - -/** - * Verifies signed payload, throws on any irregularities. - * - * @param noiseStaticKey - owner's noise static key - * @param payload - decoded payload - * @param remotePeer - owner's libp2p peer ID - * @returns peer ID of payload owner - */ -export async function verifySignedPayload ( - noiseStaticKey: Uint8Array | Uint8ArrayList, - payload: NoiseHandshakePayload, - remotePeer: PeerId -): Promise { - // Unmarshaling from PublicKey protobuf - const payloadPeerId = await peerIdFromKeys(payload.identityKey) - if (!payloadPeerId.equals(remotePeer)) { - throw new Error(`Payload identity key ${payloadPeerId.toString()} does not match expected remote peer ${remotePeer.toString()}`) - } - const generatedPayload = getHandshakePayload(noiseStaticKey) - - if (payloadPeerId.publicKey == null) { - throw new Error('PublicKey was missing from PeerId') - } - - if (payload.identitySig == null) { - throw new Error('Signature was missing from message') - } - - const publicKey = unmarshalPublicKey(payloadPeerId.publicKey) - - const valid = await publicKey.verify(generatedPayload, payload.identitySig) - - if (!valid) { - throw new Error("Static key doesn't match to peer that signed payload!") - } - - return payloadPeerId -} - -export function isValidPublicKey (pk: Uint8Array | Uint8ArrayList): boolean { - if (!(pk instanceof Uint8Array) && !(isUint8ArrayList(pk))) { - return false - } - - if (pk.byteLength !== 32) { - return false - } - - return true -} diff --git a/test/handshakes/xx.spec.ts b/test/handshakes/xx.spec.ts deleted file mode 100644 index eb06d36..0000000 --- a/test/handshakes/xx.spec.ts +++ /dev/null @@ -1,160 +0,0 @@ -import { Buffer } from 'buffer' -import { defaultLogger } from '@libp2p/logger' -import { expect, assert } from 'aegir/chai' -import { equals as uint8ArrayEquals } from 'uint8arrays/equals' -import { toString as uint8ArrayToString } from 'uint8arrays/to-string' -import { pureJsCrypto } from '../../src/crypto/js.js' -import { XX } from '../../src/handshakes/xx.js' -import { createHandshakePayload, getHandshakePayload } from '../../src/utils.js' -import { generateEd25519Keys } from '../utils.js' -import type { KeyPair, NoiseSession } from '../../src/types.js' - -describe('XX Handshake', () => { - const prologue = Buffer.alloc(0) - - it('Test creating new XX session', async () => { - try { - const xx = new XX({ logger: defaultLogger() }, pureJsCrypto) - - const kpInitiator: KeyPair = pureJsCrypto.generateX25519KeyPair() - - xx.initSession(true, prologue, kpInitiator) - } catch (e) { - const err = e as Error - assert(false, err.message) - } - }) - - it('Test get HKDF', () => { - const ckBytes = Buffer.from('4e6f6973655f58585f32353531395f58436861436861506f6c795f53484132353600000000000000000000000000000000000000000000000000000000000000', 'hex') - const ikm = Buffer.from('a3eae50ea37a47e8a7aa0c7cd8e16528670536dcd538cebfd724fb68ce44f1910ad898860666227d4e8dd50d22a9a64d1c0a6f47ace092510161e9e442953da3', 'hex') - const ck = Buffer.alloc(32) - ckBytes.copy(ck) - - const [k1, k2, k3] = pureJsCrypto.getHKDF(ck, ikm) - expect(uint8ArrayToString(k1, 'hex')).to.equal('cc5659adff12714982f806e2477a8d5ddd071def4c29bb38777b7e37046f6914') - expect(uint8ArrayToString(k2, 'hex')).to.equal('a16ada915e551ab623f38be674bb4ef15d428ae9d80688899c9ef9b62ef208fa') - expect(uint8ArrayToString(k3, 'hex')).to.equal('ff67bf9727e31b06efc203907e6786667d2c7a74ac412b4d31a80ba3fd766f68') - }) - - async function doHandshake (xx: XX): Promise<{ nsInit: NoiseSession, nsResp: NoiseSession }> { - const kpInit = pureJsCrypto.generateX25519KeyPair() - const kpResp = pureJsCrypto.generateX25519KeyPair() - - // initiator setup - const libp2pInitKeys = await generateEd25519Keys() - const initSignedPayload = await libp2pInitKeys.sign(getHandshakePayload(kpInit.publicKey)) - - // responder setup - const libp2pRespKeys = await generateEd25519Keys() - const respSignedPayload = await libp2pRespKeys.sign(getHandshakePayload(kpResp.publicKey)) - - // initiator: new XX noise session - const nsInit = xx.initSession(true, prologue, kpInit) - // responder: new XX noise session - const nsResp = xx.initSession(false, prologue, kpResp) - - /* STAGE 0 */ - - // initiator creates payload - libp2pInitKeys.marshal().slice(0, 32) - const libp2pInitPubKey = libp2pInitKeys.marshal().slice(32, 64) - - const payloadInitEnc = createHandshakePayload(libp2pInitPubKey, initSignedPayload) - - // initiator sends message - const message = Buffer.concat([Buffer.alloc(0), payloadInitEnc]) - const messageBuffer = xx.sendMessage(nsInit, message) - - expect(messageBuffer.ne.length).not.equal(0) - - // responder receives message - xx.recvMessage(nsResp, messageBuffer) - - /* STAGE 1 */ - - // responder creates payload - libp2pRespKeys.marshal().slice(0, 32) - const libp2pRespPubKey = libp2pRespKeys.marshal().slice(32, 64) - const payloadRespEnc = createHandshakePayload(libp2pRespPubKey, respSignedPayload) - - const message1 = Buffer.concat([message, payloadRespEnc]) - const messageBuffer2 = xx.sendMessage(nsResp, message1) - - expect(messageBuffer2.ne.length).not.equal(0) - expect(messageBuffer2.ns.length).not.equal(0) - - // initiator receive payload - xx.recvMessage(nsInit, messageBuffer2) - - /* STAGE 2 */ - - // initiator send message - const messageBuffer3 = xx.sendMessage(nsInit, Buffer.alloc(0)) - - // responder receive message - xx.recvMessage(nsResp, messageBuffer3) - - if (nsInit.cs1 == null || nsResp.cs1 == null || nsInit.cs2 == null || nsResp.cs2 == null) { - throw new Error('CipherState missing') - } - - assert(uint8ArrayEquals(nsInit.cs1.k, nsResp.cs1.k)) - assert(uint8ArrayEquals(nsInit.cs2.k, nsResp.cs2.k)) - - return { nsInit, nsResp } - } - - it('Test handshake', async () => { - try { - const xx = new XX({ logger: defaultLogger() }, pureJsCrypto) - await doHandshake(xx) - } catch (e) { - const err = e as Error - assert(false, err.message) - } - }) - - it('Test symmetric encrypt and decrypt', async () => { - try { - const xx = new XX({ logger: defaultLogger() }, pureJsCrypto) - const { nsInit, nsResp } = await doHandshake(xx) - const ad = Buffer.from('authenticated') - const message = Buffer.from('HelloCrypto') - - if (nsInit.cs1 == null || nsResp.cs1 == null || nsInit.cs2 == null || nsResp.cs2 == null) { - throw new Error('CipherState missing') - } - - const ciphertext = xx.encryptWithAd(nsInit.cs1, ad, message) - assert(!uint8ArrayEquals(Buffer.from('HelloCrypto'), ciphertext.subarray()), 'Encrypted message should not be same as plaintext.') - const { plaintext: decrypted, valid } = xx.decryptWithAd(nsResp.cs1, ad, ciphertext) - - assert(uint8ArrayEquals(Buffer.from('HelloCrypto'), decrypted.subarray()), 'Decrypted text not equal to original message.') - assert(valid) - } catch (e) { - const err = e as Error - assert(false, err.message) - } - }) - - it('Test multiple messages encryption and decryption', async () => { - const xx = new XX({ logger: defaultLogger() }, pureJsCrypto) - const { nsInit, nsResp } = await doHandshake(xx) - const ad = Buffer.from('authenticated') - const message = Buffer.from('ethereum1') - - if (nsInit.cs1 == null || nsResp.cs1 == null || nsInit.cs2 == null || nsResp.cs2 == null) { - throw new Error('CipherState missing') - } - - const encrypted = xx.encryptWithAd(nsInit.cs1, ad, message) - const { plaintext: decrypted } = xx.decryptWithAd(nsResp.cs1, ad, encrypted) - assert.equal('ethereum1', uint8ArrayToString(decrypted.subarray(), 'utf8'), 'Decrypted text not equal to original message.') - - const message2 = Buffer.from('ethereum2') - const encrypted2 = xx.encryptWithAd(nsInit.cs1, ad, message2) - const { plaintext: decrypted2 } = xx.decryptWithAd(nsResp.cs1, ad, encrypted2) - assert.equal('ethereum2', uint8ArrayToString(decrypted2.subarray(), 'utf-8'), 'Decrypted text not equal to original message.') - }) -}) diff --git a/test/noise.spec.ts b/test/noise.spec.ts index 25f0046..caeb7a3 100644 --- a/test/noise.spec.ts +++ b/test/noise.spec.ts @@ -7,17 +7,10 @@ import { lpStream } from 'it-length-prefixed-stream' import { duplexPair } from 'it-pair/duplex' import sinon from 'sinon' import { equals as uint8ArrayEquals } from 'uint8arrays/equals' -import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string' import { toString as uint8ArrayToString } from 'uint8arrays/to-string' -import { NOISE_MSG_MAX_LENGTH_BYTES } from '../src/constants.js' import { pureJsCrypto } from '../src/crypto/js.js' -import { decode0, decode2, encode1, uint16BEDecode, uint16BEEncode } from '../src/encoder.js' -import { XXHandshake } from '../src/handshake-xx.js' -import { XX } from '../src/handshakes/xx.js' import { Noise } from '../src/noise.js' -import { createHandshakePayload, getHandshakePayload, getPayload, signPayload } from '../src/utils.js' import { createPeerIdsFromFixtures } from './fixtures/peer.js' -import { getKeyPairFromPeerId } from './utils.js' import type { PeerId } from '@libp2p/interface/peer-id' import type { Uint8ArrayList } from 'uint8arraylist' @@ -55,59 +48,6 @@ describe('Noise', () => { } }) - it('should test that secureOutbound is spec compliant', async () => { - const noiseInit = new Noise({ logger: defaultLogger() }, { staticNoiseKey: undefined }) - const [inboundConnection, outboundConnection] = duplexPair() - - const [outbound, { wrapped, handshake }] = await Promise.all([ - noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), - (async () => { - const wrapped = lpStream( - inboundConnection, - { - lengthEncoder: uint16BEEncode, - lengthDecoder: uint16BEDecode, - maxDataLength: NOISE_MSG_MAX_LENGTH_BYTES - } - ) - const prologue = Buffer.alloc(0) - const staticKeys = pureJsCrypto.generateX25519KeyPair() - const xx = new XX({ logger: defaultLogger() }, pureJsCrypto) - - const payload = await getPayload(remotePeer, staticKeys.publicKey) - const handshake = new XXHandshake({ logger: defaultLogger() }, false, payload, prologue, pureJsCrypto, staticKeys, wrapped, localPeer, xx) - - let receivedMessageBuffer = decode0((await wrapped.read()).slice()) - // The first handshake message contains the initiator's ephemeral public key - expect(receivedMessageBuffer.ne.length).equal(32) - xx.recvMessage(handshake.session, receivedMessageBuffer) - - // Stage 1 - const { publicKey: libp2pPubKey } = getKeyPairFromPeerId(remotePeer) - const signedPayload = await signPayload(remotePeer, getHandshakePayload(staticKeys.publicKey).subarray()) - const handshakePayload = createHandshakePayload(libp2pPubKey, signedPayload) - - const messageBuffer = xx.sendMessage(handshake.session, handshakePayload) - await wrapped.write(encode1(messageBuffer)) - - // Stage 2 - finish handshake - receivedMessageBuffer = decode2((await wrapped.read()).slice()) - xx.recvMessage(handshake.session, receivedMessageBuffer) - return { wrapped, handshake } - })() - ]) - - const wrappedOutbound = byteStream(outbound.conn) - await wrappedOutbound.write(uint8ArrayFromString('test')) - - // Check that noise message is prefixed with 16-bit big-endian unsigned integer - const data = (await wrapped.read()).slice() - const { plaintext: decrypted, valid } = handshake.decrypt(data, handshake.session) - // Decrypted data should match - expect(uint8ArrayEquals(decrypted.subarray(), uint8ArrayFromString('test'))).to.be.true() - expect(valid).to.be.true() - }) - it('should test large payloads', async function () { this.timeout(10000) try { diff --git a/test/performHandshake.spec.ts b/test/performHandshake.spec.ts new file mode 100644 index 0000000..67bd2bd --- /dev/null +++ b/test/performHandshake.spec.ts @@ -0,0 +1,138 @@ +import { Buffer } from 'buffer' +import { unmarshalPrivateKey } from '@libp2p/crypto/keys' +import { defaultLogger } from '@libp2p/logger' +import { assert, expect } from 'aegir/chai' +import { lpStream } from 'it-length-prefixed-stream' +import { duplexPair } from 'it-pair/duplex' +import { toString as uint8ArrayToString } from 'uint8arrays' +import { equals as uint8ArrayEquals } from 'uint8arrays/equals' +import { defaultCrypto } from '../src/crypto/index.js' +import { wrapCrypto } from '../src/crypto.js' +import { performHandshakeInitiator, performHandshakeResponder } from '../src/performHandshake.js' +import { createPeerIdsFromFixtures } from './fixtures/peer.js' +import type { PrivateKey } from '@libp2p/interface' +import type { PeerId } from '@libp2p/interface/peer-id' + +describe('performHandshake', () => { + let peerA: PeerId, peerB: PeerId, fakePeer: PeerId + let privateKeyA: PrivateKey, privateKeyB: PrivateKey + + before(async () => { + [peerA, peerB, fakePeer] = await createPeerIdsFromFixtures(3) + if (!peerA.privateKey || !peerB.privateKey || !fakePeer.privateKey) throw new Error('unreachable') + privateKeyA = await unmarshalPrivateKey(peerA.privateKey) + privateKeyB = await unmarshalPrivateKey(peerB.privateKey) + }) + + it('should propose, exchange and finish handshake', async () => { + const duplex = duplexPair() + const connectionInitiator = lpStream(duplex[0]) + const connectionResponder = lpStream(duplex[1]) + + const prologue = Buffer.alloc(0) + const staticKeysInitiator = defaultCrypto.generateX25519KeyPair() + const staticKeysResponder = defaultCrypto.generateX25519KeyPair() + + const [initiator, responder] = await Promise.all([ + performHandshakeInitiator({ + log: defaultLogger().forComponent('test'), + connection: connectionInitiator, + crypto: wrapCrypto(defaultCrypto), + privateKey: privateKeyA, + prologue, + remoteIdentityKey: peerB.publicKey, + s: staticKeysInitiator + }), + performHandshakeResponder({ + log: defaultLogger().forComponent('test'), + connection: connectionResponder, + crypto: wrapCrypto(defaultCrypto), + privateKey: privateKeyB, + prologue, + remoteIdentityKey: peerA.publicKey, + s: staticKeysResponder + }) + ]) + + // Test encryption and decryption + const encrypted = initiator.encrypt(Buffer.from('encryptthis')) + const decrypted = responder.decrypt(encrypted) + assert(uint8ArrayEquals(decrypted.subarray(), Buffer.from('encryptthis'))) + }) + + it('Initiator should fail to exchange handshake if given wrong public key in payload', async () => { + try { + const duplex = duplexPair() + const connectionInitiator = lpStream(duplex[0]) + const connectionResponder = lpStream(duplex[1]) + + const prologue = Buffer.alloc(0) + const staticKeysInitiator = defaultCrypto.generateX25519KeyPair() + const staticKeysResponder = defaultCrypto.generateX25519KeyPair() + + await Promise.all([ + performHandshakeInitiator({ + log: defaultLogger().forComponent('test'), + connection: connectionInitiator, + crypto: wrapCrypto(defaultCrypto), + privateKey: privateKeyA, + prologue, + remoteIdentityKey: fakePeer.publicKey, // <----- look here + s: staticKeysInitiator + }), + performHandshakeResponder({ + log: defaultLogger().forComponent('test'), + connection: connectionResponder, + crypto: wrapCrypto(defaultCrypto), + privateKey: privateKeyB, + prologue, + remoteIdentityKey: peerA.publicKey, + s: staticKeysResponder + }) + ]) + + assert(false, 'Should throw exception') + } catch (e) { + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + expect((e as Error).message).equals(`Payload identity key ${uint8ArrayToString(peerB.publicKey!, 'hex')} does not match expected remote identity key ${uint8ArrayToString(fakePeer.publicKey!, 'hex')}`) + } + }) + + it('Responder should fail to exchange handshake if given wrong public key in payload', async () => { + try { + const duplex = duplexPair() + const connectionInitiator = lpStream(duplex[0]) + const connectionResponder = lpStream(duplex[1]) + + const prologue = Buffer.alloc(0) + const staticKeysInitiator = defaultCrypto.generateX25519KeyPair() + const staticKeysResponder = defaultCrypto.generateX25519KeyPair() + + await Promise.all([ + performHandshakeInitiator({ + log: defaultLogger().forComponent('test'), + connection: connectionInitiator, + crypto: wrapCrypto(defaultCrypto), + privateKey: privateKeyA, + prologue, + remoteIdentityKey: peerB.publicKey, + s: staticKeysInitiator + }), + performHandshakeResponder({ + log: defaultLogger().forComponent('test'), + connection: connectionResponder, + crypto: wrapCrypto(defaultCrypto), + privateKey: privateKeyB, + prologue, + remoteIdentityKey: fakePeer.publicKey, + s: staticKeysResponder + }) + ]) + + assert(false, 'Should throw exception') + } catch (e) { + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + expect((e as Error).message).equals(`Payload identity key ${uint8ArrayToString(peerA.publicKey!, 'hex')} does not match expected remote identity key ${uint8ArrayToString(fakePeer.publicKey!, 'hex')}`) + } + }) +}) diff --git a/test/protocol.spec.ts b/test/protocol.spec.ts new file mode 100644 index 0000000..d27cd09 --- /dev/null +++ b/test/protocol.spec.ts @@ -0,0 +1,112 @@ +import { Buffer } from 'buffer' +import { expect, assert } from 'aegir/chai' +import { Uint8ArrayList } from 'uint8arraylist' +import { equals as uint8ArrayEquals } from 'uint8arrays/equals' +import { toString as uint8ArrayToString } from 'uint8arrays/to-string' +import { pureJsCrypto } from '../src/crypto/js.js' +import { wrapCrypto } from '../src/crypto.js' +import { type CipherState, type SymmetricState, XXHandshakeState, ZEROLEN } from '../src/protocol.js' + +describe('XXHandshakeState', () => { + const prologue = Buffer.alloc(0) + const protocolName = 'Noise_XX_25519_ChaChaPoly_SHA256' + + it('Test creating new XX session', async () => { + try { + // eslint-disable-next-line no-new + new XXHandshakeState({ crypto: wrapCrypto(pureJsCrypto), protocolName, initiator: true, prologue }) + } catch (e) { + assert(false, (e as Error).message) + } + }) + + it('Test get HKDF', () => { + const ckBytes = Buffer.from('4e6f6973655f58585f32353531395f58436861436861506f6c795f53484132353600000000000000000000000000000000000000000000000000000000000000', 'hex') + const ikm = Buffer.from('a3eae50ea37a47e8a7aa0c7cd8e16528670536dcd538cebfd724fb68ce44f1910ad898860666227d4e8dd50d22a9a64d1c0a6f47ace092510161e9e442953da3', 'hex') + const ck = Buffer.alloc(32) + ckBytes.copy(ck) + + const [k1, k2, k3] = pureJsCrypto.getHKDF(ck, ikm) + expect(uint8ArrayToString(k1, 'hex')).to.equal('cc5659adff12714982f806e2477a8d5ddd071def4c29bb38777b7e37046f6914') + expect(uint8ArrayToString(k2, 'hex')).to.equal('a16ada915e551ab623f38be674bb4ef15d428ae9d80688899c9ef9b62ef208fa') + expect(uint8ArrayToString(k3, 'hex')).to.equal('ff67bf9727e31b06efc203907e6786667d2c7a74ac412b4d31a80ba3fd766f68') + }) + + interface ProtocolHandshakeResult { ss: SymmetricState, cs1: CipherState, cs2: CipherState } + async function doHandshake (): Promise<{ nsInit: ProtocolHandshakeResult, nsResp: ProtocolHandshakeResult }> { + const kpInit = pureJsCrypto.generateX25519KeyPair() + const kpResp = pureJsCrypto.generateX25519KeyPair() + + // initiator: new XX noise session + const nsInit = new XXHandshakeState({ crypto: wrapCrypto(pureJsCrypto), protocolName, prologue, initiator: true, s: kpInit }) + // responder: new XX noise session + const nsResp = new XXHandshakeState({ crypto: wrapCrypto(pureJsCrypto), protocolName, prologue, initiator: false, s: kpResp }) + + /* STAGE 0 */ + + // initiator sends message + // responder receives message + nsResp.readMessageA(new Uint8ArrayList(nsInit.writeMessageA(ZEROLEN))) + + /* STAGE 1 */ + + // responder sends message + // initiator receives message + nsInit.readMessageB(new Uint8ArrayList(nsResp.writeMessageB(ZEROLEN))) + + /* STAGE 2 */ + + // initiator sends message + // responder receives message + nsResp.readMessageC(new Uint8ArrayList(nsInit.writeMessageC(ZEROLEN))) + + const nsInitSplit = nsInit.ss.split() + const nsRespSplit = nsResp.ss.split() + + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + assert(uint8ArrayEquals(nsInitSplit[0].k!, nsRespSplit[0].k!)) + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + assert(uint8ArrayEquals(nsInitSplit[1].k!, nsRespSplit[1].k!)) + + return { + nsInit: { ss: nsInit.ss, cs1: nsInitSplit[0], cs2: nsInitSplit[1] }, + nsResp: { ss: nsResp.ss, cs1: nsRespSplit[0], cs2: nsRespSplit[1] } + } + } + + it('Test symmetric encrypt and decrypt', async () => { + try { + const { nsInit, nsResp } = await doHandshake() + const ad = Buffer.from('authenticated') + const message = Buffer.from('HelloCrypto') + + const ciphertext = nsInit.cs1.encryptWithAd(ad, message) + assert(!uint8ArrayEquals(Buffer.from('HelloCrypto'), ciphertext.subarray()), 'Encrypted message should not be same as plaintext.') + const decrypted = nsResp.cs1.decryptWithAd(ad, ciphertext) + + assert(uint8ArrayEquals(Buffer.from('HelloCrypto'), decrypted.subarray()), 'Decrypted text not equal to original message.') + } catch (e) { + assert(false, (e as Error).message) + } + }) + + it('Test multiple messages encryption and decryption', async () => { + const { nsInit, nsResp } = await doHandshake() + const ad = Buffer.from('authenticated') + + for (let i = 0; i < 50; i++) { + const strMessage = 'ethereum' + String(i) + const message = Buffer.from(strMessage) + { + const encrypted = nsInit.cs1.encryptWithAd(ad, message) + const decrypted = nsResp.cs1.decryptWithAd(ad, encrypted) + assert.equal(strMessage, uint8ArrayToString(decrypted.subarray(), 'utf8'), 'Decrypted text not equal to original message.') + } + { + const encrypted = nsResp.cs2.encryptWithAd(ad, message) + const decrypted = nsInit.cs2.decryptWithAd(ad, encrypted) + assert.equal(strMessage, uint8ArrayToString(decrypted.subarray(), 'utf8'), 'Decrypted text not equal to original message.') + } + } + }) +}) diff --git a/test/xx-handshake.spec.ts b/test/xx-handshake.spec.ts deleted file mode 100644 index 56d0958..0000000 --- a/test/xx-handshake.spec.ts +++ /dev/null @@ -1,143 +0,0 @@ -import { Buffer } from 'buffer' -import { defaultLogger } from '@libp2p/logger' -import { assert, expect } from 'aegir/chai' -import { lpStream } from 'it-length-prefixed-stream' -import { duplexPair } from 'it-pair/duplex' -import { equals as uint8ArrayEquals } from 'uint8arrays/equals' -import { defaultCrypto } from '../src/crypto/index.js' -import { XXHandshake } from '../src/handshake-xx.js' -import { getPayload } from '../src/utils.js' -import { createPeerIdsFromFixtures } from './fixtures/peer.js' -import type { PeerId } from '@libp2p/interface/peer-id' - -describe('XX Handshake', () => { - let peerA: PeerId, peerB: PeerId, fakePeer: PeerId - - before(async () => { - [peerA, peerB, fakePeer] = await createPeerIdsFromFixtures(3) - }) - - it('should propose, exchange and finish handshake', async () => { - try { - const duplex = duplexPair() - const connectionFrom = lpStream(duplex[0]) - const connectionTo = lpStream(duplex[1]) - - const prologue = Buffer.alloc(0) - const staticKeysInitiator = defaultCrypto.generateX25519KeyPair() - const staticKeysResponder = defaultCrypto.generateX25519KeyPair() - - const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey) - const handshakeInitiator = new XXHandshake({ logger: defaultLogger() }, true, initPayload, prologue, defaultCrypto, staticKeysInitiator, connectionFrom, peerB) - - const respPayload = await getPayload(peerB, staticKeysResponder.publicKey) - const handshakeResponder = new XXHandshake({ logger: defaultLogger() }, false, respPayload, prologue, defaultCrypto, staticKeysResponder, connectionTo, peerA) - - await Promise.all([ - handshakeInitiator.propose(), - handshakeResponder.propose() - ]) - - await Promise.all([ - handshakeResponder.exchange(), - handshakeInitiator.exchange() - ]) - - await Promise.all([ - handshakeInitiator.finish(), - handshakeResponder.finish() - ]) - - const sessionInitator = handshakeInitiator.session - const sessionResponder = handshakeResponder.session - - // Test shared key - if (sessionInitator.cs1 && sessionResponder.cs1 && sessionInitator.cs2 && sessionResponder.cs2) { - assert(uint8ArrayEquals(sessionInitator.cs1.k, sessionResponder.cs1.k)) - assert(uint8ArrayEquals(sessionInitator.cs2.k, sessionResponder.cs2.k)) - } else { - assert(false) - } - - // Test encryption and decryption - const encrypted = handshakeInitiator.encrypt(Buffer.from('encryptthis'), handshakeInitiator.session) - const { plaintext: decrypted, valid } = handshakeResponder.decrypt(encrypted, handshakeResponder.session) - assert(uint8ArrayEquals(decrypted.subarray(), Buffer.from('encryptthis'))) - assert(valid) - } catch (e) { - const err = e as Error - assert(false, err.message) - } - }) - - it('Initiator should fail to exchange handshake if given wrong public key in payload', async () => { - try { - const duplex = duplexPair() - const connectionFrom = lpStream(duplex[0]) - const connectionTo = lpStream(duplex[1]) - - const prologue = Buffer.alloc(0) - const staticKeysInitiator = defaultCrypto.generateX25519KeyPair() - const staticKeysResponder = defaultCrypto.generateX25519KeyPair() - - const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey) - const handshakeInitiator = new XXHandshake({ logger: defaultLogger() }, true, initPayload, prologue, defaultCrypto, staticKeysInitiator, connectionFrom, fakePeer) - - const respPayload = await getPayload(peerB, staticKeysResponder.publicKey) - const handshakeResponder = new XXHandshake({ logger: defaultLogger() }, false, respPayload, prologue, defaultCrypto, staticKeysResponder, connectionTo, peerA) - - await Promise.all([ - handshakeInitiator.propose(), - handshakeResponder.propose() - ]) - - await Promise.all([ - handshakeResponder.exchange(), - handshakeInitiator.exchange() - ]) - - assert(false, 'Should throw exception') - } catch (e) { - const err = e as Error - expect(err.message).equals(`Error occurred while verifying signed payload: Payload identity key ${peerB.toString()} does not match expected remote peer ${fakePeer.toString()}`) - } - }) - - it('Responder should fail to exchange handshake if given wrong public key in payload', async () => { - try { - const duplex = duplexPair() - const connectionFrom = lpStream(duplex[0]) - const connectionTo = lpStream(duplex[1]) - - const prologue = Buffer.alloc(0) - const staticKeysInitiator = defaultCrypto.generateX25519KeyPair() - const staticKeysResponder = defaultCrypto.generateX25519KeyPair() - - const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey) - const handshakeInitiator = new XXHandshake({ logger: defaultLogger() }, true, initPayload, prologue, defaultCrypto, staticKeysInitiator, connectionFrom, peerB) - - const respPayload = await getPayload(peerB, staticKeysResponder.publicKey) - const handshakeResponder = new XXHandshake({ logger: defaultLogger() }, false, respPayload, prologue, defaultCrypto, staticKeysResponder, connectionTo, fakePeer) - - await Promise.all([ - handshakeInitiator.propose(), - handshakeResponder.propose() - ]) - - await Promise.all([ - handshakeResponder.exchange(), - handshakeInitiator.exchange() - ]) - - await Promise.all([ - handshakeInitiator.finish(), - handshakeResponder.finish() - ]) - - assert(false, 'Should throw exception') - } catch (e) { - const err = e as Error - expect(err.message).equals(`Error occurred while verifying signed payload: Payload identity key ${peerA.toString()} does not match expected remote peer ${fakePeer.toString()}`) - } - }) -})