diff options
Diffstat (limited to 'crypto/ntru.c')
-rw-r--r-- | crypto/ntru.c | 1889 |
1 files changed, 1889 insertions, 0 deletions
diff --git a/crypto/ntru.c b/crypto/ntru.c new file mode 100644 index 00000000..edc57a91 --- /dev/null +++ b/crypto/ntru.c @@ -0,0 +1,1889 @@ +/* + * Implementation of OpenSSH 9.x's hybrid key exchange protocol + * sntrup761x25519-sha512@openssh.com . + * + * This consists of the 'Streamlined NTRU Prime' quantum-resistant + * cryptosystem, run in parallel with ordinary Curve25519 to generate + * a shared secret combining the output of both systems. + * + * (Hence, even if you don't trust this newfangled NTRU Prime thing at + * all, it's at least no _less_ secure than the kex you were using + * already.) + * + * References for the NTRU Prime cryptosystem, up to and including + * binary encodings of public and private keys and the exact preimages + * of the hashes used in key exchange: + * + * https://ntruprime.cr.yp.to/ + * https://ntruprime.cr.yp.to/nist/ntruprime-20201007.pdf + * + * The SSH protocol layer is not documented anywhere I could find (as + * of 2022-04-15, not even in OpenSSH's PROTOCOL.* files). I had to + * read OpenSSH's source code to find out how it worked, and the + * answer is as follows: + * + * This hybrid kex method is treated for SSH purposes as a form of + * elliptic-curve Diffie-Hellman, and shares the same SSH message + * sequence: client sends SSH2_MSG_KEX_ECDH_INIT containing its public + * half, server responds with SSH2_MSG_KEX_ECDH_REPLY containing _its_ + * public half plus the host key and signature on the shared secret. + * + * (This is a bit of a fudge, because unlike actual ECDH, this kex + * method is asymmetric: one side sends a public key, and the other + * side encrypts something with it and sends the ciphertext back. So + * while the normal ECDH implementations can compute the two sides + * independently in parallel, this system reusing the same messages + * has to be serial. But the order of the messages _is_ firmly + * specified in SSH ECDH, so it works anyway.) + * + * For this kex method, SSH2_MSG_KEX_ECDH_INIT still contains a single + * SSH 'string', which consists of the concatenation of a Streamlined + * NTRU Prime public key with the Curve25519 public value. (Both of + * these have fixed length in bytes, so there's no ambiguity in the + * concatenation.) + * + * SSH2_MSG_KEX_ECDH_REPLY is mostly the same as usual. The only + * string in the packet that varies is the second one, which would + * normally contain the server's public elliptic curve point. Instead, + * it now contains the concatenation of + * + * - a Streamlined NTRU Prime ciphertext + * - the 'confirmation hash' specified in ntruprime-20201007.pdf, + * hashing the plaintext of that ciphertext together with the + * public key + * - the Curve25519 public point as usual. + * + * Again, all three of those elements have fixed lengths. + * + * The client decrypts the ciphertext, checks the confirmation hash, + * and if successful, generates the 'session hash' specified in + * ntruprime-20201007.pdf, which is 32 bytes long and is the ultimate + * output of the Streamlined NTRU Prime key exchange. + * + * The output of the hybrid kex method as a whole is an SSH 'string' + * of length 64 containing the SHA-512 hash of the concatenatio of + * + * - the Streamlined NTRU Prime session hash (32 bytes) + * - the Curve25519 shared secret (32 bytes). + * + * That string is included directly into the SSH exchange hash and key + * derivation hashes, in place of the mpint that comes out of most + * other kex methods. + */ + +#include <stdio.h> +#include <stdlib.h> +#include <assert.h> + +#include "putty.h" +#include "ssh.h" +#include "mpint.h" +#include "ntru.h" + +/* ---------------------------------------------------------------------- + * Preliminaries: we're going to need to do modular arithmetic on + * small values (considerably smaller than 2^16), and we need to do it + * without using integer division which might not be time-safe. + * + * The strategy for this is the same as I used in + * mp_mod_known_integer: see there for the proofs. The basic idea is + * that we precompute the reciprocal of our modulus as a fixed-point + * number, and use that to get an approximate quotient which we + * subtract off. For these integer sizes, precomputing a fixed-point + * reciprocal of the form (2^48 / modulus) leaves us at most off by 1 + * in the quotient, so there's a single (time-safe) trial subtraction + * at the end. + * + * (It's possible that some speed could be gained by not reducing + * fully at every step. But then you'd have to carefully identify all + * the places in the algorithm where things are compared to zero. This + * was the easiest way to get it all working in the first place.) + */ + +/* Precompute the reciprocal */ +static uint64_t reciprocal_for_reduction(uint16_t q) +{ + return ((uint64_t)1 << 48) / q; +} + +/* Reduce x mod q, assuming qrecip == reciprocal_for_reduction(q) */ +static uint16_t reduce(uint32_t x, uint16_t q, uint64_t qrecip) +{ + uint64_t unshifted_quot = x * qrecip; + uint64_t quot = unshifted_quot >> 48; + uint16_t reduced = x - quot * q; + reduced -= q * (1 & ((q-1 - reduced) >> 15)); + return reduced; +} + +/* Reduce x mod q as above, but also return the quotient */ +static uint16_t reduce_with_quot(uint32_t x, uint32_t *quot_out, + uint16_t q, uint64_t qrecip) +{ + uint64_t unshifted_quot = x * qrecip; + uint64_t quot = unshifted_quot >> 48; + uint16_t reduced = x - quot * q; + uint64_t extraquot = (1 & ((q-1 - reduced) >> 15)); + reduced -= extraquot * q; + *quot_out = quot + extraquot; + return reduced; +} + +/* Invert x mod q, assuming it's nonzero. (For time-safety, no check + * is made for zero; it just returns 0.) */ +static uint16_t invert(uint16_t x, uint16_t q, uint64_t qrecip) +{ + /* Fermat inversion: compute x^(q-2), since x^(q-1) == 1. */ + uint32_t sq = x, bit = 1, acc = 1, exp = q-2; + while (1) { + if (exp & bit) { + acc = reduce(acc * sq, q, qrecip); + exp &= ~bit; + if (!exp) + return acc; + } + sq = reduce(sq * sq, q, qrecip); + bit <<= 1; + } +} + +/* Check whether x == 0, time-safely, and return 1 if it is or 0 otherwise. */ +static unsigned iszero(uint16_t x) +{ + return 1 & ~((x + 0xFFFF) >> 16); +} + +/* + * Handy macros to cut down on all those extra function parameters. In + * the common case where a function is working mod the same modulus + * throughout (and has called it q), you can just write 'SETUP;' at + * the top and then call REDUCE(...) and INVERT(...) without having to + * write out q and qrecip every time. + */ +#define SETUP uint64_t qrecip = reciprocal_for_reduction(q) +#define REDUCE(x) reduce(x, q, qrecip) +#define INVERT(x) invert(x, q, qrecip) + +/* ---------------------------------------------------------------------- + * Quotient-ring functions. + * + * NTRU Prime works with two similar but different quotient rings: + * + * Z_q[x] / <x^p-x-1> where p,q are the prime parameters of the system + * Z_3[x] / <x^p-x-1> with the same p, but coefficients mod 3. + * + * The former is a field (every nonzero element is invertible), + * because the system parameters are chosen such that x^p-x-1 is + * invertible over Z_q. The latter is not a field (or not necessarily, + * and in particular, not for the value of p we use here). + * + * In these core functions, you pass in the modulus you want as the + * parameter q, which is either the 'real' q specified in the system + * parameters, or 3 if you're doing one of the mod-3 parts of the + * algorithm. + */ + +/* + * Multiply two elements of a quotient ring. + * + * 'a' and 'b' are arrays of exactly p coefficients, with constant + * term first. 'out' is an array the same size to write the inverse + * into. + */ +void ntru_ring_multiply(uint16_t *out, const uint16_t *a, const uint16_t *b, + unsigned p, unsigned q) +{ + SETUP; + + /* + * Strategy: just compute the full product with 2p coefficients, + * and then reduce it mod x^p-x-1 by working downwards from the + * top coefficient replacing x^{p+k} with (x+1)x^k for k = ...,1,0. + * + * Possibly some speed could be gained here by doing the recursive + * Karatsuba optimisation for the initial multiplication? But I + * haven't tried it. + */ + uint32_t *unreduced = snewn(2*p, uint32_t); + for (unsigned i = 0; i < 2*p; i++) + unreduced[i] = 0; + for (unsigned i = 0; i < p; i++) + for (unsigned j = 0; j < p; j++) + unreduced[i+j] = REDUCE(unreduced[i+j] + a[i] * b[j]); + + for (unsigned i = 2*p - 1; i >= p; i--) { + unreduced[i-p] += unreduced[i]; + unreduced[i-p+1] += unreduced[i]; + unreduced[i] = 0; + } + + for (unsigned i = 0; i < p; i++) + out[i] = REDUCE(unreduced[i]); + + smemclr(unreduced, 2*p * sizeof(*unreduced)); + sfree(unreduced); +} + +/* + * Invert an element of the quotient ring. + * + * 'in' is an array of exactly p coefficients, with constant term + * first. 'out' is an array the same size to write the inverse into. + * + * Method: essentially Stein's gcd algorithm, taking the gcd of the + * input (regarded as an element of Z_q[x] proper) and x^p-x-1. Given + * two polynomials over a field which are not both divisible by x, you + * can find their gcd by iterating the following procedure: + * + * - if one is divisible by x, divide off x + * - otherwise, subtract from the higher-degree one whatever scalar + * multiple of the lower-degree one will make it divisible by x, + * and _then_ divide off x + * + * Neither of these types of step changes the gcd of the two + * polynomials. + * + * Each step reduces the sum of the two polynomials' degree by at + * least one, as long as at least one of the degrees is positive. + * (Maybe more than one if all the stars align in the second case, if + * the subtraction cancels the leading term as well as the constant + * term.) So in at most deg A + deg B steps, we must have reached the + * situation where both polys are constants; in one more step after + * that, one of them will be zero; and in one step after _that_, the + * zero one will reliably be the one we're dividing by x. Or rather, + * that's what happens in the case where A,B are coprime; if not, then + * one hits zero while the other is still nonzero. + * + * In a normal gcd algorithm, you'd track a linear combination of the + * two original polynomials that yields each working value, and end up + * with a linear combination of the inputs that yields the gcd. In + * this algorithm, the 'divide off x' step makes that awkward - but we + * can solve that by instead multiplying by the inverse of x in the + * ring that we want our answer to be valid in! And since the modulus + * polynomial of the ring is x^p-x-1, the inverse of x is easy to + * calculate, because it's always just x^{p-1} - 1, which is also very + * easy to multiply by. + */ +unsigned ntru_ring_invert(uint16_t *out, const uint16_t *in, + unsigned p, unsigned q) +{ + SETUP; + + /* Size of the polynomial arrays we'll work with */ + const size_t SIZE = p+1; + + /* Number of steps of the algorithm is the max possible value of + * deg A + deg B + 2, where deg A <= p-1 and deg B = p */ + const size_t STEPS = 2*p + 1; + + /* Our two working polynomials */ + uint16_t *A = snewn(SIZE, uint16_t); + uint16_t *B = snewn(SIZE, uint16_t); + + /* Coefficient of the input value in each one */ + uint16_t *Ac = snewn(SIZE, uint16_t); + uint16_t *Bc = snewn(SIZE, uint16_t); + + /* Initialise A to the input, and Ac correspondingly to 1 */ + memcpy(A, in, p*sizeof(uint16_t)); + A[p] = 0; + Ac[0] = 1; + for (size_t i = 1; i < SIZE; i++) + Ac[i] = 0; + + /* Initialise B to the quotient polynomial of the ring, x^p-x-1 + * And Bc = 0 */ + B[0] = B[1] = q-1; + for (size_t i = 2; i < p; i++) + B[i] = 0; + B[p] = 1; + for (size_t i = 0; i < SIZE; i++) + Bc[i] = 0; + + /* Run the gcd-finding algorithm. */ + for (size_t i = 0; i < STEPS; i++) { + /* + * First swap round so that A is the one we'll be dividing by x. + * + * In the case where one of the two polys has a zero constant + * term, it's that one. In the other case, it's the one of + * smaller degree. We must compute both, and choose between + * them in a side-channel-safe way. + */ + unsigned x_divides_A = iszero(A[0]); + unsigned x_divides_B = iszero(B[0]); + unsigned B_is_bigger = 0; + { + unsigned not_seen_top_term_of_A = 1, not_seen_top_term_of_B = 1; + for (size_t j = SIZE; j-- > 0 ;) { + not_seen_top_term_of_A &= iszero(A[j]); + not_seen_top_term_of_B &= iszero(B[j]); + B_is_bigger |= (~not_seen_top_term_of_B & + not_seen_top_term_of_A); + } + } + unsigned need_swap = x_divides_B | (~x_divides_A & B_is_bigger); + uint16_t swap_mask = -need_swap; + for (size_t j = 0; j < SIZE; j++) { + uint16_t diff = (A[j] ^ B[j]) & swap_mask; + A[j] ^= diff; + B[j] ^= diff; + } + for (size_t j = 0; j < SIZE; j++) { + uint16_t diff = (Ac[j] ^ Bc[j]) & swap_mask; + Ac[j] ^= diff; + Bc[j] ^= diff; + } + + /* + * Replace A with a linear combination of both A and B that + * has constant term zero, which we do by calculating + * + * (constant term of B) * A - (constant term of A) * B + * + * In one of the two cases, A's constant term is already zero, + * so the coefficient of B will be zero too; hence, this will + * do nothing useful (it will merely scale A by some scalar + * value), but it will take the same length of time as doing + * something, which is just what we want. + */ + uint16_t Amult = B[0], Bmult = q - A[0]; + for (size_t j = 0; j < SIZE; j++) + A[j] = REDUCE(Amult * A[j] + Bmult * B[j]); + /* And do the same transformation to Ac */ + for (size_t j = 0; j < SIZE; j++) + Ac[j] = REDUCE(Amult * Ac[j] + Bmult * Bc[j]); + + /* + * Now divide A by x, and compensate by multiplying Ac by + * x^{p-1}-1 mod x^p-x-1. + * + * That multiplication is particularly easy, precisely because + * x^{p-1}-1 is the multiplicative inverse of x! Each x^n term + * for n>0 just moves down to the x^{n-1} term, and only the + * constant term has to be dealt with in an interesting way. + */ + for (size_t j = 1; j < SIZE; j++) + A[j-1] = A[j]; + A[SIZE-1] = 0; + uint16_t Ac0 = Ac[0]; + for (size_t j = 1; j < p; j++) + Ac[j-1] = Ac[j]; + Ac[p-1] = Ac0; + Ac[0] = REDUCE(Ac[0] + q - Ac0); + } + + /* + * Now we expect that A is 0, and B is a constant. If so, then + * they are coprime, and we're going to return success. If not, + * they have a common factor. + */ + unsigned success = iszero(A[0]) & (1 ^ iszero(B[0])); + for (size_t j = 1; j < SIZE; j++) + success &= iszero(A[j]) & iszero(B[j]); + + /* + * So we're going to return Bc, but first, scale it by the + * multiplicative inverse of the constant we ended up with in + * B[0]. + */ + uint16_t scale = INVERT(B[0]); + for (size_t i = 0; i < p; i++) + out[i] = REDUCE(scale * Bc[i]); + + smemclr(A, SIZE * sizeof(*A)); + sfree(A); + smemclr(B, SIZE * sizeof(*B)); + sfree(B); + smemclr(Ac, SIZE * sizeof(*Ac)); + sfree(Ac); + smemclr(Bc, SIZE * sizeof(*Bc)); + sfree(Bc); + + return success; +} + +/* + * Given an array of values mod q, convert each one to its + * minimum-absolute-value representative, and then reduce mod 3. + * + * Output values are 0, 1 and 0xFFFF, representing -1. + * + * (Normally our arrays of uint16_t are in 'minimal non-negative + * residue' form, so the output of this function is unusual. But it's + * useful to have it in this form so that it can be reused by + * ntru_round3. You can put it back to the usual representation using + * ntru_normalise, below.) + */ +void ntru_mod3(uint16_t *out, const uint16_t *in, unsigned p, unsigned q) +{ + uint64_t qrecip = reciprocal_for_reduction(q); + uint64_t recip3 = reciprocal_for_reduction(3); + + unsigned bias = q/2; + uint16_t adjust = 3 - reduce(bias-1, 3, recip3); + + for (unsigned i = 0; i < p; i++) { + uint16_t val = reduce(in[i] + bias, q, qrecip); + uint16_t residue = reduce(val + adjust, 3, recip3); + out[i] = residue - 1; + } +} + +/* + * Given an array of values mod q, round each one to the nearest + * multiple of 3 to its minimum-absolute-value representative. + * + * Output values are signed integers coerced to uint16_t, so again, + * use ntru_normalise afterwards to put them back to normal. + */ +void ntru_round3(uint16_t *out, const uint16_t *in, unsigned p, unsigned q) +{ + SETUP; + unsigned bias = q/2; + ntru_mod3(out, in, p, q); + for (unsigned i = 0; i < p; i++) + out[i] = REDUCE(in[i] + bias) - bias - out[i]; +} + +/* + * Given an array of signed integers coerced to uint16_t in the range + * [-q/2,+q/2], normalise them back to mod q values. + */ +static void ntru_normalise(uint16_t *out, const uint16_t *in, + unsigned p, unsigned q) +{ + for (unsigned i = 0; i < p; i++) + out[i] = in[i] + q * (in[i] >> 15); +} + +/* + * Given an array of values mod q, add a constant to each one. + */ +void ntru_bias(uint16_t *out, const uint16_t *in, unsigned bias, + unsigned p, unsigned q) +{ + SETUP; + for (unsigned i = 0; i < p; i++) + out[i] = REDUCE(in[i] + bias); +} + +/* + * Given an array of values mod q, multiply each one by a constant. + */ +void ntru_scale(uint16_t *out, const uint16_t *in, uint16_t scale, + unsigned p, unsigned q) +{ + SETUP; + for (unsigned i = 0; i < p; i++) + out[i] = REDUCE(in[i] * scale); +} + +/* + * Given an array of values mod 3, convert them to values mod q in a + * way that maps -1,0,+1 to -1,0,+1. + */ +static void ntru_expand( + uint16_t *out, const uint16_t *in, unsigned p, unsigned q) +{ + for (size_t i = 0; i < p; i++) { + uint16_t v = in[i]; + /* Map 2 to q-1, and leave 0 and 1 unchanged */ + v += (v >> 1) * (q-3); + out[i] = v; + } +} + +/* ---------------------------------------------------------------------- + * Implement the binary encoding from ntruprime-20201007.pdf, which is + * used to encode public keys and ciphertexts (though not plaintexts, + * which are done in a much simpler way). + * + * The general idea is that your encoder takes as input a list of + * small non-negative integers (r_i), and a sequence of limits (m_i) + * such that 0 <= r_i < m_i, and emits a sequence of bytes that encode + * all of these as tightly as reasonably possible. + * + * That's more general than is really needed, because in both the + * actual uses of this encoding, the input m_i are all the same! But + * the array of (r_i,m_i) pairs evolves during encoding, so they don't + * _stay_ all the same, so you still have to have all the generality. + * + * The encoding process makes a number of passes along the list of + * inputs. In each step, pairs of adjacent numbers are combined into + * one larger one by turning (r_i,m_i) and (r_{i+1},m_{i+1}) into the + * pair (r_i + m_i r_{i+1}, m_i m_{i+1}), i.e. so that the original + * numbers could be recovered by taking the quotient and remaiinder of + * the new r value by m_i. Then, if the new m_i is at least 2^14, we + * emit the low 8 bits of r_i to the output stream and reduce r_i and + * its limit correspondingly. So at the end of the pass, we've got + * half as many numbers still to encode, they're all still not too + * big, and we've emitted some amount of data into the output. Then do + * another pass, keep going until there's only one number left, and + * emit it little-endian. + * + * That's all very well, but how do you decode it again? DJB exhibits + * a pair of recursive functions that are supposed to be mutually + * inverse, but I didn't have any confidence that I'd be able to debug + * them sensibly if they turned out not to be (or rather, if I + * implemented one of them wrong). So I came up with my own strategy + * instead. + * + * In my strategy, we start by processing just the (m_i) into an + * 'encoding schedule' consisting of a sequence of simple + * instructions. The instructions operate on a FIFO queue of numbers, + * initialised to the original (r_i). The three instruction types are: + * + * - 'COMBINE': consume two numbers a,b from the head of the queue, + * combine them by calculating a + m*b for some specified m, and + * push the result on the tail of the queue. + * + * - 'BYTE': divide the tail element of the queue by 2^8 and emit the + * low bits into the output stream. + * + * - 'COPY': pop a number from the head of the queue and push it + * straight back on the tail. (Used for handling the leftover + * element at the end of a pass if the input to the pass was a list + * of odd length.) + * + * So we effectively implement DJB's encoding process in simulation, + * and instead of actually processing a set of (r_i), we 'compile' the + * process into a sequence of instructions that can be handed just the + * (r_i) later and encode them in the right way. At the end of the + * instructions, the queue is expected to have been reduced to length + * 1 and contain the single integer 0. + * + * The nice thing about this system is that each of those three + * instructions is easy to reverse. So you can also use the same + * instructions for decoding: start with a queue containing 0, and + * process the instructions in reverse order and reverse sense. So + * BYTE means to _consume_ a byte from the encoded data (starting from + * the rightmost end) and use it to make a queue element bigger; and + * COMBINE run in reverse pops a single element from one end of the + * queue, divides it by m, and pushes the quotient and remainder on + * the other end. + * + * (So it's easy to debug, because the queue passes through the exact + * same sequence of states during decoding that it did during + * encoding, just in reverse order.) + * + * Also, the encoding schedule comes with information about the + * expected size of the encoded data, because you can find that out + * easily by just counting the BYTE commands. + */ + +enum { + /* + * Command values appearing in the 'ops' array. ENC_COPY and + * ENC_BYTE are single values; values of the form + * (ENC_COMBINE_BASE + m) represent a COMBINE command with + * parameter m. + */ + ENC_COPY, ENC_BYTE, ENC_COMBINE_BASE +}; +struct NTRUEncodeSchedule { + /* + * Object representing a compiled set of encoding instructions. + * + * 'nvals' is the number of r_i we expect to encode. 'nops' is the + * number of encoding commands in the 'ops' list; 'opsize' is the + * physical size of the array, used during construction. + * + * 'endpos' is used to avoid a last-minute faff during decoding. + * We implement our FIFO of integers as a ring buffer of size + * 'nvals'. Encoding cycles round it some number of times, and the + * final 0 element ends up at some random location in the array. + * If we know _where_ the 0 ends up during encoding, we can put + * the initial 0 there at the start of decoding, and then when we + * finish reversing all the instructions, we'll end up with the + * output numbers already arranged at their correct positions, so + * that there's no need to rotate the array at the last minute. + */ + size_t nvals, endpos, nops, opsize; + uint32_t *ops; +}; +static inline void sched_append(NTRUEncodeSchedule *sched, uint16_t op) +{ + /* Helper function to append an operation to the schedule, and + * update endpos. */ + sgrowarray(sched->ops, sched->opsize, sched->nops); + sched->ops[sched->nops++] = op; + if (op != ENC_BYTE) + sched->endpos = (sched->endpos + 1) % sched->nvals; +} + +/* + * Take in the list of limit values (m_i) and compute the encoding + * schedule. + */ +NTRUEncodeSchedule *ntru_encode_schedule(const uint16_t *ms_in, size_t n) +{ + NTRUEncodeSchedule *sched = snew(NTRUEncodeSchedule); + sched->nvals = n; + sched->endpos = n-1; + sched->nops = sched->opsize = 0; + sched->ops = NULL; + + assert(n != 0); + + /* + * 'ms' is the list of (m_i) on input to the current pass. + * 'ms_new' is the list output from the current pass. After each + * pass we swap the arrays round. + */ + uint32_t *ms = snewn(n, uint32_t); + uint32_t *msnew = snewn(n, uint32_t); + for (size_t i = 0; i < n; i++) + ms[i] = ms_in[i]; + + while (n > 1) { + size_t nnew = 0; + for (size_t i = 0; i < n; i += 2) { + if (i+1 == n) { + /* + * Odd element at the end of the input list: just copy + * it unchanged to the output. + */ + sched_append(sched, ENC_COPY); + msnew[nnew++] = ms[i]; + break; + } + + /* + * Normal case: consume two elements from the input list + * and combine them. + */ + uint32_t m1 = ms[i], m2 = ms[i+1], m = m1*m2; + sched_append(sched, ENC_COMBINE_BASE + m1); + + /* + * And then, as long as the combined limit is big enough, + * emit an output byte from the bottom of it. + */ + while (m >= (1<<14)) { + sched_append(sched, ENC_BYTE); + m = (m + 0xFF) >> 8; + } + + /* + * Whatever is left after that, we emit into the output + * list and append to the fifo. + */ + msnew[nnew++] = m; + } + + /* + * End of pass. The output list of (m_i) now becomes the input + * list. + */ + uint32_t *tmp = ms; + ms = msnew; + n = nnew; + msnew = tmp; + } + + /* + * When that loop terminates, it's because there's exactly one + * number left to encode. (Or, technically, _at most_ one - but we + * don't support encoding a completely empty list in this + * implementation, because what would be the point?) That number + * is just emitted little-endian until its limit is 1 (meaning its + * only possible actual value is 0). + */ + assert(n == 1); + uint32_t m = ms[0]; + while (m > 1) { + sched_append(sched, ENC_BYTE); + m = (m + 0xFF) >> 8; + } + + sfree(ms); + sfree(msnew); + + return sched; +} + +void ntru_encode_schedule_free(NTRUEncodeSchedule *sched) +{ + sfree(sched->ops); + sfree(sched); +} + +/* + * Calculate the output length of the encoded data in bytes. + */ +size_t ntru_encode_schedule_length(NTRUEncodeSchedule *sched) +{ + size_t len = 0; + for (size_t i = 0; i < sched->nops; i++) + if (sched->ops[i] == ENC_BYTE) + len++; + return len; +} + +/* + * Retrieve the number of items encoded. (Used by testcrypt.) + */ +size_t ntru_encode_schedule_nvals(NTRUEncodeSchedule *sched) +{ + return sched->nvals; +} + +/* + * Actually encode a sequence of (r_i), emitting the output bytes to + * an arbitrary BinarySink. + */ +void ntru_encode(NTRUEncodeSchedule *sched, const uint16_t *rs_in, + BinarySink *bs) +{ + size_t n = sched->nvals; + uint32_t *rs = snewn(n, uint32_t); + for (size_t i = 0; i < n; i++) + rs[i] = rs_in[i]; + + /* + * The head and tail pointers of the queue are both 'full'. That + * is, rs[head] is the first element actually in the queue, and + * rs[tail] is the last element. + * + * So you append to the queue by first advancing 'tail' and then + * writing to rs[tail], whereas you consume from the queue by + * first reading rs[head] and _then_ advancing 'head'. + * + * The more normal thing would be to make 'tail' point to the + * first empty slot instead of the last full one. But then you'd + * have to faff about with modular arithmetic to find the last + * full slot for the BYTE command, so in this case, it's easier to + * do it the less usual way. + */ + size_t head = 0, tail = n-1; + + for (size_t i = 0; i < sched->nops; i++) { + uint16_t op = sched->ops[i]; + switch (op) { + case ENC_BYTE: + put_byte(bs, rs[tail] & 0xFF); + rs[tail] >>= 8; + break; + case ENC_COPY: { + uint32_t r = rs[head]; + head = (head + 1) % n; + tail = (tail + 1) % n; + rs[tail] = r; + break; + } + default: { + uint32_t r1 = rs[head]; + head = (head + 1) % n; + uint32_t r2 = rs[head]; + head = (head + 1) % n; + tail = (tail + 1) % n; + rs[tail] = r1 + (op - ENC_COMBINE_BASE) * r2; + break; + } + } + } + + /* + * Expect that we've ended up with a single zero in the queue, at + * exactly the position that the setup-time analysis predicted it. + */ + assert(head == sched->endpos); + assert(tail == sched->endpos); + assert(rs[head] == 0); + + smemclr(rs, n * sizeof(*rs)); + sfree(rs); +} + +/* + * Decode a ptrlen of binary data into a sequence of (r_i). The data + * is expected to be of exactly the right length (on pain of assertion + * failure). + */ +void ntru_decode(NTRUEncodeSchedule *sched, uint16_t *rs_out, ptrlen data) +{ + size_t n = sched->nvals; + const uint8_t *base = (const uint8_t *)data.ptr; + const uint8_t *pos = base + data.len; + + /* + * Initialise the queue to a single zero, at the 'endpos' position + * that will mean the final output is correctly aligned. + * + * 'head' and 'tail' have the same meanings as in encoding. So + * 'tail' is the location that BYTE modifies and COPY and COMBINE + * consume from, and 'head' is the location that COPY and COMBINE + * push on to. As in encoding, they both point at the extremal + * full slots in the array. + */ + uint32_t *rs = snewn(n, uint32_t); + size_t head = sched->endpos, tail = head; + rs[tail] = 0; + + for (size_t i = sched->nops; i-- > 0 ;) { + uint16_t op = sched->ops[i]; + switch (op) { + case ENC_BYTE: { + assert(pos > base); + uint8_t byte = *--pos; + rs[tail] = (rs[tail] << 8) | byte; + break; + } + case ENC_COPY: { + uint32_t r = rs[tail]; + tail = (tail + n - 1) % n; + head = (head + n - 1) % n; + rs[head] = r; + break; + } + default: { + uint32_t r = rs[tail]; + tail = (tail + n - 1) % n; + + uint32_t m = op - ENC_COMBINE_BASE; + uint64_t mrecip = reciprocal_for_reduction(m); + + uint32_t r1, r2; + r1 = reduce_with_quot(r, &r2, m, mrecip); + + head = (head + n - 1) % n; + rs[head] = r2; + head = (head + n - 1) % n; + rs[head] = r1; + break; + } + } + } + + assert(pos == base); + assert(head == 0); + assert(tail == n-1); + + for (size_t i = 0; i < n; i++) + rs_out[i] = rs[i]; + smemclr(rs, n * sizeof(*rs)); + sfree(rs); +} + +/* ---------------------------------------------------------------------- + * The actual public-key cryptosystem. + */ + +struct NTRUKeyPair { + unsigned p, q, w; + uint16_t *h; /* public key */ + uint16_t *f3, *ginv; /* private key */ + uint16_t *rho; /* for implicit rejection */ +}; + +/* Helper function to free an array of uint16_t containing a ring + * element, clearing it on the way since some of them are sensitive. */ +static void ring_free(uint16_t *val, unsigned p) +{ + smemclr(val, p*sizeof(*val)); + sfree(val); +} + +void ntru_keypair_free(NTRUKeyPair *keypair) +{ + ring_free(keypair->h, keypair->p); + ring_free(keypair->f3, keypair->p); + ring_free(keypair->ginv, keypair->p); + ring_free(keypair->rho, keypair->p); + sfree(keypair); +} + +/* Trivial accessors used by test programs. */ +unsigned ntru_keypair_p(NTRUKeyPair *keypair) { return keypair->p; } +const uint16_t *ntru_pubkey(NTRUKeyPair *keypair) { return keypair->h; } + +/* + * Generate a value of the class DJB describes as 'Short': it consists + * of p terms that are all either 0 or +1 or -1, and exactly w of them + * are not zero. + * + * Values of this kind are used for several purposes: part of the + * private key, a plaintext, and the 'rho' fake-plaintext value used + * for deliberately returning a duff but non-revealing session hash if + * things go wrong. + * + * -1 is represented as 2 in the output array. So if you want these + * numbers mod 3, then they come out already in the right form. + * Otherwise, use ntru_expand. + */ +void ntru_gen_short(uint16_t *v, unsigned p, unsigned w) +{ + /* + * Get enough random data to generate a polynomial all of whose p + * terms are in {0,+1,-1}, and exactly w of them are nonzero. + * We'll do this by making up a completely random sequence of + * {+1,-1} and then setting a random subset of them to 0. + * + * So we'll need p random bits to choose the nonzero values, and + * then (doing it the simplest way) log2(p!) bits to shuffle them, + * plus say 128 bits to ensure any fluctuations in uniformity are + * negligible. + * + * log2(p!) is a pain to calculate, so we'll bound it above by + * p*log2(p), which we bound in turn by p*16. + */ + size_t randbitpos = 17 * p + 128; + mp_int *randdata = mp_resize(mp_random_bits(randbitpos), randbitpos + 32); + + /* + * Initial value before zeroing out some terms: p randomly chosen + * values in {1,2}. + */ + for (size_t i = 0; i < p; i++) + v[i] = 1 + mp_get_bit(randdata, --randbitpos); + + /* + * Hereafter we're going to extract random bits by multiplication, + * treating randdata as a large fixed-point number. + */ + mp_reduce_mod_2to(randdata, randbitpos); + + /* + * Zero out some terms, leaving a randomly selected w of them + * nonzero. + */ + uint32_t nonzeros_left = w; + mp_int *x = mp_new(64); + for (size_t i = p; i-- > 0 ;) { + /* + * Pick a random number out of the number of terms remaning. + */ + mp_mul_integer_into(randdata, randdata, i+1); + mp_rshift_fixed_into(x, randdata, randbitpos); + mp_reduce_mod_2to(randdata, randbitpos); + size_t j = mp_get_integer(x); + + /* + * If that's less than nonzeros_left, then we're leaving this + * number nonzero. Otherwise we're zeroing it out. + */ + uint32_t keep = (uint32_t)(j - nonzeros_left) >> 31; + v[i] &= -keep; /* clear this field if keep == 0 */ + nonzeros_left -= keep; /* decrement counter if keep == 1 */ + } + + mp_free(x); + mp_free(randdata); +} + +/* + * Make a single attempt at generating a key pair. This involves + * inventing random elements of both our quotient rings and hoping + * they're both invertible. + * + * They may not be, if you're unlucky. The element of Z_q/<x^p-x-1> + * will _almost_ certainly be invertible, because that is a field, so + * invertibility can only fail if you were so unlucky as to choose the + * all-0s element. But the element of Z_3/<x^p-x-1> may fail to be + * invertible because it has a common factor with x^p-x-1 (which, over + * Z_3, is not irreducible). + * + * So we can't guarantee to generate a key pair in constant time, + * because there's no predicting how many retries we'll need. However, + * this isn't a failure of side-channel safety, because we completely + * discard all the random numbers and state from each failed attempt. + * So if there were a side-channel leakage from a failure, the only + * thing it would give away would be a bunch of random numbers that + * turned out not to be used anyway. + * + * But a _successful_ call to this function should execute in a + * secret-independent manner, and this 'make a single attempt' + * function is exposed in the API so that 'testsc' can check that. + */ +NTRUKeyPair *ntru_keygen_attempt(unsigned p, unsigned q, unsigned w) +{ + /* + * First invent g, which is the one more likely to fail to invert. + * This is simply a uniformly random polynomial with p terms over + * Z_3. So we need p*log2(3) random bits for it, plus 128 for + * uniformity. It's easiest to bound log2(3) above by 2. + */ + size_t randbitpos = 2 * p + 128; + mp_int *randdata = mp_resize(mp_random_bits(randbitpos), randbitpos + 32); + + /* + * Select p random values from {0,1,2}. + */ + uint16_t *g = snewn(p, uint16_t); + mp_int *x = mp_new(64); + for (size_t i = 0; i < p; i++) { + mp_mul_integer_into(randdata, randdata, 3); + mp_rshift_fixed_into(x, randdata, randbitpos); + mp_reduce_mod_2to(randdata, randbitpos); + g[i] = mp_get_integer(x); + } + mp_free(x); + mp_free(randdata); + + /* + * Try to invert g over Z_3, and fail if it isn't invertible. + */ + uint16_t *ginv = snewn(p, uint16_t); + if (!ntru_ring_invert(ginv, g, p, 3)) { + ring_free(g, p); + ring_free(ginv, p); + return NULL; + } + + /* + * Fine; we have g. Now make up an f, and convert it to a + * polynomial over q. + */ + uint16_t *f = snewn(p, uint16_t); + ntru_gen_short(f, p, w); + ntru_expand(f, f, p, q); + + /* + * Multiply f by 3. + */ + uint16_t *f3 = snewn(p, uint16_t); + ntru_scale(f3, f, 3, p, q); + + /* + * Try to invert 3*f over Z_q. This should be _almost_ guaranteed + * to succeed, since Z_q/<x^p-x-1> is a field, so the only + * non-invertible value is 0. Even so, there _is_ one, so check + * the return value! + */ + uint16_t *f3inv = snewn(p, uint16_t); + if (!ntru_ring_invert(f3inv, f3, p, q)) { + ring_free(f, p); + ring_free(f3, p); + ring_free(f3inv, p); + ring_free(g, p); + ring_free(ginv, p); + return NULL; + } + + /* + * Make the public key, by converting g to a polynomial over q and + * then multiplying by f3inv. + */ + uint16_t *g_q = snewn(p, uint16_t); + ntru_expand(g_q, g, p, q); + uint16_t *h = snewn(p, uint16_t); + ntru_ring_multiply(h, g_q, f3inv, p, q); + + /* + * Make up rho, used to substitute for the plaintext in the + * session hash in case of confirmation failure. + */ + uint16_t *rho = snewn(p, uint16_t); + ntru_gen_short(rho, p, w); + + /* + * And we're done! Free everything except the pieces we're + * returning. + */ + NTRUKeyPair *keypair = snew(NTRUKeyPair); + keypair->p = p; + keypair->q = q; + keypair->w = w; + keypair->h = h; + keypair->f3 = f3; + keypair->ginv = ginv; + keypair->rho = rho; + ring_free(f, p); + ring_free(f3inv, p); + ring_free(g, p); + ring_free(g_q, p); + return keypair; +} + +/* + * The top-level key generation function for real use (as opposed to + * testsc): keep trying to make a key until you succeed. + */ +NTRUKeyPair *ntru_keygen(unsigned p, unsigned q, unsigned w) +{ + while (1) { + NTRUKeyPair *keypair = ntru_keygen_attempt(p, q, w); + if (keypair) + return keypair; + } +} + +/* + * Public-key encryption. + */ +void ntru_encrypt(uint16_t *ciphertext, const uint16_t *plaintext, + uint16_t *pubkey, unsigned p, unsigned q) +{ + uint16_t *r_q = snewn(p, uint16_t); + ntru_expand(r_q, plaintext, p, q); + + uint16_t *unrounded = snewn(p, uint16_t); + ntru_ring_multiply(unrounded, r_q, pubkey, p, q); + + ntru_round3(ciphertext, unrounded, p, q); + ntru_normalise(ciphertext, ciphertext, p, q); + + ring_free(r_q, p); + ring_free(unrounded, p); +} + +/* + * Public-key decryption. + */ +void ntru_decrypt(uint16_t *plaintext, const uint16_t *ciphertext, + NTRUKeyPair *keypair) +{ + unsigned p = keypair->p, q = keypair->q, w = keypair->w; + uint16_t *tmp = snewn(p, uint16_t); + + ntru_ring_multiply(tmp, ciphertext, keypair->f3, p, q); + + ntru_mod3(tmp, tmp, p, q); + ntru_normalise(tmp, tmp, p, 3); + + ntru_ring_multiply(plaintext, tmp, keypair->ginv, p, 3); + ring_free(tmp, p); + + /* + * With luck, this should have recovered exactly the original + * plaintext. But, as per the spec, we check whether it has + * exactly w nonzero coefficients, and if not, then something has + * gone wrong - and in that situation we time-safely substitute a + * different output. + * + * (I don't know exactly why we do this, but I assume it's because + * otherwise the mis-decoded output could be made to disgorge a + * secret about the private key in some way.) + */ + + unsigned weight = p; + for (size_t i = 0; i < p; i++) + weight -= iszero(plaintext[i]); + unsigned ok = iszero(weight ^ w); + + /* + * The default failure return value consists of w 1s followed by + * 0s. + */ + unsigned mask = ok - 1; + for (size_t i = 0; i < w; i++) { + uint16_t diff = (1 ^ plaintext[i]) & mask; + plaintext[i] ^= diff; + } + for (size_t i = w; i < p; i++) { + uint16_t diff = (0 ^ plaintext[i]) & mask; + plaintext[i] ^= diff; + } +} + +/* ---------------------------------------------------------------------- + * Encode and decode public keys, ciphertexts and plaintexts. + * + * Public keys and ciphertexts use the complicated binary encoding + * system implemented above. In both cases, the inputs are regarded as + * symmetric about zero, and are first biased to map their most + * negative permitted value to 0, so that they become non-negative and + * hence suitable as inputs to the encoding system. In the case of a + * ciphertext, where the input coefficients have also been coerced to + * be multiples of 3, we divide by 3 as well, saving space by reducing + * the upper bounds (m_i) on all the encoded numbers. + */ + +/* + * Compute the encoding schedule for a public key. + */ +static NTRUEncodeSchedule *ntru_encode_pubkey_schedule(unsigned p, unsigned q) +{ + uint16_t *ms = snewn(p, uint16_t); + for (size_t i = 0; i < p; i++) + ms[i] = q; + NTRUEncodeSchedule *sched = ntru_encode_schedule(ms, p); + sfree(ms); + return sched; +} + +/* + * Encode a public key. + */ +void ntru_encode_pubkey(const uint16_t *pubkey, unsigned p, unsigned q, + BinarySink *bs) +{ + /* Compute the biased version for encoding */ + uint16_t *biased_pubkey = snewn(p, uint16_t); + ntru_bias(biased_pubkey, pubkey, q / 2, p, q); + + /* Encode it */ + NTRUEncodeSchedule *sched = ntru_encode_pubkey_schedule(p, q); + ntru_encode(sched, biased_pubkey, bs); + ntru_encode_schedule_free(sched); + + ring_free(biased_pubkey, p); +} + +/* + * Decode a public key and write it into 'pubkey'. We also return a + * ptrlen pointing at the chunk of data we removed from the + * BinarySource. + */ +ptrlen ntru_decode_pubkey(uint16_t *pubkey, unsigned p, unsigned q, + BinarySource *src) +{ + NTRUEncodeSchedule *sched = ntru_encode_pubkey_schedule(p, q); + + /* Retrieve the right number of bytes from the source */ + size_t len = ntru_encode_schedule_length(sched); + ptrlen encoded = get_data(src, len); + if (get_err(src)) { + /* If there wasn't enough data, give up and return all-zeroes + * purely for determinism. But that value should never be + * used, because the caller will also check get_err(src). */ + memset(pubkey, 0, p*sizeof(*pubkey)); + } else { + /* Do the decoding */ + ntru_decode(sched, pubkey, encoded); + + /* Unbias the coefficients */ + ntru_bias(pubkey, pubkey, q-q/2, p, q); + } + + ntru_encode_schedule_free(sched); + return encoded; +} + +/* + * For ciphertext biasing: work out the largest absolute value a + * ciphertext element can take, which is given by taking q/2 and + * rounding it to the nearest multiple of 3. + */ +static inline unsigned ciphertext_bias(unsigned q) +{ + return (q/2+1) / 3; +} + +/* + * The number of possible values of a ciphertext coefficient (for use + * as the m_i in encoding) ranges from +ciphertext_bias(q) to + * -ciphertext_bias(q) inclusive. + */ +static inline unsigned ciphertext_m(unsigned q) +{ + return 1 + 2 * ciphertext_bias(q); +} + +/* + * Compute the encoding schedule for a ciphertext. + */ +static NTRUEncodeSchedule *ntru_encode_ciphertext_schedule( + unsigned p, unsigned q) +{ + unsigned m = ciphertext_m(q); + uint16_t *ms = snewn(p, uint16_t); + for (size_t i = 0; i < p; i++) + ms[i] = m; + NTRUEncodeSchedule *sched = ntru_encode_schedule(ms, p); + sfree(ms); + return sched; +} + +/* + * Encode a ciphertext. + */ +void ntru_encode_ciphertext(const uint16_t *ciphertext, unsigned p, unsigned q, + BinarySink *bs) +{ + SETUP; + + /* + * Bias the ciphertext, and scale down by 1/3, which we do by + * modular multiplication by the inverse of 3 mod q. (That only + * works if we know the inputs are all _exact_ multiples of 3 + * - but we do!) + */ + uint16_t *biased_ciphertext = snewn(p, uint16_t); + ntru_bias(biased_ciphertext, ciphertext, 3 * ciphertext_bias(q), p, q); + ntru_scale(biased_ciphertext, biased_ciphertext, INVERT(3), p, q); + + /* Encode. */ + NTRUEncodeSchedule *sched = ntru_encode_ciphertext_schedule(p, q); + ntru_encode(sched, biased_ciphertext, bs); + ntru_encode_schedule_free(sched); + + ring_free(biased_ciphertext, p); +} + +ptrlen ntru_decode_ciphertext(uint16_t *ct, NTRUKeyPair *keypair, + BinarySource *src) +{ + unsigned p = keypair->p, q = keypair->q; + + NTRUEncodeSchedule *sched = ntru_encode_ciphertext_schedule(p, q); + + /* Retrieve the right number of bytes from the source */ + size_t len = ntru_encode_schedule_length(sched); + ptrlen encoded = get_data(src, len); + if (get_err(src)) { + /* As above, return deterministic nonsense on failure */ + memset(ct, 0, p*sizeof(*ct)); + } else { + /* Do the decoding */ + ntru_decode(sched, ct, encoded); + + /* Undo the scaling and bias */ + ntru_scale(ct, ct, 3, p, q); + ntru_bias(ct, ct, q - 3 * ciphertext_bias(q), p, q); + } + + ntru_encode_schedule_free(sched); + return encoded; /* also useful to the caller, optionally */ +} + +/* + * Encode a plaintext. + * + * This is a much simpler encoding than the NTRUEncodeSchedule system: + * since elements of a plaintext are mod 3, we just encode each one in + * 2 bits, applying the usual bias so that {-1,0,+1} map to {0,1,2} + * respectively. + * + * There's no corresponding decode function, because plaintexts are + * never transmitted on the wire (the whole point is that they're too + * secret!). Plaintexts are only encoded in order to put them into + * hash preimages. + */ +void ntru_encode_plaintext(const uint16_t *plaintext, unsigned p, + BinarySink *bs) +{ + unsigned byte = 0, bitpos = 0; + for (size_t i = 0; i < p; i++) { + unsigned encoding = (plaintext[i] + 1) * iszero(plaintext[i] >> 1); + byte |= encoding << bitpos; + bitpos += 2; + if (bitpos == 8 || i+1 == p) { + put_byte(bs, byte); + byte = 0; + bitpos = 0; + } + } +} + +/* ---------------------------------------------------------------------- + * Compute the hashes required by the key exchange layer of NTRU Prime. + * + * There are two of these. The 'confirmation hash' is sent by the + * server along with the ciphertext, and the client can recalculate it + * to check whether the ciphertext was decrypted correctly. Then, the + * 'session hash' is the actual output of key exchange, and if the + * confirmation hash doesn't match, it gets deliberately corrupted. + */ + +/* + * Make the confirmation hash, whose inputs are the plaintext and the + * public key. + * + * This is defined as H(2 || H(3 || r) || H(4 || K)), where r is the + * plaintext and K is the public key (as encoded by the above + * functions), and the constants 2,3,4 are single bytes. The choice of + * hash function (H itself) is SHA-512 truncated to 256 bits. + * + * (To be clear: that is _not_ the thing that FIPS 180-4 6.7 defines + * as "SHA-512/256", which varies the initialisation vector of the + * SHA-512 algorithm as well as truncating the output. _This_ + * algorithm uses the standard SHA-512 IV, and _just_ truncates the + * output, in the manner suggested by FIPS 180-4 section 7.) + * + * 'out' should therefore expect to receive 32 bytes of data. + */ +static void ntru_confirmation_hash( + uint8_t *out, const uint16_t *plaintext, + const uint16_t *pubkey, unsigned p, unsigned q) +{ + /* The outer hash object */ + ssh_hash *hconfirm = ssh_hash_new(&ssh_sha512); + put_byte(hconfirm, 2); /* initial byte 2 */ + + uint8_t hashdata[64]; + + /* Compute H(3 || r) and add it to the main hash */ + ssh_hash *h3r = ssh_hash_new(&ssh_sha512); + put_byte(h3r, 3); + ntru_encode_plaintext(plaintext, p, BinarySink_UPCAST(h3r)); + ssh_hash_final(h3r, hashdata); + put_data(hconfirm, hashdata, 32); + + /* Compute H(4 || K) and add it to the main hash */ + ssh_hash *h4K = ssh_hash_new(&ssh_sha512); + put_byte(h4K, 4); + ntru_encode_pubkey(pubkey, p, q, BinarySink_UPCAST(h4K)); + ssh_hash_final(h4K, hashdata); + put_data(hconfirm, hashdata, 32); + + /* Compute the full output of the main SHA-512 hash */ + ssh_hash_final(hconfirm, hashdata); + + /* And copy the first 32 bytes into the caller's output array */ + memcpy(out, hashdata, 32); + smemclr(hashdata, sizeof(hashdata)); +} + +/* + * Make the session hash, whose inputs are the plaintext, the + * ciphertext, and the confirmation hash (hence, transitively, a + * dependence on the public key as well). + * + * As computed by the server, and by the client if the confirmation + * hash matched, this is defined as + * + * H(1 || H(3 || r) || ciphertext || confirmation hash) + * + * but if the confirmation hash _didn't_ match, then the plaintext r + * is replaced with the dummy plaintext-shaped value 'rho' we invented + * during key generation (presumably to avoid leaking any information + * about our secrets), and the initial byte 1 is replaced with 0 (to + * ensure that the resulting hash preimage can't match any legitimate + * preimage). So in that case, you instead get + * + * H(0 || H(3 || rho) || ciphertext || confirmation hash) + * + * The inputs to this function include 'ok', which is the value to use + * as the initial byte (1 on success, 0 on failure), and 'plaintext' + * which should already have been substituted with rho in case of + * failure. + * + * The ciphertext is provided in already-encoded form. + */ +static void ntru_session_hash( + uint8_t *out, unsigned ok, const uint16_t *plaintext, + unsigned p, ptrlen ciphertext, ptrlen confirmation_hash) +{ + /* The outer hash object */ + ssh_hash *hsession = ssh_hash_new(&ssh_sha512); + put_byte(hsession, ok); /* initial byte 1 or 0 */ + + uint8_t hashdata[64]; + + /* Compute H(3 || r), or maybe H(3 || rho), and add it to the main hash */ + ssh_hash *h3r = ssh_hash_new(&ssh_sha512); + put_byte(h3r, 3); + ntru_encode_plaintext(plaintext, p, BinarySink_UPCAST(h3r)); + ssh_hash_final(h3r, hashdata); + put_data(hsession, hashdata, 32); + + /* Put the ciphertext and confirmation hash in */ + put_datapl(hsession, ciphertext); + put_datapl(hsession, confirmation_hash); + + /* Compute the full output of the main SHA-512 hash */ + ssh_hash_final(hsession, hashdata); + + /* And copy the first 32 bytes into the caller's output array */ + memcpy(out, hashdata, 32); + smemclr(hashdata, sizeof(hashdata)); +} + +/* ---------------------------------------------------------------------- + * Top-level key exchange and SSH integration. + * + * Although this system borrows the ECDH packet structure, it's unlike + * true ECDH in that it is completely asymmetric between client and + * server. So we have two separate vtables of methods for the two + * sides of the system, and a third vtable containing only the class + * methods, in particular a constructor which chooses which one to + * instantiate. + */ + +/* + * The parameters p,q,w for the system. There are other choices of + * these, but OpenSSH only specifies this set. (If that ever changes, + * we'll need to turn these into elements of the state structures.) + */ +#define p_LIVE 761 +#define q_LIVE 4591 +#define w_LIVE 286 + +static char *ssh_ntru_description(const ssh_kex *kex) +{ + return dupprintf("NTRU Prime / Curve25519 hybrid key exchange"); +} + +/* + * State structure for the client, which takes the role of inventing a + * key pair and decrypting a secret plaintext sent to it by the server. + */ +typedef struct ntru_client_key { + NTRUKeyPair *keypair; + ecdh_key *curve25519; + + ecdh_key ek; +} ntru_client_key; + +static void ssh_ntru_client_free(ecdh_key *dh); +static void ssh_ntru_client_getpublic(ecdh_key *dh, BinarySink *bs); +static bool ssh_ntru_client_getkey(ecdh_key *dh, ptrlen remoteKey, + BinarySink *bs); + +static const ecdh_keyalg ssh_ntru_client_vt = { + /* This vtable has no 'new' method, because it's constructed via + * the selector vt below */ + .free = ssh_ntru_client_free, + .getpublic = ssh_ntru_client_getpublic, + .getkey = ssh_ntru_client_getkey, + .description = ssh_ntru_description, +}; + +static ecdh_key *ssh_ntru_client_new(void) +{ + ntru_client_key *nk = snew(ntru_client_key); + nk->ek.vt = &ssh_ntru_client_vt; + + nk->keypair = ntru_keygen(p_LIVE, q_LIVE, w_LIVE); + nk->curve25519 = ecdh_key_new(&ssh_ec_kex_curve25519, false); + + return &nk->ek; +} + +static void ssh_ntru_client_free(ecdh_key *dh) +{ + ntru_client_key *nk = container_of(dh, ntru_client_key, ek); + ntru_keypair_free(nk->keypair); + ecdh_key_free(nk->curve25519); + sfree(nk); +} + +static void ssh_ntru_client_getpublic(ecdh_key *dh, BinarySink *bs) +{ + ntru_client_key *nk = container_of(dh, ntru_client_key, ek); + + /* + * The client's public information is a single SSH string + * containing the NTRU public key and the Curve25519 public point + * concatenated. So write both of those into the output + * BinarySink. + */ + ntru_encode_pubkey(nk->keypair->h, p_LIVE, q_LIVE, bs); + ecdh_key_getpublic(nk->curve25519, bs); +} + +static bool ssh_ntru_client_getkey(ecdh_key *dh, ptrlen remoteKey, + BinarySink *bs) +{ + ntru_client_key *nk = container_of(dh, ntru_client_key, ek); + + /* + * We expect the server to have sent us a string containing a + * ciphertext, a confirmation hash, and a Curve25519 public point. + * Extract all three. + */ + BinarySource src[1]; + BinarySource_BARE_INIT_PL(src, remoteKey); + + uint16_t *ciphertext = snewn(p_LIVE, uint16_t); + ptrlen ciphertext_encoded = ntru_decode_ciphertext( + ciphertext, nk->keypair, src); + ptrlen confirmation_hash = get_data(src, 32); + ptrlen curve25519_remoteKey = get_data(src, 32); + + if (get_err(src) || get_avail(src)) { + /* Hard-fail if the input wasn't exactly the right length */ + ring_free(ciphertext, p_LIVE); + return false; + } + + /* + * Main hash object which will combine the NTRU and Curve25519 + * outputs. + */ + ssh_hash *h = ssh_hash_new(&ssh_sha512); + + /* Reusable buffer for storing various hash outputs. */ + uint8_t hashdata[64]; + + /* + * NTRU side. + */ + { + /* Decrypt the ciphertext to recover the server's plaintext */ + uint16_t *plaintext = snewn(p_LIVE, uint16_t); + ntru_decrypt(plaintext, ciphertext, nk->keypair); + + /* Make the confirmation hash */ + ntru_confirmation_hash(hashdata, plaintext, nk->keypair->h, + p_LIVE, q_LIVE); + + /* Check it matches the one the server sent */ + unsigned ok = smemeq(hashdata, confirmation_hash.ptr, 32); + + /* If not, substitute in rho for the plaintext in the session hash */ + unsigned mask = ok-1; + for (size_t i = 0; i < p_LIVE; i++) + plaintext[i] ^= mask & (plaintext[i] ^ nk->keypair->rho[i]); + + /* Compute the session hash, whether or not we did that */ + ntru_session_hash(hashdata, ok, plaintext, p_LIVE, ciphertext_encoded, + confirmation_hash); + + /* Free temporary values */ + ring_free(plaintext, p_LIVE); + ring_free(ciphertext, p_LIVE); + + /* And put the NTRU session hash into the main hash object. */ + put_data(h, hashdata, 32); + } + + /* + * Curve25519 side. + */ + { + strbuf *otherkey = strbuf_new_nm(); + + /* Call out to Curve25519 to compute the shared secret from that + * kex method */ + bool ok = ecdh_key_getkey(nk->curve25519, curve25519_remoteKey, + BinarySink_UPCAST(otherkey)); + + /* If that failed (which only happens if the other end does + * something wrong, like sending a low-order curve point + * outside the subgroup it's supposed to), we might as well + * just abort and return failure. That's what we'd have done + * in standalone Curve25519. */ + if (!ok) { + ssh_hash_free(h); + smemclr(hashdata, sizeof(hashdata)); + strbuf_free(otherkey); + return false; + } + + /* + * ecdh_key_getkey will have returned us a chunk of data + * containing an encoded mpint, which is how the Curve25519 + * output normally goes into the exchange hash. But in this + * context we want to treat it as a fixed big-endian 32 bytes, + * so extract it from its encoding and put it into the main + * hash object in the new format. + */ + BinarySource src[1]; + BinarySource_BARE_INIT_PL(src, ptrlen_from_strbuf(otherkey)); + mp_int *curvekey = get_mp_ssh2(src); + + for (unsigned i = 32; i-- > 0 ;) + put_byte(h, mp_get_byte(curvekey, i)); + + mp_free(curvekey); + strbuf_free(otherkey); + } + + /* + * Finish up: compute the final output hash (full 64 bytes of + * SHA-512 this time), and return it encoded as a string. + */ + ssh_hash_final(h, hashdata); + put_stringpl(bs, make_ptrlen(hashdata, sizeof(hashdata))); + smemclr(hashdata, sizeof(hashdata)); + + return true; +} + +/* + * State structure for the server, which takes the role of inventing a + * secret plaintext and sending it to the client encrypted with the + * public key the client sent. + */ +typedef struct ntru_server_key { + uint16_t *plaintext; + strbuf *ciphertext_encoded, *confirmation_hash; + ecdh_key *curve25519; + + ecdh_key ek; +} ntru_server_key; + +static void ssh_ntru_server_free(ecdh_key *dh); +static void ssh_ntru_server_getpublic(ecdh_key *dh, BinarySink *bs); +static bool ssh_ntru_server_getkey(ecdh_key *dh, ptrlen remoteKey, + BinarySink *bs); + +static const ecdh_keyalg ssh_ntru_server_vt = { + /* This vtable has no 'new' method, because it's constructed via + * the selector vt below */ + .free = ssh_ntru_server_free, + .getpublic = ssh_ntru_server_getpublic, + .getkey = ssh_ntru_server_getkey, + .description = ssh_ntru_description, +}; + +static ecdh_key *ssh_ntru_server_new(void) +{ + ntru_server_key *nk = snew(ntru_server_key); + nk->ek.vt = &ssh_ntru_server_vt; + + nk->plaintext = snewn(p_LIVE, uint16_t); + nk->ciphertext_encoded = strbuf_new_nm(); + nk->confirmation_hash = strbuf_new_nm(); + ntru_gen_short(nk->plaintext, p_LIVE, w_LIVE); + + nk->curve25519 = ecdh_key_new(&ssh_ec_kex_curve25519, false); + + return &nk->ek; +} + +static void ssh_ntru_server_free(ecdh_key *dh) +{ + ntru_server_key *nk = container_of(dh, ntru_server_key, ek); + ring_free(nk->plaintext, p_LIVE); + strbuf_free(nk->ciphertext_encoded); + strbuf_free(nk->confirmation_hash); + ecdh_key_free(nk->curve25519); + sfree(nk); +} + +static bool ssh_ntru_server_getkey(ecdh_key *dh, ptrlen remoteKey, + BinarySink *bs) +{ + ntru_server_key *nk = container_of(dh, ntru_server_key, ek); + + /* + * In the server, getkey is called first, with the public + * information received from the client. We expect the client to + * have sent us a string containing a public key and a Curve25519 + * public point. + */ + BinarySource src[1]; + BinarySource_BARE_INIT_PL(src, remoteKey); + + uint16_t *pubkey = snewn(p_LIVE, uint16_t); + ntru_decode_pubkey(pubkey, p_LIVE, q_LIVE, src); + ptrlen curve25519_remoteKey = get_data(src, 32); + + if (get_err(src) || get_avail(src)) { + /* Hard-fail if the input wasn't exactly the right length */ + ring_free(pubkey, p_LIVE); + return false; + } + + /* + * Main hash object which will combine the NTRU and Curve25519 + * outputs. + */ + ssh_hash *h = ssh_hash_new(&ssh_sha512); + + /* Reusable buffer for storing various hash outputs. */ + uint8_t hashdata[64]; + + /* + * NTRU side. + */ + { + /* Encrypt the plaintext we generated at construction time, + * and encode the ciphertext into a strbuf so we can reuse it + * for both the session hash and sending to the client. */ + uint16_t *ciphertext = snewn(p_LIVE, uint16_t); + ntru_encrypt(ciphertext, nk->plaintext, pubkey, p_LIVE, q_LIVE); + ntru_encode_ciphertext(ciphertext, p_LIVE, q_LIVE, + BinarySink_UPCAST(nk->ciphertext_encoded)); + ring_free(ciphertext, p_LIVE); + + /* Compute the confirmation hash, and write it into another + * strbuf. */ + ntru_confirmation_hash(hashdata, nk->plaintext, pubkey, + p_LIVE, q_LIVE); + put_data(nk->confirmation_hash, hashdata, 32); + + /* Compute the session hash (which is easy on the server side, + * requiring no conditional substitution). */ + ntru_session_hash(hashdata, 1, nk->plaintext, p_LIVE, + ptrlen_from_strbuf(nk->ciphertext_encoded), + ptrlen_from_strbuf(nk->confirmation_hash)); + + /* And put the NTRU session hash into the main hash object. */ + put_data(h, hashdata, 32); + + /* Now we can free the public key */ + ring_free(pubkey, p_LIVE); + } + + /* + * Curve25519 side. + */ + { + strbuf *otherkey = strbuf_new_nm(); + + /* Call out to Curve25519 to compute the shared secret from that + * kex method */ + bool ok = ecdh_key_getkey(nk->curve25519, curve25519_remoteKey, + BinarySink_UPCAST(otherkey)); + /* As on the client side, abort if Curve25519 reported failure */ + if (!ok) { + ssh_hash_free(h); + smemclr(hashdata, sizeof(hashdata)); + strbuf_free(otherkey); + return false; + } + + /* As on the client side, decode Curve25519's mpint so we can + * re-encode it appropriately for our hash preimage */ + BinarySource src[1]; + BinarySource_BARE_INIT_PL(src, ptrlen_from_strbuf(otherkey)); + mp_int *curvekey = get_mp_ssh2(src); + + for (unsigned i = 32; i-- > 0 ;) + put_byte(h, mp_get_byte(curvekey, i)); + + mp_free(curvekey); + strbuf_free(otherkey); + } + + /* + * Finish up: compute the final output hash (full 64 bytes of + * SHA-512 this time), and return it encoded as a string. + */ + ssh_hash_final(h, hashdata); + put_stringpl(bs, make_ptrlen(hashdata, sizeof(hashdata))); + smemclr(hashdata, sizeof(hashdata)); + + return true; +} + +static void ssh_ntru_server_getpublic(ecdh_key *dh, BinarySink *bs) +{ + ntru_server_key *nk = container_of(dh, ntru_server_key, ek); + + /* + * In the server, this function is called after getkey, so we + * already have all our pieces prepared. Just concatenate them all + * into the 'server's public data' string to go in ECDH_REPLY. + */ + put_datapl(bs, ptrlen_from_strbuf(nk->ciphertext_encoded)); + put_datapl(bs, ptrlen_from_strbuf(nk->confirmation_hash)); + ecdh_key_getpublic(nk->curve25519, bs); +} + +/* ---------------------------------------------------------------------- + * Selector vtable that instantiates the appropriate one of the above, + * depending on is_server. + */ +static ecdh_key *ssh_ntru_new(const ssh_kex *kex, bool is_server) +{ + if (is_server) + return ssh_ntru_server_new(); + else + return ssh_ntru_client_new(); +} + +static const ecdh_keyalg ssh_ntru_selector_vt = { + /* This is a never-instantiated vtable which only implements the + * functions that don't require an instance. */ + .new = ssh_ntru_new, + .description = ssh_ntru_description, +}; + +static const ssh_kex ssh_ntru_curve25519 = { + .name = "sntrup761x25519-sha512@openssh.com", + .main_type = KEXTYPE_ECDH, + .hash = &ssh_sha512, + .ecdh_vt = &ssh_ntru_selector_vt, +}; + +static const ssh_kex *const hybrid_list[] = { + &ssh_ntru_curve25519, +}; + +const ssh_kexes ssh_ntru_hybrid_kex = { lenof(hybrid_list), hybrid_list }; |