A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/cdluminate/leicht below:

cdluminate/leicht: Naive (Static Graph) Deep Learning Framework

Still under development. Stage: alpha5.

A light weight neural network implementation in C++ from scratch, for educational purpose, inspired by Caffe, (Lua)Torch and Pytorch.

Features
  1. Light Weight and Simple. Only make necessary abstractions.
  2. Almost Dependency Free (from Scratch). Used some I/O libs.
  3. Educational Purpose. Verify one's understanding to neural nets.
  4. Designed in a mixed style of Caffe and (Lua)Torch, (and maybe PyTorch).

"Leicht" is a German word.

There are 4 core concepts in this framework, which are: Tensor, Blob, Layer, and Graph. Note that the default vector for this project is column vector.

Tensor

Generalized container of numerical data. Vectors, matrices or any higher-dimensional number blocks are regarded as Tensor, where the data is stored in a contiguous memory block.

See tensor.hpp for detail.

Blob

Combination of two Tensors, one for the value and another for its gradient. This is useful in the Caffe-styled computation graph, where the backward pass just uses the forward graph instead of extending the graph for gradient computation and parameter update.

See blob.hpp for detail.

Layer

Network layers, including loss functions. Each of them takes some input Blobs and output Blobs as argument during forward and backward.

See layer.hpp for detail.

Graph

Graph, or say Directed Acyclic Graph, is the computation graph interpretation of the neural network, where the nodes are Blobs, the edge (or edge groups) are Layers. The graph is static graph.

See graph.hpp for detail.

Apart from Core part, there are some auxiliary components:

Dataloader

Basically an I/O helper, which reads dataset or data batch from disk to memory. This is not a key part of the project.

See dataloader.hpp for detail.

Curve

Save the curve data to ASCII file, and optionally draw a picture for you. Although one can parse the screen output with UNIX blackmagics e.g. awk.

See curve.hpp for detail.

Just include the header in your C++ file like this

Example of network definition:

// create a network(static graph), input dim 784, label dim 1, batch 100
// There are two pre-defined blobs in the graph: entry{Data,Label}Blob
Graph<double> net (784, 1, 100);

// add a layer, name=fc1, type=Linear, bottom=entryDataBlob, top=fc1, out dim=10
net.addLayer("fc1", "Linear", "entryDataBlob", "fc1", 10);

// add a layer, name=sm1, type=Softmax, bottom=fc1, top=sm1
net.addLayer("sm1", "Softmax", "fc1", "sm1");

// add a layer, name=cls1, type=NLLLoss, bottom=sm1, top=cls1, label=entryLabelBlob
net.addLayer("cls1", "ClassNLLLoss", "sm1", "cls1", "entryLabelBlob");

// add a layer, name=acc1, type=Accuracy, bottom=sm1, top=acc1, label=entryLabelBlob
net.addLayer("acc1", "ClassAccuracy", "sm1", "acc1", "entryLabelBlob");

Example of network training:

for (int iteration = 0; iteration < MAX_ITERATION; iteration++) {
  // get batch, input dim = 784, batchsize = 100, (pseudo code)
  get_batch_to("entryDataBlob", 784*100)
  get_batch_to("entryLabelBlob", 100)

  // forward pass of the network (graph)
  net.forward();

  // clear gradient
  net.zeroGrad();

  // backward pass of the network (graph)
  net.backward();

  // report the loss and accuracy
  net.report();

  // parameter update (SGD), learning rate = 1e-3
  net.update(1e-3);
}

Here is the full example test_graph_mnist_cls.cc

This is a leight-weight project, please just READ THE CODE.

Dependency and Compilation

This project is designed to use as less library as possible, i.e. designed from scratch. The only libraries needed by this project are some auxiliary I/O helper libraries.

The MIT License.

Want:

Postponed:

Not decided:


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4