summaryrefslogtreecommitdiff
path: root/threaded_sieve.cc
blob: 3e683c15d16b2a1c521e6b2162e8ef08748216fc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include "sieve.hh"
#include <cmath>
#include <cstring>
#include <iostream>
#include <pthread.h>

struct BlockArgs {
  BlockArgs(const std::vector<int> &p, size_t block_sz)
      : s(block_sz), primes(p) {
    n = (unsigned char *)malloc(s);
  }
  ~BlockArgs() { free(n); }

  std::size_t l, h;              // in
  const size_t s;                // in
  unsigned char *n;              // in
  const std::vector<int> primes; // in
  std::vector<int> r;            // out
};

static void *worker(void *args) {
  auto *a = reinterpret_cast<BlockArgs *>(args);

  for (std::size_t l = a->l; l < a->h; l += a->s) {
    memset(a->n, false, a->s);

    for (const auto prime : a->primes) {
      // find the first number divisible by prime
      std::size_t f = (l / prime) * prime;
      if (f < l)
        f += prime;

      for (std::size_t i = f; i < l + a->s; i += prime)
        a->n[i - l] = true;
    }

    for (size_t i = 0; i < a->s; ++i)
      if (a->n[i] == false)
        a->r.emplace_back(i + l);
  }

  return nullptr;
}

constexpr auto thread_count = 16;

[[nodiscard]] std::vector<int> pthreadSieve(std::size_t size) {
  const int s = sqrt(size);
  const int block_sz = size / thread_count;

  const auto primes = sieve3(s);
  std::vector<int> r{primes};

  pthread_t threads[thread_count];
  BlockArgs *args[thread_count]{nullptr};
  for (int i = 0; i < thread_count; ++i) {
    args[i] = new BlockArgs(primes, s);
  }

  {
    int i = 0; // thread index
    for (std::size_t l = s; l < size; l += block_sz) {

      // start thread i
      args[i]->l = l;
      args[i]->h = std::min(l + block_sz, size);
      pthread_create(&threads[i], NULL, worker, args[i]);
      // worker(args[i]);

      ++i;
    }

    // wait out the threads
    for (int j = 0; j < i; ++j) {
      pthread_join(threads[j], NULL);
      r.reserve(r.size() + args[j]->r.size());
      r.insert(r.end(), args[j]->r.begin(), args[j]->r.end());
    }
  }
  // cleanup
  for (int i = 0; i < thread_count; ++i)
    delete args[i];

  return r;
}