summaryrefslogtreecommitdiffstats
path: root/preload.c
blob: 87a42ee8888ac84514de93151d860d45417e3332 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>

#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/syscall.h>
#include <sys/types.h>

#include <linux/memfd.h>

/* bits to get memfd_create to work */
#define MFD_SECRET 0x0008U
#define MFD_SECRET_IOCTL '-'
#define MFD_SECRET_EXCLUSIVE	_IOW(MFD_SECRET_IOCTL, 0x13, unsigned long)
#define MFD_SECRET_UNCACHED	_IOW(MFD_SECRET_IOCTL, 0x14, unsigned long)

/* secretmemfd defines */
#define SECRETMEM_EXCLUSIVE	0x1
#define SECRETMEM_UNCACHED	0x2

#define ASSERT(x) do { if (!(x)) { printf("ASSERTION failed at line %d\n", __LINE__); exit(1); } } while (0)

/* glibc should have defined this by now, sigh */
static inline int memfd_create(const char *name, unsigned int flags)
{
	return syscall(__NR_memfd_create, name, flags);
}

#ifndef __NR_secretmemfd
#define __NR_secretmemfd 440
#endif

static inline int secretmemfd(unsigned long flags)
{
	return syscall(__NR_secretmemfd, flags);
}

static int use_secret = 1;
static int secret_option = MFD_SECRET_EXCLUSIVE;

static int secretmem_open(void)
{
	unsigned long secretmemfd_flags = SECRETMEM_EXCLUSIVE;
	int fd, ret;

	if (!use_secret)
		return memfd_create("secure", MFD_CLOEXEC);

	if (secret_option == MFD_SECRET_UNCACHED)
		secretmemfd_flags = SECRETMEM_EXCLUSIVE;

	/* try dedicated syscall first */
	fd = secretmemfd(secretmemfd_flags);
	if (fd >= 0 || errno != ENOSYS)
		return fd;

	/* secretmemfd is not implemented, maybe memfd_create(SECRET) is */
	fd = memfd_create("secure", MFD_CLOEXEC|MFD_SECRET);
	if (fd < 0)
		return fd;

	ret = ioctl(fd, secret_option);
	if (ret < 0)
		return ret;

	return fd;
}


/* segment size.  Matches hugepage size */
#define SEG_SIZE 2*1024*1024

static int debug = 0;

#define DEBUG(...) do { if (debug > 1) printf(__VA_ARGS__); } while(0)
#define INFO(...) do { if (debug) printf(__VA_ARGS__); } while(0)

#define PINUSE_BIT	0x01
#define CINUSE_BIT	0x02

#define FLAG_BITS	(CINUSE_BIT | PINUSE_BIT)

struct malloc_chunk_head {
	int	prev_foot;	/* size of previous if free */
	int	head;		/* entire size of this chunk */
};

struct malloc_chunk {
	struct malloc_chunk_head h;
	struct malloc_chunk *fd;
	struct malloc_chunk *bk;
};

#define CHUNK_SIZE (sizeof(struct malloc_chunk_head))
#define MIN_CHUNK 0x20
#define CHUNK_ALIGNMENT (MIN_CHUNK - 1)

struct segptr {
	char *base;
	/* no size because they're always SEG_SIZE */
	struct segptr *next;
};

struct malloc_state {
	struct segptr *seg;
	struct malloc_chunk free;
};

static 	struct malloc_state m;

static size_t pad_request(size_t s)
{
	return (s + CHUNK_SIZE + CHUNK_ALIGNMENT) & ~CHUNK_ALIGNMENT;
}

static int in_use(struct malloc_chunk *c)
{
	return c->h.head & CINUSE_BIT ? 1 : 0;
}

static int prev_in_use(struct malloc_chunk *c)
{
	return c->h.head & PINUSE_BIT ? 1 : 0;
}

static void check(int cond, const char *str)
{
	if (cond) {
		perror(str);
		exit(1);
	}
}

static struct segptr *chunk_to_segment(struct malloc_chunk *c)
{
	struct segptr *sp;

	for (sp = m.seg; sp != NULL; sp = sp->next)
		if (sp->base <= (char *)c &&
		    (char *)c < sp->base + SEG_SIZE)
			return sp;
	return NULL;
}

static void *chunk2mem(struct malloc_chunk *c)
{
	return (char *)c + CHUNK_SIZE;
}

static struct malloc_chunk *mem2chunk(void *p)
{
	return (struct malloc_chunk *)((char *)p - CHUNK_SIZE);
}

static size_t chunk_size(struct malloc_chunk *c)
{
	return c->h.head & ~FLAG_BITS;
}

static struct malloc_chunk *next_chunk(struct malloc_chunk *c)
{
	struct segptr *seg = chunk_to_segment(c);
	void *n = (char*)c + chunk_size(c);

	ASSERT(seg != NULL);

	ASSERT((char *)n > seg->base &&
	       (char *)n <= seg->base + SEG_SIZE);
	if (n == seg->base + SEG_SIZE)
		return NULL;

	return n;
}



static struct malloc_chunk *prev_chunk(struct malloc_chunk *c)
{
	return (struct malloc_chunk *)((char*)c - c->h.prev_foot);
}

static void link_free_chunk(struct malloc_chunk *c)
{
	struct malloc_chunk *f = &m.free;
	struct malloc_chunk *b = f->bk;

	c->fd = f;
	f->bk = c;

	b->fd = c;
	c->bk = b;
}

static void unlink_free_chunk(struct malloc_chunk *c)
{
	struct malloc_chunk *f = c->fd;
	struct malloc_chunk *b = c->bk;

	b->fd = f;
	f->bk = b;
}

static void show_segment(void)
{
	struct malloc_chunk *c;
	struct segptr *seg;
	int i;

	if (debug < 2)
		return;

	for (i = 0, seg = m.seg; seg != NULL; i++, seg = seg->next) {
		printf("SHOW SEGMENT %i\n", i);
		for (c = (struct malloc_chunk *)seg->base; c != NULL; c = next_chunk(c)) {
			printf("%p:%d:%d:%zu:%zu", c, in_use(c), prev_in_use(c),
			       c->h.prev_foot, chunk_size(c));
			if (!in_use(c))
				printf(":%p:%p", c->fd, c->bk);
			printf("\n");
		}
	}
	printf("SHOW SEGMENT END\n");
}

static void alloc_segment(void)
{
	int fd;
	int ret;
	void *p;
	struct segptr *seg, *sp, **psp;
	const size_t ssize = pad_request(sizeof(*seg));
	struct malloc_chunk *c;

	fd = secretmem_open();
	check(fd < 0, "secretmem_open");

	ret = ftruncate(fd, SEG_SIZE);
	check(ret < 0, "ftruncate");

	p = mmap(NULL, SEG_SIZE, PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0);
	check(p == MAP_FAILED, "mmap");
	close(fd);

	DEBUG("initial malloc at 0x%p-0x%p\n", p, p+SEG_SIZE);
	c = p;
	seg = chunk2mem(c);
	memset(seg, 0, sizeof(*seg));
	seg->base = p;
	c->h.head = ssize | CINUSE_BIT | PINUSE_BIT;
	c->h.prev_foot = 0;

	if (m.seg == NULL) {
		m.seg = seg;
	} else {
		for (sp = m.seg; sp->next != NULL; sp = sp->next)
			;
		sp->next = seg;
	}
	c = next_chunk(c);
	c->h.head = (size_t)(((char *)p + SEG_SIZE) - (char *)c) | PINUSE_BIT;
	c->h.prev_foot = ssize;
	DEBUG("next chunk at 0x%p, head=%zu, prev=%zu\n", c, chunk_size(c), c->h.prev_foot);

	link_free_chunk(c);
}

void __attribute__ ((constructor)) preload_setup(void)
{
	/*
	 * if this fails, the MIN_CHUNK, which must be a power of 2
	 * isn't big enough.  If that happens, chunk splitting will
	 * fail because the free list pointers get overwritten by the
	 * split
	 */
	ASSERT(sizeof(struct malloc_chunk) < MIN_CHUNK);
	if (getenv("MALLOC_DEBUG") != NULL)
		debug = atoi(getenv("MALLOC_DEBUG"));
	if (getenv("NO_SECRET_MEM") != NULL)
		use_secret = 0;
	if (getenv("SECRET_UNCACHED") != NULL)
		secret_option = MFD_SECRET_UNCACHED;

	m.free.fd = &m.free;
	m.free.bk = &m.free;
	m.seg = NULL;
}

static struct malloc_chunk *find_free(size_t size)
{
	struct malloc_chunk *c, *found = NULL;

	for (c = m.free.fd; c != &m.free; c = c->fd) {
		ASSERT(prev_in_use(c));
		if (chunk_size(c) < size)
			continue;
		if (found && chunk_size(found) < chunk_size(c))
			continue;
		found = c;
	}
	return found;
}

static void split_free_chunk(struct malloc_chunk *c, size_t size)
{
	struct malloc_chunk *new_c, *n;
	size_t csize = chunk_size(c);

	if (csize - size < pad_request(1)) {
		/* nothing to split, just give everything up */
		unlink_free_chunk(c);
		c->h.head |= CINUSE_BIT;
		new_c = next_chunk(c);
		if (new_c) {
			ASSERT(!prev_in_use(new_c));
			ASSERT(new_c->h.prev_foot == chunk_size(c));

			new_c->h.head |= PINUSE_BIT;
		}
		return;
	}

	/* set the old chunk to the size */
	c->h.head = size | CINUSE_BIT | PINUSE_BIT;
	/* get the new part of the split */
	new_c = next_chunk(c);
	new_c->h.head = (csize - size) | PINUSE_BIT;
	new_c->h.prev_foot  = size;
	/* now replace the new chunk with the old chunk */
	new_c->fd = c->fd;
	new_c->bk = c->bk;
	c->bk->fd = new_c;
	c->fd->bk = new_c;
	/* and adjust the next prev_foot if there is one */
	n = next_chunk(new_c);
	if (n) {
		ASSERT(!prev_in_use(n));
		ASSERT(n->h.prev_foot == csize);
		n->h.prev_foot = chunk_size(new_c);
	}
}

static void *dlmalloc(size_t size)
{
	struct malloc_chunk *c;

	size = pad_request(size);
	c = find_free(size);
	if (c == NULL) {
		alloc_segment();
		c = find_free(size);
		ASSERT(c != NULL);
	}
	DEBUG("found chunk 0x%p\n", c);

	split_free_chunk(c, size);
	return chunk2mem(c);
}

void *CRYPTO_malloc(size_t size, const char *file, int line)
{
	void *ret = NULL;
	if (size < SEG_SIZE)
		ret = dlmalloc(size);
	INFO("in crypto malloc from %s:%d %zu@%p\n", file, line, size, ret);
	show_segment();
	return ret;
}

void CRYPTO_free(void *ptr, const char *file, int line)
{
	struct malloc_chunk *c, *n;

	INFO("in crypto free from %s:%d: %p\n", file, line, ptr);
	if (ptr == NULL)
		return;

	c = mem2chunk(ptr);
	ASSERT(in_use(c));
	/* shred the data */
	memset(ptr, 0, chunk_size(c) - CHUNK_SIZE);

	n = next_chunk(c);
	DEBUG("free c=%p, n=%p, next_in_use=%d, prev_in_use=%d\n",
	       c,n,n && in_use(n),prev_in_use(c));

	/* now check for consolidation with previous */
	if (!prev_in_use(c)) {
		struct malloc_chunk *p = prev_chunk(c);

		p->h.head += chunk_size(c);
		if (n)
			n->h.prev_foot = chunk_size(p);

		/* the new consolidated chunk becomes our current
		 * chunk for the next free check below.  The previous
		 * chunk was already linked */
		c = p;
	} else {
		DEBUG("linking %p\n", c);
		link_free_chunk(c);
	}

	/* and finally consolidation with next */
	if (n && !in_use(n)) {
		unlink_free_chunk(n);
		c->h.head += chunk_size(n);
		n = next_chunk(c);
		if (n)
			n->h.prev_foot = chunk_size(c);
	} else if (n) {
		ASSERT(prev_in_use(n));
		n->h.head &= ~PINUSE_BIT;
	}
	c->h.head &= ~CINUSE_BIT;
	show_segment();
}

void CRYPTO_clear_free(void *str, size_t num, const char *file, int line)
{
    if (str == NULL)
        return;

    CRYPTO_free(str, file, line);
}