|           Line data    Source code 
       1             : // Copyright (c) 2018 The Dash Core developers
       2             : // Copyright (c) 2021 The PIVX Core developers
       3             : // Distributed under the MIT/X11 software license, see the accompanying
       4             : // file COPYING or http://www.opensource.org/licenses/mit-license.php.
       5             : 
       6             : #include "bls/bls_worker.h"
       7             : #include "hash.h"
       8             : #include "serialize.h"
       9             : #include "util/system.h"
      10             : #include "util/threadnames.h"
      11             : 
      12             : 
      13             : template <typename T>
      14         369 : bool VerifyVectorHelper(const std::vector<T>& vec, size_t start, size_t count)
      15             : {
      16         369 :     if (start == 0 && count == 0) {
      17         192 :         count = vec.size();
      18             :     }
      19         369 :     std::set<uint256> set;
      20        4218 :     for (size_t i = start; i < start + count; i++) {
      21        3849 :         if (!vec[i].IsValid())
      22             :             return false;
      23             :         // check duplicates
      24        3849 :         if (!set.emplace(vec[i].GetHash()).second) {
      25             :             return false;
      26             :         }
      27             :     }
      28             :     return true;
      29             : }
      30             : 
      31             : // Creates a doneCallback and a future. The doneCallback simply finishes the future
      32             : template <typename T>
      33         410 : std::pair<std::function<void(const T&)>, std::future<T> > BuildFutureDoneCallback()
      34             : {
      35         410 :     auto p = std::make_shared<std::promise<T> >();
      36        2581 :     std::function<void(const T&)> f = [p](const T& v) {
      37         410 :         p->set_value(v);
      38             :     };
      39        1230 :     return std::make_pair(std::move(f), p->get_future());
      40             : }
      41             : template <typename T>
      42           0 : std::pair<std::function<void(T)>, std::future<T> > BuildFutureDoneCallback2()
      43             : {
      44           0 :     auto p = std::make_shared<std::promise<T> >();
      45           0 :     std::function<void(const T&)> f = [p](T v) {
      46           0 :         p->set_value(v);
      47             :     };
      48           0 :     return std::make_pair(std::move(f), p->get_future());
      49             : }
      50             : 
      51             : 
      52             : /////
      53             : 
      54         476 : CBLSWorker::CBLSWorker()
      55             : {
      56         476 : }
      57             : 
      58         476 : CBLSWorker::~CBLSWorker()
      59             : {
      60         476 :     Stop();
      61         476 : }
      62             : 
      63         348 : void CBLSWorker::Start()
      64             : {
      65         348 :     int workerCount = GetNumCores() / 2;
      66         348 :     workerCount = std::max(std::min(1, workerCount), 4);
      67         348 :     workerPool.resize(workerCount);
      68         348 :     RenameThreadPool(workerPool, "pivx-bls-work");
      69         348 : }
      70             : 
      71         834 : void CBLSWorker::Stop()
      72             : {
      73         834 :     workerPool.clear_queue();
      74         834 :     workerPool.stop(true);
      75         834 : }
      76             : 
      77         109 : bool CBLSWorker::GenerateContributions(int quorumThreshold, const BLSIdVector& ids, BLSVerificationVectorPtr& vvecRet, BLSSecretKeyVector& skShares)
      78             : {
      79         109 :     BLSSecretKeyVectorPtr svec = std::make_shared<BLSSecretKeyVector>((size_t)quorumThreshold);
      80         218 :     vvecRet = std::make_shared<BLSVerificationVector>((size_t)quorumThreshold);
      81         109 :     skShares.resize(ids.size());
      82             : 
      83        1447 :     for (int i = 0; i < quorumThreshold; i++) {
      84        1338 :         (*svec)[i].MakeNewKey();
      85             :     }
      86         218 :     std::list<std::future<bool> > futures;
      87         109 :     size_t batchSize = 8;
      88             : 
      89         338 :     for (size_t i = 0; i < (size_t)quorumThreshold; i += batchSize) {
      90         229 :         size_t start = i;
      91         229 :         size_t count = std::min(batchSize, quorumThreshold - start);
      92        2025 :         auto f = [&, start, count](int threadId) {
      93        1567 :             for (size_t j = start; j < start + count; j++) {
      94        1338 :                 (*vvecRet)[j] = (*svec)[j].GetPublicKey();
      95             :             }
      96         229 :             return true;
      97         229 :         };
      98         458 :         futures.emplace_back(workerPool.push(f));
      99             :     }
     100             : 
     101         378 :     for (size_t i = 0; i < ids.size(); i += batchSize) {
     102         269 :         size_t start = i;
     103         269 :         size_t count = std::min(batchSize, ids.size() - start);
     104        2614 :         auto f = [&, start, count](int threadId) {
     105        2076 :             for (size_t j = start; j < start + count; j++) {
     106        1807 :                 if (!skShares[j].SecretKeyShare(*svec, ids[j])) {
     107             :                     return false;
     108             :                 }
     109             :             }
     110             :             return true;
     111         269 :         };
     112         538 :         futures.emplace_back(workerPool.push(f));
     113             :     }
     114         109 :     bool success = true;
     115         607 :     for (auto& f : futures) {
     116         498 :         if (!f.get()) {
     117           0 :             success = false;
     118             :         }
     119             :     }
     120         218 :     return success;
     121             : }
     122             : 
     123             : // aggregates a single vector of BLS objects in parallel
     124             : // the input vector is split into batches and each batch is aggregated in parallel
     125             : // when enough batches are finished to form a new batch, the new batch is queued for further parallel aggregation
     126             : // when no more batches can be created from finished batch results, the final aggregated is created and the doneCallback
     127             : // called.
     128             : // The Aggregator object needs to be created on the heap and it will delete itself after calling the doneCallback
     129             : // The input vector is not copied into the Aggregator but instead a vector of pointers to the original entries from the
     130             : // input vector is stored. This means that the input vector must stay alive for the whole lifetime of the Aggregator
     131             : template <typename T>
     132             : struct Aggregator : public std::enable_shared_from_this<Aggregator<T>> {
     133             :     typedef T ElementType;
     134             : 
     135             :     size_t batchSize{16};
     136             :     ctpl::thread_pool& workerPool;
     137             :     bool parallel;
     138             : 
     139             :     std::shared_ptr<std::vector<const T*> > inputVec;
     140             : 
     141             :     std::mutex m;
     142             :     // items in the queue are all intermediate aggregation results of finished batches.
     143             :     // The intermediate results must be deleted by us again (which we do in SyncAggregateAndPushAggQueue)
     144             :     ctpl::detail::Queue<T*> aggQueue;
     145             :     std::atomic<size_t> aggQueueSize{0};
     146             : 
     147             :     typedef std::function<void(const T& agg)> DoneCallback;
     148             :     DoneCallback doneCallback;
     149             : 
     150             :     // keeps track of currently queued/in-progress batches. If it reaches 0, we are done
     151             :     std::atomic<size_t> waitCount{0};
     152             : 
     153             :     // TP can either be a pointer or a reference
     154             :     template <typename TP>
     155         574 :     Aggregator(const std::vector<TP>& _inputVec,
     156             :                size_t start, size_t count,
     157             :                bool _parallel,
     158             :                ctpl::thread_pool& _workerPool,
     159             :                DoneCallback _doneCallback) :
     160             :             workerPool(_workerPool),
     161             :             parallel(_parallel),
     162         574 :             doneCallback(std::move(_doneCallback))
     163             :     {
     164         574 :         inputVec = std::make_shared<std::vector<const T*> >(count);
     165        6237 :         for (size_t i = 0; i < count; i++) {
     166        5663 :             (*inputVec)[i] = pointer(_inputVec[start + i]);
     167             :         }
     168         574 :     }
     169             : 
     170        3619 :     const T* pointer(const T& v) { return &v; }
     171        2044 :     const T* pointer(const T* v) { return v; }
     172             : 
     173             :     // Starts aggregation.
     174             :     // If parallel=true, then this will return fast, otherwise this will block until aggregation is done
     175         574 :     void Start()
     176             :     {
     177         574 :         size_t batchCount = (inputVec->size() + batchSize - 1) / batchSize;
     178             : 
     179         574 :         if (!parallel) {
     180           0 :             if (inputVec->size() == 1) {
     181           0 :                 doneCallback(*(*inputVec)[0]);
     182             :             } else {
     183           0 :                 doneCallback(SyncAggregate(*inputVec, 0, inputVec->size()));
     184             :             }
     185           0 :             return;
     186             :         }
     187             : 
     188         574 :         if (batchCount == 1) {
     189             :             // just a single batch of work, take a shortcut.
     190        1392 :             auto self(this->shared_from_this());
     191        1702 :             PushWork([self](int threadId) {
     192         464 :               size_t vecSize = self->inputVec->size();
     193         464 :               if (vecSize == 1) {
     194          15 :                   self->doneCallback(*(*self->inputVec)[0]);
     195             :               } else {
     196         598 :                   self->doneCallback(self->SyncAggregate(*self->inputVec, 0, vecSize));
     197             :               }
     198             :             });
     199         464 :             return;
     200             :         }
     201             : 
     202             :         // increment wait counter as otherwise the first finished async aggregation might signal that we're done
     203         110 :         IncWait();
     204         440 :         for (size_t i = 0; i < batchCount; i++) {
     205         330 :             size_t start = i * batchSize;
     206         330 :             size_t count = std::min(batchSize, inputVec->size() - start);
     207         330 :             AsyncAggregateAndPushAggQueue(inputVec, start, count, false);
     208             :         }
     209             :         // this will decrement the wait counter and in most cases NOT finish, as async work is still in progress
     210         110 :         CheckDone();
     211             :     }
     212             : 
     213         440 :     void IncWait()
     214             :     {
     215         110 :         ++waitCount;
     216             :     }
     217             : 
     218         440 :     void CheckDone()
     219             :     {
     220         440 :         if (--waitCount == 0) {
     221         110 :             Finish();
     222             :         }
     223         440 :     }
     224             : 
     225         110 :     void Finish()
     226             :     {
     227             :         // All async work is done, but we might have items in the aggQueue which are the results of the async
     228             :         // work. This is the case when these did not add up to a new batch. In this case, we have to aggregate
     229             :         // the items into the final result
     230             : 
     231         110 :         std::vector<T*> rem(aggQueueSize);
     232         440 :         for (size_t i = 0; i < rem.size(); i++) {
     233         330 :             T* p = nullptr;
     234         330 :             bool s = aggQueue.pop(p);
     235         330 :             assert(s);
     236         330 :             rem[i] = p;
     237             :         }
     238             : 
     239         150 :         T r;
     240         110 :         if (rem.size() == 1) {
     241             :             // just one intermediate result, which is actually the final result
     242           0 :             r = *rem[0];
     243             :         } else {
     244             :             // multiple intermediate results left which did not add up to a new batch. aggregate them now
     245         220 :             r = SyncAggregate(rem, 0, rem.size());
     246             :         }
     247             : 
     248             :         // all items which are left in the queue are intermediate results, so we must delete them
     249         440 :         for (size_t i = 0; i < rem.size(); i++) {
     250         450 :             delete rem[i];
     251             :         }
     252         110 :         doneCallback(r);
     253         110 :     }
     254             : 
     255         330 :     void AsyncAggregateAndPushAggQueue(const std::shared_ptr<std::vector<const T*>>& vec, size_t start, size_t count, bool del)
     256             :     {
     257         330 :         IncWait();
     258         660 :         auto self(this->shared_from_this());
     259         330 :         PushWork([self, vec, start, count, del](int threadId){
     260         330 :           self->SyncAggregateAndPushAggQueue(vec, start, count, del);
     261             :         });
     262         330 :     }
     263             : 
     264         330 :     void SyncAggregateAndPushAggQueue(const std::shared_ptr<std::vector<const T*> >& vec, size_t start, size_t count, bool del)
     265             :     {
     266             :         // aggregate vec and push the intermediate result onto the work queue
     267         330 :         PushAggQueue(SyncAggregate(*vec, start, count));
     268         330 :         if (del) {
     269           0 :             for (size_t i = 0; i < count; i++) {
     270           0 :                 delete (*vec)[start + i];
     271             :             }
     272             :         }
     273         330 :         CheckDone();
     274         330 :     }
     275             : 
     276         330 :     void PushAggQueue(const T& v)
     277             :     {
     278         330 :         auto copyT = new T(v);
     279             :         try {
     280         330 :             aggQueue.push(copyT);
     281           0 :         } catch (...) {
     282           0 :             delete copyT;
     283           0 :             throw;
     284             :         }
     285             : 
     286         330 :         if (++aggQueueSize >= batchSize) {
     287             :             // we've collected enough intermediate results to form a new batch.
     288           0 :             std::shared_ptr<std::vector<const T*> > newBatch;
     289             :             {
     290           0 :                 std::unique_lock<std::mutex> l(m);
     291           0 :                 if (aggQueueSize < batchSize) {
     292             :                     // some other worker thread grabbed this batch
     293           0 :                     return;
     294             :                 }
     295           0 :                 newBatch = std::make_shared<std::vector<const T*> >(batchSize);
     296             :                 // collect items for new batch
     297           0 :                 for (size_t i = 0; i < batchSize; i++) {
     298           0 :                     T* p = nullptr;
     299           0 :                     bool s = aggQueue.pop(p);
     300           0 :                     assert(s);
     301           0 :                     (*newBatch)[i] = p;
     302             :                 }
     303           0 :                 aggQueueSize -= batchSize;
     304             :             }
     305             : 
     306             :             // push new batch to work queue. del=true this time as these items are intermediate results and need to be deleted
     307             :             // after aggregation is done
     308           0 :             AsyncAggregateAndPushAggQueue(newBatch, 0, newBatch->size(), true);
     309             :         }
     310             :     }
     311             : 
     312             :     template <typename TP>
     313         889 :     T SyncAggregate(const std::vector<TP>& vec, size_t start, size_t count)
     314             :     {
     315         889 :         T result = *vec[start];
     316        5978 :         for (size_t j = 1; j < count; j++) {
     317        5089 :             result.AggregateInsecure(*vec[start + j]);
     318             :         }
     319         889 :         return result;
     320             :     }
     321             : 
     322             :     template <typename Callable>
     323         794 :     void PushWork(Callable&& f)
     324             :     {
     325         794 :         workerPool.push(f);
     326         794 :     }
     327             : };
     328             : 
     329             : // Aggregates multiple input vectors into a single output vector
     330             : // Inputs are in the following form:
     331             : //   [
     332             : //     [a1, b1, c1, d1],
     333             : //     [a2, b2, c2, d2],
     334             : //     [a3, b3, c3, d3],
     335             : //     [a4, b4, c4, d4],
     336             : //   ]
     337             : // The result is in the following form:
     338             : //   [ a1+a2+a3+a4, b1+b2+b3+b4, c1+c2+c3+c4, d1+d2+d3+d4]
     339             : // Same rules for the input vectors apply to the VectorAggregator as for the Aggregator (they must stay alive)
     340             : template <typename T>
     341             : struct VectorAggregator : public std::enable_shared_from_this<VectorAggregator<T>> {
     342             :     typedef Aggregator<T> AggregatorType;
     343             :     typedef std::vector<T> VectorType;
     344             :     typedef std::shared_ptr<VectorType> VectorPtrType;
     345             :     typedef std::vector<VectorPtrType> VectorVectorType;
     346             :     typedef std::function<void(const VectorPtrType& agg)> DoneCallback;
     347             : 
     348             :     const VectorVectorType& vecs;
     349             :     bool parallel;
     350             :     size_t start;
     351             :     size_t count;
     352             : 
     353             :     ctpl::thread_pool& workerPool;
     354             : 
     355             :     DoneCallback doneCallback;
     356             :     std::atomic<size_t> doneCount;
     357             : 
     358             :     VectorPtrType result;
     359             :     size_t vecSize;
     360             : 
     361         156 :     VectorAggregator(const VectorVectorType& _vecs,
     362             :                      size_t _start, size_t _count,
     363             :                      bool _parallel, ctpl::thread_pool& _workerPool,
     364             :                      DoneCallback _doneCallback) :
     365             :             vecs(_vecs),
     366             :             parallel(_parallel),
     367             :             start(_start),
     368             :             count(_count),
     369             :             workerPool(_workerPool),
     370         156 :             doneCallback(std::move(_doneCallback))
     371             :     {
     372         156 :         assert(!vecs.empty());
     373         156 :         vecSize = vecs[0]->size();
     374         156 :         result = std::make_shared<VectorType>(vecSize);
     375         156 :         doneCount = 0;
     376         156 :     }
     377             : 
     378         156 :     void Start()
     379             :     {
     380         496 :         for (size_t i = 0; i < vecSize; i++) {
     381         680 :             std::vector<const T*> tmp(count);
     382        2384 :             for (size_t j = 0; j < count; j++) {
     383        2044 :                 tmp[j] = &(*vecs[start + j])[i];
     384             :             }
     385             : 
     386        1020 :             auto self(this->shared_from_this());
     387        2720 :             auto aggregator = std::make_shared<AggregatorType>(std::move(tmp), 0, count, parallel, workerPool, [self, i](const T& agg) {self->CheckDone(agg, i);});
     388         340 :             aggregator->Start();
     389             :         }
     390         156 :     }
     391             : 
     392         340 :     void CheckDone(const T& agg, size_t idx)
     393             :     {
     394         340 :         (*result)[idx] = agg;
     395         340 :         if (++doneCount == vecSize) {
     396         156 :             doneCallback(result);
     397             :         }
     398         340 :     }
     399             : };
     400             : 
     401             : // See comment of AsyncVerifyContributionShares for a description on what this does
     402             : // Same rules as in Aggregator apply for the inputs
     403             : struct ContributionVerifier : public std::enable_shared_from_this<ContributionVerifier> {
     404             :     struct BatchState {
     405             :         size_t start;
     406             :         size_t count;
     407             : 
     408             :         BLSVerificationVectorPtr vvec;
     409             :         CBLSSecretKey skShare;
     410             : 
     411             :         // starts with 0 and is incremented if either vvec or skShare aggregation finishes. If it reaches 2, we know
     412             :         // that aggregation for this batch is fully done. We can then start verification.
     413             :         std::unique_ptr<std::atomic<int> > aggDone;
     414             : 
     415             :         // we can't directly update a vector<bool> in parallel
     416             :         // as vector<bool> is not thread safe (uses bitsets internally)
     417             :         // so we must use vector<char> temporarily and concatenate/convert
     418             :         // each batch result into a final vector<bool>
     419             :         std::vector<char> verifyResults;
     420             :     };
     421             : 
     422             :     CBLSId forId;
     423             :     const std::vector<BLSVerificationVectorPtr>& vvecs;
     424             :     const BLSSecretKeyVector& skShares;
     425             :     size_t batchSize;
     426             :     bool parallel;
     427             :     bool aggregated;
     428             : 
     429             :     ctpl::thread_pool& workerPool;
     430             : 
     431             :     size_t batchCount;
     432             :     size_t verifyCount;
     433             : 
     434             :     std::vector<BatchState> batchStates;
     435             :     std::atomic<size_t> verifyDoneCount{0};
     436             :     std::function<void(const std::vector<bool>&)> doneCallback;
     437             : 
     438          57 :     ContributionVerifier(const CBLSId& _forId, const std::vector<BLSVerificationVectorPtr>& _vvecs,
     439             :                          const BLSSecretKeyVector& _skShares, size_t _batchSize,
     440             :                          bool _parallel, bool _aggregated, ctpl::thread_pool& _workerPool,
     441          57 :                          std::function<void(const std::vector<bool>&)> _doneCallback) :
     442             :         forId(_forId),
     443             :         vvecs(_vvecs),
     444             :         skShares(_skShares),
     445             :         batchSize(_batchSize),
     446             :         parallel(_parallel),
     447             :         aggregated(_aggregated),
     448             :         workerPool(_workerPool),
     449          57 :         doneCallback(std::move(_doneCallback))
     450             :     {
     451             :     }
     452             : 
     453          57 :     void Start()
     454             :     {
     455          57 :         if (!aggregated) {
     456             :             // treat all inputs as one large batch
     457           0 :             batchSize = vvecs.size();
     458           0 :             batchCount = 1;
     459             :         } else {
     460          57 :             batchCount = (vvecs.size() + batchSize - 1) / batchSize;
     461             :         }
     462          57 :         verifyCount = vvecs.size();
     463             : 
     464          57 :         batchStates.resize(batchCount);
     465         114 :         for (size_t i = 0; i < batchCount; i++) {
     466          57 :             auto& batchState = batchStates[i];
     467             : 
     468          57 :             batchState.aggDone.reset(new std::atomic<int>(0));
     469          57 :             batchState.start = i * batchSize;
     470          57 :             batchState.count = std::min(batchSize, vvecs.size() - batchState.start);
     471          57 :             batchState.verifyResults.assign(batchState.count, 0);
     472             :         }
     473             : 
     474          57 :         if (aggregated) {
     475         114 :             size_t batchCount2 = batchCount; // 'this' might get deleted while we're still looping
     476         114 :             for (size_t i = 0; i < batchCount2; i++) {
     477          57 :                 AsyncAggregate(i);
     478             :             }
     479             :         } else {
     480             :             // treat all inputs as a single batch and verify one-by-one
     481           0 :             AsyncVerifyBatchOneByOne(0);
     482             :         }
     483          57 :     }
     484             : 
     485          57 :     void Finish()
     486             :     {
     487          57 :         size_t batchIdx = 0;
     488          57 :         std::vector<bool> result(vvecs.size());
     489         114 :         for (size_t i = 0; i < vvecs.size(); i += batchSize) {
     490          57 :             auto& batchState = batchStates[batchIdx++];
     491         211 :             for (size_t j = 0; j < batchState.count; j++) {
     492         308 :                 result[batchState.start + j] = batchState.verifyResults[j] != 0;
     493             :             }
     494             :         }
     495          57 :         doneCallback(result);
     496          57 :     }
     497             : 
     498          57 :     void AsyncAggregate(size_t batchIdx)
     499             :     {
     500          57 :         auto& batchState = batchStates[batchIdx];
     501             : 
     502             :         // aggregate vvecs and skShares of batch in parallel
     503          57 :         auto self(this->shared_from_this());
     504         456 :         auto vvecAgg = std::make_shared<VectorAggregator<CBLSPublicKey>>(vvecs, batchState.start, batchState.count, parallel, workerPool, [self, batchIdx] (const BLSVerificationVectorPtr& vvec) {self->HandleAggVvecDone(batchIdx, vvec);});
     505         456 :         auto skShareAgg = std::make_shared<Aggregator<CBLSSecretKey>>(skShares, batchState.start, batchState.count, parallel, workerPool, [self, batchIdx] (const CBLSSecretKey& skShare) {self->HandleAggSkShareDone(batchIdx, skShare);});
     506             : 
     507          57 :         vvecAgg->Start();
     508          57 :         skShareAgg->Start();
     509          57 :     }
     510             : 
     511          57 :     void HandleAggVvecDone(size_t batchIdx, const BLSVerificationVectorPtr& vvec)
     512             :     {
     513          57 :         auto& batchState = batchStates[batchIdx];
     514          57 :         batchState.vvec = vvec;
     515          57 :         if (++(*batchState.aggDone) == 2) {
     516          20 :             HandleAggDone(batchIdx);
     517             :         }
     518          57 :     }
     519          57 :     void HandleAggSkShareDone(size_t batchIdx, const CBLSSecretKey& skShare)
     520             :     {
     521          57 :         auto& batchState = batchStates[batchIdx];
     522          57 :         batchState.skShare = skShare;
     523          57 :         if (++(*batchState.aggDone) == 2) {
     524          37 :             HandleAggDone(batchIdx);
     525             :         }
     526          57 :     }
     527             : 
     528          69 :     void HandleVerifyDone(size_t batchIdx, size_t count)
     529             :     {
     530          69 :         size_t c = verifyDoneCount += count;
     531          69 :         if (c == verifyCount) {
     532          57 :             Finish();
     533             :         }
     534          69 :     }
     535             : 
     536          57 :     void HandleAggDone(size_t batchIdx)
     537             :     {
     538          57 :         auto& batchState = batchStates[batchIdx];
     539             : 
     540          57 :         if (batchState.vvec == nullptr || batchState.vvec->empty() || !batchState.skShare.IsValid()) {
     541             :             // something went wrong while aggregating and there is nothing we can do now except mark the whole batch as failed
     542             :             // this can only happen if inputs were invalid in some way
     543           0 :             batchState.verifyResults.assign(batchState.count, 0);
     544           0 :             HandleVerifyDone(batchIdx, batchState.count);
     545           0 :             return;
     546             :         }
     547             : 
     548          57 :         AsyncAggregatedVerifyBatch(batchIdx);
     549             :     }
     550             : 
     551          57 :     void AsyncAggregatedVerifyBatch(size_t batchIdx)
     552             :     {
     553          57 :         auto self(shared_from_this());
     554         228 :         auto f = [self, batchIdx](int threadId) {
     555          57 :           auto& batchState = self->batchStates[batchIdx];
     556          57 :           bool result = self->Verify(batchState.vvec, batchState.skShare);
     557          57 :           if (result) {
     558             :               // whole batch is valid
     559          51 :               batchState.verifyResults.assign(batchState.count, 1);
     560          51 :               self->HandleVerifyDone(batchIdx, batchState.count);
     561             :           } else {
     562             :               // at least one entry in the batch is invalid, revert to per-contribution verification (but parallelized)
     563           6 :               self->AsyncVerifyBatchOneByOne(batchIdx);
     564             :           }
     565          57 :         };
     566          57 :         PushOrDoWork(std::move(f));
     567          57 :     }
     568             : 
     569           6 :     void AsyncVerifyBatchOneByOne(size_t batchIdx)
     570             :     {
     571           6 :         size_t count = batchStates[batchIdx].count;
     572           6 :         batchStates[batchIdx].verifyResults.assign(count, 0);
     573          24 :         for (size_t i = 0; i < count; i++) {
     574          36 :             auto self(this->shared_from_this());
     575          36 :             PushOrDoWork([self, i, batchIdx](int threadId) {
     576          18 :               auto& batchState = self->batchStates[batchIdx];
     577          18 :               batchState.verifyResults[i] = self->Verify(self->vvecs[batchState.start + i], self->skShares[batchState.start + i]);
     578          18 :               self->HandleVerifyDone(batchIdx, 1);
     579          18 :             });
     580             :         }
     581           6 :     }
     582             : 
     583          75 :     bool Verify(const BLSVerificationVectorPtr& vvec, const CBLSSecretKey& skShare)
     584             :     {
     585          75 :         CBLSPublicKey pk1;
     586          75 :         if (!pk1.PublicKeyShare(*vvec, forId)) {
     587             :             return false;
     588             :         }
     589             : 
     590          75 :         CBLSPublicKey pk2 = skShare.GetPublicKey();
     591          75 :         return pk1 == pk2;
     592             :     }
     593             : 
     594             :     template <typename Callable>
     595          75 :     void PushOrDoWork(Callable&& f)
     596             :     {
     597          75 :         if (parallel) {
     598         150 :             workerPool.push(std::move(f));
     599             :         } else {
     600           0 :             f(0);
     601             :         }
     602          75 :     }
     603             : };
     604             : 
     605         176 : void CBLSWorker::AsyncBuildQuorumVerificationVector(const std::vector<BLSVerificationVectorPtr>& vvecs,
     606             :                                                     size_t start, size_t count, bool parallel,
     607             :                                                     std::function<void(const BLSVerificationVectorPtr&)> doneCallback)
     608             : {
     609         176 :     if (start == 0 && count == 0) {
     610         176 :         count = vvecs.size();
     611             :     }
     612         176 :     if (vvecs.empty() || count == 0 || start > vvecs.size() || start + count > vvecs.size()) {
     613           0 :         doneCallback(nullptr);
     614          77 :         return;
     615             :     }
     616         176 :     if (!VerifyVerificationVectors(vvecs, start, count)) {
     617          77 :         doneCallback(nullptr);
     618          77 :         return;
     619             :     }
     620             : 
     621         198 :     auto agg = std::make_shared<VectorAggregator<CBLSPublicKey>>(vvecs, start, count, parallel, workerPool, std::move(doneCallback));
     622          99 :     agg->Start();
     623             : }
     624             : 
     625         176 : std::future<BLSVerificationVectorPtr> CBLSWorker::AsyncBuildQuorumVerificationVector(const std::vector<BLSVerificationVectorPtr>& vvecs,
     626             :                                                                                      size_t start, size_t count, bool parallel)
     627             : {
     628         352 :     auto p = BuildFutureDoneCallback<BLSVerificationVectorPtr>();
     629         176 :     AsyncBuildQuorumVerificationVector(vvecs, start, count, parallel, std::move(p.first));
     630         176 :     return std::move(p.second);
     631             : }
     632             : 
     633         176 : BLSVerificationVectorPtr CBLSWorker::BuildQuorumVerificationVector(const std::vector<BLSVerificationVectorPtr>& vvecs,
     634             :                                                                    size_t start, size_t count, bool parallel)
     635             : {
     636         352 :     return AsyncBuildQuorumVerificationVector(vvecs, start, count, parallel).get();
     637             : }
     638             : 
     639             : template <typename T>
     640         177 : void AsyncAggregateHelper(ctpl::thread_pool& workerPool,
     641             :                           const std::vector<T>& vec, size_t start, size_t count, bool parallel,
     642             :                           std::function<void(const T&)> doneCallback)
     643             : {
     644         177 :     if (start == 0 && count == 0) {
     645         177 :         count = vec.size();
     646             :     }
     647         177 :     if (vec.empty() || count == 0 || start > vec.size() || start + count > vec.size()) {
     648           0 :         doneCallback(T());
     649           0 :         return;
     650             :     }
     651         177 :     if (!VerifyVectorHelper(vec, start, count)) {
     652           0 :         doneCallback(T());
     653           0 :         return;
     654             :     }
     655             : 
     656         354 :     auto agg = std::make_shared<Aggregator<T>>(vec, start, count, parallel, workerPool, std::move(doneCallback));
     657         177 :     agg->Start();
     658             : }
     659             : 
     660         137 : void CBLSWorker::AsyncAggregateSecretKeys(const BLSSecretKeyVector& secKeys,
     661             :                                           size_t start, size_t count, bool parallel,
     662             :                                           std::function<void(const CBLSSecretKey&)> doneCallback)
     663             : {
     664         137 :     AsyncAggregateHelper(workerPool, secKeys, start, count, parallel, doneCallback);
     665         137 : }
     666             : 
     667         137 : std::future<CBLSSecretKey> CBLSWorker::AsyncAggregateSecretKeys(const BLSSecretKeyVector& secKeys,
     668             :                                                                 size_t start, size_t count, bool parallel)
     669             : {
     670         274 :     auto p = BuildFutureDoneCallback<CBLSSecretKey>();
     671         137 :     AsyncAggregateSecretKeys(secKeys, start, count, parallel, std::move(p.first));
     672         137 :     return std::move(p.second);
     673             : }
     674             : 
     675         137 : CBLSSecretKey CBLSWorker::AggregateSecretKeys(const BLSSecretKeyVector& secKeys,
     676             :                                               size_t start, size_t count, bool parallel)
     677             : {
     678         274 :     return AsyncAggregateSecretKeys(secKeys, start, count, parallel).get();
     679             : }
     680             : 
     681          40 : void CBLSWorker::AsyncAggregatePublicKeys(const BLSPublicKeyVector& pubKeys,
     682             :                                           size_t start, size_t count, bool parallel,
     683             :                                           std::function<void(const CBLSPublicKey&)> doneCallback)
     684             : {
     685          40 :     AsyncAggregateHelper(workerPool, pubKeys, start, count, parallel, doneCallback);
     686          40 : }
     687             : 
     688          40 : std::future<CBLSPublicKey> CBLSWorker::AsyncAggregatePublicKeys(const BLSPublicKeyVector& pubKeys,
     689             :                                                                 size_t start, size_t count, bool parallel)
     690             : {
     691          80 :     auto p = BuildFutureDoneCallback<CBLSPublicKey>();
     692          40 :     AsyncAggregatePublicKeys(pubKeys, start, count, parallel, std::move(p.first));
     693          40 :     return std::move(p.second);
     694             : }
     695             : 
     696          40 : CBLSPublicKey CBLSWorker::AggregatePublicKeys(const BLSPublicKeyVector& pubKeys,
     697             :                                               size_t start, size_t count, bool parallel)
     698             : {
     699          80 :     return AsyncAggregatePublicKeys(pubKeys, start, count, parallel).get();
     700             : }
     701             : 
     702           0 : void CBLSWorker::AsyncAggregateSigs(const BLSSignatureVector& sigs,
     703             :                                     size_t start, size_t count, bool parallel,
     704             :                                     std::function<void(const CBLSSignature&)> doneCallback)
     705             : {
     706           0 :     AsyncAggregateHelper(workerPool, sigs, start, count, parallel, doneCallback);
     707           0 : }
     708             : 
     709           0 : std::future<CBLSSignature> CBLSWorker::AsyncAggregateSigs(const BLSSignatureVector& sigs,
     710             :                                                           size_t start, size_t count, bool parallel)
     711             : {
     712           0 :     auto p = BuildFutureDoneCallback<CBLSSignature>();
     713           0 :     AsyncAggregateSigs(sigs, start, count, parallel, std::move(p.first));
     714           0 :     return std::move(p.second);
     715             : }
     716             : 
     717           0 : CBLSSignature CBLSWorker::AggregateSigs(const BLSSignatureVector& sigs,
     718             :                                         size_t start, size_t count, bool parallel)
     719             : {
     720           0 :     return AsyncAggregateSigs(sigs, start, count, parallel).get();
     721             : }
     722             : 
     723             : 
     724        1867 : CBLSPublicKey CBLSWorker::BuildPubKeyShare(const BLSVerificationVectorPtr& vvec, const CBLSId& id)
     725             : {
     726        1867 :     CBLSPublicKey pkShare;
     727        1867 :     pkShare.PublicKeyShare(*vvec, id);
     728        1867 :     return pkShare;
     729             : }
     730             : 
     731          57 : void CBLSWorker::AsyncVerifyContributionShares(const CBLSId& forId, const std::vector<BLSVerificationVectorPtr>& vvecs, const BLSSecretKeyVector& skShares,
     732             :                                                bool parallel, bool aggregated, std::function<void(const std::vector<bool>&)> doneCallback)
     733             : {
     734          57 :     if (!forId.IsValid() || !VerifyVerificationVectors(vvecs)) {
     735           0 :         std::vector<bool> result;
     736           0 :         result.assign(vvecs.size(), false);
     737           0 :         doneCallback(result);
     738           0 :         return;
     739             :     }
     740             : 
     741         114 :     auto verifier = std::make_shared<ContributionVerifier>(forId, vvecs, skShares, 8, parallel, aggregated, workerPool, std::move(doneCallback));
     742          57 :     verifier->Start();
     743             : }
     744             : 
     745          57 : std::future<std::vector<bool> > CBLSWorker::AsyncVerifyContributionShares(const CBLSId& forId, const std::vector<BLSVerificationVectorPtr>& vvecs, const BLSSecretKeyVector& skShares,
     746             :                                                                           bool parallel, bool aggregated)
     747             : {
     748         114 :     auto p = BuildFutureDoneCallback<std::vector<bool> >();
     749          57 :     AsyncVerifyContributionShares(forId, vvecs, skShares, parallel, aggregated, std::move(p.first));
     750          57 :     return std::move(p.second);
     751             : }
     752             : 
     753          57 : std::vector<bool> CBLSWorker::VerifyContributionShares(const CBLSId& forId, const std::vector<BLSVerificationVectorPtr>& vvecs, const BLSSecretKeyVector& skShares,
     754             :                                                        bool parallel, bool aggregated)
     755             : {
     756         114 :     return AsyncVerifyContributionShares(forId, vvecs, skShares, parallel, aggregated).get();
     757             : }
     758             : 
     759          18 : std::future<bool> CBLSWorker::AsyncVerifyContributionShare(const CBLSId& forId,
     760             :                                                            const BLSVerificationVectorPtr& vvec,
     761             :                                                            const CBLSSecretKey& skContribution)
     762             : {
     763          18 :     if (!forId.IsValid() || !VerifyVerificationVector(*vvec)) {
     764           0 :         auto p = BuildFutureDoneCallback<bool>();
     765           0 :         p.first(false);
     766           0 :         return std::move(p.second);
     767             :     }
     768             : 
     769          36 :     auto f = [this, &forId, &vvec, &skContribution](int threadId) {
     770          18 :          return VerifyContributionShare(forId, vvec, skContribution);
     771          18 :     };
     772          18 :     return workerPool.push(f);
     773             : }
     774             : 
     775        1618 : bool CBLSWorker::VerifyContributionShare(const CBLSId& forId, const BLSVerificationVectorPtr& vvec,
     776             :                                          const CBLSSecretKey& skContribution)
     777             : {
     778        1618 :     CBLSPublicKey pk1;
     779        1618 :     if (!pk1.PublicKeyShare(*vvec, forId)) {
     780             :         return false;
     781             :     }
     782             : 
     783        1618 :     CBLSPublicKey pk2 = skContribution.GetPublicKey();
     784        1618 :     return pk1 == pk2;
     785             : }
     786             : 
     787         192 : bool CBLSWorker::VerifyVerificationVector(const BLSVerificationVector& vvec, size_t start, size_t count)
     788             : {
     789         192 :     return VerifyVectorHelper(vvec, start, count);
     790             : }
     791             : 
     792         233 : bool CBLSWorker::VerifyVerificationVectors(const std::vector<BLSVerificationVectorPtr>& vvecs,
     793             :                                            size_t start, size_t count)
     794             : {
     795         233 :     if (start == 0 && count == 0) {
     796          57 :         count = vvecs.size();
     797             :     }
     798             : 
     799         466 :     std::set<uint256> set;
     800         695 :     for (size_t i = 0; i < count; i++) {
     801         539 :         auto& vvec = vvecs[start + i];
     802         539 :         if (vvec == nullptr) {
     803             :             return false;
     804             :         }
     805         462 :         if (vvec->size() != vvecs[start]->size()) {
     806             :             return false;
     807             :         }
     808        2506 :         for (size_t j = 0; j < vvec->size(); j++) {
     809        2044 :             if (!(*vvec)[j].IsValid()) {
     810             :                 return false;
     811             :             }
     812             :             // check duplicates
     813        2044 :             if (!set.emplace((*vvec)[j].GetHash()).second) {
     814             :                 return false;
     815             :             }
     816             :         }
     817             :     }
     818             : 
     819             :     return true;
     820             : }
     821             : 
     822           0 : bool CBLSWorker::VerifySecretKeyVector(const BLSSecretKeyVector& secKeys, size_t start, size_t count)
     823             : {
     824           0 :     return VerifyVectorHelper(secKeys, start, count);
     825             : }
     826             : 
     827           0 : bool CBLSWorker::VerifySignatureVector(const BLSSignatureVector& sigs, size_t start, size_t count)
     828             : {
     829           0 :     return VerifyVectorHelper(sigs, start, count);
     830             : }
     831             : 
     832           0 : void CBLSWorker::AsyncSign(const CBLSSecretKey& secKey, const uint256& msgHash, CBLSWorker::SignDoneCallback doneCallback)
     833             : {
     834           0 :     workerPool.push([secKey, msgHash, doneCallback](int threadId) {
     835           0 :         doneCallback(secKey.Sign(msgHash));
     836           0 :     });
     837           0 : }
     838             : 
     839           0 : std::future<CBLSSignature> CBLSWorker::AsyncSign(const CBLSSecretKey& secKey, const uint256& msgHash)
     840             : {
     841           0 :     auto p = BuildFutureDoneCallback<CBLSSignature>();
     842           0 :     AsyncSign(secKey, msgHash, std::move(p.first));
     843           0 :     return std::move(p.second);
     844             : }
     845             : 
     846           0 : void CBLSWorker::AsyncVerifySig(const CBLSSignature& sig, const CBLSPublicKey& pubKey, const uint256& msgHash,
     847             :                                 CBLSWorker::SigVerifyDoneCallback doneCallback, CancelCond cancelCond)
     848             : {
     849           0 :     if (!sig.IsValid() || !pubKey.IsValid()) {
     850           0 :         doneCallback(false);
     851           0 :         return;
     852             :     }
     853             : 
     854           0 :     std::unique_lock<std::mutex> l(sigVerifyMutex);
     855             : 
     856           0 :     bool foundDuplicate = false;
     857           0 :     for (auto& s : sigVerifyQueue) {
     858           0 :         if (s.msgHash == msgHash) {
     859             :             foundDuplicate = true;
     860             :             break;
     861             :         }
     862             :     }
     863             : 
     864           0 :     if (foundDuplicate) {
     865             :         // batched/aggregated verification does not allow duplicate hashes, so we push what we currently have and start
     866             :         // with a fresh batch
     867           0 :         PushSigVerifyBatch();
     868             :     }
     869             : 
     870           0 :     sigVerifyQueue.emplace_back(std::move(doneCallback), std::move(cancelCond), sig, pubKey, msgHash);
     871           0 :     if (sigVerifyBatchesInProgress == 0 || sigVerifyQueue.size() >= SIG_VERIFY_BATCH_SIZE) {
     872           0 :         PushSigVerifyBatch();
     873             :     }
     874             : }
     875             : 
     876           0 : std::future<bool> CBLSWorker::AsyncVerifySig(const CBLSSignature& sig, const CBLSPublicKey& pubKey, const uint256& msgHash, CancelCond cancelCond)
     877             : {
     878           0 :     auto p = BuildFutureDoneCallback2<bool>();
     879           0 :     AsyncVerifySig(sig, pubKey, msgHash, std::move(p.first), cancelCond);
     880           0 :     return std::move(p.second);
     881             : }
     882             : 
     883           0 : bool CBLSWorker::IsAsyncVerifyInProgress()
     884             : {
     885           0 :     std::unique_lock<std::mutex> l(sigVerifyMutex);
     886           0 :     return sigVerifyBatchesInProgress != 0;
     887             : }
     888             : 
     889             : // sigVerifyMutex must be held while calling
     890           0 : void CBLSWorker::PushSigVerifyBatch()
     891             : {
     892           0 :     auto f = [this](int threadId, std::shared_ptr<std::vector<SigVerifyJob> > _jobs) {
     893           0 :         auto& jobs = *_jobs;
     894           0 :         if (jobs.size() == 1) {
     895           0 :             auto& job = jobs[0];
     896           0 :             if (!job.cancelCond()) {
     897           0 :                 bool valid = job.sig.VerifyInsecure(job.pubKey, job.msgHash);
     898           0 :                 job.doneCallback(valid);
     899             :             }
     900           0 :             std::unique_lock<std::mutex> l(sigVerifyMutex);
     901           0 :             sigVerifyBatchesInProgress--;
     902           0 :             if (!sigVerifyQueue.empty()) {
     903           0 :                 PushSigVerifyBatch();
     904             :             }
     905           0 :             return;
     906             :         }
     907             : 
     908           0 :         CBLSSignature aggSig;
     909           0 :         std::vector<size_t> indexes;
     910           0 :         std::vector<CBLSPublicKey> pubKeys;
     911           0 :         std::vector<uint256> msgHashes;
     912           0 :         indexes.reserve(jobs.size());
     913           0 :         pubKeys.reserve(jobs.size());
     914           0 :         msgHashes.reserve(jobs.size());
     915           0 :         for (size_t i = 0; i < jobs.size(); i++) {
     916           0 :             auto& job = jobs[i];
     917           0 :             if (job.cancelCond()) {
     918           0 :                 continue;
     919             :             }
     920           0 :             if (pubKeys.empty()) {
     921           0 :                 aggSig = job.sig;
     922             :             } else {
     923           0 :                 aggSig.AggregateInsecure(job.sig);
     924             :             }
     925           0 :             indexes.emplace_back(i);
     926           0 :             pubKeys.emplace_back(job.pubKey);
     927           0 :             msgHashes.emplace_back(job.msgHash);
     928             :         }
     929             : 
     930           0 :         if (!pubKeys.empty()) {
     931           0 :             bool allValid = aggSig.VerifyInsecureAggregated(pubKeys, msgHashes);
     932           0 :             if (allValid) {
     933           0 :                 for (size_t i = 0; i < pubKeys.size(); i++) {
     934           0 :                     jobs[indexes[i]].doneCallback(true);
     935             :                 }
     936             :             } else {
     937             :                 // one or more sigs were not valid, revert to per-sig verification
     938             :                 // TODO this could be improved if we would cache pairing results in some way as the previous aggregated verification already calculated all the pairings for the hashes
     939           0 :                 for (size_t i = 0; i < pubKeys.size(); i++) {
     940           0 :                     auto& job = jobs[indexes[i]];
     941           0 :                     bool valid = job.sig.VerifyInsecure(job.pubKey, job.msgHash);
     942           0 :                     job.doneCallback(valid);
     943             :                 }
     944             :             }
     945             :         }
     946             : 
     947           0 :         std::unique_lock<std::mutex> l(sigVerifyMutex);
     948           0 :         sigVerifyBatchesInProgress--;
     949           0 :         if (!sigVerifyQueue.empty()) {
     950           0 :             PushSigVerifyBatch();
     951             :         }
     952           0 :     };
     953             : 
     954           0 :     auto batch = std::make_shared<std::vector<SigVerifyJob> >(std::move(sigVerifyQueue));
     955           0 :     sigVerifyQueue.reserve(SIG_VERIFY_BATCH_SIZE);
     956             : 
     957           0 :     sigVerifyBatchesInProgress++;
     958           0 :     workerPool.push(f, batch);
     959           0 : }
 |