PyTorch Distributed

The pytorch_distributed_example.py script demonstrates integrating Trains into code that uses the PyTorch Distributed Communications Package (torch.distributed). This script initializes a main Task and spawns subprocesses, each for an instances of that Task. The Task in each subprocess trains a neural network over a partitioned dataset (the torchvision built-in MNIST dataset), and reports the following to the main Task:

  • Artifacts - A dictionary containing different key-value pairs is uploaded from the Task in each subprocess to the main Task.
  • Scalars - Loss reported as a scalar during training in each Task in a subprocess is logged in the main Task.
  • Hyperparameters - Hyperparameters created in each Task in a subprocess are added to the hyperparameters in the main Task.

Each Task in a subprocess references the main Task by calling Task.current_task, which always returns the main Task.

When the script runs, it creates an experiment named test torch distributed which is associated with the examples project in the Trains Web (UI).

Artifacts (dictionaries)

The example uploads a dictionary as an artifact in the main Task by calling the Task.upload_artifact method on Task.current_task (the main Task). The dictionary contains the dist.rank of the subprocess, making each unique.

Task.current_task().upload_artifact(
    'temp {:02d}'.format(dist.get_rank()), artifact_object={'worker_rank': dist.get_rank()})

All of these artifacts appear in the main Task, ARTIFACTS tab, OTHER area.

Scalars

We report loss to the main Task by calling the Logger.report_scalar method on Task.current_task().get_logger(), which is the logger for the main Task. Since we call Logger.report_scalar with the same title (loss), but a different series name (containing the subprocess' rank), all loss scalar series are logged together.

Task.current_task().get_logger().report_scalar(
    'loss', 'worker {:02d}'.format(dist.get_rank()), value=loss.item(), iteration=i)

The single scalar plot for loss appears in RESULTS tab, SCALARS sub-tab.

Hyperparameters

Trains automatically logs the argparse command line arguments. Since we call Task.connect method on Task.current_task, they are logged in the main Task. We use a different hyperparameter key in each subprocess, so that they do not overwrite each other in the main Task.

param = {'worker_{}_stuff'.format(dist.get_rank()): 'some stuff ' + str(randint(0, 100))}
Task.current_task().connect(param)

All the hyperparameters appear in the HYPER PARAMETERS tab.

Log

Output to the console, including the text messages printed from the main Task object and each subprocess appear in the RESULTS tab, LOG sub-tab.

Artifacts (models)

Trains automatically logs the input model and output model, which appear in the ARTIFACTS tab.

Input model

In the model details for the input model (which appear when you click the model name in ARTIFACTS tab, Input Model area), you can see the following in the GENERAL:

  • Input model location (URL)
  • Model snapshots / checkpoint model locations (URLs)
  • Experiment creating the model
  • Other general information about the model.

Output model

In the model details for the output model, you can see the following in the GENERAL:

  • Output model location (URL)
  • Model snapshots / checkpoint model locations (URLs)
  • Experiment creating the model
  • Other general information about the model.