from itertools import tee
from xmm.pipeline import Pipeline
from xmm.pipeline.steps import PipelineStep
[docs]class SwitchStep(PipelineStep):
def __init__(self, func, pipelines):
self.func = func
self.pipelines = pipelines
for key, pipeline in self.pipelines.items():
if not isinstance(pipeline, Pipeline):
pipeline = Pipeline(pipeline)
self.pipelines[key] = pipeline, {}
def map_sub(self, sub_state, key):
for dataset in sub_state:
if self.func(dataset) == key:
yield dataset
def map(self, master_state):
# generate N independent generators
sub_states = tee(master_state, len(self.pipelines))
generators = {
key: pipeline.setup(self.map_sub(sub_states[i], key), context)[0]
for i, (key, (pipeline, context)) in enumerate(self.pipelines.items())
}
while len(generators):
for key in self.pipelines:
if key in generators:
try:
yield next(generators[key])
except StopIteration:
del generators[key]
def process_step(self, state, context):
state = self.map(state)
return state, context
def cleanup(self, context, exception=None):
context['switch'] = {}
for key, (pipeline, subcontext) in self.pipelines.items():
context['switch'][key] = pipeline.cleanup(subcontext, exception)
return context