Commit a6984177 authored by nmedfort's avatar nmedfort
Browse files

current progress; not quite right yet

parent 8e2d292e
......@@ -109,6 +109,14 @@ void PipelineAnalysis::makeTerminationPropagationGraph() {
HasTerminationSignal.resize(PipelineOutput + 1U);
const auto firstComputeKernelId = FirstKernelInPartition[FirstComputePartitionId];
const auto afterLastComputeKernelId = FirstKernelInPartition[LastComputePartitionId + 1];
if (LLVM_UNLIKELY(firstComputeKernelId != FirstKernel)) {
HasTerminationSignal.set(firstComputeKernelId - 1);
}
BitVector marks(PipelineOutput + 1U);
for (auto pid = KernelPartitionId[FirstKernel]; pid < PartitionCount; ++pid) {
......@@ -142,10 +150,23 @@ void PipelineAnalysis::makeTerminationPropagationGraph() {
if (kernelObj->canSetTerminateSignal()) {
add_edge(i, start, true, mTerminationPropagationGraph);
HasTerminationSignal.set(i);
} else {
// If we have a cross threaded buffer, we cannot rely on only storing the root's
// termination signal because we need to know exactly when the actual producer
// is terminated and not only whether there is any reason to execute the partition.
for (const auto e : make_iterator_range(out_edges(i , mBufferGraph))) {
const BufferNode & bn = mBufferGraph[target(e, mBufferGraph)];
if (LLVM_UNLIKELY(bn.isCrossThreaded())) {
HasTerminationSignal.set(i);
break;
}
}
}
}
}
ReverseTopologicalOrdering ordering;
ordering.reserve(num_vertices(mTerminationPropagationGraph));
topological_sort(mTerminationPropagationGraph, std::back_inserter(ordering));
......
......@@ -352,14 +352,23 @@ void PipelineCompiler::constructStreamSetBuffers(KernelBuilder & /* b */) {
* @brief readAvailableItemCounts
** ------------------------------------------------------------------------------------------------------------- */
void PipelineCompiler::readAvailableItemCounts(KernelBuilder & b) {
mKernelIsClosed.reset(FirstKernel, LastKernel);
for (const auto e : make_iterator_range(in_edges(mKernelId, mBufferGraph))) {
const auto streamSet = source(e, mBufferGraph);
if (mLocallyAvailableItems[streamSet] == nullptr || mIsIOProcessThread) {
const BufferNode & bn = mBufferGraph[streamSet];
if (bn.isCrossThreaded()) {
// We need to check the terminated signal *before* the item count or we risk getting
// an old available count and new termination signal. Do not rearrange this order.
const auto producer = parent(streamSet, mBufferGraph);
if (mKernelIsClosed[producer] == nullptr) {
mKernelIsClosed[producer] = readTerminationSignal(b, producer);
}
mLocallyAvailableItems[streamSet] = readAvailableItemCount(b, streamSet);
} else if (mLocallyAvailableItems[streamSet] == nullptr) {
assert (bn.isExternal() || bn.isConstant());
mLocallyAvailableItems[streamSet] = readAvailableItemCount(b, streamSet);
}
}
}
/** ------------------------------------------------------------------------------------------------------------- *
......@@ -533,6 +542,11 @@ void PipelineCompiler::writeCrossThreadedProducedItemCountAfterTermination(Kerne
const BufferNode & bn = mBufferGraph[streamSet];
if (bn.isCrossThreaded()) {
assert (bn.isInternal());
assert (bn.isNonThreadLocal());
if (!HasTerminationSignal.test(mKernelId)) {
errs() << mKernelId << " -> " << streamSet << "\n";
}
assert (HasTerminationSignal.test(mKernelId));
const BufferPort & br = mBufferGraph[e];
const auto outputPort = br.Port;
Value * const produced = mProducedAtTermination[outputPort];
......
......@@ -731,7 +731,12 @@ Value * PipelineCompiler::getAccessibleInputItems(KernelBuilder & b, const Buffe
const StreamSetBuffer * const buffer = bn.Buffer;
Value * const processed = mCurrentProcessedItemCountPhi[inputPort];
Value * const available = mLocallyAvailableItems[streamSet]; assert (available);
Value * const available = mLocallyAvailableItems[streamSet];
if (available == nullptr) {
errs() << "Missing avail " << mKernelId << "." << inputPort.Number << " -> " << streamSet << "\n";
}
assert (available);
#ifdef PRINT_DEBUG_MESSAGES
const auto prefix = makeBufferName(mKernelId, inputPort);
debugPrint(b, prefix + "_available = %" PRIu64, available);
......
......@@ -20,7 +20,6 @@ void PipelineCompiler::executeKernel(KernelBuilder & b) {
mFixedRateLCM = getLCMOfFixedRateInputs(mKernel);
mKernelIsInternallySynchronized = mIsInternallySynchronized.test(mKernelId);
mKernelCanTerminateEarly = mKernel->canSetTerminateSignal();
assert (HasTerminationSignal[mKernelId] == (mIsPartitionRoot || mKernelCanTerminateEarly));
mIsOptimizationBranch = isa<OptimizationBranch>(mKernel);
mRecordHistogramData = recordsAnyHistogramData();
mExecuteStridesIndividually =
......@@ -85,7 +84,7 @@ void PipelineCompiler::executeKernel(KernelBuilder & b) {
if (LLVM_UNLIKELY(mIsPartitionRoot || mKernelCanTerminateEarly)) {
mKernelInitiallyTerminated = b.CreateBasicBlock(prefix + "_initiallyTerminated", mNextPartitionEntryPoint);
if (LLVM_LIKELY(!mIsIOProcessThread)) {
if (LLVM_LIKELY(mIsPartitionRoot && !mIsIOProcessThread)) {
// if we are actually jumping over any kernels, create the basicblock for the code to perform it.
const auto jumpId = PartitionJumpTargetId[mCurrentPartitionId];
assert (jumpId > mCurrentPartitionId);
......@@ -212,18 +211,19 @@ void PipelineCompiler::executeKernel(KernelBuilder & b) {
splatMultiStepPartialSumValues(b);
if (LLVM_UNLIKELY(mCurrentKernelIsStateFree)) {
writeInternalProcessedAndProducedItemCounts(b, true);
}
if (mIsPartitionRoot || mKernelCanTerminateEarly) {
} else {
// When we have a cross-threaded buffer, we need to write out their produced counts
// prior to writing the termination signal otherwise we risk a consumer assuming
// a streamset is closed but does not know the correct item count. Rather than
// potentially slowing down the loop exit by writing an unnecessary value to the
// global state, we just write it here despite knowing it'll be written again later.
if (LLVM_LIKELY(!mCurrentKernelIsStateFree)) {
writeCrossThreadedProducedItemCountAfterTermination(b);
}
writeCrossThreadedProducedItemCountAfterTermination(b);
}
if (HasTerminationSignal.test(mKernelId)) {
writeTerminationSignal(b, mKernelId, mTerminatedSignalPhi);
propagateTerminationSignal(b);
} else {
}
// We do not release the pre-invocation synchronization lock in the execution phase
// when a kernel is terminating.
......@@ -720,10 +720,10 @@ void PipelineCompiler::writeInsufficientIOExit(KernelBuilder & b) {
assert (isFromCurrentFunction(b, mAlreadyProgressedPhi, false));
mAnyProgressedAtLoopExitPhi->addIncoming(mAlreadyProgressedPhi, exitBlock);
assert (mInitialTerminationSignal);
mTerminatedAtLoopExitPhi->addIncoming(mInitialTerminationSignal, exitBlock);
if (mKernelJumpToNextUsefulPartition) {
assert (mIsPartitionRoot);
for (const auto e : make_iterator_range(out_edges(mKernelId, mBufferGraph))) {
const auto & br = mBufferGraph[e];
const auto port = br.Port;
......@@ -850,11 +850,6 @@ void PipelineCompiler::updatePhisAfterTermination(KernelBuilder & b) {
for (const auto e : make_iterator_range(out_edges(mKernelId, mBufferGraph))) {
const auto port = mBufferGraph[e].Port;
Value * const produced = mProducedAtTermination[port];
#ifdef PRINT_DEBUG_MESSAGES
debugPrint(b, makeBufferName(mKernelId, port) + "_producedAtTermination = %" PRIu64, produced);
#endif
mUpdatedProducedPhi[port]->addIncoming(produced, exitBlock);
if (mUpdatedProducedDeferredPhi[port]) {
mUpdatedProducedDeferredPhi[port]->addIncoming(produced, exitBlock);
......
......@@ -229,7 +229,16 @@ void PipelineCompiler::generateMultiThreadKernelMethod(KernelBuilder & b) {
const auto resumePoint = b.saveIP();
const auto anyDebugOptionIsSet = codegen::AnyDebugOptionIsSet();
SmallVector<Type *, 2> csRetValFields(CheckAssertions ? 2 : 1, boolTy);
const auto hasTermSignal = !mIsNestedPipeline || PipelineHasTerminationSignal;
SmallVector<Type *, 2> csRetValFields;
csRetValFields.push_back(hasTermSignal ? sizeTy : boolTy);
if (CheckAssertions) {
csRetValFields.push_back(boolTy);
}
StructType * const csRetValType = StructType::get(b.getContext(), csRetValFields);
FixedArray<Type *, 2> csParams;
......@@ -249,8 +258,6 @@ void PipelineCompiler::generateMultiThreadKernelMethod(KernelBuilder & b) {
csDoSegmentProcessFuncType = csDoSegmentComputeFuncType;
}
const auto hasTermSignal = !mIsNestedPipeline || PipelineHasTerminationSignal;
// -------------------------------------------------------------------------------------------------------------------------
// GENERATE DO SEGMENT (KERNEL EXECUTION) FUNCTION CODE
// -------------------------------------------------------------------------------------------------------------------------
......@@ -355,22 +362,17 @@ void PipelineCompiler::generateMultiThreadKernelMethod(KernelBuilder & b) {
Value * const terminated = hasPipelineTerminated(b);
SmallVector<Value *, 2> retValFields;
retValFields.push_back(terminated);
if (hasTermSignal) {
retValFields.push_back(terminated);
} else {
retValFields.push_back(b.CreateIsNotNull(terminated));
}
if (LLVM_UNLIKELY(CheckAssertions)) {
retValFields.push_back(mPipelineProgress);
}
b.CreateAggregateRet(retValFields.data(), CheckAssertions ? 2U : 1U);
mIsIOProcessThread = false;
// if (LLVM_LIKELY(hasTermSignal)) {
// writeTerminationSignalToLocalState(b, threadStructTy, threadStruct, hasPipelineTerminated(b));
// }
// if (LLVM_UNLIKELY(CheckAssertions)) {
// b.CreateRet(mPipelineProgress);
// } else {
// b.CreateRetVoid();
// }
};
const auto outerFuncName = concat(mTarget->getName(), "_MultithreadedThread", tmp);
......
......@@ -333,7 +333,7 @@ void PipelineCompiler::generateInitializeMethod(KernelBuilder & b) {
// Is this the last kernel in a partition? If so, store the accumulated
// termination signal.
if (terminated && HasTerminationSignal[mKernelId]) {
if (terminated && HasTerminationSignal.test(mKernelId)) {
Value * const signal = b.CreateSelect(terminated, aborted, unterminated);
writeTerminationSignal(b, mKernelId, signal);
terminated = nullptr;
......
......@@ -192,8 +192,6 @@ void PipelineCompiler::waitUntilCurrentSegmentNumberIsLessThan(KernelBuilder & b
BasicBlock * const segmentCheckLoop = b.CreateBasicBlock(prefix + "_crossTheadWaitLoop", nextNode);
BasicBlock * const segmentCheckExit = b.CreateBasicBlock(prefix + "_crossTheadWaitExit", nextNode);
Value * const syncLockPtr = getSynchronizationLockPtrForKernel(b, kernelId, lockType);
Value * signalPtr; Type * signalTy;
std::tie(signalPtr, signalTy) = getKernelTerminationSignalPtr(b, kernelId);
b.CreateBr(segmentCheckLoop);
b.SetInsertPoint(segmentCheckLoop);
......@@ -205,10 +203,14 @@ void PipelineCompiler::waitUntilCurrentSegmentNumberIsLessThan(KernelBuilder & b
min = b.CreateAdd(syncNum, windowLength);
}
Value * const isProgressedFarEnough = b.CreateICmpULT(mSegNo, min);
Value * const isTerminated = b.CreateIsNotNull(b.CreateLoad(signalTy, signalPtr));
Value * const isTerminated = b.CreateIsNotNull(readTerminationSignal(b, kernelId));
b.CreateLikelyCondBr(b.CreateOr(isProgressedFarEnough, isTerminated), segmentCheckExit, segmentCheckLoop);
b.SetInsertPoint(segmentCheckExit);
#ifdef PRINT_DEBUG_MESSAGES
debugPrint(b, prefix + ": waited for cross thread %ssegment number %" PRIu64 " of %" PRIu64 " isProgressed=%" PRIu8 " isTerminated=%" PRIu8,
__getSyncLockName(b, lockType), syncNum, min, isProgressedFarEnough, isTerminated);
#endif
}
/** ------------------------------------------------------------------------------------------------------------- *
......
......@@ -23,11 +23,11 @@ void PipelineCompiler::addTerminationProperties(KernelBuilder & b, const size_t
* @brief getTerminationSignalIndex
** ------------------------------------------------------------------------------------------------------------- */
unsigned PipelineCompiler::getTerminationSignalIndex(const unsigned kernel) const {
if (HasTerminationSignal[kernel]) {
if (HasTerminationSignal.test(kernel)) {
return kernel;
} else {
const auto root = FirstKernelInPartition[KernelPartitionId[kernel]];
assert (HasTerminationSignal[root]);
assert (HasTerminationSignal.test(root));
return root;
}
}
......@@ -36,15 +36,17 @@ unsigned PipelineCompiler::getTerminationSignalIndex(const unsigned kernel) cons
* @brief hasKernelTerminated
** ------------------------------------------------------------------------------------------------------------- */
Value * PipelineCompiler::hasKernelTerminated(KernelBuilder & b, const size_t kernel, const bool normally) {
const auto idx = getTerminationSignalIndex(kernel);
const auto partitionId = KernelPartitionId[kernel];
Value * signal = nullptr;
const auto isComputeThreadPartition = (FirstComputePartitionId <= partitionId) && (partitionId <= LastComputePartitionId);
if (mIsIOProcessThread != isComputeThreadPartition) {
signal = mKernelTerminationSignal[idx];
assert (isFromCurrentFunction(b, signal, false));
} else {
signal = readIfKernelIsClosed(b, kernel);
Value * signal = mKernelIsClosed[kernel];
if (signal == nullptr) {
const auto partitionId = KernelPartitionId[kernel];
const auto isComputeThreadPartition = (FirstComputePartitionId <= partitionId) && (partitionId <= LastComputePartitionId);
if (mIsIOProcessThread != isComputeThreadPartition) {
const auto idx = getTerminationSignalIndex(kernel);
signal = mKernelTerminationSignal[idx];
assert (isFromCurrentFunction(b, signal, false));
} else {
signal = readTerminationSignal(b, kernel);
}
}
if (normally) {
Constant * const completed = getTerminationSignal(b, TerminationSignal::Completed);
......@@ -63,13 +65,13 @@ Value * PipelineCompiler::hasPipelineTerminated(KernelBuilder & b) {
Value * hard = nullptr;
Value * soft = nullptr;
Constant * const unterminated = getTerminationSignal(b, TerminationSignal::None);
// Constant * const aborted = getTerminationSignal(b, TerminationSignal::Aborted);
Constant * const fatal = getTerminationSignal(b, TerminationSignal::Fatal);
assert (KernelPartitionId[PipelineInput] == 0);
assert (KernelPartitionId[PipelineOutput] == (PartitionCount - 1));
Constant * const unterminated = getTerminationSignal(b, TerminationSignal::None);
Constant * const aborted = getTerminationSignal(b, TerminationSignal::Aborted);
Constant * const fatal = getTerminationSignal(b, TerminationSignal::Fatal);
for (auto partitionId = 1U; partitionId < (PartitionCount - 1); ++partitionId) {
if (const auto type = mTerminationCheck[partitionId]) {
const auto root = FirstKernelInPartition[partitionId];
......@@ -104,16 +106,12 @@ Value * PipelineCompiler::hasPipelineTerminated(KernelBuilder & b) {
}
}
assert (soft);
Value * signal = soft;
Value * signal = b.CreateSelect(soft, aborted, unterminated);
if (hard) {
signal = b.CreateOr(soft, hard);
signal = b.CreateSelect(hard, fatal, signal);
}
// Value * signal = b.CreateSelect(soft, aborted, unterminated);
// if (hard) {
// signal = b.CreateSelect(hard, fatal, signal);
// }
return signal;
}
......@@ -188,35 +186,15 @@ bool PipelineCompiler::kernelCanTerminateAbnormally(const unsigned kernel) const
* @brief checkIfKernelIsAlreadyTerminated
** ------------------------------------------------------------------------------------------------------------- */
void PipelineCompiler::checkIfKernelIsAlreadyTerminated(KernelBuilder & b) {
if (mIsPartitionRoot || mKernelCanTerminateEarly) {
if (HasTerminationSignal.test(mKernelId)) {
Value * const signal = readTerminationSignal(b, mKernelId);
mKernelTerminationSignal[mKernelId] = signal;
mInitialTerminationSignal = signal;
mKernelIsClosed[mKernelId] = nullptr;
mInitiallyTerminated = hasKernelTerminated(b, mKernelId);
}
}
/** ------------------------------------------------------------------------------------------------------------- *
* @brief readIfStreamSetlIsClosed
** ------------------------------------------------------------------------------------------------------------- */
Value * PipelineCompiler::readIfStreamSetlIsClosed(KernelBuilder & b, const size_t streamSet) {
return readIfKernelIsClosed(b, parent(streamSet, mBufferGraph));
}
/** ------------------------------------------------------------------------------------------------------------- *
* @brief readIfStreamSetlIsClosed
** ------------------------------------------------------------------------------------------------------------- */
Value * PipelineCompiler::readIfKernelIsClosed(KernelBuilder & b, const size_t kernelId) {
const auto idx = getTerminationSignalIndex(kernelId);
Value * signal = mKernelTerminationSignal[idx];
if (signal == nullptr) {
signal = readTerminationSignal(b, idx);
// mKernelTerminationSignal[idx] = signal;
}
return signal;
}
/** ------------------------------------------------------------------------------------------------------------- *
* @brief checkPropagatedTerminationSignals
** ------------------------------------------------------------------------------------------------------------- */
......
#pragma once
#define PRINT_DEBUG_MESSAGES
// #define PRINT_DEBUG_MESSAGES
// #define PRINT_DEBUG_MESSAGES_FOR_KERNEL_NUM 40
......
......@@ -370,8 +370,6 @@ public:
Value * readTerminationSignal(KernelBuilder & b, const unsigned kernelId) const;
ScalarRef getKernelTerminationSignalPtr(KernelBuilder & b, const unsigned kernelId) const;
void writeTerminationSignal(KernelBuilder & b, const unsigned kernelId, Value * const signal) const;
Value * readIfStreamSetlIsClosed(KernelBuilder & b, const size_t streamSet);
Value * readIfKernelIsClosed(KernelBuilder & b, const size_t kernelId);
Value * hasPipelineTerminated(KernelBuilder & b);
void signalAbnormalTermination(KernelBuilder & b);
LLVM_READNONE static Constant * getTerminationSignal(KernelBuilder & b, const TerminationSignal type);
......@@ -719,6 +717,7 @@ protected:
Vec<AllocaInst *, 16> mAddressableItemCountPtr;
Vec<AllocaInst *, 16> mVirtualBaseAddressPtr;
FixedVector<PHINode *> mInitiallyAvailableItemsPhi;
FixedVector<Value *> mKernelIsClosed;
FixedVector<Value *> mLocallyAvailableItems;
FixedVector<Value *> mScalarValue;
......@@ -998,6 +997,7 @@ inline PipelineCompiler::PipelineCompiler(PipelineKernel * const pipelineKernel,
, mZeroInputGraph(std::move(P.mZeroInputGraph))
, mInitiallyAvailableItemsPhi(FirstStreamSet, LastStreamSet, mAllocator)
, mKernelIsClosed(FirstKernel, LastKernel, mAllocator)
, mLocallyAvailableItems(FirstStreamSet, LastStreamSet, mAllocator)
, mScalarValue(FirstKernel, LastScalar, mAllocator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment