Particle Refined Localization
Particle refinement localization module¶
For the refinement step, MiLoPYP learns to localize proteins of interest with high accuracy, when trained using sparsely annotated data. This step can be used without the previous exploration step.
Input preparation¶
The training set should include two files:
.txt
file with the tomogram names and paths to the tomogram files.txt
file with the tomogram names and correspondingx,y,z
-coordinate files
When using the refinement module after the exploration module, we can use the same image file and the generated coordinates file from the exploration module.
When using the refinement module alone, some manual labeling is needed to generate the training coordinates. A corresponding train image file will need to be generated as well.
In the first case, the *.txt
file should have the following format:
In the second case, the *.txt
file should have the following format:
Generate training set from manual labels¶
Here, we provide a simple procedure to generate the described training set from a selected folder.
First, create a train/
directory to store all tomograms and their corresponding coordinate files. Each tomogram should have its own coordinate file: e.g., train_1_img.rec
and train_1_img_coord.txt
.
For training coordinates, manual picking is performed on selected tomograms using IMOD. For a single tomogram, full annotation is not required. Simply find some subregions and pick around 10% to 70% of the particles in that subregion. The subregion does not need to be big. After manual annotation in IMOD, save the resulting .mod
files containing the annotated coordinates. Converting .mod
files to .txt
files can be done using IMOD
's model2point
command. For example:
.mod
files are converted to txt files, move all the coordinates.txt
files to the train/
directory.
Warning
Depending on the x-y-z order of your input tomograms, the output coordinates generated using IMOD may be in a different order. The two most common orders are x-y-z
and x-z-y
. Make sure you get the orders correct.
Once all the .mod
files are converted to txt files, move all coordinates.txt
files to the train/
directory.
To generate the image and coordinate files for training, run generate_train_files.py
under the utils/
folder. Two input arguments are required: -d/--dir
to indicate the path to the train/
directory, and -o/--out
to specify the name and location of the output training file. The default order for input coordinates is x-z-y
, if you want to specify a different order, add the option -r/--ord
. Possible orders are: x-y-z
, x-z-y
, or z-x-y
. For example:
Arguments | Purpose |
---|---|
ext |
extension of tomogram files (*.mrc or *.rec ) |
dir |
path to tomograms and labels for training |
out |
training file output name |
ord |
coordinate order (xyz , xzy , or zxy ) |
inference |
generate input for evaluation stage |
Once all files are generated, move ther training files to the data/
directory (create the directory if it doesn't exist).
Training¶
Globular-shaped targets¶
Here is a sample command to train using tomograms from the EMPIAR-10304 dataset assuming a train image file sample_train_explore_img.txt
, train coordinates file training_coordinates.txt
, validation image file sample_val_img.txt
, and validation coordinates file val_coordinates.txt
(validation files are optional):
python main.py semi --down_ratio 2 --num_epochs 10 --bbox 16 --exp_id sample_refinement --dataset semi --arch unet_5 --save_all --debug 4 --val_interval 1 --thresh 0.85 --cr_weight 0.1 --temp 0.07 --tau 0.01 --lr 5e-4 --train_img_txt sample_train_explore_img.txt --train_coord_txt training_coordinates.txt --val_img_txt sample_val_img.txt --val_coord_txt val_coordinates.txt --K 900 --compress --order xzy --gauss 0.8 --contrastive --last_k 3
Arguments | Purpose |
---|---|
num_epochs |
number of training epochs, 5 to 10 recommended |
exp_id |
experiment id to use as prefix for saving output files |
bbox |
box size for particles, used to generate Guassian kernel during training |
dataset |
sampling and dataloader mode, defaults to semi |
arch |
model backbone architecture (name_numOfLayers format), unet_4 or unet_5 recommended |
lr |
learning rate, 1e-3 to 5e-4 recommended (for fewer training examples, lower the learning rate) |
debug |
debug mode for visualization, currently only supports mode 4 for easier visualization - output will be saved to 'debug folder' including view of each slice, ground-truth heatmap, predicted heatmap, and detection-prediction based on heatmap |
val_interval |
interval to perform validation and save intermediate models |
cr_weight |
weight for contrastive regularization (smaller values recommended for more samples, larger values for fewer samples) |
save_all |
whether to save all models for each val_interval |
gauss |
use a Gaussian filter to denoise tilt-series and tomograms during preprocessing |
temp |
infoNCE temperature |
down_ratio |
downsampling in x-y direction, default is 2 |
tau |
class prior probability |
thresh |
threshold for soft/hard positives |
last_k |
size of convolution filter for last layer |
compress |
whether to combine 2 z-slices into 1, recommended |
K |
maximum number of particles |
fiber |
turn on for fiber/tubular-shaped particles |
A more detailed description of arguments is included in the file opt.py
.
All outputs will be saved to the folder exp/semi/exp_id
. In our case, the output will be saved into exp/semi/sample_refinement
and contain the following files:
opt.txt
where all options used for training will be saveddebug
where all outputs from validation will be savedmodel_xxx.pth
intermediate model checkpoints (weights for the final model will be saved inmodelxxx_last_contrastive.pth
)- A directory with specific training/validation loss info for each run
Here are some sample outputs generated in the debug/
folder:
How to select the best model and detection threshold?
The best model and detection threshold can be selected based on the validation loss and outputs included in the debug/
folder.
When there are fully labeled tomograms for validation, select the model with the lowest validation loss. When there are only partially labeled tomograms, select the model that generates the best heatmaps. Unless there is severe over-fitting, the model from the last epoch typically generates good results.
Threshold selection can be estimated based on the detection output .txt
file that contains x,y,z coordinates and corresponding detection scores). It can also be estimated from *_pred_out.png
images in the debug/
folder that marks identified particles above a certain threshold. If there are many false positives, consider using a higher threshold.
Tubular-shaped targets¶
Here is a sample command to train using tomograms from the EMPIAR-10987 dataset assuming a train image list sample_train_microtubule_img.txt
, train coordinates training_coordinates_microtubule.txt
, validation image list sample_val_microtubule.txt
, and validation coordinates val_coordinates_microtubule.txt
(validation files are optional):
python main.py semi --down_ratio 2 --num_epochs 10 --bbox 12 --contrastive --exp_id fib_test --dataset semi --arch unet_5 --save_all --debug 4 --val_interval 1 --thresh 0.3 --cr_weight 1.0 --temp 0.07 --tau 0.01 --lr 1e-4 --train_img_txt sample_train_microtubule_img.txt --train_coord_txt training_coordinates_microtubule.txt --val_img_txt sample_val_microtubule.txt --val_coord_txt val_coordinates_microtubule.txt --K 550 --compress --gauss 1 --order xzy --last_k 5 --fiber
Note that the main difference is the use of the --fiber
option.
Outputs generated by this command will be the same as those generated for the globular-shaped targets.
Here are some sample outputs saved in the debug/
folder:
Inference¶
Globular-shaped targets¶
Once training is finished, we use the trained model for testing. A file test_img.txt
containing all the tomograms can be generated using generate_train_files.py
following a similar process as described above. To run inference on all tomograms, run:
python test.py semi --arch unet_5 --dataset semi --exp_id sample_refinement --load_model exp/semi/sample_refinement/model_4.pth --down_ratio 2 --K 900 --ord xzy --out_thresh 0.2 --test_img_txt test_img.txt --compress --gauss 0.8 --out_id all_out
exp/semi/sample_refinement/all_out/
. For each tomogram, 2 outputs will be generated:
.txt
file with particle coordinates in x-z-y order (if the option--with_score
was used, a column with the scores will also be included)*hm.mrc
3D detection heatmaps for each tomogram
Arguments | Purpose |
---|---|
load_model |
path to the trained model (the command above is using the model from the 4th epoch) |
out_thresh |
threshold used for detection |
out_id |
folder to save all outputs |
ord |
coordinate order of the tomogram |
with_score |
whether generated output should include score values in addition to x-y-z coordinates |
Warning
Make sure to use the same --last_k
, --gauss
, and --arch
options used during training to ensure that the correct model is loaded.
Tubular-shaped targets¶
For tubular-shaped targets, we just add the option --fiber
to the inference command above and we specify a threshold used for curve fitting. To run inference on tomograms with fiber-specific post-processing, run:
python test.py semi --arch unet_5 --dataset semi --exp_id fib_test --load_model exp/semi/fib_test/model_10.pth --down_ratio 2 --K 550 --order xzy --out_thresh 0.205 --test_img_txt sample_train_microtubule_img.txt --compress --gauss 1 --cutoff_z 10 --out_id new_out --last_k 5 --fiber --curvature_cutoff 0.03 --nms 3
Arguments | Purpose |
---|---|
curvature_cutoff |
max curvature for fitted curve, segments with higher curvature will be discarded (microtubules should have small curvature) |
r2_cutoff |
max residual for fitted curve, discard if above the residual (bad fitting) |
distance_cutoff |
distance cutoff for whether two points are connected in a graph |
Here are some example outputs:
Convert coordinates to IMOD format¶
Make sure the output *.txt
files do not include heatmap scores. Then, run:
Output coordinates obtained from the trained model can be used to extract sub-volumes for sub-tomogram averaging.