# Common Initialization:
import torch
import torch.nn as nn
 
# --------------------------------------------------------------------------
# LinearRegression class
class LinearRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegression, self).__init__()
        # define diferent layers
        self.lin = nn.Linear(input_dim, output_dim)
 
    def forward(self, x):
        return self.lin(x)
 
# --------------------------------------------------------------------------
# Model & DataSet Construction
X = torch.tensor([[1], [2], [3], [4]], dtype=torch.float32)
Y = torch.tensor([[2], [4], [6], [8]], dtype=torch.float32)
 
n_samples, n_features = X.shape
input_size = n_features
output_size = n_features
 
model = LinearRegression(input_size, output_size)
 
# --------------------------------------------------------------------------
# Test
X_test = torch.tensor([5], dtype=torch.float32)
print(f'Prediction before training: f(5) = {model(X_test).item():.3f}')
 
learning_rate = 0.01
n_iters = 100
 
loss = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
 
# --------------------------------------------------------------------------
# Training loop
for epoch in range(n_iters):
    # predict = forward pass with our model
    y_predicted = model(X)
 
    # loss
    l = loss(Y, y_predicted)
 
    # calculate gradients = backward pass
    l.backward()
 
    # update weights
    optimizer.step()
 
    # zero the gradients after updating
    optimizer.zero_grad()
 
    if epoch % 10 == 0:
        [w, b] = model.parameters() # unpack parameters
        print('epoch ', epoch+1, ': w = ', w[0][0].item(), ' loss = ', l)
 
print(f'Prediction after training: f(5) = {model(X_test).item():.3f}')