|
@@ -1,36 +1,42 @@
|
|
|
import torch
|
|
|
from torch.autograd import Variable
|
|
|
-import torch.nn
|
|
|
+from torch import nn, optim
|
|
|
import torchvision.transforms as transforms
|
|
|
import torchvision.datasets as dsets
|
|
|
+import numpy as np
|
|
|
+from skorch import NeuralNet
|
|
|
+from skorch.callbacks import EpochScoring
|
|
|
+from matplotlib import pyplot as plt
|
|
|
|
|
|
-class LogisticRegression(torch.nn.Module):
|
|
|
+from spacecutter.callbacks import AscensionCallback
|
|
|
+from spacecutter.losses import CumulativeLinkLoss
|
|
|
+from spacecutter.models import OrdinalLogisticModel
|
|
|
+
|
|
|
+class LogisticRegression(nn.Module):
|
|
|
def __init__(self, input_dim, output_dim):
|
|
|
super(LogisticRegression, self).__init__()
|
|
|
- self.linear = torch.nn.Linear(input_dim, output_dim)
|
|
|
+ self.linear = nn.Linear(input_dim, output_dim)
|
|
|
|
|
|
def forward(self, x):
|
|
|
outputs = self.linear(x)
|
|
|
return outputs
|
|
|
-
|
|
|
-batch_size = 100
|
|
|
-n_iters = 3000
|
|
|
-input_dim = 784
|
|
|
-output_dim = 10
|
|
|
-lr_rate = 0.001
|
|
|
|
|
|
def regression_on_mnist():
|
|
|
- train_dataset = dsets.MNIST(root='./torch_test/data', train=True, transform=transforms.ToTensor(), download=True)
|
|
|
+ batch_size = 100
|
|
|
+ n_iters = 5000
|
|
|
+ input_dim = 784
|
|
|
+ output_dim = 10
|
|
|
+ lr_rate = 0.001
|
|
|
+
|
|
|
+ train_dataset = dsets.MNIST(root='./torch_test/data', train=True, transform=transforms.ToTensor(), download=False)
|
|
|
test_dataset = dsets.MNIST(root='./torch_test/data', train=False, transform=transforms.ToTensor())
|
|
|
- _regression(train_dataset, test_dataset)
|
|
|
|
|
|
-def _regression(train_dataset, test_dataset):
|
|
|
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
|
|
|
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
|
|
|
|
|
|
model = LogisticRegression(input_dim, output_dim)
|
|
|
- criterion = torch.nn.CrossEntropyLoss() # 计算 softmax 分布之上的交叉熵损失
|
|
|
- optimizer = torch.optim.SGD(model.parameters(), lr=lr_rate)
|
|
|
+ criterion = nn.CrossEntropyLoss() # 计算 softmax 分布之上的交叉熵损失
|
|
|
+ optimizer = optim.SGD(model.parameters(), lr=lr_rate)
|
|
|
|
|
|
epochs = n_iters / (len(train_dataset) / batch_size)
|
|
|
iter = 0
|
|
@@ -58,4 +64,56 @@ def _regression(train_dataset, test_dataset):
|
|
|
# 如果用的是 GPU,则要把预测值和标签都取回 CPU,才能用 Python 来计算
|
|
|
correct+= (predicted == labels).sum()
|
|
|
accuracy = 100 * correct/total
|
|
|
- print("Iteration: {}. Loss: {}. Accuracy: {}.".format(iter, loss.item(), accuracy))
|
|
|
+ print("Iteration: {}. Loss: {}. Accuracy: {}.".format(iter, loss.item(), accuracy))
|
|
|
+
|
|
|
+def ordinal_regression():
|
|
|
+ X = np.array([
|
|
|
+ [0.5, 0.1, -0.1],
|
|
|
+ [1.0, 0.2, 0.6],
|
|
|
+ [-2.0, 0.4, 0.8]
|
|
|
+ ], dtype=np.float32)
|
|
|
+
|
|
|
+ y = np.array([0, 1, 2]).reshape(-1, 1)
|
|
|
+
|
|
|
+ num_features = X.shape[1]
|
|
|
+ num_classes = len(np.unique(y))
|
|
|
+
|
|
|
+ predictor = nn.Sequential(
|
|
|
+ nn.Linear(num_features, num_features),
|
|
|
+ nn.ReLU(),
|
|
|
+ nn.Linear(num_features, 1)
|
|
|
+ )
|
|
|
+
|
|
|
+ model = OrdinalLogisticModel(predictor, num_classes)
|
|
|
+
|
|
|
+ y_pred = model(torch.as_tensor(X))
|
|
|
+
|
|
|
+ print(y_pred)
|
|
|
+
|
|
|
+ # tensor([[0.2325, 0.2191, 0.5485],
|
|
|
+ # [0.2324, 0.2191, 0.5485],
|
|
|
+ # [0.2607, 0.2287, 0.5106]], grad_fn=<CatBackward>)
|
|
|
+
|
|
|
+ skorch_model = NeuralNet(
|
|
|
+ module=OrdinalLogisticModel,
|
|
|
+ module__predictor=predictor,
|
|
|
+ module__num_classes=num_classes,
|
|
|
+ criterion=CumulativeLinkLoss,
|
|
|
+ train_split=None,
|
|
|
+ max_epochs= 30,
|
|
|
+ callbacks=[
|
|
|
+ ('ascension', AscensionCallback())
|
|
|
+ ],
|
|
|
+ )
|
|
|
+
|
|
|
+ skorch_model.fit(X, y)
|
|
|
+
|
|
|
+ # train_acc = model.history[:, 'train_accuracy']
|
|
|
+ # valid_acc = model.history[:, 'valid_accuracy']
|
|
|
+
|
|
|
+ # plt.plot(train_acc, label='Train Accuracy')
|
|
|
+ # plt.plot(valid_acc, label='Validation Accuracy')
|
|
|
+ # plt.xlabel('Epoch')
|
|
|
+ # plt.ylabel('Accuracy')
|
|
|
+ # plt.legend()
|
|
|
+ # plt.show()
|