/* See LICENSE file for copyright and license details. */
#include <sys/mount.h>
#include <pthread.h>
#include <libsimple.h>
#include <libsimple-arg.h>
USAGE("[-o offset] [-l length | -e postend] device [< random-source]");
enum direction {
FORWARDS = 0,
BACKWARDS = 1
};
struct span {
off_t start;
off_t end;
off_t bad;
size_t blocksize;
};
static _Atomic volatile sig_atomic_t exiting = 0;
static struct span *spans = NULL;
static size_t nspans = 0;
static size_t spans_size = 0;
static off_t shredded = 0;
static off_t total_size = 0;
static char reservoir[128U << 10];
static size_t reservoir_off = sizeof(reservoir);
static int use_stdin;
static char total_size_1000[256];
static char total_size_1024[256];
static uintmax_t bad_writes = 0;
static uintmax_t bad_sections = 0;
static off_t bad_bytes = 0;
static clockid_t clck = CLOCK_MONOTONIC_COARSE;
static const char *clkcstr = "CLOCK_MONOTONIC_COARSE";
static struct timespec last_success = {-1, 0};
static struct timespec max_success = {-1, 0};
static struct timespec start_time;
#define MILLISECONDS(X) X##000000L
static const struct timespec progress_print_interval = {0, MILLISECONDS(500)};
static const struct timespec poll_timeout = {0, MILLISECONDS(500)};
static int progress_print_sig_pipe[2];
static pthread_mutex_t progress_mutex;
/* TODO deal with machine and process suspension */
static struct timespec write_average_begin_times[5];
static off_t write_average_amounts[ELEMSOF(write_average_begin_times)] = {0};
static int write_average_i = 0;
static enum direction direction = FORWARDS;
static uintmax_t pass_nr = 1;
static void
signal_handler(int signo)
{
(void) signo;
exiting = 1;
}
static off_t
filesize(int fd, const char *fname)
{
struct stat st;
if (fstat(fd, &st))
eprintf("fstat %s:", fname);
switch (st.st_mode & S_IFMT) {
case S_IFREG:
break;
case S_IFBLK:
if (ioctl(fd, BLKGETSIZE64, &st.st_size) < 0)
eprintf("ioctl %s BLKGETSIZE64:", fname);
break;
default:
eprintf("%s: not a regular file or block device", fname);
}
return st.st_size;
}
static void
ensure_random(size_t needed)
{
size_t off;
ssize_t r;
if (sizeof(reservoir) - reservoir_off >= needed)
return;
if (!use_stdin) {
libsimple_random_bytes(&libsimple_random_bits, NULL, reservoir, reservoir_off);
reservoir_off = 0;
return;
}
for (off = 0; off < reservoir_off;) {
r = read(STDIN_FILENO, &reservoir[off], reservoir_off - off);
if (r <= 0) {
if (!r)
eprintf("random source depleted");
if (errno == EINTR)
continue;
eprintf("read <stdin>:");
}
off += (size_t)r;
}
reservoir_off = 0;
}
static void
add_span(off_t off, off_t amount, size_t blocksize)
{
off_t end = off + amount;
while ((off_t)(blocksize >> 1) >= amount)
blocksize >>= 1;
if (nspans == spans_size) {
spans_size += 1024;
spans = ereallocarray(spans, spans_size, sizeof(*spans));
}
spans[nspans].start = off;
spans[nspans].end = end;
spans[nspans].bad = amount;
spans[nspans].blocksize = blocksize;
nspans++;
}
static char *
humansize1000(off_t s, char *buf)
{
const char *units = "kMGTPEZYRQ";
size_t unit = 0;
if (s < 1000) {
sprintf(buf, "%u B", (unsigned)s);
return buf;
}
s /= 100;
while (units[unit + 1U] && s >= 10000) {
s /= 1000;
unit++;
}
sprintf(buf, "%u.%u %cB", (unsigned)s / 10U, (unsigned)s % 10U, units[unit]);
return buf;
}
static char *
humansize1024(off_t s, char *buf)
{
const char *units = "KMGTPEZYRQ";
size_t unit = 0;
if (s < 1024) {
sprintf(buf, "%u B", (unsigned)s);
return buf;
}
while (units[unit + 1U] && s >= 1024 * 1024) {
s /= 1024;
unit++;
}
sprintf(buf, "%lu.%lu %ciB", (unsigned long int)s / 1024UL, (unsigned long int)(s * 10 % 10240) / 1024UL, units[unit]);
return buf;
}
#if defined(__GNUC__)
__attribute__((__pure__))
#endif
static off_t
unhumansize(const char *s, char flag)
{
off_t sum = 0, term, digit, divisor, power, base;
if (!isdigit(s[0]) && !(s[0] == '-' && isdigit(s[1])))
usage();
do {
divisor = 1;
term = 0;
while (isdigit(*s)) {
digit = (*s++ & 15);
if (term > (OFF_MAX - digit) / 10)
eprintf("value of -%c flag is too large", flag);
term = term * 10 + digit;
}
if (*s == '.') {
s++;
while (isdigit(*s)) {
digit = (*s++ & 15);
if (term > (OFF_MAX - digit) / 10)
eprintf("value of -%c flag is too large", flag);
term = term * 10 + digit;
divisor *= 10;
}
}
power = 0;
switch (*s) {
case 'Q': power++; /* fall through */
case 'R': power++; /* fall through */
case 'Y': power++; /* fall through */
case 'Z': power++; /* fall through */
case 'E': power++; /* fall through */
case 'P': power++; /* fall through */
case 'T': power++; /* fall through */
case 'G': power++; /* fall through */
case 'M': power++; /* fall through */
case 'k': case 'K': power++;
if (s[1] == 'i' || s[2] == 'B') {
base = 1024;
s = &s[3];
} else if (s[1] == 'B') {
base = 1000;
s = &s[2];
} else {
base = 1024;
s = &s[1];
}
while (power) {
term *= base;
power--;
}
break;
case 'B':
if (!power && divisor > 1)
usage();
s++;
break;
default:
break;
}
sum += term /= divisor;
while (*s == ' ' || *s == ',' || *s == '+')
s++;
} while (isdigit(s[0]) || (s[0] == '.' && isdigit(s[1])));
return sum;
}
static const char *
durationstr(const struct timespec *dur, char *buf, int second_decimals)
{
uintmax_t ss, s, m, h, d, ss_div = UINTMAX_C(1000000000);
char *p;
const char *unit;
int i;
if (dur->tv_sec < 0 || dur->tv_nsec < 0)
return "-";
if (second_decimals < 0)
second_decimals = 0;
else if (second_decimals > 9)
second_decimals = 9;
for (i = 0; i < second_decimals; i++)
ss_div /= 10U;
ss = (uintmax_t)dur->tv_nsec / ss_div;
s = (uintmax_t)dur->tv_sec % 60U;
m = (uintmax_t)dur->tv_sec / 60U % 60U;
h = (uintmax_t)dur->tv_sec / 60U / 60U % 24U;
d = (uintmax_t)dur->tv_sec / 60U / 60U / 24U;
p = buf;
if (d)
p += sprintf(p, "%ju days, ", d);
if (h) {
p += sprintf(p, "%ju:%02ju:%02ju", h, m, s);
unit = "hours";
} else if (m) {
p += sprintf(p, "%ju:%02ju", m, s);
unit = "minutes";
} else {
p += sprintf(p, "%ju", s);
unit = "seconds";
}
if (second_decimals)
p += sprintf(p, ".%0*ju", second_decimals, ss);
p += sprintf(p, " %s", unit);
return buf;
}
#if defined(__linux__)
#include <linux/kd.h>
static int
have_micro_symbol(void)
{
static int ret = -1;
if (ret < 0) {
struct unimapdesc desc;
struct unipair *pairs = NULL;
size_t i;
ret = 1;
desc.entry_ct = 0;
desc.entries = NULL;
if (ioctl(STDIN_FILENO, GIO_UNIMAP, &desc))
if (!desc.entry_ct)
goto out;
desc.entries = pairs = ecalloc(desc.entry_ct, sizeof(*pairs));
if (ioctl(STDIN_FILENO, GIO_UNIMAP, &desc))
goto out;
for (i = 0; i < desc.entry_ct; i++)
if (desc.entries[i++].unicode == 0xB5U)
goto out;
ret = 0;
out:
free(pairs);
}
return ret;
}
#else
# define have_micro_symbol() 1
#endif
static int
was_write_average_overrun(int i, const struct timespec *now, int seconds)
{
struct timespec diff;
libsimple_difftimespec(&diff, now, &write_average_begin_times[i]);
if (diff.tv_sec >= seconds)
return 1;
if (diff.tv_sec == seconds - 1 && diff.tv_nsec >= 900000000L)
return 1;
return 0;
}
static void
shift_write_average(void)
{
write_average_i--;
memmove(&write_average_begin_times[0], &write_average_begin_times[1],
(size_t)write_average_i * sizeof(*write_average_begin_times));
memmove(&write_average_amounts[0], &write_average_amounts[1],
(size_t)write_average_i * sizeof(*write_average_amounts));
}
static void
print_progress(int done, const struct timespec *now)
{
static char buf1[2048] = {0};
static char buf2[2048] = {0};
static int bufi = 0;
char subbuf1[256];
char subbuf2[256];
char subbuf3[256];
char subbuf4[256];
char subbuf5[256];
char subbuf6[256];
char subbuf7[256];
char write_average_buf[512];
struct timespec since_success = {-1, 0};
struct timespec time_spent, write_average_time;
int i;
off_t write_average_sum = 0;
double write_average;
for (i = 0; i <= write_average_i; i++)
write_average_sum += write_average_amounts[i];
libsimple_difftimespec(&write_average_time, now, &write_average_begin_times[0]);
if (write_average_time.tv_sec < 0 || (write_average_time.tv_sec == 0 && write_average_time.tv_nsec < 100000000L)) {
stpcpy(write_average_buf, "-");
} else if (write_average_sum == 0) {
stpcpy(write_average_buf, "0");
} else {
const char *units_big = "kMGTPEZYRQ";
const char *units_small = "munpfazyrq";
int unit = -1;
write_average = (double)write_average_time.tv_nsec;
write_average /= 1000000000L;
write_average += (double)write_average_time.tv_sec;
write_average = (double)write_average_sum / write_average;
if (write_average < (double)0.01f) {
do {
write_average *= 1000;
unit++;
} while (units_small[unit + 1] && write_average < (double)0.01f);
if (units_small[unit] == 'u' && have_micro_symbol())
sprintf(write_average_buf, "%.02lf µB/s", write_average);
else
sprintf(write_average_buf, "%.02lf %cB/s", write_average, units_small[unit]);
} else {
while (units_big[unit + 1] && write_average >= 1000) {
write_average /= 1000;
unit++;
}
if (unit < 0)
sprintf(write_average_buf, "%.02lf B/s", write_average);
else
sprintf(write_average_buf, "%.02lf %cB/s", write_average, units_big[unit]);
}
}
if (last_success.tv_sec >= 0) {
libsimple_difftimespec(&since_success, now, &last_success);
if (libsimple_cmptimespec(&since_success, &max_success) > 0)
max_success = since_success;
}
/* TODO deal with machine and process suspension */
libsimple_difftimespec(&time_spent, now, &start_time);
sprintf(bufi == 0 ? buf1 : buf2,
"%ji bytes (%s, %s, %.2lf %%) of %s (%s) shredded\033[K\n"
"failed writes: %ju; bad sections: %ju (%s, %s)\033[K\n"
"time spent shredding: %s; performance: %s\033[K\n"
"time since last successful write: %s\033[K\n"
"maximum time until a successful write: %s\033[K\n"
"pass: %ju; pass direction: %s\033[K\n"
"%s",
/* line 1 { */
(intmax_t)shredded,
humansize1000(shredded, subbuf1),
humansize1024(shredded, subbuf2),
100 * (double)shredded / (double)total_size,
total_size_1000,
total_size_1024,
/* } line 2 { */
bad_writes,
bad_sections,
humansize1000(bad_bytes, subbuf3),
humansize1024(bad_bytes, subbuf4),
/* } line 3 { */
durationstr(&time_spent, subbuf5, 1),
write_average_buf,
/* } line 4 { */
durationstr(&since_success, subbuf6, 2),
/* } line 5 { */
durationstr(&max_success, subbuf7, 2),
/* } line 6 { */
pass_nr,
direction == FORWARDS ? "forwards" : "backwards",
/* } */
done ? "" : "\033[6A");
if (strcmp(buf1, buf2)) {
fprintf(stderr, "%s", bufi == 0 ? buf1 : buf2);
fflush(stderr);
}
bufi ^= 1;
}
static void
update_progress(const struct timespec *now)
{
if (was_write_average_overrun(write_average_i, now, 1)) {
write_average_i++;
if (write_average_i == ELEMSOF(write_average_amounts))
shift_write_average();
write_average_begin_times[write_average_i] = *now;
write_average_amounts[write_average_i] = 0;
}
while (write_average_i && was_write_average_overrun(0, now, (int)ELEMSOF(write_average_amounts)))
shift_write_average();
}
static void *
progress_print_loop(void *user)
{
struct pollfd pfd = {.fd = progress_print_sig_pipe[0], .events = POLLOUT};
struct timespec now;
ssize_t r;
int terminate = 0;
sigset_t sigset;
(void) user;
sigemptyset(&sigset);
sigaddset(&sigset, SIGTERM);
sigaddset(&sigset, SIGINT);
errno = pthread_sigmask(SIG_UNBLOCK, &sigset, NULL);
if (errno)
eprintf("pthread_sigmask SIG_UNBLOCK {SIGTERM, SIGINT} NULL:");
do {
switch (ppoll(&pfd, 1U, &poll_timeout, NULL)) {
case -1:
if (errno != EINTR)
eprintf("ppoll:");
if (exiting)
break;
continue;
case 0:
break;
default:
r = read(progress_print_sig_pipe[0], &now, sizeof(now));
if (r == (ssize_t)sizeof(now)) {
goto have_time;
} else if (!r) {
terminate = 1;
} else if (r < 0) {
if (errno == EINTR)
continue;
eprintf("read <internal pipe>:");
}
break;
}
if (clock_gettime(clck, &now))
eprintf("clock_gettime %s:", clkcstr);
have_time:
pthread_mutex_lock(&progress_mutex);
if (exiting) {
fprintf(stderr, "\033[K\nTermination initialised by user...\033[K\n\033[K\n");
terminate = 1;
}
print_progress(0, &now);
update_progress(&now);
pthread_mutex_unlock(&progress_mutex);
} while (!terminate);
return NULL;
}
static void
shredspan(int fd, struct span *span, const char *fname)
{
off_t off, n;
ssize_t r;
struct timespec now, when = {0, 0};
int bad = span->bad > 0;
pthread_mutex_lock(&progress_mutex);
off = (direction == FORWARDS ? span->start : span->end);
while (direction == FORWARDS ? off < span->end : off > span->start) {
if (exiting) {
userexit:
if (direction == FORWARDS)
span->start = off;
else
span->end = off;
close(progress_print_sig_pipe[1]);
progress_print_sig_pipe[1] = -1;
goto out;
}
if (clock_gettime(clck, &now))
eprintf("clock_gettime %s:", clkcstr);
if (libsimple_cmptimespec(&now, &when) >= 0) {
libsimple_sumtimespec(&when, &now, &progress_print_interval);
write(progress_print_sig_pipe[1], &now, sizeof(now));
}
ensure_random(span->blocksize);
if (direction == FORWARDS) {
n = MIN((off_t)span->blocksize, span->end - off);
} else {
n = off--;
off &= ~(off_t)(span->blocksize - 1U);
if (off < span->start)
off = span->start;
n -= off;
if (!n)
break;
}
pthread_mutex_unlock(&progress_mutex);
pwrite_again:
r = pwrite(fd, &reservoir[reservoir_off], (size_t)n, off);
if (r < 0) {
if (errno == EINTR) {
if (exiting) {
pthread_mutex_lock(&progress_mutex);
goto userexit;
}
goto pwrite_again;
}
pthread_mutex_lock(&progress_mutex);
if (errno != EIO)
weprintf("pwrite %s <buffer> %zu %ji:", fname, (size_t)n, (intmax_t)off);
add_span(off, n, span->blocksize == 1U ? 1U : span->blocksize);
if (direction == FORWARDS)
off += n;
if (!span->bad)
bad_bytes += n;
if (bad)
bad = 0;
else
bad_sections += 1U;
when.tv_sec = 0;
when.tv_nsec = 0;
bad_writes += 1U;
continue;
}
pthread_mutex_lock(&progress_mutex);
if (direction == FORWARDS) {
off += (off_t)r;
} else if ((off_t)r < n) {
n -= (off_t)r;
add_span(off + (off_t)r, n, span->blocksize == 1U ? 1U : span->blocksize);
if (direction == FORWARDS)
off += n;
if (!span->bad)
bad_bytes += n;
if (bad)
bad = 0;
else
bad_sections += 1U;
when.tv_sec = 0;
when.tv_nsec = 0;
bad_writes += 1U;
}
shredded += (off_t)r;
reservoir_off += (size_t)r;
write_average_amounts[write_average_i] += (off_t)r;
last_success = now;
if (span->bad) {
bad_bytes -= (off_t)r;
span->bad -= (off_t)r;
}
}
if (bad && !span->bad)
bad_sections -= 1U;
out:
pthread_mutex_unlock(&progress_mutex);
}
static void
dump_map(int fd, const char *fname)
{
size_t i;
int r;
if (!nspans)
return;
for (i = 0; i < nspans; i++) {
r = dprintf(fd, "%s%jx-%jx/%zx",
i ? "," : "0x",
(uintmax_t)spans[i].start,
(uintmax_t)spans[i].end,
spans[i].blocksize);
if (r < 0)
goto fail;
}
r = dprintf(fd, "\n");
if (r < 0)
goto fail;
return;
fail:
eprintf("dprintf %s:", fname);
}
int
main(int argc, char *argv[])
{
off_t off = -1, len = -1, end = -1;
size_t i, j;
int fd;
struct timespec now;
struct sigaction sa;
pthread_t progress_print_thread;
sigset_t sigset;
ARGBEGIN {
case 'o':
if (off >= 0)
usage();
off = unhumansize(ARG(), FLAG());
break;
case 'l':
if (len >= 0 || end >= 0)
usage();
len = unhumansize(ARG(), FLAG());
break;
case 'e':
if (len >= 0 || end >= 0)
usage();
end = unhumansize(ARG(), FLAG());
break;
default:
usage();
} ARGEND;
if (argc != 1)
usage();
memset(&sa, 0, sizeof(sa));
sa.sa_handler = &signal_handler;
sigemptyset(&sa.sa_mask);
if (sigaction(SIGTERM, &sa, NULL))
eprintf("sigaction SIGTERM {.sa_handler=<function>, .sa_mask={}, .sa_flags=0} NULL:");
if (sigaction(SIGINT, &sa, NULL))
eprintf("sigaction SIGINT {.sa_handler=<function>, .sa_mask={}, .sa_flags=0} NULL:");
sigemptyset(&sigset);
sigaddset(&sigset, SIGTERM);
sigaddset(&sigset, SIGINT);
errno = pthread_sigmask(SIG_BLOCK, &sigset, NULL);
if (errno)
eprintf("pthread_sigmask SIG_BLOCK {SIGTERM, SIGINT} NULL:");
fd = open(argv[0], O_WRONLY | O_DSYNC);
if (fd < 0)
eprintf("open %s O_WRONLY|O_DSYNC:", argv[0]);
use_stdin = !isatty(STDIN_FILENO);
if (use_stdin) {
struct stat st;
if (fstat(STDIN_FILENO, &st)) {
if (errno == EBADF)
use_stdin = 0;
else
eprintf("fstat <stdin>:");
} else {
if (S_ISFIFO(st.st_mode) || S_ISCHR(st.st_mode) || S_ISSOCK(st.st_mode))
weprintf("stdin is open but is not a TTY, character device, FIFO, or socket");
}
}
if (!use_stdin) {
libsimple_srand();
}
spans = emalloc(sizeof(*spans));
spans[0].start = 0;
spans[0].end = filesize(fd, argv[0]);
spans[0].bad = 0;
spans[0].blocksize = sizeof(reservoir);
while (spans[0].blocksize & (spans[0].blocksize - 1U))
spans[0].blocksize &= (spans[0].blocksize - 1U);
nspans = 1U;
spans_size = 1U;
if (off >= 0) {
if (off > spans[0].end)
eprintf("value of -o flag is beyond the end of the file");
spans[0].start = off;
}
if (len >= 0) {
if (len > OFF_MAX - spans[0].start)
eprintf("the sum of the values of -o and -l flag is too large");
end = spans[0].start + len;
if (end > spans[0].end)
eprintf("the sum of the values of -o and -l flag is beyond the end of the file");
spans[0].end = end;
} else if (end >= 0) {
if (end > spans[0].end)
eprintf("the value of -e flag is beyond the end of the file");
spans[0].end = end;
}
total_size = spans[0].end - spans[0].start;
humansize1000(total_size, total_size_1000);
humansize1024(total_size, total_size_1024);
if (pipe(progress_print_sig_pipe))
eprintf("pipe:");
errno = pthread_mutex_init(&progress_mutex, NULL);
if (errno)
eprintf("pthread_mutex_init NULL:");
errno = pthread_create(&progress_print_thread, NULL, &progress_print_loop, NULL);
if (errno)
eprintf("pthread_create NULL:");
if (clock_gettime(clck, &start_time)) {
clck = CLOCK_MONOTONIC;
clkcstr = "CLOCK_MONOTONIC";
if (clock_gettime(clck, &start_time))
eprintf("clock_gettime %s:", clkcstr);
}
write_average_begin_times[0] = start_time;
while (nspans) {
size_t old_nspans = nspans;
for (i = 0; i < old_nspans; i++)
shredspan(fd, &spans[i], argv[0]);
for (i = 0, j = nspans, nspans -= old_nspans; i < nspans;)
spans[i++] = spans[--j];
if (exiting)
break;
direction ^= 1;
pass_nr++;
}
close(fd);
if (progress_print_sig_pipe[1] >= 0)
close(progress_print_sig_pipe[1]);
errno = pthread_join(progress_print_thread, NULL);
if (errno)
weprintf("pthread_join:");
pthread_mutex_destroy(&progress_mutex);
if (clock_gettime(clck, &now))
eprintf("clock_gettime %s:", clkcstr);
print_progress(1, &now);
if (nspans) {
/* TODO document in man page */
dump_map(STDOUT_FILENO, "<stdout>");
if (close(STDOUT_FILENO))
eprintf("write <stdout>");
}
return 0;
}