Skip to content

Commit

Permalink
Merge pull request #79 from scipion-em/jj_fixes_subtomo_align
Browse files Browse the repository at this point in the history
Jj fixes subtomo align
  • Loading branch information
pconesa authored May 24, 2023
2 parents b43dac6 + 7633a7b commit 845a2a1
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 113 deletions.
3 changes: 3 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
V3.2.20:
- Subtomogram alignment: fix bug when using masks.
- Subtomogram alignment: fix output problem when the alignment removed some particles.
V3.1.19:
- Hot fix: one of the steps of the common list from the model workflow is not included in the Dynamo side. The plugin
now offers the same behavior, so this kind of models do not fail now.
Expand Down
4 changes: 1 addition & 3 deletions dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,14 @@
import pyworkflow.utils as pwutils
from .constants import *

__version__ = '3.1.19'
__version__ = '3.1.20'
_logo = "icon.png"
_references = ['CASTANODIEZ2012139']


class Plugin(pwem.Plugin):
_homeVar = DYNAMO_HOME
_pathVars = [DYNAMO_HOME]
_url = "https://github.com/scipion-em/scipion-em-dynamo"
# _supportedVersions =

@classmethod
def _defineVariables(cls):
Expand Down
102 changes: 31 additions & 71 deletions dynamo/convert/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,71 +99,30 @@ def writeDynTable(fhTable, setOfSubtomograms):
% (subtomo.getObjId(), shiftx, shifty, shiftz, tdrot, tilt, narot, anglemin, anglemax, x, y, z))


def readDynTable(self, item, tomoSet=None):
nline = next(self.fhTable)
nline = nline.rstrip()
id = int(nline.split()[0])
item.setObjId(id)
shiftx = nline.split()[3]
shifty = nline.split()[4]
shiftz = nline.split()[5]
tdrot = nline.split()[6]
tilt = nline.split()[7]
narot = nline.split()[8]
A = eulerAngles2matrix(tdrot, tilt, narot, shiftx, shifty, shiftz)
transform = Transform()
transform.setMatrix(A)
item.setTransform(transform)
angleMin = nline.split()[13]
angleMax = nline.split()[14]
acquisition = TomoAcquisition()
acquisition.setAngleMin(angleMin)
acquisition.setAngleMax(angleMax)
item.setAcquisition(acquisition)
volId = int(nline.split()[19]) + 1
item.setVolId(volId)
classId = nline.split()[21]
item.setClassId(classId)
if tomoSet:
tomo = tomoSet[volId] if tomoSet.getSize() > 1 \
else tomoSet.getFirstItem()
tomoOrigin = tomo.getOrigin()
item.setVolName(tomo.getFileName())
item.setOrigin(tomoOrigin)
coordinate3d = Coordinate3D()
coordinate3d.setVolId(tomo.getObjId())
coordinate3d.setVolume(tomo)
x = nline.split()[23]
y = nline.split()[24]
z = nline.split()[25]
coordinate3d.setX(float(x), const.BOTTOM_LEFT_CORNER)
coordinate3d.setY(float(y), const.BOTTOM_LEFT_CORNER)
coordinate3d.setZ(float(z), const.BOTTOM_LEFT_CORNER)
item.setCoordinate3D(coordinate3d)


def dynTableLine2Subtomo(nline, subtomo, subtomoSet, tomo=None, coordSet=None):
nline = nline.rstrip()
subtomo.setObjId(int(nline.split()[0]))
shiftx = nline.split()[3]
shifty = nline.split()[4]
shiftz = nline.split()[5]
tdrot = nline.split()[6]
tilt = nline.split()[7]
narot = nline.split()[8]
def dynTableLine2Subtomo(inLine, subtomo, subtomoSet=None, tomo=None, coordSet=None):
if type(inLine) != list:
inLine = inLine.rstrip().split()
# inLine = inLine.rstrip()
subtomo.setObjId(int(inLine[0]))
shiftx = inLine[3]
shifty = inLine[4]
shiftz = inLine[5]
tdrot = inLine[6]
tilt = inLine[7]
narot = inLine[8]
A = eulerAngles2matrix(tdrot, tilt, narot, shiftx, shifty, shiftz)
transform = Transform()
transform.setMatrix(A)
subtomo.setTransform(transform)
angleMin = nline.split()[13]
angleMax = nline.split()[14]
angleMin = inLine[13]
angleMax = inLine[14]
acquisition = TomoAcquisition()
acquisition.setAngleMin(angleMin)
acquisition.setAngleMax(angleMax)
subtomo.setAcquisition(acquisition)
volId = int(nline.split()[19])
volId = int(inLine[19])
subtomo.setVolId(volId)
classId = nline.split()[21]
classId = inLine[21]
subtomo.setClassId(classId)
if tomo:
tomoOrigin = tomo.getOrigin()
Expand All @@ -172,33 +131,34 @@ def dynTableLine2Subtomo(nline, subtomo, subtomoSet, tomo=None, coordSet=None):
coordinate3d = Coordinate3D()
coordinate3d.setVolId(tomo.getObjId())
coordinate3d.setVolume(tomo)
x = nline.split()[23]
y = nline.split()[24]
z = nline.split()[25]
x = inLine[23]
y = inLine[24]
z = inLine[25]
coordinate3d.setX(float(x), const.BOTTOM_LEFT_CORNER)
coordinate3d.setY(float(y), const.BOTTOM_LEFT_CORNER)
coordinate3d.setZ(float(z), const.BOTTOM_LEFT_CORNER)
subtomo.setCoordinate3D(coordinate3d)
coordSet.append(coordinate3d)
subtomoSet.append(subtomo)
if subtomoSet is not None:
subtomoSet.append(subtomo)


def readDynCoord(tableFile, coord3DSet, tomo):
with open(tableFile) as fhTable:
for nline in fhTable:
coordinate3d = Coordinate3D()
nline = nline.rstrip()
shiftx = nline.split()[3]
shifty = nline.split()[4]
shiftz = nline.split()[5]
tdrot = nline.split()[6]
tilt = nline.split()[7]
narot = nline.split()[8]
nline = nline.rstrip().split()
shiftx = nline[3]
shifty = nline[4]
shiftz = nline[5]
tdrot = nline[6]
tilt = nline[7]
narot = nline[8]
A = eulerAngles2matrix(tdrot, tilt, narot, shiftx, shifty, shiftz)
x = nline.split()[23]
y = nline.split()[24]
z = nline.split()[25]
groupId = nline.split()[21]
x = nline[23]
y = nline[24]
z = nline[25]
groupId = nline[21]
coordinate3d.setVolume(tomo)
coordinate3d.setX(float(x), const.BOTTOM_LEFT_CORNER)
coordinate3d.setY(float(y), const.BOTTOM_LEFT_CORNER)
Expand Down
2 changes: 1 addition & 1 deletion dynamo/protocols/protocol_import_subtomos.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _fillSubtomogram(line, subtomo, subtomoSet, newFileName, tomo=None, coordSet
""" adds a subtomogram to a set """
subtomo.cleanObjId()
subtomo.setFileName(newFileName)
dynTableLine2Subtomo(line, subtomo, subtomoSet, tomo=tomo, coordSet=coordSet)
dynTableLine2Subtomo(line, subtomo, subtomoSet=subtomoSet, tomo=tomo, coordSet=coordSet)

# --------------------------- INFO functions ------------------------------
def _validate(self):
Expand Down
78 changes: 49 additions & 29 deletions dynamo/protocols/protocol_subtomo_MRA.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from pyworkflow.utils import Message
from pyworkflow.utils.path import makePath
from dynamo import Plugin
from dynamo.convert import writeSetOfVolumes, writeDynTable, readDynTable
from dynamo.convert import writeSetOfVolumes, writeDynTable, dynTableLine2Subtomo
from tomo.protocols.protocol_base import ProtTomoSubtomogramAveraging
from tomo.objects import AverageSubTomogram, SetOfSubTomograms

Expand Down Expand Up @@ -89,17 +89,17 @@ class DynamoSubTomoMRA(ProtTomoSubtomogramAveraging):
_devStatus = BETA
_possibleOutputs = DynRefineOuts

@classmethod
def getUrl(cls):
return "https://wiki.dynamo.biozentrum.unibas.ch/w/index.php/Alignment_project"

def __init__(self, **args):
ProtTomoSubtomogramAveraging.__init__(self, **args)
self.dynTableDict = None
self.dimRounds = String()
self.fhTable = None
self.masksDir = None
self.doMra = None

@classmethod
def getUrl(cls):
return "https://wiki.dynamo.biozentrum.unibas.ch/w/index.php/Alignment_project"

# --------------------------- DEFINE param functions ------------------------

def _defineParams(self, form: Form):
Expand Down Expand Up @@ -461,13 +461,6 @@ def convertInputStep(self):

Plugin.runDynamo(self, IMPORT_CMD_FILE, cwd=self._getExtraPath())

def showDynamoGUI(self):
fhCommands2 = open(self._getExtraPath(SHOW_PROJECT_CMD_FILE), 'w')
content2 = "dcp '%s';" % DYNAMO_ALIGNMENT_PROJECT
fhCommands2.write(content2)
fhCommands2.close()
Plugin.runDynamo(self, SHOW_PROJECT_CMD_FILE, cwd=self._getExtraPath())

def alignStep(self):
with open(self._getExtraPath(ALIGNMENT_CMD_FILE), 'w') as fhCommands2:
alignmentCommands = self.get_computing_command()
Expand All @@ -487,16 +480,11 @@ def alignStep(self):

resultsDir = self.getLastIterResultsDir()
if not os.path.exists(resultsDir):
raise RuntimeError("Results folder (%s) no generated. "
raise RuntimeError("No results folder (%s) was generated. "
"Probably there has been an error while running the alignment in Dynamo. "
"Please, see run.stdout log for more details." % resultsDir)

def getTotalIterations(self):
iters = self.numberOfIters.getListFromValues()
return sum(iters)

def createOutputStep(self):

niters = self.getTotalIterations()

if self.doMra:
Expand Down Expand Up @@ -534,13 +522,9 @@ def createOutputStep(self):
outSubtomos = SetOfSubTomograms.create(self._getPath(), template='subtomograms%s.sqlite')
inputSet = self.inputVolumes.get()
outSubtomos.copyInfo(inputSet)
resTbl = join(self.getLastIterAvgsDir(), 'refined_table_ref_001_ite_%04d.tbl' % niters)
avgFile = join(self.getLastIterAvgsDir(), 'average_symmetrized_ref_001_ite_%04d.em' % niters)
# Open the final particles table, that will be used in the updateItemCallback
self.fhTable = open(resTbl, 'r')
outSubtomos.copyItems(inputSet, updateItemCallback=self._updateItem)
self.fhTable.close()
# Fill the resulting average object
avgFile = join(self.getLastIterAvgsDir(), 'average_symmetrized_ref_001_ite_%04d.em' % niters)
averageSubTomogram.setFileName(self.convertToMrc(avgFile))
averageSubTomogram.setSamplingRate(inputSet.getSamplingRate())
# Define outputs and relations
Expand All @@ -557,6 +541,17 @@ def closeSetsStep(self):
self._store()

# --------------------------- UTILS functions --------------------------------
def getTotalIterations(self):
iters = self.numberOfIters.getListFromValues()
return sum(iters)

def showDynamoGUI(self):
fhCommands2 = open(self._getExtraPath(SHOW_PROJECT_CMD_FILE), 'w')
content2 = "dcp '%s';" % DYNAMO_ALIGNMENT_PROJECT
fhCommands2.write(content2)
fhCommands2.close()
Plugin.runDynamo(self, SHOW_PROJECT_CMD_FILE, cwd=self._getExtraPath())

def getDimRounds(self, nRounds):
"""The number of rounds will be determined the same as Dynamo, which is through the number of elements
introduced in label 'Iterations'. The parameter 'Particle dimensions' present a specific behavior: if one round,
Expand Down Expand Up @@ -681,15 +676,21 @@ def get_computing_command(self):
return command

def _updateItem(self, item, row):
readDynTable(self, item)
row = self.getDynRow(item.getObjId())
if row is None:
# This is to consider possible particle removal carried out by Dynamo during the alignment
item._appendItem = False
else:
# row to subtomo
dynTableLine2Subtomo(row, item)

def prepareMask(self, maskObj):
if maskObj:
if isinstance(maskObj, SetOfVolumes):
writeSetOfVolumes(maskObj, join(self._getExtraPath(), 'fmasks/fmask_initial_ref_'), 'ix')
return self.get_dvput('fmask', self.masksDir)
else:
return self.get_dvput('fmask', maskObj.getFileName())
return self.get_dvput('fmask', abspath(maskObj.getFileName()))
else:
return ''

Expand Down Expand Up @@ -727,6 +728,24 @@ def anyValActiveInNumListParam(iParam) -> bool:
be active at least in one of the rounds"""
return np.any(np.array(iParam.getListFromValues()) > 0)

def getResultsTblFile(self):
niters = self.getTotalIterations()
return join(self.getLastIterAvgsDir(), 'refined_table_ref_001_ite_%04d.tbl' % niters)

def getDynRow(self, lineId):
if self.dynTableDict is None:
self._loadDynamoTable()
return self.dynTableDict.get(lineId, None)

def _loadDynamoTable(self):
"""Reads a Dynamo tbl file and stores it in a dictionary of type:
{key = objId (first number in a Dynamo table row), value = line}"""
tblFile = self.getResultsTblFile()
with open(tblFile, 'r') as dynTable:
self.dynTableDict = {}
for lineNum, line in enumerate(dynTable):
self.dynTableDict[int(line.split()[0])] = line

# --------------------------- INFO functions --------------------------------
def _validate(self):
validateMsgs = []
Expand All @@ -753,9 +772,10 @@ def _validate(self):
# Check the masks
if introducedMasks:
for mask in masks:
if not self.sizesOk(mask):
validateMsgs.append('The introduced masks must be of the same size as the template.')
break
if mask:
if not self.sizesOk(mask):
validateMsgs.append('The introduced masks must be of the same size as the template.')
break
# Check the dims values
dimValues = self.dim.getListFromValues()
nDims = len(dimValues)
Expand Down
43 changes: 34 additions & 9 deletions dynamo/tests/test_dynamo_align_subtomos.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
# * e-mail address '[email protected]'
# *
# **************************************************************************
from xmipp3.constants import MASK3D_CYLINDER
from xmipp3.protocols import XmippProtCreateMask3D
from xmipp3.protocols.protocol_preprocess.protocol_create_mask3d import SOURCE_GEOMETRY

from dynamo.protocols import DynamoSubTomoMRA
from dynamo.protocols.protocol_extraction import SAME_AS_PICKING
from dynamo.protocols.protocol_subtomo_MRA import FROM_PREVIOUS_ESTIMATION, NO_THRESHOLD
Expand Down Expand Up @@ -77,15 +81,17 @@ def runPreviousProtocols(cls):

@classmethod
def runAlignSubtomos(cls, nIters='3', dims='0', thMode=str(NO_THRESHOLD),
areaSearchMode=str(FROM_PREVIOUS_ESTIMATION), protLabel=None):
protAlign = cls.newProtocol(DynamoSubTomoMRA,
inputVolumes=cls.subtomosExtracted,
templateRef=cls.avg,
numberOfIters=nIters,
dim=dims,
thresholdMode=thMode,
limm=areaSearchMode,
useGpu=True)
areaSearchMode=str(FROM_PREVIOUS_ESTIMATION), alignMask=None, protLabel=None):
argsDict = {'inputVolumes': cls.subtomosExtracted,
'templateRef': cls.avg,
'numberOfIters': nIters,
'dim': dims,
'thresholdMode': thMode,
'limm': areaSearchMode,
'useGpu': True}
if alignMask:
argsDict['alignMask'] = alignMask
protAlign = cls.newProtocol(DynamoSubTomoMRA, **argsDict)
if protLabel:
protAlign.setObjLabel(protLabel)
cls.launchProtocol(protAlign)
Expand All @@ -98,6 +104,25 @@ def test_alignSubtomos_oneRound(self):
subtomos, avg = self.runAlignSubtomos(protLabel='Subtomo align, 1 round')
self.checkResults(avg, subtomos)

def test_alignSubtomos_oneRound_With_AlignMask(self):
print(magentaStr("\n==> aligning the subtomograms with alignment mask, 1 round:"))
# Generate the mask
protMask3D = self.newProtocol(XmippProtCreateMask3D,
source=SOURCE_GEOMETRY,
samplingRate=self.bin2SRate,
size=self.bin2BoxSize,
geo=MASK3D_CYLINDER,
radius=15,
shiftCenter=True,
centerZ=6,
height=20,
doSmooth=True)
self.launchProtocol(protMask3D)
alignMask = getattr(protMask3D, 'outputMask', None)
# Align the subtomograms
subtomos, avg = self.runAlignSubtomos(protLabel='Subtomo align, 1 round', alignMask=alignMask)
self.checkResults(avg, subtomos)

def test_alignSubtomos_twoRounds(self):
print(magentaStr("\n==> aligning the subtomograms, 2 rounds:"))
subtomos, avg = self.runAlignSubtomos(nIters='2 2', dims='0', protLabel='Subtomo align, 2 rounds, dim=0')
Expand Down

0 comments on commit 845a2a1

Please sign in to comment.