/*
 *  signalfd-test by Davide Libenzi (test app for signalfd)
 *  Copyright (C) 2007  Davide Libenzi
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 *  Davide Libenzi <davidel@xmailserver.org>
 *
 */

#define _GNU_SOURCE
#include <sys/syscall.h>
#include <sys/types.h>
#include <sys/signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <signal.h>
#include <poll.h>
#include <fcntl.h>
#include <errno.h>
#include <pthread.h>


/*
 * This were good at the time of 2.6.22-rc3 ...
 */
#ifndef __NR_signalfd
#if defined(__x86_64__)
#define __NR_signalfd 282
#elif defined(__i386__)
#define __NR_signalfd 321
#else
#error Cannot detect your architecture!
#endif
#endif


#define SIZEOF_SIG (_NSIG / 8)
#define SIZEOF_SIGSET (SIZEOF_SIG > sizeof(sigset_t) ? sizeof(sigset_t): SIZEOF_SIG)

#define TEST_SIG SIGUSR2



struct signalfd_siginfo {
	u_int32_t ssi_signo;
	int32_t ssi_errno;
	int32_t ssi_code;
	u_int32_t ssi_pid;
	u_int32_t ssi_uid;
	int32_t ssi_fd;
	u_int32_t ssi_tid;
	u_int32_t ssi_band;
	u_int32_t ssi_overrun;
	u_int32_t ssi_trapno;
	int32_t ssi_status;
	int32_t ssi_int;
	u_int64_t ssi_ptr;
	u_int64_t ssi_utime;
	u_int64_t ssi_stime;
	u_int64_t ssi_addr;
	u_int8_t __pad[48];
};

#if defined(USE_PTHREAD)

typedef pthread_t thread_id_t;

#else

typedef pid_t thread_id_t;

#endif


static int sfd;
static thread_id_t thid[8];
static unsigned long tids[8];


int signalfd(int ufc, sigset_t const *mask, size_t sizemask) {

	return syscall(__NR_signalfd, ufc, mask, sizemask);
}

long waitsig(int sfd, int timeo) {
	int n;
	struct pollfd pfd;
	struct signalfd_siginfo info;

	pfd.fd = sfd;
	pfd.events = POLLIN;
	pfd.revents = 0;
	if (poll(&pfd, 1, timeo) < 0) {
		perror("poll");
		return -1;
	}
	if ((pfd.revents & POLLIN) == 0) {
		fprintf(stdout, "no signals\n");
		return -1;
	}
	if ((n = read(sfd, &info, sizeof(info))) < 0) {
		perror("signal dequeue");
		return -1;
	} else if (n == 0) {
		fprintf(stdout, "task detached the sighand\n");
		return 0;
	}

	return info.ssi_signo;
}

int gettid(void) {

	return syscall(__NR_gettid);
}

int tkill(unsigned long tid, int sig) {

	return syscall(__NR_tkill, tid, sig);
}

#if defined(USE_PTHREAD)

static thread_id_t thread_new(void *(*proc)(void *), void *data) {
	pthread_t tid;

	if (pthread_create(&tid, NULL, proc, data) != 0) {
		perror("pthread_create()");
		return 0;
	}

	return tid;
}

static int thread_wait(thread_id_t tid) {

	if (pthread_join(tid, NULL)) {
		perror("pthread_wait()");
		return -1;
	}

	return 0;
}

static thread_id_t thread_id(void) {

	return gettid();
}

#else

#define THREAD_STK_SIZE (1024 * 64)

static thread_id_t thread_new(void *(*proc)(void *), void *data) {
	int tid;
	char *stk;

	stk = malloc(THREAD_STK_SIZE);
	if ((tid = clone((int (*)(void *)) proc, stk + THREAD_STK_SIZE - sizeof(long),
			 CLONE_FS | CLONE_FILES | CLONE_SIGHAND | CLONE_VM | SIGCHLD,
			 data)) < 0) {
		perror("clone()");
		return 0;
	}

	return tid;
}

static int thread_wait(thread_id_t tid) {

	if (waitpid(tid, NULL, __WALL) != tid) {
		perror("waitpid()");
		return -1;
	}

	return 0;
}

static thread_id_t thread_id(void) {

	return gettid();
}

#endif

static void dummy_sig(int sig) {

	fprintf(stderr, "*** got REAL signal %d (%s)\n", sig, strsignal(sig));
}

static void *thproc(void *data) {
	long thn = (long) data, sig;

	tids[thn] = thread_id();
	fprintf(stdout, "thread %ld tid is %lu pgrp=%d\n", thn, tids[thn], getpgrp());
	while ((sig = waitsig(sfd, -1)) > 0) {
		fprintf(stdout, "thread %ld got sig = %ld (%s)\n",
			thn, sig, strsignal(sig));
		if (sig == TEST_SIG)
			break;
	}
	fprintf(stdout, "thread %ld quit (sig = %ld)\n", thn, sig);
	kill(0, TEST_SIG);

	return NULL;
}

int main(int ac, char **av) {
	int i, sfd2, sigs;
	long lsig;
	pid_t pid;
	struct signalfd_siginfo info;
	sigset_t sset, oset;

	setvbuf(stdout, NULL, _IONBF, 0);
	setvbuf(stderr, NULL, _IONBF, 0);
	signal(TEST_SIG, dummy_sig);
	sigfillset(&sset);
	sigdelset(&sset, SIGINT);
	sigprocmask(SIG_BLOCK, &sset, &oset);
	if ((sfd = signalfd(-1, &sset, SIZEOF_SIGSET)) == -1) {
		perror("signalfd");
		return 1;
	}
	fprintf(stdout, "signalfd = %d\n", sfd);

	fprintf(stdout, "creating child (SIGCHLD test) ...\n");
	if ((pid = fork()) == 0) {
		fprintf(stdout, "child exit\n");
		exit(0);
	}
	fprintf(stdout, "waiting  SIGCHLD ...\n");
	lsig = waitsig(sfd, 1000);
	fprintf(stdout, "got sig = %ld (%s)\n\n", lsig, strsignal(lsig));

	fprintf(stdout, "creating child (child send SIGUSR1 test) ...\n");
	if ((pid = fork()) == 0) {
		fprintf(stdout, "child sends SIGUSR1\n");
		kill(getppid(), SIGUSR1);
		exit(0);
	}
	fprintf(stdout, "waiting signal ...\n");
	lsig = waitsig(sfd, 1000);
	fprintf(stdout, "got sig = %ld (%s) - expect %d (%s)\n",
		lsig, strsignal(lsig), SIGUSR1, strsignal(SIGUSR1));
	fprintf(stdout, "waiting signal ...\n");
	lsig = waitsig(sfd, 1000);
	fprintf(stdout, "got sig = %ld (%s) - expect %d (%s)\n",
		lsig, strsignal(lsig), SIGCHLD, strsignal(SIGCHLD));
	fputs("\n", stdout);

	fprintf(stdout, "creating child (parent send SIGUSR1 test) ...\n");
	if ((pid = fork()) == 0) {
		fprintf(stdout, "child waiting signal ...\n");
		lsig = waitsig(sfd, 1000);
		fprintf(stdout, "child got sig = %ld (%s) - expect %d (%s)\n",
			lsig, strsignal(lsig), SIGUSR1, strsignal(SIGUSR1));

		exit(0);
	}
	fprintf(stdout, "parent sends SIGUSR1\n");
	kill(pid, SIGUSR1);
	usleep(250000);
	fprintf(stdout, "waiting signal ...\n");
	lsig = waitsig(sfd, 1000);
	fprintf(stdout, "got sig = %ld (%s) - expect %d (%s)\n\n",
		lsig, strsignal(lsig), SIGCHLD, strsignal(SIGCHLD));

	fprintf(stdout, "setting new mask ...\n");
	sigfillset(&sset);
	sigdelset(&sset, SIGUSR1);
	if ((sfd = signalfd(sfd, &sset, SIZEOF_SIGSET)) == -1) {
		perror("signalfd");
		return 1;
	}
	fprintf(stdout, "new signalfd = %d\n", sfd);
	fprintf(stdout, "sending SIGUSR1\n");
	kill(0, SIGUSR1);
	fprintf(stdout, "waiting SIGUSR1 ...\n");
	if ((lsig = waitsig(sfd, 0)) > 0)
		fprintf(stdout, "whooops! got sig = %ld (%s)\n", lsig, strsignal(lsig));
	else
		fprintf(stdout, "no signal, correct\n");
	fputs("\n", stdout);

	fprintf(stdout, "creating new signalfd (multiple fd receive test) ...\n");
	sigfillset(&sset);
	if ((sfd = signalfd(sfd, &sset, SIZEOF_SIGSET)) == -1) {
		perror("signalfd");
		return 1;
	}
	fprintf(stdout, "new signalfd = %d\n", sfd);
	if ((sfd2 = signalfd(-1, &sset, SIZEOF_SIGSET)) == -1) {
		perror("signalfd");
		return 1;
	}
	fprintf(stdout, "signalfd2 = %d\n", sfd2);
	fprintf(stdout, "parent sends SIGUSR1\n");
	kill(0, SIGUSR1);
	sigs = 0;
	if ((lsig = waitsig(sfd, 0)) > 0)
		sigs++;
	fprintf(stdout, "1st fd got sig = %ld (%s)\n", lsig, strsignal(lsig));
	if ((lsig = waitsig(sfd2, 0)) > 0)
		sigs++;
	fprintf(stdout, "2nd fd got sig = %ld (%s)\n", lsig, strsignal(lsig));
	if (sigs > 1)
		fprintf(stdout, "whooops! got 2 sigs instead of one!\n");
	fputs("\n", stdout);
	close(sfd2);

	fprintf(stdout, "multi-thread test ...\n");
	for (i = 0; i < 8; i++) {
		if ((thid[i] = thread_new(thproc, (void *) (long) i)) == 0)
			return 1;
		fprintf(stdout, "thread %d is %ld (%p)\n", i, (long) thid[i],
			(void *) (long) thid[i]);
	}
	sleep(1);

	fprintf(stdout, "sending signal %d (%s) pgrp=%d ...\n",
		TEST_SIG, strsignal(TEST_SIG), getpgrp());
	kill(0, TEST_SIG);
	for (i = 0; i < 8; i++) {
		fprintf(stdout, "waiting for thread %d\n", i);
		thread_wait(thid[i]);
	}

	while ((i = waitsig(sfd, 0)) > 0)
		fprintf(stdout, "flushing signal %d (%s)\n", i, strsignal(i));

	fprintf(stdout, "setting O_NONBLOCK (non blocking read test) ...\n");
	fcntl(sfd, F_SETFL, fcntl(sfd, F_GETFL, 0) | O_NONBLOCK);
	if (read(sfd, &info, sizeof(info)) > 0)
		fprintf(stdout, "whooops! read signal when should have not\n\n");
	else if (errno != EAGAIN)
		fprintf(stdout, "whooops! bad errno value (%d = '%s')!\n\n",
			errno, strerror(errno));
	else
		fprintf(stdout, "success\n\n");

	fcntl(sfd, F_SETFL, fcntl(sfd, F_GETFL, 0) & ~O_NONBLOCK);
	close(sfd);

	return 0;
}

