Line data Source code
1 : // Copyright (C) 2012 The Android Open Source Project 2 : // 3 : // Licensed under the Apache License, Version 2.0 (the "License"); 4 : // you may not use this file except in compliance with the License. 5 : // You may obtain a copy of the License at 6 : // 7 : // http://www.apache.org/licenses/LICENSE-2.0 8 : // 9 : // Unless required by applicable law or agreed to in writing, software 10 : // distributed under the License is distributed on an "AS IS" BASIS, 11 : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 : // See the License for the specific language governing permissions and 13 : // limitations under the License. 14 : 15 : package com.google.gerrit.sshd; 16 : 17 : import com.google.common.base.Throwables; 18 : import com.google.gerrit.extensions.registration.DynamicSet; 19 : import com.google.gerrit.server.AccessPath; 20 : import com.google.gerrit.server.CancellationMetrics; 21 : import com.google.gerrit.server.DeadlineChecker; 22 : import com.google.gerrit.server.DynamicOptions; 23 : import com.google.gerrit.server.InvalidDeadlineException; 24 : import com.google.gerrit.server.RequestInfo; 25 : import com.google.gerrit.server.RequestListener; 26 : import com.google.gerrit.server.cancellation.RequestCancelledException; 27 : import com.google.gerrit.server.cancellation.RequestStateContext; 28 : import com.google.gerrit.server.config.GerritServerConfig; 29 : import com.google.gerrit.server.logging.PerformanceLogContext; 30 : import com.google.gerrit.server.logging.PerformanceLogger; 31 : import com.google.gerrit.server.logging.TraceContext; 32 : import com.google.gerrit.server.plugincontext.PluginSetContext; 33 : import com.google.inject.Inject; 34 : import java.io.IOException; 35 : import java.io.PrintWriter; 36 : import java.util.Optional; 37 : import org.apache.sshd.server.Environment; 38 : import org.apache.sshd.server.channel.ChannelSession; 39 : import org.eclipse.jgit.lib.Config; 40 : import org.kohsuke.args4j.Option; 41 : 42 6 : public abstract class SshCommand extends BaseCommand { 43 : @Inject private DynamicSet<PerformanceLogger> performanceLoggers; 44 : @Inject private PluginSetContext<RequestListener> requestListeners; 45 : @Inject @GerritServerConfig private Config config; 46 : @Inject private DeadlineChecker.Factory deadlineCheckerFactory; 47 : @Inject private CancellationMetrics cancellationMetrics; 48 : 49 : @Option(name = "--trace", usage = "enable request tracing") 50 : private boolean trace; 51 : 52 : @Option(name = "--trace-id", usage = "trace ID (can only be set if --trace was set too)") 53 : private String traceId; 54 : 55 : @Option(name = "--deadline", usage = "deadline after which the request should be aborted)") 56 : private String deadline; 57 : 58 : protected PrintWriter stdout; 59 : protected PrintWriter stderr; 60 : 61 : @Override 62 : public void start(ChannelSession channel, Environment env) throws IOException { 63 5 : startThread( 64 : () -> { 65 5 : try (DynamicOptions pluginOptions = new DynamicOptions(injector, dynamicBeans)) { 66 5 : parseCommandLine(pluginOptions); 67 5 : stdout = toPrintWriter(out); 68 5 : stderr = toPrintWriter(err); 69 5 : try (TraceContext traceContext = enableTracing(); 70 5 : PerformanceLogContext performanceLogContext = 71 : new PerformanceLogContext(config, performanceLoggers)) { 72 5 : RequestInfo requestInfo = 73 5 : RequestInfo.builder(RequestInfo.RequestType.SSH, user, traceContext).build(); 74 : try (RequestStateContext requestStateContext = 75 5 : RequestStateContext.open() 76 5 : .addRequestStateProvider( 77 5 : deadlineCheckerFactory.create(requestInfo, deadline))) { 78 5 : requestListeners.runEach(l -> l.onRequest(requestInfo)); 79 5 : SshCommand.this.run(); 80 1 : } catch (InvalidDeadlineException e) { 81 1 : stderr.println(e.getMessage()); 82 1 : } catch (RuntimeException e) { 83 1 : Optional<RequestCancelledException> requestCancelledException = 84 1 : RequestCancelledException.getFromCausalChain(e); 85 1 : if (!requestCancelledException.isPresent()) { 86 0 : Throwables.throwIfUnchecked(e); 87 : } 88 1 : cancellationMetrics.countCancelledRequest( 89 1 : requestInfo, requestCancelledException.get().getCancellationReason()); 90 1 : StringBuilder msg = 91 1 : new StringBuilder(requestCancelledException.get().formatCancellationReason()); 92 1 : if (requestCancelledException.get().getCancellationMessage().isPresent()) { 93 1 : msg.append( 94 1 : String.format( 95 1 : " (%s)", requestCancelledException.get().getCancellationMessage().get())); 96 : } 97 1 : stderr.println(msg.toString()); 98 5 : } 99 : } finally { 100 5 : stdout.flush(); 101 5 : stderr.flush(); 102 : } 103 : } 104 5 : }, 105 : AccessPath.SSH_COMMAND); 106 5 : } 107 : 108 : protected abstract void run() throws UnloggedFailure, Failure, Exception; 109 : 110 : private TraceContext enableTracing() throws UnloggedFailure { 111 5 : if (!trace && traceId != null) { 112 1 : throw die("A trace ID can only be set if --trace was specified."); 113 : } 114 5 : return TraceContext.newTrace( 115 : trace, 116 : traceId, 117 1 : (tagName, traceId) -> stderr.println(String.format("%s: %s", tagName, traceId))); 118 : } 119 : }