/*
 * Copyright (C) 2009 Luigi Rizzo
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 
 * THIS SOFTWARE IS PROVIDED BY AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 * $Id: kldpatch.c 1047 2009-01-11 12:08:03Z luigi $
 *
 */

/*
 * DESCRIPTION

Several USB, PCI and other device drivers require the kernel to have
the exact product/vendor IDs (and sometimes matching quirks) in its
internal tables, to identify the device and possibly enable/disable
specific features. As an example, this is the case for many "umass"
devices, where implementation of standard features varies, and
for uscanner devices, where there is no specific device class info
so the driver has to do an exact match using a lookup table.

Apart from rebuilding the kernel or module, the device table can be
altered with external tools, either on-disk or in memory.

This program can be used to read or patch the in-memory kernel
structures used to store the match/quirks tables. The program is
also able to patch the file containing the module or kernel, so
the change can be made persistent across reboots.

When working on the live kernel, the program uses kldfind and kldsym
to locate the module, address and size for the desired data structure,
and then kvm_open/kvm_read/kvm_write to read and possibly update
the table.

When working on a file, the location of the symbol is determined
using a custom version of the elfdump code.

Patching a binary module is clearly dangerous, so we try to make the
process a little less risky by putting some additional controls in
this program. In particular, kmodpatch constains a table of the form

	module symbol  record_format...

defining which modules/symbols can be patched and the structure of
each entry.  As an example:

    uscanner.ko	uscanner_devs		 2:vendor 2:product 4:flags
    umass.ko	umass_devdescrs		 4 4 4 2 2
    if_sis.ko	sis_devs		 2 2 s

means that
+ module uscanner.ko has a device table called uscanner_devs
  where each record is made of 3 numeric fields of size 2 2 4 bytes,
  whose "names" (used to print the content of the table) are
  "vendor", "product" and "flags" respectively;
+ module umass.ko has a device table called umass_devdescr with
  records of size 4 4 4 2 2 bytes;
+ module if_sis.ko has a device table called sis_devs where each record
  has two numeric entries of 2 bytes each, plus a string pointer;

Additional format specifiers can be used for little or big-endian
numbers, generic pointers and so on. Furthermore, comments can be
added to the fields.

The program accepts commands of the form

    kldpatch if_sis.ko sis_devs write @3 0x1234 0x5678 "temporary entry"

which means 'patch record 3 of table sis_devs with the values specified".
The program will make sure that operands are of the required size,
and that the table entry is within the table.

 */

/*
 * The table of allowed modules and symbols
 */
static const char _default_table[] ="\
    umass.ko	umass_devdescrs	4:vendor 4:product 4:rev 2:proto 2:quirks\n\
    uscanner.ko	uscanner_devs	2:vendor 2:device 4:flags #comment\n\
    if_nfe.ko	nfe_devs	2:vendor 2:device s:name \n\
    if_re.ko	re_devs		2:vendor 2:device 4:type s:name \n\
";


#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>	/* read() */
#include <sys/endian.h>

#include <sys/types.h>
#include <sys/param.h>
#include <sys/linker.h>
#include <sys/module.h>

#include <fcntl.h>
#include <kvm.h>

#include "elfdump.h"

int verbose = 0;

enum _ty_tags {
	TAG_LEN_MASK =	0x00ffffff,
	TAG_TYPE_MASK =	0x0f000000,
	TAG_HOST =	0,
	TAG_BE =	0x01000000,
	TAG_LE =	0x02000000,
	TAG_PTR =	0x03000000,
	TAG_STRING =	0x04000000,
};

#define ARGCOUNT	16	/* how many arguments at most */
/*
 * A structure describing the matching entry
 */
struct match_info {
	const char *mod_name;
	const char *symbol;
	const char *filename;		/* patch a file, not the memory */

	void *dump_addr;		/* dump address */
	int data_ofs;			/* read/write position */
	/* the following two are looked up from the kernel */
	struct kld_sym_lookup sym;

	kvm_t	*k;			/* kvm_open result */
	const char *match_buf;		/* copy of the description line */
	int reclen;			/* length of a single record */
	int argcount;			/* number of arguments */
	const char *argdesc[ARGCOUNT];	/* description strings */
	int args[ARGCOUNT];	/* ids for arguments */
		/* The low 24 bit represent the length,
		 * the high 8 bits represent various flags
		 * 	.... ....
		 *	      |''-- 00 = host order
		 *	      |     01 = big endian
		 *	      |     10 = little endian
		 *	      |     11 = pointer
		 *	      `---- 0  ordinary pointer
		 *	            1  string
		 */
};

enum errcode {
	E_NO_ARGS,
	E_INVALID_MODULE,
	E_NO_SYMBOL,
	E_NO_ROOT,
	E_NO_MEMORY,
	E_NO_OFS,
	E_BAD_OFS,
	E_BAD_ARGS,
	E_UNSUPPORTED,
	E_WRITE_ERROR,
	E_FILE,
};

/* print an error message */
static void
help(enum errcode err, struct match_info *m)
{
    switch (err) {
    case E_NO_ARGS:
	fprintf(stderr, "Need at least two arguments\n");
	break;
    case E_INVALID_MODULE:
	fprintf(stderr, "Invalid module name %s\n", m->mod_name);
	break;
    case E_NO_SYMBOL:
	fprintf(stderr, "Symbol %s not found in module %s\n",
		m->symbol, m->mod_name);
	break;
    case E_NO_ROOT:
	fprintf(stderr, "Must be root to use kvm_read/kvm_write\n");
	break;
    case E_NO_OFS:
	fprintf(stderr, "Missing write offset\n");
	break;
    case E_BAD_OFS:
	fprintf(stderr, "Bad write offset\n");
	break;
    case E_NO_MEMORY:
	perror("malloc failure");
	break;
    case E_BAD_ARGS:
	fprintf(stderr, "Invalid argument length\n");
	break;
    case E_UNSUPPORTED:
	fprintf(stderr, "Operand size unsupported\n");
	break;
    case E_WRITE_ERROR:
	fprintf(stderr, "Error writing to kvm\n");
	break;
    case E_FILE:
	fprintf(stderr, "Error opening/reading file\n");
	break;
    }
    exit(1);
}

/* wrapper around calloc: try an allocation, exit on error */
static void *
my_calloc(int size)
{
    void *res = calloc(1, size);

    if (!res)
        help(E_NO_MEMORY, NULL);
    return res;
}

/* print the record */
static void
table_dump(unsigned char *buf, struct match_info *m)
{
    int i, l;
    int n;
    int n_args = m->sym.symsize / m->reclen;

    if (verbose)
	fprintf(stderr, "-- %d records of size %d, %d fields\n",
	    n_args, m->reclen, m->argcount);
    /* print headers */
    fprintf(stderr, "Index ");
    for (i = 0; i < m->argcount; i++) {
	l = m->args[i] & TAG_LEN_MASK;
	l = 2*l + 2; /* field length */
	fprintf(stderr, "%*s ", l, m->argdesc[i] ? m->argdesc[i] : "");
    }
    fprintf(stderr, "\n");
    for (n = 0; n < n_args; n++) {
	const unsigned char *p = buf + n * m->reclen;
	if (m->dump_addr && m->dump_addr != p) {
	    continue;
	}
	fprintf(stderr, "%4d: ", n);
	for (i = 0; i < m->argcount; i++) {
	    uint16_t d16;
	    uint32_t d32;
	    uint64_t d64;

	    int t = m->args[i] & TAG_TYPE_MASK;
	    l = m->args[i] & TAG_LEN_MASK;
	    if (t == TAG_STRING) {
		const char *s;
		char str[65];

		bzero(str, sizeof(str));
		bcopy(p, &s, l);
		kvm_read(m->k, (unsigned long)s, str, sizeof(str) - 1);
		fprintf(stderr, "%s ", str);
		continue;
	    }
	    switch (l) {
	    default:
		fprintf(stderr, "arg size not recognised %d\n", l);
		break;
	    case 1:
		fprintf(stderr, "0x%02x ", *(const unsigned char *)p);
		break;
	    case 2:
		bcopy(p, &d16, l);
		if (t & TAG_BE)
			d16 = be16toh(d16);
		else if (t & TAG_LE)
			d16 = le16toh(d16);
		fprintf(stderr, "0x%04x ", d16);
		break;
	    case 4:
		bcopy(p, &d32, l);
		if (t & TAG_BE)
			d32 = be32toh(d32);
		else if (t & TAG_LE)
			d32 = le32toh(d32);
		fprintf(stderr, "0x%08x ", d32);
		break;
	    case 8:
		bcopy(p, &d64, l);
		if (t & TAG_BE)
			d64 = be64toh(d64);
		else if (t & TAG_LE)
			d64 = le64toh(d64);
		fprintf(stderr, "0x%016jx ", d64);
		break;
	    }
	    p += l;
	}
	fprintf(stderr, "\n");
    }
#if 0 /* hexdump code */
    for (i = 0; i < m->sym.symsize; i++) {
	if ( (i & 15) == 0)
	    printf("\n%04x:  ", i);
	printf("%02x ", buf[i]);
	if ( (i & 15) == 7)
	    printf(" ");
    }
    printf("\n");
#endif
}

/*
 * Match name and/or symbol with the content of the table.
 * All input and output fields are in *m.
 * Return the record filled with info from the matching entry.
 * Use NULL, "" or "-" to use a wildcard on name or symbol.
 * Retval is 0 on failure, 1 on match
 */
static int
match_table(const char *table, struct match_info *m)
{
    const char *c = NULL, *s;
    char *s1;
    int ll = 0;
    int i;
    int mod_len;	/* length of module name */
    int sym_len;	/* length of symbol name */

    if (!table || !m)
	return 0;
    /* store the length, set to 0 for wildcard */
    mod_len = m->mod_name ? strlen(m->mod_name) : 0;
    if (mod_len == 1 && m->mod_name[0] == '-')
	mod_len = 0;
    sym_len = m->symbol ? strlen(m->symbol) : 0;
    if (sym_len == 1 && m->symbol[0] == '-')
	sym_len = 0;

    /* match module and symbol. On return, s and c point to the matching
     * module and symbol name, lengths are in mod_len and sym_len.
     */
    for (s = table; *s; s += ll + 1) {
	int ln, ls;
	s += strspn(s, " \t");		/* skip blanks */
	ll = strcspn(s, "\n");		/* remaining line length */
	if (verbose)
	    fprintf(stderr, "Analysing: [%.*s]\n", ll, s);
	if (*s == '\0')			/* table complete */
	    break;
	ln = strcspn(s, " \t#\n");	/* module name length in table */
	if (mod_len == 0) {
	    /* no module name, assume a match */
	} else if (mod_len == ln && !strncmp(s, m->mod_name, mod_len)) {
	    /* exact match on module name */
	} else if (ln > 3 && mod_len == ln - 3 &&
		!strncmp(s, m->mod_name, mod_len) &&
		!strncmp(s + ln - 3, ".ko", 3) ) {
	    /* match without .ko suffix */
	} else {
	    continue;			/* no match */
	}
	c = s + ln;			/* skip the module name */
	c += strspn(c, " \t");		/* skip separators */
	ls = strcspn(c, " \t#\n");	/* length of symbol name */
	if (*c == '\0')
	    continue;			/* no symbol, fail at next cycle */
	if (!sym_len || (ls == sym_len && !strncmp(c, m->symbol, sym_len))) {
	    sym_len = ls;
	    mod_len = ln;
	    break;
	}
    }
    if (*s == '\0')
	return 0;
    /*
     * make a copy of the matching line to fetch module and symbol
     * names and comment fields.
     */
    m->match_buf = s1 = my_calloc(ll+1);
    strncpy(s1, s, ll);
    if (verbose) {
	fprintf(stderr, "found: [%.*s]\n", ll, s);
	fprintf(stderr, "copy: [%s]\n", s1);
        fprintf(stderr, "    name %.*s symbol %.*s\n", mod_len, s, sym_len, c);
    }
    /* store the matching values in the return struct */
    m->mod_name = s1;
    s1[mod_len] = '\0';			/* truncate module name */
    m->symbol = s1 + (c - s);
    s1[c + sym_len - s] = '\0';		/* truncate symbol name */
    s1 += c + sym_len - s + 1;
    s1 += strspn(s1, " \t");

    /* now scan the argument list extracting type, size and description */
    m->reclen = 0;			/* total record length so far */
    ll = sizeof(m->args)/sizeof(m->args[0]);	 /* size of argument list */
    for (i=0; !index("#\n", *s1) && i < ll; i++) {
	char cc, *endp;

	m->args[i] = strtoul(s1, &endp, 0);
	cc = *endp++;
	switch (cc) {
	default:
	    endp--;	/* not recognised */
	    break;

	case 'i':
	    m->args[i] = sizeof(int);
	    break;

	case 'b':
	    m->args[i] |= TAG_BE;
	    break;

	case 'l':
	    m->args[i] |= TAG_LE;
	    break;

	case 'p':
	case 's':
	    if (m->args[i] == 0)
		m->args[i] = sizeof(void *);
	    m->args[i] |= (cc == 's') ? TAG_STRING : TAG_PTR;
	    break;
	}
	if (*endp == ':')	/* comment up to the next space */
	    m->argdesc[i] = endp + 1;
	m->reclen += m->args[i] & TAG_LEN_MASK;
	if (verbose)
	    fprintf(stderr, "found entry 0x%x\n", m->args[i]);
	s1 += strcspn(s1, " \t\n#");	/* skip descriptor */
	if (*s1 == '\n')
	    break;
	if (*s1)
	    *s1++ = '\0';		/* truncate the string */
	s1 += strspn(s1, " \t");	/* skip whitespace */
    }
    *s1 = '\0';		/* truncate the string */
    m->argcount = i;
    if (verbose)
	fprintf(stderr, "found: name [%s] sym [%s] reclen %d argcount %d\n",
		m->mod_name,
		m->symbol,
		m->reclen, m->argcount);
    return 1;
}


/*
 * read and possibly write into the module
 */
static int
do_rw(struct match_info *m, int argc, char *argv[])
{
    unsigned char *buf, *srcbuf;
    int fd = -1, i, l, bufp;

    if (verbose)
	fprintf(stderr, "kldfind %s %s\n", m->mod_name, m->symbol);
    m->sym.version = sizeof(struct kld_sym_lookup);
    m->sym.symname = (char *)(int)m->symbol;	/* XXX why not const ? */
    m->sym.symvalue = 0;
    m->sym.symsize = 0;
    if (m->filename) {		/* work on file, find offset and size */
	/*
	 * XXX working notes below.
	 * for a .ko file must load at PT_LOAD and p_offset
	 * entry 0 or entry 1 ? 
	 *
	 * for a kernel... entry 2 in the prog header
	 *
	 * for a .o file look at offset on .rodata
	 */
	struct elfdump_info *res;
	char *syms;

	asprintf(&syms, "%s,kernload,kernbase", m->sym.symname);
	if (!syms)
	    help(E_FILE, NULL);
	res = elfdump(ED_SYMTAB, m->filename, syms);
	free(syms);
	if (!res)
	    help(E_FILE, NULL);
	if (res->m[0].size == 0)	/* not found */
	    help(E_NO_SYMBOL, m);
	m->sym.symsize = res->m[0].size;
	m->sym.symvalue = res->m[0].value;
	if (res->m[1].value)
	    m->sym.symvalue -= res->m[1].value;
	if (res->m[2].value)
	    m->sym.symvalue -= res->m[2].value;
    } else {			/* work on memory, find offset and size */
        int kid = kldfind(m->mod_name);

	if (kid == -1) {
	    fprintf(stderr, "module %s not found, try kernel\n", m->mod_name);
	    kid = 0;
	}
	if (kldsym(kid, KLDSYM_LOOKUP, &m->sym))
	    help(E_NO_SYMBOL, m);
    }
    if (verbose)
	fprintf(stderr, "found %s at 0x%lx len %d\n",
	    m->symbol, m->sym.symvalue, (int)m->sym.symsize);
    if (m->sym.symsize % m->reclen != 0)
	fprintf(stderr, "struct size %d not multiple of reclen %d\n",
		(int)m->sym.symsize, m->reclen);
    if (argc > 3) {	/* compute the data offset, check bounds */
	m->data_ofs *= m->reclen;
	if (m->data_ofs < 0)
	    m->data_ofs = m->sym.symsize + m->data_ofs;
	if (verbose)
	    fprintf(stderr, " data ofs %d\n", m->data_ofs);
	if (m->data_ofs < 0 || m->data_ofs > (int)m->sym.symsize - m->reclen)
	    help(E_BAD_OFS, m);
    }
    buf = my_calloc(m->sym.symsize);
    if (m->filename) {		/* work on the file */
	fd = open(m->filename, (argc <= 4) ? O_RDONLY : O_RDWR);
	if (fd < 0)
	    help(E_NO_ROOT, m);
	if (lseek(fd, m->sym.symvalue, SEEK_SET) < 0 ||
		read(fd, buf, m->sym.symsize) != (int)m->sym.symsize)
	    help(E_BAD_OFS, m);
    } else {			/* work on memory */
	m->k = kvm_open(NULL, NULL, NULL,
		(argc <= 4) ? O_RDONLY : O_RDWR, "kldpatch");
	if (m->k == NULL || kvm_read(m->k, m->sym.symvalue, buf, m->sym.symsize)
		!= (int)m->sym.symsize)
	    help(E_NO_ROOT, m);
    }
    if (verbose)
	fprintf(stderr, "data read successfully\n");
    if (argc <= 4) {	/* read mode */
	m->dump_addr = (argc != 4) ? NULL : buf + m->data_ofs;
	table_dump(buf, m);
	return 0;
    }

    /*
     * allocate the write buffer, dup the existing content and then
     * copy arguments passed from the command line
     */
    srcbuf = my_calloc(m->reclen);
    bcopy(buf + m->data_ofs, srcbuf, m->reclen);
    /* prepare to copy in memory in endian-safe way */
    for (i = 0, bufp = 0; i < m->argcount && i+4 < argc; i++, bufp += l) {
	uint64_t d64;
	uint32_t d32;
	uint16_t d16;

	l = m->args[i] & TAG_LEN_MASK;
	if ( m->args[i] & (TAG_PTR | TAG_STRING) ) {
	    /* do not overwrite strings or pointers */
	    fprintf(stderr, "skip string/pointer %d:%s\n", i, argv[i+4]);
	    continue;
	}
	switch (l) {
	default:
	    fprintf(stderr, "unsupported format 0x%x, exit\n", m->args[i]);
	    exit(1);

	case 8:
	    d64 = strtoull(argv[i+4], NULL, 0);
	    if (m->args[i] & TAG_LE)
		d32 = htole64(d64);
	    else if (m->args[i] & TAG_BE)
		d64 = htobe32(d64);
	    bcopy(&d64, srcbuf+bufp, sizeof(d64));
	    break;

	case 4:
	    d32 = strtoul(argv[i+4], NULL, 0);
	    if (m->args[i] & TAG_LE)
		d32 = htole32(d32);
	    else if (m->args[i] & TAG_BE)
		d32 = htobe32(d32);
	    bcopy(&d32, srcbuf+bufp, sizeof(d32));
	    break;

	case 2:
	    d16 = strtoul(argv[i+4], NULL, 0);
	    if (m->args[i] & TAG_LE)
		d16 = htole16(d16);
	    else if (m->args[i] & TAG_BE)
		d16 = htobe16(d16);
	    bcopy(&d16, srcbuf+bufp, sizeof(d16));
	    break;

	case 1:
	    d16 = strtoul(argv[i+4], NULL, 0);
	    srcbuf[bufp] = d16 & 0xff;
	    break;
	}
    }
    if (bufp != m->reclen)
	help(E_BAD_ARGS, m);

    if (m->filename) {		/* write to file */
	if (lseek(fd, m->sym.symvalue + m->data_ofs, SEEK_SET) < 0 ||
		write(fd, srcbuf, m->reclen) != (int)m->reclen)
	    help(E_WRITE_ERROR, m);
	if (lseek(fd, m->sym.symvalue, SEEK_SET) < 0 ||
		read(fd, buf, m->sym.symsize) != (int)m->sym.symsize)
	    help(E_NO_ROOT, m);
    } else {			/* write to memory */
	if (kvm_write(m->k, m->sym.symvalue + m->data_ofs, srcbuf, m->reclen)
		!= m->reclen)
	    help(E_WRITE_ERROR, m);
	if (kvm_read(m->k, m->sym.symvalue, buf, m->sym.symsize)
		!= (int)m->sym.symsize)
	help(E_NO_ROOT, m);
    }
    table_dump(buf, m);
    return 0;
}

/*
 * usage: program [options] module_name symbol [@recno [val ... ]]
 * where options are
 *	-v:	verbose
 *	-t file		external table
 *	-m modulefile	patch this file, not the kernel
 */
int
main(int argc, char *argv[])
{
    char *module, *sym;
    const char *table = _default_table;
    struct match_info m;

    bzero(&m, sizeof(m));
    if (argc > 1 && !strcmp(argv[1], "-v")) {
	verbose = 1;
	argc--;
	argv++;
    }
    if (argc > 2 && !strcmp(argv[1], "-m")) {
	/* patch a file */
	m.filename = argv[2];
	argc -= 2;
	argv += 2;
    }
    if (argc > 2 && !strcmp(argv[1], "-t")) {
	/* either a filename or the table line */
	int fd = open(argv[2], O_RDONLY);
	if (fd < 0) {	/* the table is on the command line */
	    table = argv[2];
	} else {
	    int l = lseek(fd, 0, SEEK_END);
	    char *new_table;

	    if (l > 100000)	/* 100k is way too much for a table */
		help(E_FILE, NULL);
	    lseek(fd, 0, SEEK_SET);
	    new_table = my_calloc(l);
	    read(fd, new_table, l);
	    table = new_table;
	    close(fd);
	}
	argc -= 2;
	argv += 2;
    }

    if (argc < 3)
	help(E_NO_ARGS, NULL);	/* need module and symbol */
    /* argc==3 -> read all; argc==4 -> read one, argc>4 -> write */
    module = argv[1];
    sym = argv[2];
    if (argc > 3) {
	if (argv[3][0] != '@')
	    help(E_NO_OFS, &m);
	m.data_ofs = atoi(argv[3] + 1);
    }
    m.mod_name = module;
    m.symbol = sym;

    /* lookup the module/symbol in the description */
    if (!match_table(table, &m))
	help(E_INVALID_MODULE, &m);

    if (argc > 4 && argc - 4 != m.argcount)
	help(E_BAD_ARGS, &m);
    do_rw(&m, argc, argv);
    return 0;
}
