Создание слоев, функций активации и функции потерь для нейронной сети¶
В настоящее время существует достаточно ограниченное количество слоев, функций активации и функций потерь для нейронных сетей. Необходимо разрабатывать свои функции для того, чтобы сделать более или менее серьезный проект.
Пример, который я привожу будет решен с помощью PyTorch. Мы будем делать линейный слой.
Официальный туториал находиться на сайте PyTorch
Этот ноутбук доступен в моем Github репозитории:
git clone https://github.com/andreiliphd/reinforcement-content.git
Если нет Git, то его нужно установить.
Linux:
sudo apt-get update
sudo apt-get install git
Windows: скачайте Git с сайта git-scm.com.
Если вы нашли ошибку на сайте, ее можно исправить самостоятельно сделав Pull Request в Git.
Этапы решения задачи¶
- Определяем класс с собственным алгоритмом дифференцирования.
- Переопределяем
forward
. - Переопределяем
backward
. - Определяем модуль.
- Переопределяем
__init__
. - Переопределяем
forward
. - Тестирование.
In [15]:
import torch
from torch import nn
Определяем класс с собственным алгоритмом дифференцирования.¶
In [9]:
# Inherit from Function
class LinearFunction(torch.autograd.Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight, grad_bias
Определяем модуль¶
Модуль - это компонента, которую мы будем вставлять в нашу нейронную сеть.
In [10]:
class Linear(torch.nn.Module):
def __init__(self, input_features, output_features, bias=True):
super(Linear, self).__init__()
self.input_features = input_features
self.output_features = output_features
# nn.Parameter is a special kind of Tensor, that will get
# automatically registered as Module's parameter once it's assigned
# as an attribute. Parameters and buffers need to be registered, or
# they won't appear in .parameters() (doesn't apply to buffers), and
# won't be converted when e.g. .cuda() is called. You can use
# .register_buffer() to register buffers.
# nn.Parameters require gradients by default.
self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(output_features))
else:
# You should always register all possible parameters, but the
# optional ones can be None if you want.
self.register_parameter('bias', None)
# Not a very smart way to initialize weights
self.weight.data.uniform_(-0.1, 0.1)
if bias is not None:
self.bias.data.uniform_(-0.1, 0.1)
def forward(self, input):
# See the autograd section for explanation of what happens here.
return LinearFunction.apply(input, self.weight, self.bias)
def extra_repr(self):
# (Optional)Set the extra information about this module. You can test
# it by printing an object of this class.
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
Тестирование¶
In [16]:
tensor = torch.randn([3,64])
In [12]:
fc = Linear(64,100)
In [13]:
fc(tensor)
Out[13]:
tensor([[-1.5266e-01, -2.2930e-01, 1.3016e-01, 1.9424e-01, -8.2268e-01, 2.9127e-01, -8.3403e-01, 7.9582e-01, 1.3244e+00, 4.4504e-01, 6.9587e-01, 1.0573e-01, -1.7380e-02, 1.8936e-01, -6.0043e-01, -2.2504e-01, 2.6544e-01, 4.8095e-02, -3.6776e-01, 7.6825e-01, 4.0249e-01, -2.3658e-01, -3.0179e-01, -8.2391e-01, -6.0442e-01, 5.4096e-01, -4.8182e-01, 8.1980e-01, -5.0751e-01, 5.3642e-02, -9.1196e-02, -3.5150e-01, 1.3570e-01, 7.3673e-01, 2.3271e-01, 4.2033e-01, 8.1786e-01, -8.5281e-01, 4.9003e-02, -4.7420e-01, -2.5353e-01, 3.8539e-01, -8.0665e-01, 5.8173e-01, 1.0536e-01, 2.1014e-02, -4.8601e-02, 4.7198e-02, -1.9768e-01, -2.5104e-01, -6.4469e-02, -1.1550e+00, 3.9481e-01, 8.5703e-01, -4.7345e-01, -5.5124e-01, 1.0197e+00, -9.6282e-01, 4.3211e-01, -3.2549e-01, 2.2929e-01, 2.9515e-01, 7.4135e-01, -3.9694e-01, -3.9768e-02, 3.7301e-01, 1.0603e-01, -9.4942e-01, 3.4649e-01, -5.1137e-01, 7.7367e-01, -3.6945e-01, 3.9820e-02, -1.0097e-01, 3.1308e-01, -8.4138e-01, 1.2890e-01, 5.2418e-01, -1.2631e-01, -5.9838e-01, 1.3050e+00, 1.3556e-01, -3.4742e-01, -2.3228e-01, 4.6177e-01, 3.5000e-01, 4.2570e-02, -5.0573e-01, 9.6248e-04, 1.7913e-01, 2.9797e-01, 9.9937e-01, 4.3492e-01, 3.6486e-01, -8.2483e-01, 2.0676e-01, -3.8136e-01, -5.3702e-01, 2.3171e-01, 6.9027e-01], [ 1.0974e+00, -6.1415e-03, -2.0058e-01, 1.7683e-01, -8.1203e-01, 3.8049e-01, -3.2706e-01, -7.7596e-01, -7.0586e-01, -1.1442e-01, -2.1003e-01, 9.2695e-02, -1.0421e+00, -7.9274e-01, -4.3265e-01, 1.3562e-02, 4.2756e-01, -2.2361e-01, 3.2032e-01, 6.5068e-01, 4.1192e-01, 4.3904e-01, -4.9598e-01, -4.0185e-01, 2.1551e-01, 3.5267e-01, 4.1140e-01, -2.7895e-01, -1.0789e-01, -1.1557e-01, 1.0927e+00, -7.9942e-02, -7.9793e-01, -1.2213e+00, 8.2556e-02, 8.8351e-01, 1.8677e-01, 7.6156e-01, -6.4896e-01, 3.6829e-01, 1.9973e-01, -6.6030e-01, 4.3469e-01, 8.4569e-02, -1.3663e-01, 2.1346e-01, -1.3092e-01, -2.8870e-01, 2.3111e-01, -7.6568e-01, 2.3689e-01, 6.5632e-01, 4.6899e-01, -7.4528e-02, 6.8648e-01, 1.5744e+00, 3.2839e-01, 3.0732e-01, 2.8080e-01, -9.8468e-01, -4.5288e-01, 1.1480e-01, -2.0604e-01, 5.4521e-01, -3.7942e-01, 5.2160e-01, 6.7871e-01, -1.6976e-01, -5.5689e-01, 2.9176e-01, -2.4046e-01, 3.6306e-01, 1.0401e-01, 1.3273e-01, -7.7092e-02, 6.4160e-01, 1.5374e-01, -2.8229e-01, 4.6005e-01, 6.6371e-01, -6.6709e-01, 2.0701e-01, -5.7994e-01, 1.8083e+00, 5.1622e-01, -8.1350e-01, -9.9559e-02, -4.6370e-01, 6.7733e-01, 3.0127e-01, -5.8182e-01, 7.5114e-01, 2.2068e-01, 5.0854e-01, 1.5060e-01, -5.0969e-03, -5.2148e-01, -2.1543e-01, -1.0236e+00, 5.5275e-01], [ 5.5619e-01, 1.1382e+00, -2.2339e-01, 5.4057e-01, 3.1593e-01, 5.9600e-01, -5.2581e-01, -2.3765e-01, 2.9819e-01, 1.7366e-01, 1.8423e-01, -1.1183e-01, 3.6662e-01, -3.5884e-01, -4.8826e-01, -5.9998e-01, -5.4633e-01, -3.4852e-02, -8.2251e-02, 2.5525e-02, 2.6472e-01, -7.2667e-01, 9.0690e-01, -1.4314e-01, -1.4916e-02, 4.2161e-01, -7.5963e-01, -9.6359e-01, 6.2666e-01, 4.5879e-02, 1.8885e-01, -8.4844e-01, 6.8490e-01, -8.0952e-01, 5.3341e-01, -8.3224e-01, 7.5706e-01, 3.9389e-01, -3.1835e-02, 7.5976e-01, -3.3267e-01, -1.8573e-01, 4.5405e-01, -9.2157e-02, 4.4280e-01, 4.6313e-01, 2.4626e-01, -7.8959e-01, 4.8041e-01, 5.5526e-01, -3.6897e-01, 4.4326e-01, 1.0441e+00, -4.1376e-01, 4.2606e-03, -2.7398e-01, -2.2844e-01, -4.4361e-01, 1.0047e+00, -1.7567e-01, 8.1397e-03, 8.4733e-01, -3.0746e-01, -2.3517e-01, 3.4149e-01, 8.1550e-01, -5.5606e-01, 2.1256e-01, 8.6248e-01, 6.0245e-01, 8.9587e-01, 1.0046e-01, -4.3044e-02, -5.7778e-01, 3.4123e-01, 1.5465e-01, -4.3242e-01, 7.2103e-02, 2.0259e-01, 7.9654e-01, 9.1484e-02, 4.9321e-01, -3.4285e-01, 6.4264e-01, -4.2674e-01, 2.8920e-01, 8.5623e-01, -1.3062e-01, -3.6677e-01, 5.0461e-01, -3.0073e-01, 2.1317e-01, -4.0173e-01, -2.4285e-01, 3.0114e-01, 4.4353e-01, -2.2686e-01, -3.4800e-02, -1.5970e-01, -1.4202e-01]], grad_fn=<LinearFunctionBackward>)
In [ ]: