This is a seed project for distributed PyTorch training, which was built to customize your network quickly.
Here is an overview of what this template can do, and most of them can be customized by the configure file.
.json
configure file for most parameter tuningtorch.backends.cudnn.enabled = True # speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html if seed >=0 and gl_seed>=0: # slower, more reproducible torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: # faster, less reproducible, default setting torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True
network
as an example:// import Network() class from models.network.py file with args "which_networks": [ { "name": ["models.network", "Network"], "args": { "init_type": "kaiming"} } ], // import mutilple Networks from defualt file with args "which_networks": [ {"name": "Network1", args: {"init_type": "kaiming"}}, {"name": "Network2", args: {"init_type": "kaiming"}}, ], // import mutilple Networks from defualt file without args "which_networks" : [ "Network1", // equivalent to {"name": "Network1", args: {}}, "Network2" ] // more details can be found on More Details part and init_objs function in praser.py
Run the run.py
with your setting.
More choices can be found on run.py
and config/base.json
.
Dataset part decides the data need to be fed into the network, you can define the dataset by following steps:
data
folder. See dataset.py
in this folder as an example.config/base.json
to import and initialize dataset."datasets": { // train or test "train": { "which_dataset": { // import designated dataset using args "name": ["data.dataset", "Dataset"], "args":{ // args to init dataset "data_root": "/data/jlw/datasets/comofod" } }, "dataloader":{ "validation_split": 0.1, // percent or number "args":{ // args to init dataloader "batch_size": 2, // batch size in every gpu "num_workers": 4, "shuffle": true, "pin_memory": true, "drop_last": true } } }, }
name
can be a list to show your file name and class/function name, or a single string to explain class name in default file(data.dataset.py
). An example is as follows:"name": ["data.dataset", "Dataset"], // import Dataset() class from data.dataset.py "name": "Dataset", // import Dataset() class from default file
data_root
as the example, you just need to add it in args
dict and edit the corresponding class to parse this value:"args":{ // args to init dataset "data_root": "your data path" }
class Dataset(data.Dataset): def __init__(self, data_root, phase='train', image_size=[256, 256], loader=pil_loader): imgs = make_dataset(data_root) # data_root value is from configure file
Network part shows your learning network structure, you can define your network by following steps:
models
folder. See network.py
in this folder as an example.config/base.json
to import and initialize your networks, and it is a list."which_networks": [ // import designated list of networks using args { "name": "Network", "args": { // args to init network "init_type": "kaiming" } } ],
name
can be a list to show your file name and class/function name, or a single string to explain class name in default file(models.network.py
). An example is as follows:"name": ["models.network", "Network"], // import Network() class from models.network.py "name": "Network", // import Network() class from default file
init_type
as the example, you just need to add it in args
dict and edit corresponding class to parse this value:"args": { // args to init network "init_type": "kaiming" }
class BaseNetwork(nn.Module): def __init__(self, init_type='kaiming', gain=0.02): super(BaseNetwork, self).__init__() # init_type value is from configure file class Network(BaseNetwork): def __init__(self, in_channels=3, **kwargs): super(Network, self).__init__(**kwargs) # get init_type value and pass it to base network
"which_networks": [ {"name": "Network1", args: {}}, {"name": "Network2", args: {}}, ],
Model part shows your training process including optimizers/losses/process control, etc. You can define your model by following steps:
models
folder. See model.py
in its folder as an example.config/base.json
to import and initialize your model."which_model": { // import designated model(trainer) using args "name": ["models.model", "Model"], "args": { // args to init model } },
name
can be a list to show your file name and class/function name, or a single string to explain class name in default file(models.model.py
). An example is as follows:"name": ["models.model", "Model"], // import Model() class / function(not recommend) from models.model.py (default is [models.model.py]) "name": "Model", // import Model() class from default file
More details
part.Losses and Metrics are defined on configure file. You also can control and record more parameters through configure file, please refer to the above More details
part.
"which_metrics": ["mae"], "which_losses": ["mse_loss"]
After the above steps, you need to rewrite several functions like base_model.py/model.py
for your network and dataset.
See __init__()
functions as the example.
See train_step()/val_step()
functions as the example.
See save_everything()/load_everything()
functions as the example.
Sometimes we hope to debug the process quickly to ensure the whole project works, so debug mode is necessary.
This mode will reduce the dataset size and speed up the training process. You just need to run the file with -d option and edit the debug dict in configure file.
"debug": { // args in debug mode, which will replace args in train "val_epoch": 1, "save_checkpoint_epoch": 1, "log_iter": 30, "data_len": 50 // percent or number, change the size of dataloder to debug_split. }
You can choose the random seed, experiment path in configure file. We will add more useful basic functions with related instructions. Welcome to more contributions for more extensive customization and code enhancements.
Here are some basic functions or examples that this repository is ready to implement:
We are benefit a lot from following projects:
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4