"""
The ComputeWorkerManager abstract base class is for creating and managing
a set of compute workers. Each different computeWorkerKind is implemented
as a concrete subclass of this. All these classes are found in this module.
"""
import sys
import os
from abc import ABC, abstractmethod
from concurrent import futures
import queue
import subprocess
import time
import threading
import copy
import random
try:
import boto3
except ImportError:
boto3 = None
from . import rioserrors
from .structures import Timers, BlockAssociations, NetworkDataChannel
from .structures import WorkerErrorRecord
from .structures import CW_NONE, CW_THREADS, CW_PBS, CW_SLURM, CW_AWSBATCH
from .structures import CW_SUBPROC, CW_ECS
from .readerinfo import makeReaderInfo
[docs]def getComputeWorkerManager(cwKind):
"""
Returns a compute-worker manager object of the requested kind.
"""
if cwKind in (CW_PBS, CW_SLURM):
cwMgrObj = ClassicBatchComputeWorkerMgr()
cwMgrObj.computeWorkerKind = cwKind
else:
cwMgrClass = None
subClasses = ComputeWorkerManager.__subclasses__()
for c in subClasses:
if c.computeWorkerKind == cwKind:
cwMgrClass = c
if cwMgrClass is None:
msg = "Unknown compute-worker kind '{}'".format(cwKind)
raise ValueError(msg)
cwMgrObj = cwMgrClass()
return cwMgrObj
[docs]class ComputeWorkerManager(ABC):
"""
Abstract base class for all compute-worker manager subclasses
A subclass implements a particular way of managing RIOS
compute-workers. It should over-ride all abstract methods given here.
"""
computeWorkerKind = CW_NONE
outObjList = None
outqueue = None
jobName = None
computeWorkersRead_default = None
[docs] @abstractmethod
def startWorkers(self, numWorkers=None, userFunction=None,
infiles=None, outfiles=None, otherArgs=None, controls=None,
blockList=None, inBlockBuffer=None, outBlockBuffer=None,
workinggrid=None, allInfo=None, singleBlockComputeWorkers=False,
tmpfileMgr=None, haveSharedTemp=True, exceptionQue=None):
"""
Start the specified compute workers
"""
[docs] @abstractmethod
def shutdown(self):
"""
Shutdown the computeWorkerManager
"""
[docs] def setupNetworkCommunication(self, userFunction, infiles, outfiles,
otherArgs, controls, workinggrid, allInfo, blockList,
numWorkers, inBlockBuffer, outBlockBuffer, forceExit,
exceptionQue, workerBarrier):
"""
Set up the standard methods of network communication between
the workers and the main thread. This is expected to be the
same for all workers running on separate machines from the
main thread.
Creates the dataChan and outqueue attributes.
This routine is not needed for the Threads subclass, because it
does not use the network versions of these communications.
"""
# Divide the block list into a sublist for each worker
allSublists = [blockList[i::numWorkers] for i in range(numWorkers)]
# Set up the data which is common for all workers
workerInitData = {}
workerInitData['userFunction'] = userFunction
workerInitData['infiles'] = infiles
workerInitData['outfiles'] = outfiles
workerInitData['otherArgs'] = otherArgs
workerInitData['controls'] = controls
workerInitData['workinggrid'] = workinggrid
workerInitData['allInfo'] = allInfo
# Set up the data which is local to each worker
blockListByWorker = {}
workerInitData['blockListByWorker'] = blockListByWorker
for workerID in range(numWorkers):
blockListByWorker[workerID] = allSublists[workerID]
# Create the network-visible data channel
try:
self.dataChan = NetworkDataChannel(workerInitData, inBlockBuffer,
outBlockBuffer, forceExit, exceptionQue, workerBarrier)
except rioserrors.UnavailableError as e:
if str(e) == "Failed to import cloudpickle":
msg = ("computeWorkerKind '{}' requires the cloudpickle " +
"package, which appears to be unavailable")
msg = msg.format(self.computeWorkerKind)
raise rioserrors.UnavailableError(msg) from None
else:
raise
self.outqueue = self.dataChan.outqueue
self.exceptionQue = self.dataChan.exceptionQue
[docs] def makeOutObjList(self):
"""
Make a list of all the objects the workers put into outqueue
on completion
"""
self.outObjList = []
done = False
while not done:
try:
outObj = self.outqueue.get(block=False)
self.outObjList.append(outObj)
except queue.Empty:
done = True
[docs] def setJobName(self, jobName):
"""
Sets the job name string, which is made available to worker
processes. Defaults to None, and has only cosmetic effects.
"""
self.jobName = jobName
[docs] def getWorkerName(self, workerID):
"""
Return a string which uniquely identifies each work, including
the jobName, if given.
"""
if self.jobName is not None:
workerName = "RIOS_{}_{}".format(self.jobName, workerID)
else:
workerName = "RIOS_{}".format(workerID)
return workerName
[docs]class ThreadsComputeWorkerMgr(ComputeWorkerManager):
"""
Manage compute workers using threads within the current process.
"""
computeWorkerKind = CW_THREADS
computeWorkersRead_default = False
def __init__(self):
self.threadPool = None
self.workerList = None
self.outqueue = queue.Queue()
self.forceExit = threading.Event()
[docs] def startWorkers(self, numWorkers=None, userFunction=None,
infiles=None, outfiles=None, otherArgs=None, controls=None,
blockList=None, inBlockBuffer=None, outBlockBuffer=None,
workinggrid=None, allInfo=None, singleBlockComputeWorkers=False,
tmpfileMgr=None, haveSharedTemp=True, exceptionQue=None):
"""
Start <numWorkers> threads to process blocks of data
"""
# Divide the block list into a sublist for each worker
allSublists = [blockList[i::numWorkers] for i in range(numWorkers)]
self.threadPool = futures.ThreadPoolExecutor(max_workers=numWorkers)
self.workerList = []
for workerID in range(numWorkers):
# otherArgs are not thread-safe, so each worker gets its own copy
otherArgsCopy = copy.deepcopy(otherArgs)
subBlocklist = allSublists[workerID]
worker = self.threadPool.submit(self.worker, userFunction, infiles,
outfiles, otherArgsCopy, controls, allInfo,
workinggrid, subBlocklist, inBlockBuffer, outBlockBuffer,
self.outqueue, workerID, exceptionQue)
self.workerList.append(worker)
[docs] def worker(self, userFunction, infiles, outfiles, otherArgs,
controls, allInfo, workinggrid, blockList, inBlockBuffer,
outBlockBuffer, outqueue, workerID, exceptionQue):
"""
This function is a worker for a single thread, with no reading
or writing going on. All I/O is via the inBlockBuffer and
outBlockBuffer objects.
"""
numBlocks = len(blockList)
try:
timings = Timers()
blockNdx = 0
while blockNdx < numBlocks and not self.forceExit.is_set():
with timings.interval('pop_readbuffer'):
(blockDefn, inputs) = inBlockBuffer.popNextBlock()
readerInfo = makeReaderInfo(workinggrid, blockDefn, controls,
infiles, inputs, allInfo)
outputs = BlockAssociations()
userArgs = (readerInfo, inputs, outputs)
if otherArgs is not None:
userArgs += (otherArgs, )
with timings.interval('userfunction'):
userFunction(*userArgs)
with timings.interval('insert_computebuffer'):
outBlockBuffer.insertCompleteBlock(blockDefn, outputs)
blockNdx += 1
if otherArgs is not None:
outqueue.put(otherArgs)
outqueue.put(timings)
except Exception as e:
workerErr = WorkerErrorRecord(e, 'compute', workerID)
exceptionQue.put(workerErr)
[docs] def shutdown(self):
"""
Shut down the thread pool
"""
self.forceExit.set()
futures.wait(self.workerList)
self.threadPool.shutdown()
self.makeOutObjList()
[docs]class ECSComputeWorkerMgr(ComputeWorkerManager):
"""
Manage compute workers using Amazon AWS ECS
Requires some extra parameters in the ConcurrencyStyle constructor
(computeWorkerExtraParams), in order to configure the AWS infrastructure.
This class provides some helper functions for creating these for
various use cases.
"""
computeWorkerKind = CW_ECS
defaultWaitClusterInstanceCountTimeout = 300
computeWorkersRead_default = True
[docs] def startWorkers(self, numWorkers=None, userFunction=None,
infiles=None, outfiles=None, otherArgs=None, controls=None,
blockList=None, inBlockBuffer=None, outBlockBuffer=None,
workinggrid=None, allInfo=None, singleBlockComputeWorkers=False,
tmpfileMgr=None, haveSharedTemp=True, exceptionQue=None):
"""
Start <numWorkers> ECS tasks to process blocks of data
"""
if boto3 is None:
raise rioserrors.UnavailableError("boto3 is unavailable")
self.forceExit = threading.Event()
self.workerBarrier = threading.Barrier(numWorkers + 1)
self.setupNetworkCommunication(userFunction, infiles, outfiles,
otherArgs, controls, workinggrid, allInfo, blockList,
numWorkers, inBlockBuffer, outBlockBuffer, self.forceExit,
exceptionQue, self.workerBarrier)
channAddr = self.dataChan.addressStr()
self.jobIDstr = self.makeJobIDstr(controls.jobName)
self.createdTaskDef = False
self.createdCluster = False
self.createdInstances = False
self.instanceList = None
ecsClient = boto3.client("ecs")
self.ecsClient = ecsClient
extraParams = controls.concurrency.computeWorkerExtraParams
if extraParams is None:
msg = "ECSComputeWorkerMgr requires computeWorkerExtraParams"
raise ValueError(msg)
self.extraParams = extraParams
# Create ECS cluster (if requested)
try:
self.createCluster()
self.runInstances(numWorkers)
except Exception as e:
self.shutdownCluster()
raise e
# Create the ECS task definition (if requested)
self.createTaskDef()
# Now create a task for each compute worker
runTask_kwArgs = extraParams['run_task']
runTask_kwArgs['taskDefinition'] = self.taskDefArn
containerOverrides = runTask_kwArgs['overrides']['containerOverrides'][0]
if self.createdCluster:
runTask_kwArgs['cluster'] = self.clusterName
self.taskArnList = []
for workerID in range(numWorkers):
# Construct the command args entry with the current workerID
workerCmdArgs = ['-i', str(workerID), '--channaddr', channAddr]
containerOverrides['command'] = workerCmdArgs
runTaskResponse = ecsClient.run_task(**runTask_kwArgs)
if len(runTaskResponse['tasks']) > 0:
taskResp = runTaskResponse['tasks'][0]
self.taskArnList.append(taskResp['taskArn'])
failuresList = runTaskResponse['failures']
if len(failuresList) > 0:
self.dataChan.shutdown()
msgList = []
for failure in failuresList:
reason = failure.get('reason', 'UnknownReason')
detail = failure.get('detail')
msg = "Worker {}: Reason: {}".format(workerID, reason)
if detail is not None:
msg += "\nDetail: {}".format(detail)
msgList.append(msg)
fullMsg = '\n'.join(msgList)
raise rioserrors.ECSError(fullMsg)
# Do not proceed until all workers have started
computeBarrierTimeout = controls.concurrency.computeBarrierTimeout
self.workerBarrier.wait(timeout=computeBarrierTimeout)
[docs] def shutdown(self):
"""
Shut down the workers
"""
self.forceExit.set()
self.makeOutObjList()
self.waitClusterTasksFinished()
self.checkTaskErrors()
if hasattr(self, 'dataChan'):
self.dataChan.shutdown()
if self.createdTaskDef:
self.ecsClient.deregister_task_definition(taskDefinition=self.taskDefArn)
# Shut down the ECS cluster, if one was created.
self.shutdownCluster()
[docs] def shutdownCluster(self):
"""
Shut down the ECS cluster, if one has been created
"""
if self.createdInstances and self.instanceList is not None:
instIdList = [inst['InstanceId'] for inst in self.instanceList]
self.ec2client.terminate_instances(InstanceIds=instIdList)
self.waitClusterInstanceCount(self.clusterName, 0)
if self.createdCluster:
self.ecsClient.delete_cluster(cluster=self.clusterName)
[docs] @staticmethod
def makeJobIDstr(jobName):
"""
Make a job ID string to use in various generate names. It is unique to
this run, and also includes any human-readable information available
"""
hexStr = random.randbytes(4).hex()
if jobName is None:
jobIDstr = hexStr
else:
jobIDstr = "{}-{}".format(jobName, hexStr)
return jobIDstr
[docs] def createCluster(self):
"""
If requested to do so, create an ECS cluster to run on.
"""
createCluster_kwArgs = self.extraParams.get('create_cluster')
if createCluster_kwArgs is not None:
self.clusterName = createCluster_kwArgs.get('clusterName')
self.ecsClient.create_cluster(**createCluster_kwArgs)
self.createdCluster = True
[docs] def runInstances(self, numWorkers):
"""
If requested to do so, run the instances required to populate
the cluster
"""
self.waitClusterInstanceCountTimeout = self.extraParams.get(
'waitClusterInstanceCountTimeout',
self.defaultWaitClusterInstanceCountTimeout)
runInstances_kwArgs = self.extraParams.get('run_instances')
if runInstances_kwArgs is not None:
self.ec2client = boto3.client('ec2')
response = self.ec2client.run_instances(**runInstances_kwArgs)
self.instanceList = response['Instances']
numInstances = len(self.instanceList)
self.createdInstances = True
self.waitClusterInstanceCount(self.clusterName, numInstances)
[docs] def getClusterInstanceCount(self, clusterName):
"""
Query the given cluster, and return the number of instances it has. If the
cluster does not exist, return None.
"""
count = None
response = self.ecsClient.describe_clusters(clusters=[clusterName])
if 'clusters' in response:
for descr in response['clusters']:
if descr['clusterName'] == clusterName:
count = descr['registeredContainerInstancesCount']
return count
[docs] def getClusterTaskCount(self):
"""
Query the cluster, and return the number of tasks it has.
This is the total of running and pending tasks.
If the cluster does not exist, return None.
"""
count = None
clusterName = self.clusterName
response = self.ecsClient.describe_clusters(clusters=[clusterName])
if 'clusters' in response:
for descr in response['clusters']:
if descr['clusterName'] == clusterName:
count = (descr['runningTasksCount'] +
descr['pendingTasksCount'])
return count
[docs] def waitClusterInstanceCount(self, clusterName, endInstanceCount):
"""
Poll the given cluster until the instanceCount is equal to the
given endInstanceCount
"""
instanceCount = self.getClusterInstanceCount(clusterName)
startTime = time.time()
timeout = self.waitClusterInstanceCountTimeout
timeExceeded = False
while ((instanceCount != endInstanceCount) and (not timeExceeded)):
time.sleep(5)
instanceCount = self.getClusterInstanceCount(clusterName)
timeExceeded = (time.time() > (startTime + timeout))
# If we exceeded timeout without reaching endInstanceCount,
# raise an exception
if timeExceeded and (instanceCount != endInstanceCount):
msg = ("Cluster instance count timeout ({} seconds). ".format(timeout) +
"See extraParams['waitClusterInstanceCountTimeout']")
raise rioserrors.TimeoutError(msg)
[docs] def waitClusterTasksFinished(self):
"""
Poll the given cluster until the number of tasks reaches zero
"""
taskCount = self.getClusterTaskCount()
startTime = time.time()
timeout = 20
timeExceeded = False
while ((taskCount > 0) and (not timeExceeded)):
time.sleep(5)
taskCount = self.getClusterTaskCount()
timeExceeded = (time.time() > (startTime + timeout))
# If we exceeded timeout without reaching zero,
# raise an exception
if timeExceeded and (taskCount > 0):
msg = ("Cluster task count timeout ({} seconds). ".format(timeout))
raise rioserrors.TimeoutError(msg)
[docs] def createTaskDef(self):
"""
If requested to do so, create a task definition for the worker tasks
"""
taskDef_kwArgs = self.extraParams.get('register_task_definition')
if taskDef_kwArgs is not None:
self.createdTaskDef = True
taskDefResponse = self.ecsClient.register_task_definition(**taskDef_kwArgs)
self.taskDefArn = taskDefResponse['taskDefinition']['taskDefinitionArn']
[docs] def checkTaskErrors(self):
"""
Check for errors in any of the worker tasks, and report to stderr.
"""
numTasks = len(self.taskArnList)
# The describe_tasks call will only take this many at a time, so we
# have to page through.
TASKS_PER_PAGE = 100
i = 0
failures = []
exitCodeList = []
while i < numTasks:
j = i + TASKS_PER_PAGE
descr = self.ecsClient.describe_tasks(cluster=self.clusterName,
tasks=self.taskArnList[i:j])
failures.extend(descr['failures'])
# Grab all the container exit codes/reasons. Note that we
# know we have only one container per task.
ctrDescrList = [t['containers'][0] for t in descr['tasks']]
for c in ctrDescrList:
if 'exitCode' in c:
exitCode = c['exitCode']
if exitCode != 0:
reason = c.get('reason', "UnknownReason")
exitCodeList.append((exitCode, reason))
i = j
for f in failures:
print("Failure in ECS task:", f.get('reason'), file=sys.stderr)
print(" ", f.get('details'), file=sys.stderr)
for (exitCode, reason) in exitCodeList:
if exitCode != 0:
print("Exit code {} from ECS task container: {}".format(
exitCode, reason), file=sys.stderr)
[docs]class AWSBatchComputeWorkerMgr(ComputeWorkerManager):
"""
Manage compute workers using AWS Batch.
"""
computeWorkerKind = CW_AWSBATCH
computeWorkersRead_default = True
[docs] def startWorkers(self, numWorkers=None, userFunction=None,
infiles=None, outfiles=None, otherArgs=None, controls=None,
blockList=None, inBlockBuffer=None, outBlockBuffer=None,
workinggrid=None, allInfo=None, singleBlockComputeWorkers=False,
tmpfileMgr=None, haveSharedTemp=True, exceptionQue=None):
"""
Start <numWorkers> AWS Batch jobs to process blocks of data
"""
self.forceExit = threading.Event()
self.workerBarrier = threading.Barrier(numWorkers + 1)
if boto3 is None:
raise rioserrors.UnavailableError("boto3 is unavailable")
self.STACK_NAME = os.getenv('RIOS_AWSBATCH_STACK', default='RIOS')
self.REGION = os.getenv('RIOS_AWSBATCH_REGION',
default='ap-southeast-2')
self.stackOutputs = self.getStackOutputs()
self.batchClient = boto3.client('batch', region_name=self.REGION)
# check what the maximum number of jobs can be run based on the
# vCPUS and maxvCPUs settings
maxBatchJobs = int(int(self.stackOutputs['BatchMaxVCPUS']) /
int(self.stackOutputs['BatchVCPUS']))
if numWorkers > maxBatchJobs:
raise ValueError('Requested number of compute workers is ' +
'greater than (MaxVCPUS / VCPUS). Either increase ' +
'this ratio, or reduce numComputeWorkers')
self.setupNetworkCommunication(userFunction, infiles, outfiles,
otherArgs, controls, workinggrid, allInfo, blockList,
numWorkers, inBlockBuffer, outBlockBuffer, self.forceExit,
exceptionQue, self.workerBarrier)
channAddr = self.dataChan.addressStr()
jobQueue = self.stackOutputs['BatchProcessingJobQueueName']
jobDefinition = self.stackOutputs['BatchProcessingJobDefinitionName']
self.jobList = []
for workerID in range(numWorkers):
workerCmdArgs = ['-i', str(workerID), '--channaddr', channAddr]
containerOverrides = {"command": workerCmdArgs}
jobRtn = self.batchClient.submit_job(
jobName=self.getWorkerName(workerID),
jobQueue=jobQueue,
jobDefinition=jobDefinition,
containerOverrides=containerOverrides)
self.jobList.append(jobRtn)
if not singleBlockComputeWorkers:
# Do not proceed until all workers have started
computeBarrierTimeout = controls.concurrency.computeBarrierTimeout
self.workerBarrier.wait(timeout=computeBarrierTimeout)
[docs] def shutdown(self):
"""
Shut down the job pool
"""
self.forceExit.set()
self.workerBarrier.abort()
self.makeOutObjList()
self.dataChan.shutdown()
[docs] def getStackOutputs(self):
"""
Helper function to query the CloudFormation stack for outputs.
Uses the RIOS_AWSBATCH_STACK and RIOS_AWSBATCH_REGION env vars to
determine which stack and region to query.
"""
client = boto3.client('cloudformation', region_name=self.REGION)
resp = client.describe_stacks(StackName=self.STACK_NAME)
if len(resp['Stacks']) == 0:
msg = "AWS Batch stack '{}' is not available".format(
self.STACK_NAME)
raise rioserrors.UnavailableError(msg)
outputsRaw = resp['Stacks'][0]['Outputs']
# convert to a normal dictionary
outputs = {}
for out in outputsRaw:
key = out['OutputKey']
value = out['OutputValue']
outputs[key] = value
return outputs
[docs]class ClassicBatchComputeWorkerMgr(ComputeWorkerManager):
"""
Manage compute workers using a classic batch queue, notably
PBS or SLURM. Initially constructed with computeWorkerKind = None,
one must then assign computeWorkerKind as either CW_PBS or CW_SLURM
before use.
Will make use of the computeWorkerExtraParams argument to ConcurrencyStyle,
if given, but this is optional. If given, it should be a dictionary, see
:doc:`concurrency` for details.
"""
computeWorkerKind = None
computeWorkersRead_default = True
[docs] def startWorkers(self, numWorkers=None, userFunction=None,
infiles=None, outfiles=None, otherArgs=None, controls=None,
blockList=None, inBlockBuffer=None, outBlockBuffer=None,
workinggrid=None, allInfo=None, singleBlockComputeWorkers=False,
tmpfileMgr=None, haveSharedTemp=True, exceptionQue=None):
"""
Start <numWorkers> PBS or SLURM jobs to process blocks of data
"""
self.checkBatchSystemAvailable()
self.haveSharedTemp = haveSharedTemp
self.scriptfileList = []
self.logfileList = []
self.jobId = {}
self.forceExit = threading.Event()
self.workerBarrier = threading.Barrier(numWorkers + 1)
if singleBlockComputeWorkers:
# We ignore numWorkers, and have a worker for each block
numWorkers = len(blockList)
self.extraParams = controls.concurrency.computeWorkerExtraParams
if self.extraParams is None:
self.extraParams = {}
self.setupNetworkCommunication(userFunction, infiles, outfiles,
otherArgs, controls, workinggrid, allInfo, blockList,
numWorkers, inBlockBuffer, outBlockBuffer, self.forceExit,
exceptionQue, self.workerBarrier)
try:
self.addressFile = None
if self.haveSharedTemp:
self.addressFile = tmpfileMgr.mktempfile(prefix='rios_batch_',
suffix='.chnl')
address = self.dataChan.addressStr()
open(self.addressFile, 'w').write(address + '\n')
for workerID in range(numWorkers):
self.worker(workerID, tmpfileMgr)
except Exception as e:
self.dataChan.shutdown()
raise e
if not singleBlockComputeWorkers:
# Do not proceed until all workers have started
computeBarrierTimeout = controls.concurrency.computeBarrierTimeout
self.workerBarrier.wait(timeout=computeBarrierTimeout)
[docs] def checkBatchSystemAvailable(self):
"""
Check whether the selected batch queue system is available.
If not, raise UnavailableError
"""
cmd = self.getQueueCmd()
try:
subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, universal_newlines=True)
batchSysAvailable = True
except FileNotFoundError:
batchSysAvailable = False
if not batchSysAvailable:
if self.computeWorkerKind == CW_PBS:
msg = "PBS is not available"
elif self.computeWorkerKind == CW_SLURM:
msg = "SLURM is not available"
raise rioserrors.UnavailableError(msg)
[docs] def worker(self, workerID, tmpfileMgr):
"""
Assemble a worker job and submit it to the batch queue
"""
scriptfile = tmpfileMgr.mktempfile(prefix='rios_batch_',
suffix='.sh')
logfile = tmpfileMgr.mktempfile(prefix='rios_batch_',
suffix='.log')
self.scriptfileList.append(scriptfile)
self.logfileList.append(logfile)
scriptCmdList = self.beginScript(logfile, workerID)
computeWorkerCmd = ["rios_computeworker", "-i", str(workerID)]
if self.addressFile is not None:
addressArgs = ["--channaddrfile", self.addressFile]
else:
addressArgs = ["--channaddr", self.dataChan.addressStr()]
computeWorkerCmd.extend(addressArgs)
computeWorkerCmdStr = " ".join(computeWorkerCmd)
# Add any cmd prefix or suffix strings given
cmdPrefix = self.extraParams.get('cmdPrefix', '')
cmdSuffix = self.extraParams.get('cmdSuffix', '')
computeWorkerCmdStr = cmdPrefix + computeWorkerCmdStr + cmdSuffix
# Mark the start of outputs from the worker command in the log
scriptCmdList.append("echo 'Begin-rios-worker'")
scriptCmdList.append(computeWorkerCmdStr)
# Capture the exit status from the command
scriptCmdList.append("WORKERCMDSTAT=$?")
# Mark the end of outputs from the worker command in the log
scriptCmdList.append("echo 'End-rios-worker'")
# Make sure the log includes the exit status from the command
scriptCmdList.append("echo 'rios_computeworker status:' $WORKERCMDSTAT")
scriptStr = '\n'.join(scriptCmdList)
open(scriptfile, 'w').write(scriptStr + "\n")
submitCmdWords = self.getSubmitCmd()
submitCmdWords.append(scriptfile)
proc = subprocess.Popen(submitCmdWords, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, universal_newlines=True)
# The submit command exits almost immediately, printing the job id
# to stdout. So, we just wait for the command to finish, and grab
# the jobID string.
(stdout, stderr) = proc.communicate()
self.jobId[workerID] = self.getJobId(stdout)
# If there was something in stderr from the submit command, then
# probably something bad happened, so we pass it on to the user
# in the form of an exception.
if (len(stderr) > 0) or (self.jobId[workerID] is None):
msg = "Error from submit command. Message:\n" + stderr
raise rioserrors.JobMgrError(msg)
[docs] def waitOnJobs(self):
"""
Wait for all batch jobs to complete
"""
jobIdSet = set([jobId for jobId in self.jobId.values()])
numJobs = len(jobIdSet)
allFinished = (numJobs == 0)
while not allFinished:
qlistCmd = self.getQueueCmd()
proc = subprocess.Popen(qlistCmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, universal_newlines=True)
(stdout, stderr) = proc.communicate()
stdoutLines = [line for line in stdout.split('\n')
if len(line) > 0] # No blank lines
# Skip header lines, and grab first word on each line,
# which is the jobID
nskip = self.getQlistHeaderCount()
qlistJobIDlist = [line.split()[0] for line in
stdoutLines[nskip:]]
qlistJobIDset = set(qlistJobIDlist)
allFinished = jobIdSet.isdisjoint(qlistJobIDset)
if not allFinished:
# Sleep for a bit before checking again
time.sleep(60)
[docs] @staticmethod
def findLine(linelist, s):
"""
Find the first line which begins with the given string.
Return the index of that line, or None if not found.
"""
ndx = None
for i in range(len(linelist)):
line = linelist[i].strip()
if ndx is None and line.startswith(s):
ndx = i
return ndx
[docs] def beginScript(self, logfile, workerID):
"""
Return list of initial script commands, depending on
whether we are PBS or SLURM
"""
workerName = self.getWorkerName(workerID)
if self.computeWorkerKind == CW_PBS:
scriptCmdList = [
"#!/bin/bash",
"#PBS -j oe -o {}".format(logfile),
"#PBS -N {}".format(workerName)
]
if (('RIOS_PBSJOBMGR_QSUBOPTIONS' in os.environ) or
('RIOS_PBSJOBMGR_INITCMDS' in os.environ)):
msg = ("RIOS PBS environment variables no longer supported. " +
"Please use computeWorkerExtraParams argument of " +
"ConcurrencyStyle instead")
rioserrors.deprecationWarning(msg)
qsubOptions = self.extraParams.get('qsubOptions')
initCmds = self.extraParams.get('initCmds')
if qsubOptions is not None:
scriptCmdList.append("#PBS %s" % qsubOptions)
if initCmds is not None:
scriptCmdList.append(initCmds)
elif self.computeWorkerKind == CW_SLURM:
scriptCmdList = [
"#!/bin/bash",
"#SBATCH -o %s" % logfile,
"#SBATCH -e %s" % logfile,
"#SBATCH -J {}".format(workerName)
]
if (('RIOS_SLURMJOBMGR_SBATCHOPTIONS' in os.environ) or
('RIOS_SLURMJOBMGR_INITCMDS' in os.environ)):
msg = ("RIOS SLURM environment variables no longer supported. " +
"Please use computeWorkerExtraParams argument of " +
"ConcurrencyStyle instead")
rioserrors.deprecationWarning(msg)
sbatchOptions = self.extraParams.get('sbatchOptions')
initCmds = self.extraParams.get('initCmds')
if sbatchOptions is not None:
scriptCmdList.append("#SBATCH %s" % sbatchOptions)
if initCmds is not None:
scriptCmdList.append(initCmds)
return scriptCmdList
[docs] def getSubmitCmd(self):
"""
Return the command name for submitting a job, depending on
whether we are PBS or SLURM. Return as a list of words,
ready to give to Popen.
"""
if self.computeWorkerKind == CW_PBS:
cmd = ["qsub"]
elif self.computeWorkerKind == CW_SLURM:
cmd = ["sbatch"]
return cmd
[docs] def getQueueCmd(self):
"""
Return the command name for listing the current jobs in the
batch queue, depending on whether we are PBS or SLURM. Return
as a list of words, ready to give to Popen.
"""
if self.computeWorkerKind == CW_PBS:
cmd = ["qstat"]
elif self.computeWorkerKind == CW_SLURM:
cmd = ["squeue", "--noheader"]
return cmd
[docs] def getJobId(self, stdout):
"""
Extract the jobId from the string returned when the job is
submitted, depending on whether we are PBS or SLURM
"""
if self.computeWorkerKind == CW_PBS:
jobID = stdout.strip()
if len(jobID) == 0:
jobID = None
elif self.computeWorkerKind == CW_SLURM:
slurmOutputList = stdout.strip().split()
jobID = None
# slurm prints a sentence to the stdout:
# 'Submitted batch job X'
if len(slurmOutputList) >= 4:
jobID = slurmOutputList[3]
return jobID
[docs] def shutdown(self):
"""
Shutdown the compute manager. Wait on batch jobs, then
shut down the data channel
"""
self.forceExit.set()
self.waitOnJobs()
self.makeOutObjList()
self.findExtraErrors()
self.dataChan.shutdown()
[docs]class SubprocComputeWorkerManager(ComputeWorkerManager):
"""
Purely for testing, not for normal use.
This class manages compute workers run through subprocess.Popen.
This is not normally any improvement over using CW_THREADS, and
should be avoided. I am using this purely as a test framework
to emulate the batch queue types of compute worker, which are
similarly disconnected from the main process, so I can work out the
right mechanisms to use for exception handling and such like,
and making sure the rios_computeworker command line works.
"""
computeWorkerKind = CW_SUBPROC
computeWorkersRead_default = False
[docs] def startWorkers(self, numWorkers=None, userFunction=None,
infiles=None, outfiles=None, otherArgs=None, controls=None,
blockList=None, inBlockBuffer=None, outBlockBuffer=None,
workinggrid=None, allInfo=None, singleBlockComputeWorkers=False,
tmpfileMgr=None, haveSharedTemp=True, exceptionQue=None):
"""
Start the specified compute workers
"""
self.haveSharedTemp = haveSharedTemp
self.processes = {}
self.results = {}
self.forceExit = threading.Event()
self.workerBarrier = threading.Barrier(numWorkers + 1)
self.setupNetworkCommunication(userFunction, infiles, outfiles,
otherArgs, controls, workinggrid, allInfo, blockList,
numWorkers, inBlockBuffer, outBlockBuffer, self.forceExit,
exceptionQue, self.workerBarrier)
try:
self.addressFile = None
if self.haveSharedTemp:
self.addressFile = tmpfileMgr.mktempfile(prefix='rios_subproc_',
suffix='.chnl')
address = self.dataChan.addressStr()
open(self.addressFile, 'w').write(address + '\n')
for workerID in range(numWorkers):
self.worker(workerID)
except Exception as e:
self.dataChan.shutdown()
raise e
if not singleBlockComputeWorkers:
# Do not proceed until all workers have started
computeBarrierTimeout = controls.concurrency.computeBarrierTimeout
self.workerBarrier.wait(timeout=computeBarrierTimeout)
[docs] def worker(self, workerID):
"""
Start one worker
"""
cmdList = ["rios_computeworker", "-i", str(workerID),
"--channaddrfile", self.addressFile]
self.processes[workerID] = subprocess.Popen(cmdList,
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
universal_newlines=True)
[docs] def waitOnJobs(self):
"""
Wait for all worker subprocesses to complete
"""
for (workerID, proc) in self.processes.items():
(stdout, stderr) = proc.communicate()
results = {
'returncode': proc.returncode,
'stdoutstr': stdout,
'stderrstr': stderr
}
self.results[workerID] = results
[docs] def shutdown(self):
"""
Shutdown the compute manager. Wait on batch jobs, then
shut down the data channel
"""
self.forceExit.set()
self.workerBarrier.abort()
self.waitOnJobs()
if self.addressFile is not None:
os.remove(self.addressFile)
self.makeOutObjList()
self.findExtraErrors()
self.dataChan.shutdown()